Skip to content

Commit 0de4dae

Browse files
committed
New method to report function progress
1 parent 75d3f26 commit 0de4dae

File tree

5 files changed

+148
-94
lines changed

5 files changed

+148
-94
lines changed
Lines changed: 92 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,36 @@
11
import Foundation
2+
import LangChain
23
import OpenAIService
34
import Preferences
45

56
struct QueryWebsiteFunction: ChatGPTFunction {
67
struct Arguments: Codable {
78
var query: String
8-
var urlString: String
9+
var urls: [String]
910
}
10-
11+
1112
struct Result: ChatGPTFunctionResult {
12-
var relevantTrunks: [String]
13-
13+
var relevantDocuments: [Document]
14+
1415
var botReadableContent: String {
1516
// don't forget to remove overlaps
16-
return ""
17+
if relevantDocuments.isEmpty {
18+
return "No relevant information found"
19+
}
20+
return relevantDocuments.map(\.pageContent).joined(separator: "\n\n")
1721
}
1822
}
19-
23+
24+
var reportProgress: (String) async -> Void = { _ in }
25+
2026
var name: String {
2127
"queryWebsite"
2228
}
23-
29+
2430
var description: String {
2531
"Useful for when you need to answer a question using information from a website."
2632
}
27-
33+
2834
var argumentSchema: JSONSchemaValue {
2935
return [
3036
.type: "object",
@@ -33,26 +39,88 @@ struct QueryWebsiteFunction: ChatGPTFunction {
3339
.type: "string",
3440
.description: "things you want to know about the website",
3541
],
36-
"urlString": [
37-
.type: "string",
38-
.description: "the url of the website"
39-
]
42+
"urls": [
43+
.type: "array",
44+
.description: "urls of the website, you can use urls appearing in the conversation",
45+
.items: [
46+
.type: "string",
47+
],
48+
],
4049
],
41-
.required: ["query", "urlString"]
50+
.required: ["query", "urls"],
4251
]
4352
}
44-
45-
func message(at phase: OpenAIService.ChatGPTFunctionCallPhase) -> String {
46-
return ""
53+
54+
func prepare() async {
55+
await reportProgress("Reading..")
4756
}
48-
57+
4958
func call(arguments: Arguments) async throws -> Result {
50-
// 1. grab the website content
51-
// 2. trunk the content
52-
// 3. embedding and store in memory
53-
// 4. embedding on the query, then search for relevant trunks, choose the 3 most relevant
54-
// 5. return the thunks
55-
56-
return .init(relevantTrunks: [])
59+
do {
60+
let embedding = OpenAIEmbedding(
61+
configuration: UserPreferenceEmbeddingConfiguration()
62+
)
63+
64+
let queryEmbeddings = try await embedding.embed(query: arguments.query)
65+
let searchCount = UserDefaults.shared.value(for: \.chatGPTMaxToken) > 5000 ? 3 : 2
66+
67+
let result = try await withThrowingTaskGroup(
68+
of: [(document: Document, distance: Float)].self
69+
) { group in
70+
for urlString in arguments.urls {
71+
guard let url = URL(string: urlString) else { continue }
72+
group.addTask {
73+
if let database = await TemporaryUSearch.view(identifier: urlString) {
74+
return try await database.searchWithDistance(
75+
embeddings: queryEmbeddings,
76+
count: searchCount
77+
)
78+
}
79+
// 1. grab the website content
80+
await reportProgress("Loading \(url)..")
81+
print("== load \(url)")
82+
let loader = WebLoader(urls: [url])
83+
let documents = try await loader.load()
84+
await reportProgress("Processing \(url)..")
85+
print("== loaded \(url), documents: \(documents.count)")
86+
// 2. split the content
87+
let splitter = RecursiveCharacterTextSplitter(
88+
chunkSize: 1000,
89+
chunkOverlap: 100
90+
)
91+
let splitDocuments = try await splitter.transformDocuments(documents)
92+
print("== split \(url), documents: \(splitDocuments.count)")
93+
// 3. embedding and store in db
94+
await reportProgress("Embedding \(url)..")
95+
let embeddedDocuments = try await embedding.embed(documents: splitDocuments)
96+
print("== embedded \(url)")
97+
let database = TemporaryUSearch(identifier: urlString)
98+
try await database.set(embeddedDocuments)
99+
print("== save to database \(url)")
100+
let result = try await database.searchWithDistance(
101+
embeddings: queryEmbeddings,
102+
count: searchCount
103+
)
104+
print("== result of \(url): \(result)")
105+
return result
106+
}
107+
}
108+
109+
var all = [(document: Document, distance: Float)]()
110+
for try await result in group {
111+
all.append(contentsOf: result)
112+
}
113+
await reportProgress("Finish reading websites.")
114+
return all
115+
.sorted { $0.distance < $1.distance }
116+
.prefix(searchCount)
117+
}
118+
119+
return .init(relevantDocuments: result.map(\.document))
120+
} catch {
121+
await reportProgress("Failed reading websites.")
122+
throw error
123+
}
57124
}
58125
}
126+

