@@ -2,8 +2,9 @@ import Foundation
22import OpenAIService
33
44public final class QAInformationRetrievalChain : Chain {
5- let vectorStore : VectorStore
5+ let vectorStores : [ VectorStore ]
66 let embedding : Embeddings
7+ let maxCount : Int
78
89 public struct Output {
910 public var information : String
@@ -12,24 +13,51 @@ public final class QAInformationRetrievalChain: Chain {
1213
1314 public init (
1415 vectorStore: VectorStore ,
15- embedding: Embeddings
16+ embedding: Embeddings ,
17+ maxCount: Int = 5
1618 ) {
17- self . vectorStore = vectorStore
19+ vectorStores = [ vectorStore]
1820 self . embedding = embedding
21+ self . maxCount = maxCount
22+ }
23+
24+ public init (
25+ vectorStores: [ VectorStore ] ,
26+ embedding: Embeddings ,
27+ maxCount: Int = 5
28+ ) {
29+ self . vectorStores = vectorStores
30+ self . embedding = embedding
31+ self . maxCount = maxCount
1932 }
2033
2134 public func callLogic(
2235 _ input: String ,
2336 callbackManagers: [ CallbackManager ]
2437 ) async throws -> Output {
2538 let embeddedQuestion = try await embedding. embed ( query: input)
26- let documents = try await vectorStore. searchWithDistance (
27- embeddings: embeddedQuestion,
28- count: 5
29- ) . filter { item in
30- item. distance < 0.31
31- }
39+ let documentsSlice = await withTaskGroup (
40+ of: [ ( document: Document , distance: Float ) ] . self
41+ ) { group in
42+ for vectorStore in vectorStores {
43+ group. addTask {
44+ ( try ? await vectorStore. searchWithDistance (
45+ embeddings: embeddedQuestion,
46+ count: 5
47+ ) . filter { item in
48+ item. distance < 0.31
49+ } ) ?? [ ]
50+ }
51+ }
52+ var result = [ ( document: Document, distance: Float) ] ( )
53+ for await items in group {
54+ result. append ( contentsOf: items)
55+ }
56+ return result
57+ } . sorted { $0. distance < $1. distance } . prefix ( maxCount)
3258
59+ let documents = Array ( documentsSlice)
60+
3361 callbackManagers. send ( CallbackEvents . RetrievalQADidExtractRelevantContent ( info: documents) )
3462
3563 let relevantInformationChain = RelevantInformationExtractionChain ( )
0 commit comments