Skip to content

Commit bf94948

Browse files
committed
Update QAInformationRetrievalChain to support multiple vector store
1 parent b30d42b commit bf94948

1 file changed

Lines changed: 37 additions & 9 deletions

File tree

Tool/Sources/LangChain/Chains/RetrievalQA.swift

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ import Foundation
22
import OpenAIService
33

44
public final class QAInformationRetrievalChain: Chain {
5-
let vectorStore: VectorStore
5+
let vectorStores: [VectorStore]
66
let embedding: Embeddings
7+
let maxCount: Int
78

89
public struct Output {
910
public var information: String
@@ -12,24 +13,51 @@ public final class QAInformationRetrievalChain: Chain {
1213

1314
public init(
1415
vectorStore: VectorStore,
15-
embedding: Embeddings
16+
embedding: Embeddings,
17+
maxCount: Int = 5
1618
) {
17-
self.vectorStore = vectorStore
19+
vectorStores = [vectorStore]
1820
self.embedding = embedding
21+
self.maxCount = maxCount
22+
}
23+
24+
public init(
25+
vectorStores: [VectorStore],
26+
embedding: Embeddings,
27+
maxCount: Int = 5
28+
) {
29+
self.vectorStores = vectorStores
30+
self.embedding = embedding
31+
self.maxCount = maxCount
1932
}
2033

2134
public func callLogic(
2235
_ input: String,
2336
callbackManagers: [CallbackManager]
2437
) async throws -> Output {
2538
let embeddedQuestion = try await embedding.embed(query: input)
26-
let documents = try await vectorStore.searchWithDistance(
27-
embeddings: embeddedQuestion,
28-
count: 5
29-
).filter { item in
30-
item.distance < 0.31
31-
}
39+
let documentsSlice = await withTaskGroup(
40+
of: [(document: Document, distance: Float)].self
41+
) { group in
42+
for vectorStore in vectorStores {
43+
group.addTask {
44+
(try? await vectorStore.searchWithDistance(
45+
embeddings: embeddedQuestion,
46+
count: 5
47+
).filter { item in
48+
item.distance < 0.31
49+
}) ?? []
50+
}
51+
}
52+
var result = [(document: Document, distance: Float)]()
53+
for await items in group {
54+
result.append(contentsOf: items)
55+
}
56+
return result
57+
}.sorted { $0.distance < $1.distance }.prefix(maxCount)
3258

59+
let documents = Array(documentsSlice)
60+
3361
callbackManagers.send(CallbackEvents.RetrievalQADidExtractRelevantContent(info: documents))
3462

3563
let relevantInformationChain = RelevantInformationExtractionChain()

0 commit comments

Comments
 (0)