Skip to content

Commit 1587c50

Browse files
committed
Merge branch 'feature/retrieval-qa-to-get-information-only' into develop
2 parents 384d44c + 3881faf commit 1587c50

10 files changed

Lines changed: 421 additions & 225 deletions

File tree

Core/Sources/ChatContextCollectors/WebChatContextCollector/QueryWebsiteFunction.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ struct QueryWebsiteFunction: ChatGPTFunction {
6464

6565
if let database = await TemporaryUSearch.view(identifier: urlString) {
6666
await reportProgress("Generating answers..")
67-
let qa = RetrievalQAChain(vectorStore: database, embedding: embedding)
68-
return try await qa.call(.init(arguments.query)).answer
67+
let qa = QAInformationRetrievalChain(vectorStore: database, embedding: embedding)
68+
return try await qa.call(.init(arguments.query)).information
6969
}
7070
let loader = WebLoader(urls: [url])
7171
let documents = try await loader.load()
@@ -83,9 +83,9 @@ struct QueryWebsiteFunction: ChatGPTFunction {
8383
try await database.set(embeddedDocuments)
8484
// 4. generate answer
8585
await reportProgress("Generating answers..")
86-
let qa = RetrievalQAChain(vectorStore: database, embedding: embedding)
86+
let qa = QAInformationRetrievalChain(vectorStore: database, embedding: embedding)
8787
let result = try await qa.call(.init(arguments.query))
88-
return result.answer
88+
return result.information
8989
}
9090
}
9191

Playground.playground/Pages/RetrievalQAChain.xcplaygroundpage/Contents.swift

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ import LangChain
44
import OpenAIService
55
import PlaygroundSupport
66
import SwiftUI
7+
import TokenEncoder
78

