Skip to content

Commit d72a7fe

Browse files
committed
Merge branch 'feature/retrieval-qa-improvement' into develop
2 parents e011246 + eb6821f commit d72a7fe

15 files changed

Lines changed: 337 additions & 184 deletions

File tree

Core/Sources/ChatContextCollectors/WebChatContextCollector/QueryWebsiteFunction.swift

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,7 @@ struct QueryWebsiteFunction: ChatGPTFunction {
6464

6565
if let database = await TemporaryUSearch.view(identifier: urlString) {
6666
await reportProgress("Generating answers..")
67-
let qa = RetrievalQAChain(vectorStore: database, embedding: embedding) {
68-
OpenAIChat(
69-
configuration: UserPreferenceChatGPTConfiguration()
70-
.overriding(.init(temperature: 0)),
71-
stream: true
72-
)
73-
}
67+
let qa = RetrievalQAChain(vectorStore: database, embedding: embedding)
7468
return try await qa.call(.init(arguments.query)).answer
7569
}
7670
let loader = WebLoader(urls: [url])
@@ -89,13 +83,7 @@ struct QueryWebsiteFunction: ChatGPTFunction {
8983
try await database.set(embeddedDocuments)
9084
// 4. generate answer
9185
await reportProgress("Generating answers..")
92-
let qa = RetrievalQAChain(vectorStore: database, embedding: embedding) {
93-
OpenAIChat(
94-
configuration: UserPreferenceChatGPTConfiguration()
95-
.overriding(.init(temperature: 0)),
96-
stream: true
97-
)
98-
}
86+
let qa = RetrievalQAChain(vectorStore: database, embedding: embedding)
9987
let result = try await qa.call(.init(arguments.query))
10088
return result.answer
10189
}

Core/Sources/ChatService/ChatFunctionProvider.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,9 @@ final class ChatFunctionProvider {
1515
}
1616
}
1717

18-
extension ChatFunctionProvider: ChatGPTFunctionProvider {}
18+
extension ChatFunctionProvider: ChatGPTFunctionProvider {
19+
var functionCallStrategy: OpenAIService.FunctionCallStrategy? {
20+
nil
21+
}
22+
}
1923

Playground.playground/Pages/RetrievalQAChain.xcplaygroundpage/Contents.swift

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,97 @@
11
import AppKit
2+
import Foundation
23
import LangChain
34
import OpenAIService
45
import PlaygroundSupport
56
import SwiftUI
67

