Skip to content

Commit da55011

Browse files
committed
WIP
1 parent 2d92790 commit da55011

File tree

6 files changed

+102
-56
lines changed

6 files changed

+102
-56
lines changed

Core/Sources/ChatContextCollectors/WebChatContextCollector/QueryWebsiteFunction.swift

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,7 @@ struct QueryWebsiteFunction: ChatGPTFunction {
6464

6565
if let database = await TemporaryUSearch.view(identifier: urlString) {
6666
await reportProgress("Generating answers..")
67-
let qa = RetrievalQAChain(vectorStore: database, embedding: embedding) {
68-
OpenAIChat(
69-
configuration: UserPreferenceChatGPTConfiguration()
70-
.overriding(.init(temperature: 0)),
71-
stream: true
72-
)
73-
}
67+
let qa = RetrievalQAChain(vectorStore: database, embedding: embedding)
7468
return try await qa.call(.init(arguments.query)).answer
7569
}
7670
let loader = WebLoader(urls: [url])
@@ -89,13 +83,7 @@ struct QueryWebsiteFunction: ChatGPTFunction {
8983
try await database.set(embeddedDocuments)
9084
// 4. generate answer
9185
await reportProgress("Generating answers..")
92-
let qa = RetrievalQAChain(vectorStore: database, embedding: embedding) {
93-
OpenAIChat(
94-
configuration: UserPreferenceChatGPTConfiguration()
95-
.overriding(.init(temperature: 0)),
96-
stream: true
97-
)
98-
}
86+
let qa = RetrievalQAChain(vectorStore: database, embedding: embedding)
9987
let result = try await qa.call(.init(arguments.query))
10088
return result.answer
10189
}

Tool/Sources/LangChain/Chains/LLMChain.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import Foundation
22

33
public class ChatModelChain<Input>: Chain {
4-
public typealias Output = String
4+
public typealias Output = ChatMessage
55

66
var chatModel: ChatModel
77
var promptTemplate: (Input) -> [ChatMessage]
@@ -31,7 +31,13 @@ public class ChatModelChain<Input>: Chain {
3131
}
3232

3333
public func parseOutput(_ output: Output) -> String {
34-
output
34+
if let content = output.content {
35+
return content
36+
} else if let functionCall = output.functionCall {
37+
return "\(functionCall.name): \(functionCall.arguments)"
38+
}
39+
40+
return ""
3541
}
3642
}
3743

Tool/Sources/LangChain/Chains/RetrievalQA.swift

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import Foundation
2+
import OpenAIService
23

34
public final class RetrievalQAChain: Chain {
45
let vectorStore: VectorStore
56
let embedding: Embeddings
6-
let chatModelFactory: () -> ChatModel
77

88
public struct Output {
99
public var answer: String
@@ -12,12 +12,10 @@ public final class RetrievalQAChain: Chain {
1212

1313
public init(
1414
vectorStore: VectorStore,
15-
embedding: Embeddings,
16-
chatModelFactory: @escaping () -> ChatModel
15+
embedding: Embeddings
1716
) {
1817
self.vectorStore = vectorStore
1918
self.embedding = embedding
20-
self.chatModelFactory = chatModelFactory
2119
}
2220

2321
public func callLogic(
@@ -29,7 +27,7 @@ public final class RetrievalQAChain: Chain {
2927
embeddings: embeddedQuestion,
3028
count: 5
3129
)
32-
let refinementChain = RefineDocumentChain(chatModelFactory: chatModelFactory)
30+
let refinementChain = RefineDocumentChain()
3331
let answer = try await refinementChain.run(
3432
.init(question: input, documents: documents),
3533
callbackManagers: callbackManagers
@@ -68,12 +66,68 @@ public final class RefineDocumentChain: Chain {
6866
var distance: Float
6967
}
7068

69+
class FunctionProvider: ChatGPTFunctionProvider {
70+
var functions: [any ChatGPTFunction] = []
71+
}
72+
73+
struct RespondFunction: ChatGPTFunction {
74+
struct Arguments: Codable {
75+
var answer: String
76+
var score: Double
77+
var more: Bool
78+
}
79+
80+
struct Result: ChatGPTFunctionResult {
81+
var botReadableContent: String { "" }
82+
}
83+
84+
var reportProgress: (String) async -> Void = { _ in }
85+
86+
var name: String = "respond"
87+
var description: String = "Respond with the refined answer"
88+
var argumentSchema: JSONSchemaValue {
89+
return [
90+
.type: "object",
91+
.properties: [
92+
"answer": [
93+
.type: "string",
94+
.description: "The answer",
95+
],
96+
"score": [
97+
.type: "number",
98+
.description: "The score of the answer, the higher the better",
99+
],
100+
"more": [
101+
.type: "boolean",
102+
.description: "Whether more information is needed to complete the answer",
103+
],
104+
],
105+
]
106+
}
107+
108+
func prepare() async {}
109+
110+
func call(arguments: Arguments) async throws -> Result {
111+
return Result()
112+
}
113+
}
114+
71115
let initialChatModel: ChatModelChain<InitialInput>
72116
let refinementChatModel: ChatModelChain<RefinementInput>
117+
let initialChatMemory: ChatGPTMemory
118+
let refinementChatMemory: ChatGPTMemory
119+
120+
public init() {
121+
initialChatMemory = ConversationChatGPTMemory(systemPrompt: "")
122+
refinementChatMemory = ConversationChatGPTMemory(systemPrompt: "")
73123

74-
public init(chatModelFactory: () -> ChatModel) {
75124
initialChatModel = .init(
76-
chatModel: chatModelFactory(),
125+
chatModel: OpenAIChat(
126+
configuration: UserPreferenceChatGPTConfiguration()
127+
.overriding(.init(temperature: 0)),
128+
memory: initialChatMemory,
129+
stream: false
130+
),
77131
promptTemplate: { input in [
78132
.init(role: .system, content: """
79133
The user will send you a question, you must answer it at your best.
@@ -85,7 +139,12 @@ public final class RefineDocumentChain: Chain {
85139
] }
86140
)
87141
refinementChatModel = .init(
88-
chatModel: chatModelFactory(),
142+
chatModel: OpenAIChat(
143+
configuration: UserPreferenceChatGPTConfiguration()
144+
.overriding(.init(temperature: 0)),
145+
memory: refinementChatMemory,
146+
stream: false
147+
),
89148
promptTemplate: { input in [
90149
.init(role: .system, content: """
91150
The user will send you a question, you must refine your previous answer to it at your best.
@@ -117,7 +176,9 @@ public final class RefineDocumentChain: Chain {
117176
),
118177
callbackManagers: callbackManagers
119178
)
120-
callbackManagers.send(CallbackEvents.RetrievalQADidGenerateIntermediateAnswer(info: output))
179+
guard var content = output.content else { return "" }
180+
callbackManagers
181+
.send(CallbackEvents.RetrievalQADidGenerateIntermediateAnswer(info: content))
121182
for document in input.documents.dropFirst(1) {
122183
output = try await refinementChatModel.call(
123184
.init(

Tool/Sources/LangChain/ChatModel/ChatModel.swift

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,16 @@
11
import Foundation
2+
import OpenAIService
23

34
public protocol ChatModel {
45
func generate(
56
prompt: [ChatMessage],
67
stops: [String],
78
callbackManagers: [CallbackManager]
8-
) async throws -> String
9-
}
10-
11-
public struct ChatMessage {
12-
public enum Role {
13-
case system
14-
case user
15-
case assistant
16-
}
17-
18-
public var role: Role
19-
public var content: String
20-
21-
public init(role: Role, content: String) {
22-
self.role = role
23-
self.content = content
24-
}
9+
) async throws -> ChatMessage
2510
}
2611

12+
public typealias ChatMessage = OpenAIService.ChatMessage
13+
2714
public extension CallbackEvents {
2815
struct LLMDidProduceNewToken: CallbackEvent {
2916
public let info: String

Tool/Sources/LangChain/ChatModel/OpenAIChat.swift

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,14 @@ public struct OpenAIChat: ChatModel {
2323
prompt: [ChatMessage],
2424
stops: [String],
2525
callbackManagers: [CallbackManager]
26-
) async throws -> String {
26+
) async throws -> ChatMessage {
2727
let service = ChatGPTService(
2828
memory: memory,
2929
configuration: configuration,
3030
functionProvider: functionProvider
3131
)
3232
for message in prompt {
33-
let role: OpenAIService.ChatMessage.Role = {
34-
switch message.role {
35-
case .system:
36-
return .system
37-
case .user:
38-
return .user
39-
case .assistant:
40-
return .assistant
41-
}
42-
}()
43-
await memory.appendMessage(.init(role: role, content: message.content))
33+
await memory.appendMessage(message)
4434
}
4535

4636
if stream {
@@ -51,9 +41,10 @@ public struct OpenAIChat: ChatModel {
5141
callbackManagers
5242
.forEach { $0.send(CallbackEvents.LLMDidProduceNewToken(info: trunk)) }
5343
}
54-
return message
44+
return await memory.messages.last ?? .init(role: .assistant, content: "")
5545
} else {
56-
return try await service.sendAndWait(content: "") ?? ""
46+
let _ = try await service.sendAndWait(content: "")
47+
return await memory.messages.last ?? .init(role: .assistant, content: "")
5748
}
5849
}
5950
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import Foundation
2+
3+
public actor EmptyChatGPTMemory: ChatGPTMemory {
4+
public var messages: [ChatMessage] = []
5+
public var remainingTokens: Int? { nil }
6+
7+
public init() {}
8+
9+
public func mutateHistory(_ update: (inout [ChatMessage]) -> Void) {
10+
update(&messages)
11+
}
12+
}
13+

0 commit comments

Comments
 (0)