Skip to content

Commit 4fba6f3

Browse files
committed
Use RelevantInformationExtractionChain to replace RefineDocumentChain
1 parent 384d44c commit 4fba6f3

File tree

9 files changed

+412
-215
lines changed

9 files changed

+412
-215
lines changed

Playground.playground/Pages/RetrievalQAChain.xcplaygroundpage/Contents.swift

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ import LangChain
44
import OpenAIService
55
import PlaygroundSupport
66
import SwiftUI
7+
import TokenEncoder
78

89
struct QAForm: View {
9-
@State var intermediateAnswers = [RefineDocumentChain.IntermediateAnswer]()
10+
@State var relevantInformation = [String]()
1011
@State var relevantDocuments = [(document: Document, distance: Float)]()
1112
@State var duration: TimeInterval = 0
1213
@State var answer: String = ""
14+
@State var tokenCount: Int = 0
1315
@State var question: String = "What is Swift macros?"
1416
@State var isProcessing: Bool = false
1517
@State var url: String = "https://developer.apple.com/documentation/swift/applying-macros"
@@ -36,23 +38,14 @@ struct QAForm: View {
3638
Text("\(duration) seconds")
3739
}
3840
}
39-
Section(header: Text("Answer")) {
41+
Section(header: Text("All Relevant Information (\(tokenCount) words)")) {
4042
Text(answer)
4143
}
42-
Section(header: Text("Intermediate Answers")) {
43-
ForEach(0..<intermediateAnswers.endIndex, id: \.self) { index in
44-
let answer = intermediateAnswers[index]
44+
Section(header: Text("Relevant Information")) {
45+
ForEach(0..<relevantInformation.endIndex, id: \.self) { index in
46+
let information = relevantInformation[index]
4547
VStack(alignment: .leading) {
46-
Text(answer.answer)
47-
VStack(alignment: .leading) {
48-
Text("Usefulness: \(answer.usefulness)")
49-
Text("Needs more context: \(answer.more ? "Yes" : "No")")
50-
}
51-
.padding()
52-
.background {
53-
RoundedRectangle(cornerRadius: 8)
54-
.fill(Color(NSColor.textBackgroundColor))
55-
}
48+
Text(information)
5649
Divider()
5750
}
5851
.textSelection(.enabled)
@@ -84,8 +77,9 @@ struct QAForm: View {
8477
let start = Date().timeIntervalSince1970
8578
answer = ""
8679
relevantDocuments = []
87-
intermediateAnswers = []
80+
relevantInformation = []
8881
duration = 0
82+
tokenCount = 0
8983
isProcessing = true
9084
defer { isProcessing = false }
9185
guard let url = URL(string: url) else {
@@ -120,15 +114,16 @@ struct QAForm: View {
120114
question,
121115
callbackManagers: [
122116
.init {
123-
$0.on(CallbackEvents.RetrievalQADidGenerateIntermediateAnswer.self) {
124-
intermediateAnswers.append($0)
117+
$0.on(\.relevantInformationExtractionChainDidExtractPartialRelevantContent) {
118+
relevantInformation.append($0)
125119
}
126-
$0.on(CallbackEvents.RetrievalQADidExtractRelevantContent.self) {
120+
$0.on(\.retrievalQADidExtractRelevantContent) {
127121
relevantDocuments = $0
128122
}
129123
},
130124
]
131125
)
126+
tokenCount = answer.split(separator: " ").count
132127
duration = Date().timeIntervalSince1970 - start
133128
}
134129
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<Timeline
3+
version = "3.0">
4+
<TimelineItems>
5+
</TimelineItems>
6+
</Timeline>

Tool/Sources/LangChain/Agent.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,26 @@ public extension CallbackEvents {
2424
struct AgentDidFinish: CallbackEvent {
2525
public let info: AgentFinish
2626
}
27+
28+
var agentDidFinish: AgentDidFinish.Type {
29+
AgentDidFinish.self
30+
}
2731

2832
struct AgentActionDidStart: CallbackEvent {
2933
public let info: AgentAction
3034
}
35+
36+
var agentActionDidStart: AgentActionDidStart.Type {
37+
AgentActionDidStart.self
38+
}
3139

3240
struct AgentActionDidEnd: CallbackEvent {
3341
public let info: AgentAction
3442
}
43+
44+
var agentActionDidEnd: AgentActionDidEnd.Type {
45+
AgentActionDidEnd.self
46+
}
3547
}
3648

3749
public struct AgentFinish: Equatable {

Tool/Sources/LangChain/Callback.swift

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ public protocol CallbackEvent {
55
var info: Info { get }
66
}
77

8-
public enum CallbackEvents {}
8+
public struct CallbackEvents {
9+
private init() {}
10+
}
911

1012
public struct CallbackManager {
1113
fileprivate var observers = [Any]()
@@ -25,19 +27,39 @@ public struct CallbackManager {
2527
observers.append(handler)
2628
}
2729

30+
public mutating func on<Event: CallbackEvent>(
31+
_: KeyPath<CallbackEvents, Event.Type>,
32+
_ handler: @escaping (Event.Info) -> Void
33+
) {
34+
observers.append(handler)
35+
}
36+
2837
public func send<Event: CallbackEvent>(_ event: Event) {
2938
for case let observer as ((Event.Info) -> Void) in observers {
3039
observer(event.info)
3140
}
3241
}
42+
43+
func send<Event: CallbackEvent>(
44+
_: KeyPath<CallbackEvents, Event.Type>,
45+
_ info: Event.Info
46+
) {
47+
for case let observer as ((Event.Info) -> Void) in observers {
48+
observer(info)
49+
}
50+
}
3351
}
3452

3553
public extension [CallbackManager] {
3654
func send<Event: CallbackEvent>(_ event: Event) {
37-
for cb in self {
38-
for case let observer as ((Event.Info) -> Void) in cb.observers {
39-
observer(event.info)
40-
}
41-
}
55+
for cb in self { cb.send(event) }
56+
}
57+
58+
func send<Event: CallbackEvent>(
59+
_ keyPath: KeyPath<CallbackEvents, Event.Type>,
60+
_ info: Event.Info
61+
) {
62+
for cb in self { cb.send(keyPath, info) }
4263
}
4364
}
65+

Tool/Sources/LangChain/Chain.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public protocol Chain {
1010
public extension Chain {
1111
typealias ChainDidStart = CallbackEvents.ChainDidStart<Self>
1212
typealias ChainDidEnd = CallbackEvents.ChainDidEnd<Self>
13-
13+
1414
func run(_ input: Input, callbackManagers: [CallbackManager] = []) async throws -> String {
1515
let output = try await call(input, callbackManagers: callbackManagers)
1616
return parseOutput(output)
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import Foundation
2+
import OpenAIService
3+
4+
public final class RefineDocumentChain: Chain {
5+
public struct Input {
6+
var question: String
7+
var documents: [(document: Document, distance: Float)]
8+
}
9+
10+
struct RefinementInput {
11+
var index: Int
12+
var totalCount: Int
13+
var question: String
14+
var previousAnswer: String?
15+
var document: String
16+
var distance: Float
17+
}
18+
19+
public struct IntermediateAnswer: Decodable {
20+
public var answer: String
21+
public var usefulness: Double
22+
public var more: Bool
23+
24+
public enum CodingKeys: String, CodingKey {
25+
case answer
26+
case usefulness
27+
case more
28+
}
29+
30+
init(answer: String, usefulness: Double, more: Bool) {
31+
self.answer = answer
32+
self.usefulness = usefulness
33+
self.more = more
34+
}
35+
36+
public init(from decoder: Decoder) throws {
37+
let container = try decoder.container(keyedBy: CodingKeys.self)
38+
answer = try container.decode(String.self, forKey: .answer)
39+
usefulness = (try? container.decode(Double.self, forKey: .usefulness)) ?? 0
40+
more = (try? container.decode(Bool.self, forKey: .more)) ?? true
41+
}
42+
}
43+
44+
class FunctionProvider: ChatGPTFunctionProvider {
45+
var functionCallStrategy: FunctionCallStrategy? = .name("respond")
46+
var functions: [any ChatGPTFunction] = [RespondFunction()]
47+
}
48+
49+
struct RespondFunction: ChatGPTFunction {
50+
typealias Arguments = IntermediateAnswer
51+
52+
struct Result: ChatGPTFunctionResult {
53+
var botReadableContent: String { "" }
54+
}
55+
56+
var reportProgress: (String) async -> Void = { _ in }
57+
58+
var name: String = "respond"
59+
var description: String = "Respond with the refined answer"
60+
var argumentSchema: JSONSchemaValue {
61+
return [
62+
.type: "object",
63+
.properties: [
64+
"answer": [
65+
.type: "string",
66+
.description: "The refined answer",
67+
],
68+
"usefulness": [
69+
.type: "number",
70+
.description: "How useful the page of document is in generating the answer, the higher the better. 0 to 10",
71+
],
72+
"more": [
73+
.type: "boolean",
74+
.description: "Whether you want to read the next page. The next page maybe less relevant to the question",
75+
],
76+
],
77+
.required: ["answer", "more", "usefulness"],
78+
]
79+
}
80+
81+
func prepare() async {}
82+
83+
func call(arguments: Arguments) async throws -> Result {
84+
return Result()
85+
}
86+
}
87+
88+
func buildChatModel() -> ChatModelChain<RefinementInput> {
89+
.init(
90+
chatModel: OpenAIChat(
91+
configuration: UserPreferenceChatGPTConfiguration().overriding {
92+
$0.temperature = 0
93+
$0.runFunctionsAutomatically = false
94+
},
95+
memory: EmptyChatGPTMemory(),
96+
functionProvider: FunctionProvider(),
97+
stream: false
98+
),
99+
promptTemplate: { input in [
100+
.init(
101+
role: .system,
102+
content: {
103+
if let previousAnswer = input.previousAnswer {
104+
return """
105+
The user will send you a question about a document, you must refine your previous answer to it only according to the document.
106+
Previous answer:###
107+
\(previousAnswer)
108+
###
109+
Page \(input.index) of \(input.totalCount) of the document:###
110+
\(input.document)
111+
###
112+
"""
113+
} else {
114+
return """
115+
The user will send you a question about a document, you must answer it only according to the document.
116+
Page \(input.index) of \(input.totalCount) of the document:###
117+
\(input.document)
118+
###
119+
"""
120+
}
121+
}()
122+
123+
),
124+
.init(role: .user, content: input.question),
125+
] }
126+
)
127+
}
128+
129+
public init() {}
130+
131+
public func callLogic(
132+
_ input: Input,
133+
callbackManagers: [CallbackManager]
134+
) async throws -> String {
135+
var intermediateAnswer: IntermediateAnswer?
136+
137+
for (index, document) in input.documents.enumerated() {
138+
if let intermediateAnswer, !intermediateAnswer.more { break }
139+
140+
let output = try await buildChatModel().call(
141+
.init(
142+
index: index,
143+
totalCount: input.documents.count,
144+
question: input.question,
145+
previousAnswer: intermediateAnswer?.answer,
146+
document: document.document.pageContent,
147+
distance: document.distance
148+
),
149+
callbackManagers: callbackManagers
150+
)
151+
intermediateAnswer = extractAnswer(output)
152+
153+
if let intermediateAnswer {
154+
callbackManagers.send(
155+
\.refineDocumentChainDidGenerateIntermediateAnswer,
156+
intermediateAnswer
157+
)
158+
}
159+
}
160+
161+
return intermediateAnswer?.answer ?? "None"
162+
}
163+
164+
public func parseOutput(_ output: String) -> String {
165+
return output
166+
}
167+
168+
func extractAnswer(_ chatMessage: ChatMessage) -> IntermediateAnswer {
169+
if let functionCall = chatMessage.functionCall {
170+
do {
171+
let intermediateAnswer = try JSONDecoder().decode(
172+
IntermediateAnswer.self,
173+
from: functionCall.arguments.data(using: .utf8) ?? Data()
174+
)
175+
return intermediateAnswer
176+
} catch {
177+
let intermediateAnswer = IntermediateAnswer(
178+
answer: functionCall.arguments,
179+
usefulness: 0,
180+
more: true
181+
)
182+
return intermediateAnswer
183+
}
184+
}
185+
return .init(answer: chatMessage.content ?? "", usefulness: 0, more: true)
186+
}
187+
}
188+
189+
public extension CallbackEvents {
190+
struct RefineDocumentChainDidGenerateIntermediateAnswer: CallbackEvent {
191+
public let info: RefineDocumentChain.IntermediateAnswer
192+
}
193+
194+
var refineDocumentChainDidGenerateIntermediateAnswer:
195+
RefineDocumentChainDidGenerateIntermediateAnswer.Type
196+
{
197+
RefineDocumentChainDidGenerateIntermediateAnswer.self
198+
}
199+
}
200+

0 commit comments

Comments
 (0)