Skip to content

Commit 5f602f6

Browse files
committed
Update RetrievalQA to support early return when the answer is good enough
1 parent 174abe5 commit 5f602f6

File tree

8 files changed

+125
-68
lines changed

8 files changed

+125
-68
lines changed

Core/Sources/ChatService/ChatFunctionProvider.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,9 @@ final class ChatFunctionProvider {
1515
}
1616
}
1717

18-
extension ChatFunctionProvider: ChatGPTFunctionProvider {}
18+
extension ChatFunctionProvider: ChatGPTFunctionProvider {
19+
var functionCallStrategy: OpenAIService.FunctionCallStrategy? {
20+
nil
21+
}
22+
}
1923

Playground.playground/Pages/RetrievalQAChain.xcplaygroundpage/Contents.swift

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import PlaygroundSupport
55
import SwiftUI
66

77
struct QAForm: View {
8-
@State var intermediateAnswers = [String]()
8+
@State var intermediateAnswers = [RefineDocumentChain.IntermediateAnswer]()
99
@State var answer: String = ""
1010
@State var question: String = "What is Swift macros?"
1111
@State var isProcessing: Bool = false
@@ -31,9 +31,14 @@ struct QAForm: View {
3131
Text(answer)
3232
}
3333
Section(header: Text("Intermediate Answers")) {
34-
ForEach(intermediateAnswers, id: \.self) { answer in
35-
Text(answer)
36-
Divider()
34+
ForEach(0..<intermediateAnswers.endIndex, id: \.self) { index in
35+
let answer = intermediateAnswers[index]
36+
VStack {
37+
Text(answer.answer)
38+
Text("Score: \(answer.score)")
39+
Text("Needs more context: \(answer.more ? "Yes" : "No")")
40+
Divider()
41+
}
3742
}
3843
}
3944
}
@@ -48,8 +53,6 @@ struct QAForm: View {
4853
answer = "Invalid URL"
4954
return
5055
}
51-
let chatGPTConfiguration = UserPreferenceChatGPTConfiguration()
52-
.overriding { $0.temperature = 0 }
5356
let embeddingConfiguration = UserPreferenceEmbeddingConfiguration().overriding()
5457
let embedding = OpenAIEmbedding(configuration: embeddingConfiguration)
5558
let store: VectorStore = try await {
@@ -72,8 +75,7 @@ struct QAForm: View {
7275

7376
let qa = RetrievalQAChain(
7477
vectorStore: store,
75-
embedding: embedding,
76-
chatModelFactory: { OpenAIChat(configuration: chatGPTConfiguration, stream: false) }
78+
embedding: embedding
7779
)
7880
answer = try await qa.run(
7981
question,

Playground.playground/Pages/RetrievalQAChain.xcplaygroundpage/timeline.xctimeline

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
version = "3.0">
44
<TimelineItems>
55
</TimelineItems>
6+
<TimelineItems>
7+
</TimelineItems>
68
<TimelineItems>
79
<LoggerValueHistoryTimelineItem
810
documentLocation = "documentLocation"

Tool/Sources/LangChain/Chains/RetrievalQA.swift

Lines changed: 76 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public final class RetrievalQAChain: Chain {
4343

4444
public extension CallbackEvents {
4545
struct RetrievalQADidGenerateIntermediateAnswer: CallbackEvent {
46-
public let info: String
46+
public let info: RefineDocumentChain.IntermediateAnswer
4747
}
4848
}
4949

@@ -66,16 +66,38 @@ public final class RefineDocumentChain: Chain {
6666
var distance: Float
6767
}
6868

69+
public struct IntermediateAnswer: Decodable {
70+
public var answer: String
71+
public var score: Double
72+
public var more: Bool
73+
74+
public enum CodingKeys: String, CodingKey {
75+
case answer
76+
case score
77+
case more
78+
}
79+
80+
init(answer: String, score: Double, more: Bool) {
81+
self.answer = answer
82+
self.score = score
83+
self.more = more
84+
}
85+
86+
public init(from decoder: Decoder) throws {
87+
let container = try decoder.container(keyedBy: CodingKeys.self)
88+
answer = try container.decode(String.self, forKey: .answer)
89+
score = (try? container.decode(Double.self, forKey: .score)) ?? 0
90+
more = (try? container.decode(Bool.self, forKey: .more)) ?? (score < 6)
91+
}
92+
}
93+
6994
class FunctionProvider: ChatGPTFunctionProvider {
70-
var functions: [any ChatGPTFunction] = []
95+
var functionCallStrategy: FunctionCallStrategy? = .name("respond")
96+
var functions: [any ChatGPTFunction] = [RespondFunction()]
7197
}
7298

7399
struct RespondFunction: ChatGPTFunction {
74-
struct Arguments: Codable {
75-
var answer: String
76-
var score: Double
77-
var more: Bool
78-
}
100+
typealias Arguments = IntermediateAnswer
79101

80102
struct Result: ChatGPTFunctionResult {
81103
var botReadableContent: String { "" }
@@ -91,17 +113,18 @@ public final class RefineDocumentChain: Chain {
91113
.properties: [
92114
"answer": [
93115
.type: "string",
94-
.description: "The answer",
116+
.description: "The refined answer",
95117
],
96118
"score": [
97119
.type: "number",
98-
.description: "The score of the answer, the higher the better",
120+
.description: "The score of the answer, the higher the better. 0 to 10.",
99121
],
100122
"more": [
101123
.type: "boolean",
102124
.description: "Whether more information is needed to complete the answer",
103125
],
104126
],
127+
.required: ["answer", "score", "more"],
105128
]
106129
}
107130

@@ -114,18 +137,16 @@ public final class RefineDocumentChain: Chain {
114137

115138
let initialChatModel: ChatModelChain<InitialInput>
116139
let refinementChatModel: ChatModelChain<RefinementInput>
117-
let initialChatMemory: ChatGPTMemory
118-
let refinementChatMemory: ChatGPTMemory
119140

120141
public init() {
121-
initialChatMemory = ConversationChatGPTMemory(systemPrompt: "")
122-
refinementChatMemory = ConversationChatGPTMemory(systemPrompt: "")
123-
124142
initialChatModel = .init(
125143
chatModel: OpenAIChat(
126-
configuration: UserPreferenceChatGPTConfiguration()
127-
.overriding(.init(temperature: 0)),
128-
memory: initialChatMemory,
144+
configuration: UserPreferenceChatGPTConfiguration().overriding {
145+
$0.temperature = 0
146+
$0.runFunctionsAutomatically = false
147+
},
148+
memory: EmptyChatGPTMemory(),
149+
functionProvider: FunctionProvider(),
129150
stream: false
130151
),
131152
promptTemplate: { input in [
@@ -140,9 +161,12 @@ public final class RefineDocumentChain: Chain {
140161
)
141162
refinementChatModel = .init(
142163
chatModel: OpenAIChat(
143-
configuration: UserPreferenceChatGPTConfiguration()
144-
.overriding(.init(temperature: 0)),
145-
memory: refinementChatMemory,
164+
configuration: UserPreferenceChatGPTConfiguration().overriding {
165+
$0.temperature = 0
166+
$0.runFunctionsAutomatically = false
167+
},
168+
memory: EmptyChatGPTMemory(),
169+
functionProvider: FunctionProvider(),
146170
stream: false
147171
),
148172
promptTemplate: { input in [
@@ -168,6 +192,26 @@ public final class RefineDocumentChain: Chain {
168192
guard let firstDocument = input.documents.first else {
169193
return ""
170194
}
195+
196+
func extractAnswer(_ chatMessage: ChatMessage) -> IntermediateAnswer {
197+
if let functionCall = chatMessage.functionCall {
198+
do {
199+
let intermediateAnswer = try JSONDecoder().decode(
200+
IntermediateAnswer.self,
201+
from: functionCall.arguments.data(using: .utf8) ?? Data()
202+
)
203+
return intermediateAnswer
204+
} catch {
205+
let intermediateAnswer = IntermediateAnswer(
206+
answer: functionCall.arguments,
207+
score: 0,
208+
more: true
209+
)
210+
return intermediateAnswer
211+
}
212+
}
213+
return .init(answer: chatMessage.content ?? "", score: 0, more: true)
214+
}
171215
var output = try await initialChatModel.call(
172216
.init(
173217
question: input.question,
@@ -176,24 +220,27 @@ public final class RefineDocumentChain: Chain {
176220
),
177221
callbackManagers: callbackManagers
178222
)
179-
guard var content = output.content else { return "" }
180-
callbackManagers
181-
.send(CallbackEvents.RetrievalQADidGenerateIntermediateAnswer(info: content))
182-
for document in input.documents.dropFirst(1) {
223+
var intermediateAnswer = extractAnswer(output)
224+
callbackManagers.send(
225+
CallbackEvents.RetrievalQADidGenerateIntermediateAnswer(info: intermediateAnswer)
226+
)
227+
228+
for document in input.documents.dropFirst(1) where intermediateAnswer.more {
183229
output = try await refinementChatModel.call(
184230
.init(
185231
question: input.question,
186-
previousAnswer: content,
232+
previousAnswer: intermediateAnswer.answer,
187233
document: document.document.pageContent,
188234
distance: document.distance
189235
),
190236
callbackManagers: callbackManagers
191237
)
192-
content = output.content ?? ""
193-
callbackManagers
194-
.send(CallbackEvents.RetrievalQADidGenerateIntermediateAnswer(info: content))
238+
intermediateAnswer = extractAnswer(output)
239+
callbackManagers.send(
240+
CallbackEvents.RetrievalQADidGenerateIntermediateAnswer(info: intermediateAnswer)
241+
)
195242
}
196-
return content
243+
return intermediateAnswer.answer
197244
}
198245

199246
public func parseOutput(_ output: String) -> String {

Tool/Sources/LangChain/ChatModel/OpenAIChat.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ public struct OpenAIChat: ChatModel {
3838
var message = ""
3939
for try await trunk in stream {
4040
message.append(trunk)
41-
callbackManagers
42-
.forEach { $0.send(CallbackEvents.LLMDidProduceNewToken(info: trunk)) }
41+
callbackManagers.send(CallbackEvents.LLMDidProduceNewToken(info: trunk))
4342
}
4443
return await memory.messages.last ?? .init(role: .assistant, content: "")
4544
} else {

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ extension ChatGPTService {
205205
model: configuration.model,
206206
remainingTokens: remainingTokens
207207
),
208-
function_call: nil,
208+
function_call: functionProvider.functionCallStrategy,
209209
functions: functionProvider.functions.map {
210210
ChatGPTFunctionSchema(
211211
name: $0.name,
@@ -302,7 +302,7 @@ extension ChatGPTService {
302302
model: configuration.model,
303303
remainingTokens: remainingTokens
304304
),
305-
function_call: nil,
305+
function_call: functionProvider.functionCallStrategy,
306306
functions: functionProvider.functions.map {
307307
ChatGPTFunctionSchema(
308308
name: $0.name,
@@ -318,6 +318,7 @@ extension ChatGPTService {
318318
url,
319319
requestBody
320320
)
321+
321322
let response = try await api()
322323

323324
guard let choice = response.choices.first else { return nil }

Tool/Sources/OpenAIService/CompletionStreamAPI.swift

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,31 @@ protocol CompletionStreamAPI {
1212
)
1313
}
1414

15+
public enum FunctionCallStrategy: Encodable, Equatable {
16+
/// Forbid the bot to call any function.
17+
case none
18+
/// Let the bot choose what function to call.
19+
case auto
20+
/// Force the bot to call a function with the given name.
21+
case name(String)
22+
23+
struct CallFunctionNamed: Codable {
24+
var name: String
25+
}
26+
27+
public func encode(to encoder: Encoder) throws {
28+
var container = encoder.singleValueContainer()
29+
switch self {
30+
case .none:
31+
try container.encode("none")
32+
case .auto:
33+
try container.encode("auto")
34+
case let .name(name):
35+
try container.encode(CallFunctionNamed(name: name))
36+
}
37+
}
38+
}
39+
1540
/// https://platform.openai.com/docs/api-reference/chat/create
1641
struct CompletionRequestBody: Encodable, Equatable {
1742
struct Message: Codable, Equatable {
@@ -31,7 +56,7 @@ struct CompletionRequestBody: Encodable, Equatable {
3156
/// "arguments": "{ \"location\": \"earth\" }"
3257
/// }
3358
/// ```
34-
var function_call: MessageFunctionCall?
59+
var function_call: CompletionRequestBody.MessageFunctionCall?
3560
}
3661

3762
struct MessageFunctionCall: Codable, Equatable {
@@ -41,31 +66,6 @@ struct CompletionRequestBody: Encodable, Equatable {
4166
var arguments: String?
4267
}
4368

44-
enum FunctionCallStrategy: Encodable, Equatable {
45-
/// Forbid the bot to call any function.
46-
case none
47-
/// Let the bot choose what function to call.
48-
case auto
49-
/// Force the bot to call a function with the given name.
50-
case name(String)
51-
52-
struct CallFunctionNamed: Codable {
53-
var name: String
54-
}
55-
56-
func encode(to encoder: Encoder) throws {
57-
var container = encoder.singleValueContainer()
58-
switch self {
59-
case .none:
60-
try container.encode("none")
61-
case .auto:
62-
try container.encode("auto")
63-
case let .name(name):
64-
try container.encode(CallFunctionNamed(name: name))
65-
}
66-
}
67-
}
68-
6969
struct Function: Codable {
7070
var name: String
7171
var description: String

Tool/Sources/OpenAIService/FucntionCall/ChatGPTFuntionProvider.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import Foundation
22

33
public protocol ChatGPTFunctionProvider {
44
var functions: [any ChatGPTFunction] { get }
5+
var functionCallStrategy: FunctionCallStrategy? { get }
56
}
67

78
extension ChatGPTFunctionProvider {
@@ -11,6 +12,7 @@ extension ChatGPTFunctionProvider {
1112
}
1213

1314
public struct NoChatGPTFunctionProvider: ChatGPTFunctionProvider {
15+
public var functionCallStrategy: FunctionCallStrategy?
1416
public var functions: [any ChatGPTFunction] { [] }
1517
public init() {}
1618
}

0 commit comments

Comments
 (0)