89
struct QAForm: View {
9-
@State var intermediateAnswers = [RefineDocumentChain.IntermediateAnswer]()
10+
@State var relevantInformation = [String]()
1011
@State var relevantDocuments = [(document: Document, distance: Float)]()
1112
@State var duration: TimeInterval = 0
1213
@State var answer: String = ""
14+
@State var tokenCount: Int = 0
1315
@State var question: String = "What is Swift macros?"
1416
@State var isProcessing: Bool = false
1517
@State var url: String = "https://developer.apple.com/documentation/swift/applying-macros"
@@ -36,23 +38,14 @@ struct QAForm: View {
3638
Text("\(duration) seconds")
3739
}
3840
}
39-
Section(header: Text("Answer")) {
41+
Section(header: Text("All Relevant Information (\(tokenCount) words)")) {
4042
Text(answer)
4143
}
42-
Section(header: Text("Intermediate Answers")) {
43-
ForEach(0..<intermediateAnswers.endIndex, id: \.self) { index in
44-
let answer = intermediateAnswers[index]
44+
Section(header: Text("Relevant Information")) {
45+
ForEach(0..<relevantInformation.endIndex, id: \.self) { index in
46+
let information = relevantInformation[index]
4547
VStack(alignment: .leading) {
46-
Text(answer.answer)
47-
VStack(alignment: .leading) {
48-
Text("Usefulness: \(answer.usefulness)")
49-
Text("Needs more context: \(answer.more ? "Yes" : "No")")
50-
}
51-
.padding()
52-
.background {
53-
RoundedRectangle(cornerRadius: 8)
54-
.fill(Color(NSColor.textBackgroundColor))
55-
}
48+
Text(information)
5649
Divider()
5750
}
5851
.textSelection(.enabled)
@@ -84,8 +77,9 @@ struct QAForm: View {
8477
let start = Date().timeIntervalSince1970
8578
answer = ""
8679
relevantDocuments = []
87-
intermediateAnswers = []
80+
relevantInformation = []
8881
duration = 0
82+
tokenCount = 0
8983
isProcessing = true
9084
defer { isProcessing = false }
9185
guard let url = URL(string: url) else {
@@ -112,23 +106,24 @@ struct QAForm: View {
112106
}
113107
}()
114108

115-
let qa = RetrievalQAChain(
109+
let qa = QAInformationRetrievalChain(
116110
vectorStore: store,
117111
embedding: embedding
118112
)
119113
answer = try await qa.run(
120114
question,
121115
callbackManagers: [
122116
.init {
123-
$0.on(CallbackEvents.RetrievalQADidGenerateIntermediateAnswer.self) {
124-
intermediateAnswers.append($0)
117+
$0.on(\.relevantInformationExtractionChainDidExtractPartialRelevantContent) {
118+
relevantInformation.append($0)
125119
}
126-
$0.on(CallbackEvents.RetrievalQADidExtractRelevantContent.self) {
120+
$0.on(\.retrievalQADidExtractRelevantContent) {
127121
relevantDocuments = $0
128122
}
129123
},
130124
]
131125
)
126+
tokenCount = answer.split(separator: " ").count
132127
duration = Date().timeIntervalSince1970 - start
133128
}
134129
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<Timeline
3+
version = "3.0">
4+
<TimelineItems>
5+
</TimelineItems>
6+
</Timeline>

Tool/Sources/LangChain/Agent.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,26 @@ public extension CallbackEvents {
2424
struct AgentDidFinish: CallbackEvent {
2525
public let info: AgentFinish
2626
}
27+
28+
var agentDidFinish: AgentDidFinish.Type {
29+
AgentDidFinish.self
30+
}
2731

2832
struct AgentActionDidStart: CallbackEvent {
2933
public let info: AgentAction
3034
}
35+
36+
var agentActionDidStart: AgentActionDidStart.Type {
37+
AgentActionDidStart.self
38+
}
3139

3240
struct AgentActionDidEnd: CallbackEvent {
3341
public let info: AgentAction
3442
}
43+
44+
var agentActionDidEnd: AgentActionDidEnd.Type {
45+
AgentActionDidEnd.self
46+
}
3547
}
3648

3749
public struct AgentFinish: Equatable {

Tool/Sources/LangChain/Callback.swift

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ public protocol CallbackEvent {
55
var info: Info { get }
66
}
77

8-
public enum CallbackEvents {}
8+
public struct CallbackEvents {
9+
private init() {}
10+
}
911

1012
public struct CallbackManager {
1113
fileprivate var observers = [Any]()
@@ -25,19 +27,39 @@ public struct CallbackManager {
2527
observers.append(handler)
2628
}
2729

30+
public mutating func on<Event: CallbackEvent>(
31+
_: KeyPath<CallbackEvents, Event.Type>,
32+
_ handler: @escaping (Event.Info) -> Void
33+
) {
34+
observers.append(handler)
35+
}
36+
2837
public func send<Event: CallbackEvent>(_ event: Event) {
2938
for case let observer as ((Event.Info) -> Void) in observers {
3039
observer(event.info)
3140
}
3241
}
42+
43+
func send<Event: CallbackEvent>(
44+
_: KeyPath<CallbackEvents, Event.Type>,
45+
_ info: Event.Info
46+
) {
47+
for case let observer as ((Event.Info) -> Void) in observers {
48+
observer(info)
49+
}
50+
}
3351
}
3452

3553
public extension [CallbackManager] {
3654
func send<Event: CallbackEvent>(_ event: Event) {
37-
for cb in self {
38-
for case let observer as ((Event.Info) -> Void) in cb.observers {
39-
observer(event.info)
40-
}
41-
}
55+
for cb in self { cb.send(event) }
56+
}
57+
58+
func send<Event: CallbackEvent>(
59+
_ keyPath: KeyPath<CallbackEvents, Event.Type>,
60+
_ info: Event.Info
61+
) {
62+
for cb in self { cb.send(keyPath, info) }
4263
}
4364
}
65+

Tool/Sources/LangChain/Chain.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public protocol Chain {
1010
public extension Chain {
1111
typealias ChainDidStart = CallbackEvents.ChainDidStart<Self>
1212
typealias ChainDidEnd = CallbackEvents.ChainDidEnd<Self>
13-
13+
1414
func run(_ input: Input, callbackManagers: [CallbackManager] = []) async throws -> String {
1515
let output = try await call(input, callbackManagers: callbackManagers)
1616
return parseOutput(output)

0 commit comments

Comments
 (0)