78
struct QAForm: View {
8-
@State var intermediateAnswers = [String]()
9+
@State var intermediateAnswers = [RefineDocumentChain.IntermediateAnswer]()
10+
@State var relevantDocuments = [(document: Document, distance: Float)]()
11+
@State var duration: TimeInterval = 0
912
@State var answer: String = ""
1013
@State var question: String = "What is Swift macros?"
1114
@State var isProcessing: Bool = false
1215
@State var url: String = "https://developer.apple.com/documentation/swift/applying-macros"
1316

1417
var body: some View {
15-
Form {
16-
Section(header: Text("Input")) {
17-
TextField("URL", text: $url)
18-
TextField("Question", text: $question)
19-
Button("Ask") {
20-
Task {
21-
do {
22-
try await ask()
23-
} catch {
24-
answer = error.localizedDescription
18+
HStack(spacing: 0) {
19+
ScrollView {
20+
Form {
21+
Section(header: Text("Input")) {
22+
TextField("URL", text: $url)
23+
TextField("Question", text: $question)
24+
HStack {
25+
Button("Ask") {
26+
Task {
27+
do {
28+
try await ask()
29+
} catch {
30+
answer = error.localizedDescription
31+
}
32+
}
33+
}
34+
.disabled(isProcessing)
35+
36+
Text("\(duration) seconds")
37+
}
38+
}
39+
Section(header: Text("Answer")) {
40+
Text(answer)
41+
}
42+
Section(header: Text("Intermediate Answers")) {
43+
ForEach(0..<intermediateAnswers.endIndex, id: \.self) { index in
44+
let answer = intermediateAnswers[index]
45+
VStack(alignment: .leading) {
46+
Text(answer.answer)
47+
VStack(alignment: .leading) {
48+
Text("Usefulness: \(answer.usefulness)")
49+
Text("Needs more context: \(answer.more ? "Yes" : "No")")
50+
}
51+
.padding()
52+
.background {
53+
RoundedRectangle(cornerRadius: 8)
54+
.fill(Color(NSColor.textBackgroundColor))
55+
}
56+
Divider()
57+
}
58+
.textSelection(.enabled)
2559
}
2660
}
2761
}
28-
.disabled(isProcessing)
62+
.formStyle(.grouped)
2963
}
30-
Section(header: Text("Answer")) {
31-
Text(answer)
32-
}
33-
Section(header: Text("Intermediate Answers")) {
34-
ForEach(intermediateAnswers, id: \.self) { answer in
35-
Text(answer)
36-
Divider()
37-
}
64+
65+
ScrollView {
66+
Form {
67+
Section(header: Text("Relevant Documents")) {
68+
ForEach(0..<relevantDocuments.endIndex, id: \.self) { index in
69+
let document = relevantDocuments[index]
70+
VStack(alignment: .leading) {
71+
Text("\(document.distance)")
72+
Text(document.document.pageContent)
73+
Divider()
74+
}
75+
.textSelection(.enabled)
76+
}
77+
}
78+
}.formStyle(.grouped)
3879
}
3980
}
40-
.formStyle(.grouped)
4181
}
4282

4383
func ask() async throws {
84+
let start = Date().timeIntervalSince1970
85+
answer = ""
86+
relevantDocuments = []
4487
intermediateAnswers = []
88+
duration = 0
4589
isProcessing = true
4690
defer { isProcessing = false }
4791
guard let url = URL(string: url) else {
4892
answer = "Invalid URL"
4993
return
5094
}
51-
let chatGPTConfiguration = UserPreferenceChatGPTConfiguration()
52-
.overriding { $0.temperature = 0 }
5395
let embeddingConfiguration = UserPreferenceEmbeddingConfiguration().overriding()
5496
let embedding = OpenAIEmbedding(configuration: embeddingConfiguration)
5597
let store: VectorStore = try await {
@@ -72,8 +114,7 @@ struct QAForm: View {
72114

73115
let qa = RetrievalQAChain(
74116
vectorStore: store,
75-
embedding: embedding,
76-
chatModelFactory: { OpenAIChat(configuration: chatGPTConfiguration, stream: false) }
117+
embedding: embedding
77118
)
78119
answer = try await qa.run(
79120
question,
@@ -82,15 +123,19 @@ struct QAForm: View {
82123
$0.on(CallbackEvents.RetrievalQADidGenerateIntermediateAnswer.self) {
83124
intermediateAnswers.append($0)
84125
}
126+
$0.on(CallbackEvents.RetrievalQADidExtractRelevantContent.self) {
127+
relevantDocuments = $0
128+
}
85129
},
86130
]
87131
)
132+
duration = Date().timeIntervalSince1970 - start
88133
}
89134
}
90135

91136
let hostingView = NSHostingController(
92137
rootView: QAForm()
93-
.frame(width: 600, height: 800)
138+
.frame(width: 800, height: 800)
94139
)
95140

96141
PlaygroundPage.current.needsIndefiniteExecution = true

Playground.playground/Pages/RetrievalQAChain.xcplaygroundpage/timeline.xctimeline

Lines changed: 0 additions & 12 deletions
This file was deleted.

Tool/Sources/LangChain/Agent.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public extension Agent {
104104
) async throws -> AgentNextStep {
105105
let input = getFullInputs(input: input, intermediateSteps: intermediateSteps)
106106
let output = try await chatModelChain.call(input, callbackManagers: callbackManagers)
107-
return parseOutput(output)
107+
return parseOutput(output.content ?? "")
108108
}
109109

110110
func returnStoppedResponse(
@@ -128,12 +128,13 @@ public extension Agent {
128128
"""
129129
let input = AgentInput(input: input, thoughts: .text(thoughts))
130130
let output = try await chatModelChain.call(input, callbackManagers: callbackManagers)
131-
let nextAction = parseOutput(output)
131+
let reply = output.content ?? ""
132+
let nextAction = parseOutput(reply)
132133
switch nextAction {
133134
case let .finish(finish):
134135
return finish
135136
case .actions:
136-
return AgentFinish(returnValue: output, log: output)
137+
return AgentFinish(returnValue: reply, log: reply)
137138
}
138139
}
139140
}

Tool/Sources/LangChain/Chains/LLMChain.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import Foundation
22

33
public class ChatModelChain<Input>: Chain {
4-
public typealias Output = String
4+
public typealias Output = ChatMessage
55

66
var chatModel: ChatModel
77
var promptTemplate: (Input) -> [ChatMessage]
@@ -31,7 +31,13 @@ public class ChatModelChain<Input>: Chain {
3131
}
3232

3333
public func parseOutput(_ output: Output) -> String {
34-
output
34+
if let content = output.content {
35+
return content
36+
} else if let functionCall = output.functionCall {
37+
return "\(functionCall.name): \(functionCall.arguments)"
38+
}
39+
40+
return ""
3541
}
3642
}
3743

0 commit comments

Comments
 (0)