Skip to content

Commit 821a2d2

Browse files
committed
Add initial implementation of RetrievalQAChain
1 parent de7c2fb commit 821a2d2

File tree

1 file changed

+107
-9
lines changed

1 file changed

+107
-9
lines changed
Lines changed: 107 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,128 @@
11
import Foundation
22

3-
final class RetrievalQA: Chain {
3+
public final class RetrievalQAChain: Chain {
44
let vectorStore: VectorStore
55
let embedding: Embeddings
6+
let chatModelFactory: () -> ChatModel
67

7-
struct Output {
8+
public struct Output {
89
var answer: String
910
var sourceDocuments: [Document]
1011
}
1112

12-
init(vectorStore: VectorStore, embedding: Embeddings) {
13+
public init(
14+
vectorStore: VectorStore,
15+
embedding: Embeddings,
16+
chatModelFactory: @escaping () -> ChatModel
17+
) {
1318
self.vectorStore = vectorStore
1419
self.embedding = embedding
20+
self.chatModelFactory = chatModelFactory
1521
}
1622

17-
func callLogic(
23+
public func callLogic(
1824
_ input: String,
19-
callbackManagers: [ChainCallbackManager]
25+
callbackManagers: [CallbackManager]
2026
) async throws -> Output {
21-
let embeddedQuestion = try awa
22-
23-
return .init(answer: "", sourceDocuments: [])
27+
let embeddedQuestion = try await embedding.embed(query: input)
28+
let documents = try await vectorStore.searchWithDistance(
29+
embeddings: embeddedQuestion,
30+
count: 10
31+
)
32+
let refinementChain = RefineDocumentChain(chatModelFactory: chatModelFactory)
33+
let answer = try await refinementChain.run(.init(question: input, documents: documents))
34+
35+
return .init(answer: answer, sourceDocuments: documents.map(\.document))
2436
}
2537

26-
func parseOutput(_ output: Output) -> String {
38+
public func parseOutput(_ output: Output) -> String {
2739
return output.answer
2840
}
2941
}
3042

43+
public final class RefineDocumentChain: Chain {
44+
public struct Input {
45+
var question: String
46+
var documents: [(document: Document, distance: Float)]
47+
}
48+
49+
struct InitialInput {
50+
var question: String
51+
var document: String
52+
var distance: Float
53+
}
54+
55+
struct RefinementInput {
56+
var question: String
57+
var previousAnswer: String
58+
var document: String
59+
var distance: Float
60+
}
61+
62+
let initialChatModel: ChatModelChain<InitialInput>
63+
let refinementChatModel: ChatModelChain<RefinementInput>
64+
65+
public init(chatModelFactory: () -> ChatModel) {
66+
initialChatModel = .init(
67+
chatModel: chatModelFactory(),
68+
promptTemplate: { input in [
69+
.init(role: .system, content: """
70+
The user will send you a question, you must answer it at your best.
71+
You can use the following document as a reference:###
72+
\(input.document)
73+
###
74+
"""),
75+
.init(role: .user, content: input.question),
76+
] }
77+
)
78+
refinementChatModel = .init(
79+
chatModel: chatModelFactory(),
80+
promptTemplate: { input in [
81+
.init(role: .system, content: """
82+
The user will send you a question, you must update your previous answer to it at your best.
83+
Previous answer:###
84+
\(input.previousAnswer)
85+
###
86+
You can use the following document as a reference:###
87+
\(input.document)
88+
###
89+
"""),
90+
.init(role: .user, content: input.question),
91+
] }
92+
)
93+
}
94+
95+
public func callLogic(
96+
_ input: Input,
97+
callbackManagers: [CallbackManager]
98+
) async throws -> String {
99+
guard let firstDocument = input.documents.first else {
100+
return ""
101+
}
102+
var output = try await initialChatModel.call(
103+
.init(
104+
question: input.question,
105+
document: firstDocument.document.pageContent,
106+
distance: firstDocument.distance
107+
),
108+
callbackManagers: callbackManagers
109+
)
110+
for document in input.documents.dropFirst(1) {
111+
output = try await refinementChatModel.call(
112+
.init(
113+
question: input.question,
114+
previousAnswer: output,
115+
document: document.document.pageContent,
116+
distance: document.distance
117+
),
118+
callbackManagers: callbackManagers
119+
)
120+
}
121+
return output
122+
}
123+
124+
public func parseOutput(_ output: String) -> String {
125+
return output
126+
}
127+
}
128+

0 commit comments

Comments
 (0)