Skip to content

Commit 74df8f7

Browse files
committed
Adjust implementation of QAInformationRetrievalChain
1 parent ba3b360 commit 74df8f7

2 files changed

Lines changed: 73 additions & 13 deletions

File tree

Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,44 +15,74 @@ public final class RelevantInformationExtractionChain: Chain {
1515
public typealias Output = String
1616

1717
class FunctionProvider: ChatGPTFunctionProvider {
18-
var functionCallStrategy: FunctionCallStrategy? = .auto
19-
var functions: [any ChatGPTFunction] = [NoneFunction()]
18+
var functionCallStrategy: FunctionCallStrategy? = .name("saveFinalAnswer")
19+
var functions: [any ChatGPTFunction] = [FinalAnswer()]
2020
}
2121

22-
struct NoneFunction: ChatGPTArgumentsCollectingFunction {
23-
typealias Arguments = NoArguments
24-
var name: String = "noInformationFound"
25-
var description: String = "Call when you can't find any relevant information from the document, or the question was not mentioned in the document"
22+
struct FinalAnswer: ChatGPTArgumentsCollectingFunction {
23+
struct Arguments: Decodable {
24+
var relevantInformation: String
25+
var noRelevantInformationFound: Bool?
26+
}
27+
28+
var name: String = "saveFinalAnswer"
29+
var description: String =
30+
"save the relevant information"
31+
var argumentSchema: JSONSchemaValue {
32+
[
33+
.type: "object",
34+
.properties: [
35+
"relevantInformation": [.type: "string"],
36+
"noRelevantInformationFound": [.type: "boolean"],
37+
],
38+
.required: ["relevantInformation", "noRelevantInformationFound"],
39+
]
40+
}
41+
}
42+
43+
let filterMetadata: (String) -> Bool
44+
let hint: String
45+
46+
init(filterMetadata: @escaping (String) -> Bool = { _ in true }, hint: String) {
47+
self.filterMetadata = filterMetadata
48+
self.hint = hint
2649
}
2750

2851
func buildChatModel() -> ChatModelChain<TaskInput> {
2952
.init(
3053
chatModel: OpenAIChat(
3154
configuration: UserPreferenceChatGPTConfiguration().overriding {
32-
$0.temperature = 0
55+
$0.temperature = 0.5
3356
$0.runFunctionsAutomatically = false
3457
},
3558
memory: EmptyChatGPTMemory(),
3659
functionProvider: FunctionProvider(),
3760
stream: false
3861
)
39-
) { input in [
62+
) { [filterMetadata, hint] input in [
4063
.init(
4164
role: .system,
4265
content: """
4366
Extract the relevant information from the Document according to the Question.
67+
The information may not directly answer the question, but it should be relevant to the question, \
68+
please think carefully and make you decision.
4469
Make the information clear, concise and short.
4570
If found code, wrap it in markdown code block.
71+
\(hint)
4672
"""
4773
),
4874
.init(
4975
role: .user,
5076
content: """
5177
Question:###
78+
(how, when, what or why)
5279
\(input.question)
5380
###
5481
Document:###
55-
\(input.document)
82+
\(input.document.metadata.filter { key, _ in
83+
filterMetadata(key)
84+
})
85+
\(input.document.pageContent)
5686
###
5787
"""
5888
),
@@ -73,6 +103,22 @@ public final class RelevantInformationExtractionChain: Chain {
73103
taskInput,
74104
callbackManagers: callbackManagers
75105
)
106+
107+
if let functionCall = output.functionCall {
108+
do {
109+
let arguments = try JSONDecoder().decode(
110+
FinalAnswer.Arguments.self,
111+
from: functionCall.arguments.data(using: .utf8) ?? Data()
112+
)
113+
if arguments.noRelevantInformationFound ?? false {
114+
return ""
115+
}
116+
return arguments.relevantInformation
117+
} catch {
118+
return output.content ?? ""
119+
}
120+
}
121+
76122
return output.content ?? ""
77123
}
78124

@@ -119,3 +165,4 @@ public extension CallbackEvents {
119165
RelevantInformationExtractionChainDidExtractPartialRelevantContent.self
120166
}
121167
}
168+

Tool/Sources/LangChain/Chains/RetrievalQA.swift

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ public final class QAInformationRetrievalChain: Chain {
55
let vectorStores: [VectorStore]
66
let embedding: Embeddings
77
let maxCount: Int
8+
let filterMetadata: (String) -> Bool
9+
let hint: String
810

911
public struct Output {
1012
public var information: String
@@ -14,21 +16,29 @@ public final class QAInformationRetrievalChain: Chain {
1416
public init(
1517
vectorStore: VectorStore,
1618
embedding: Embeddings,
17-
maxCount: Int = 5
19+
maxCount: Int = 5,
20+
filterMetadata: @escaping (String) -> Bool = { _ in true },
21+
hint: String = ""
1822
) {
1923
vectorStores = [vectorStore]
2024
self.embedding = embedding
2125
self.maxCount = maxCount
26+
self.filterMetadata = filterMetadata
27+
self.hint = hint
2228
}
2329

2430
public init(
2531
vectorStores: [VectorStore],
2632
embedding: Embeddings,
27-
maxCount: Int = 5
33+
maxCount: Int = 5,
34+
filterMetadata: @escaping (String) -> Bool = { _ in true },
35+
hint: String = ""
2836
) {
2937
self.vectorStores = vectorStores
3038
self.embedding = embedding
3139
self.maxCount = maxCount
40+
self.filterMetadata = filterMetadata
41+
self.hint = hint
3242
}
3343

3444
public func callLogic(
@@ -57,10 +67,13 @@ public final class QAInformationRetrievalChain: Chain {
5767
}.sorted { $0.distance < $1.distance }.prefix(maxCount)
5868

5969
let documents = Array(documentsSlice)
60-
70+
6171
callbackManagers.send(CallbackEvents.RetrievalQADidExtractRelevantContent(info: documents))
6272

63-
let relevantInformationChain = RelevantInformationExtractionChain()
73+
let relevantInformationChain = RelevantInformationExtractionChain(
74+
filterMetadata: filterMetadata,
75+
hint: hint
76+
)
6477
let relevantInformation = try await relevantInformationChain.run(
6578
.init(question: input, documents: documents),
6679
callbackManagers: callbackManagers

0 commit comments

Comments
 (0)