Core/Sources/ChatContextCollectors/WebChatContextCollector/SearchFunction.swift

Lines changed: 30 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ struct SearchFunction: ChatGPTFunction {
2929
}
3030
}
3131

32+
var reportProgress: (String) async -> Void = { _ in }
33+
3234
var name: String {
3335
"searchWeb"
3436
}
@@ -58,60 +60,38 @@ struct SearchFunction: ChatGPTFunction {
5860
]
5961
}
6062

61-
func message(at phase: ChatGPTFunctionCallPhase) -> String {
62-
func parseArgument(_ string: String) throws -> Arguments {
63-
try JSONDecoder().decode(Arguments.self, from: string.data(using: .utf8) ?? Data())
64-
}
65-
66-
switch phase {
67-
case .detected:
68-
return "Searching.."
69-
case let .processing(argumentsJsonString):
70-
do {
71-
let arguments = try parseArgument(argumentsJsonString)
72-
return "Searching \(arguments.query)"
73-
} catch {
74-
return "Searching.."
75-
}
76-
case let .ended(argumentsJsonString, result):
77-
do {
78-
let arguments = try parseArgument(argumentsJsonString)
79-
if let result = result as? Result {
80-
return """
81-
Finish searching \(arguments.query)
82-
\(
83-
result.result.webPages.value
84-
.map { "- [\($0.name)](\($0.url))" }
85-
.joined(separator: "\n")
86-
)
87-
"""
88-
}
89-
return "Finish searching \(arguments.query)"
90-
} catch {
91-
return "Finish searching"
92-
}
93-
case let .error(argumentsJsonString, _):
94-
do {
95-
let arguments = try parseArgument(argumentsJsonString)
96-
return "Failed searching \(arguments.query)"
97-
} catch {
98-
return "Failed searching"
99-
}
100-
}
63+
func prepare() async {
64+
await reportProgress("Searching..")
10165
}
10266

10367
func call(arguments: Arguments) async throws -> Result {
104-
let bingSearch = BingSearchService(
105-
subscriptionKey: UserDefaults.shared.value(for: \.bingSearchSubscriptionKey),
106-
searchURL: UserDefaults.shared.value(for: \.bingSearchEndpoint)
107-
)
108-
let result = try await bingSearch.search(
109-
query: arguments.query,
110-
numberOfResult: UserDefaults.shared.value(for: \.chatGPTMaxToken) > 5000 ? 5 : 3,
111-
freshness: arguments.freshness
112-
)
68+
await reportProgress("Searching \(arguments.query)")
11369

114-
return .init(result: result)
70+
do {
71+
let bingSearch = BingSearchService(
72+
subscriptionKey: UserDefaults.shared.value(for: \.bingSearchSubscriptionKey),
73+
searchURL: UserDefaults.shared.value(for: \.bingSearchEndpoint)
74+
)
75+
let result = try await bingSearch.search(
76+
query: arguments.query,
77+
numberOfResult: UserDefaults.shared.value(for: \.chatGPTMaxToken) > 5000 ? 5 : 3,
78+
freshness: arguments.freshness
79+
)
80+
81+
await reportProgress("""
82+
Finish searching \(arguments.query)
83+
\(
84+
result.webPages.value
85+
.map { "- [\($0.name)](\($0.url))" }
86+
.joined(separator: "\n")
87+
)
88+
""")
89+
90+
return .init(result: result)
91+
} catch {
92+
await reportProgress("Failed searching: \(error.localizedDescription)")
93+
throw error
94+
}
11595
}
11696
}
11797

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -331,15 +331,20 @@ extension ChatGPTService {
331331
/// When a function call is detected, but arguments are not yet ready, we can call this
332332
/// to insert a message placeholder in memory.
333333
func prepareFunctionCall(_ call: ChatMessage.FunctionCall, messageId: String) async {
334-
guard let function = functionProvider.function(named: call.name) else { return }
334+
guard var function = functionProvider.function(named: call.name) else { return }
335335
let responseMessage = ChatMessage(
336336
id: messageId,
337337
role: .function,
338338
content: nil,
339-
name: call.name,
340-
summary: function.message(at: .detected)
339+
name: call.name
341340
)
342341
await memory.appendMessage(responseMessage)
342+
function.reportProgress = { [weak self] summary in
343+
await self?.memory.updateMessage(id: messageId) { message in
344+
message.summary = summary
345+
}
346+
}
347+
await function.prepare()
343348
}
344349

345350
/// Run a function call from the bot, and insert the result in memory.
@@ -350,7 +355,7 @@ extension ChatGPTService {
350355
) async -> String {
351356
let messageId = messageId ?? uuidGenerator()
352357

353-
guard let function = functionProvider.function(named: call.name) else {
358+
guard var function = functionProvider.function(named: call.name) else {
354359
let content = "Error: function not found"
355360
let responseMessage = ChatMessage(
356361
id: messageId,
@@ -368,34 +373,31 @@ extension ChatGPTService {
368373
id: messageId,
369374
role: .function,
370375
content: nil,
371-
name: call.name,
372-
summary: function.message(at: .processing(argumentsJsonString: call.arguments))
376+
name: call.name
373377
)
378+
374379
await memory.appendMessage(responseMessage)
375380

381+
function.reportProgress = { [weak self] summary in
382+
await self?.memory.updateMessage(id: messageId) { message in
383+
message.summary = summary
384+
}
385+
}
386+
376387
do {
377388
// Run the function
378-
let result = try await function
379-
.call(argumentsJsonString: call.arguments)
389+
let result = try await function.call(argumentsJsonString: call.arguments)
380390

381-
// Update the message to display the finish state of the function.
382391
await memory.updateMessage(id: messageId) { message in
383392
message.content = result.botReadableContent
384-
message.summary = function.message(at: .ended(
385-
argumentsJsonString: call.arguments,
386-
result: result
387-
))
388393
}
394+
389395
return result.botReadableContent
390396
} catch {
391397
// For errors, use the error message as the result.
392398
let content = "Error: \(error.localizedDescription)"
393399
await memory.updateMessage(id: messageId) { message in
394400
message.content = content
395-
message.summary = function.message(at: .error(
396-
argumentsJsonString: call.arguments,
397-
result: error
398-
))
399401
}
400402
return content
401403
}

Tool/Sources/OpenAIService/FucntionCall/ChatGPTFunction.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ public protocol ChatGPTFunction {
2626
var description: String { get }
2727
/// The arguments schema that the function take in [JSON schema](https://json-schema.org).
2828
var argumentSchema: JSONSchemaValue { get }
29+
/// Prepare to call the function
30+
func prepare() async
2931
/// Call the function with the given arguments.
3032
func call(arguments: Arguments) async throws -> Result
3133
/// The message to present in different phases.
32-
func message(at phase: ChatGPTFunctionCallPhase) -> String
34+
var reportProgress: (String) async -> Void { get set }
3335
}
3436

3537
public extension ChatGPTFunction {

Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,11 @@ extension ChatGPTStreamTests {
347347
.type: ["null"]
348348
]
349349
}
350+
351+
var reportProgress: (String) async -> Void = { print($0) }
350352

351-
func message(at phase: ChatGPTFunctionCallPhase) -> String {
352-
return "running"
353+
func prepare() async {
354+
print("Function will be called")
353355
}
354356

355357
func call(arguments: Parameters) async throws -> String {

0 commit comments

Comments
 (0)