Skip to content

Commit 324b918

Browse files
committed
Support function call
1 parent 8a3a275 commit 324b918

File tree

19 files changed

+681
-171
lines changed

19 files changed

+681
-171
lines changed

Core/Sources/ChatPlugin/AskChatGPT.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ public func askChatGPT(
1212
let memory = AutoManagedChatGPTMemory(systemPrompt: systemPrompt, configuration: configuration)
1313
let service = ChatGPTService(
1414
memory: memory,
15-
configuration: UserPreferenceChatGPTConfiguration()
16-
.overriding(.init(temperature: temperature))
15+
configuration: configuration
1716
)
1817
return try await service.sendAndWait(content: question)
1918
}

Core/Sources/ChatPlugins/ShortcutChatPlugin/ShortcutInputChatPlugin.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,17 @@ public actor ShortcutInputChatPlugin: ChatPlugin {
8888
if let text = String(data: data, encoding: .utf8) {
8989
if text.isEmpty { return }
9090
let stream = try await chatGPTService.send(content: text, summary: nil)
91-
for try await _ in stream {}
91+
do {
92+
for try await _ in stream {}
93+
} catch {}
9294
} else {
9395
let text = """
9496
[View File](\(temporaryOutputFileURL))
9597
"""
9698
let stream = try await chatGPTService.send(content: text, summary: nil)
97-
for try await _ in stream {}
99+
do {
100+
for try await _ in stream {}
101+
} catch {}
98102
}
99103

100104
return

Core/Sources/ChatService/ChatService.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,13 @@ public final class ChatService: ObservableObject {
6060
\(extraSystemPrompt)
6161
""", content: content)
6262

63+
let stream = try await chatGPTService.send(content: content, summary: nil)
64+
isReceivingMessage = true
6365
do {
64-
let stream = try await chatGPTService.send(content: content, summary: nil)
65-
isReceivingMessage = true
6666
for try await _ in stream {}
6767
isReceivingMessage = false
6868
} catch {
6969
isReceivingMessage = false
70-
throw error
7170
}
7271
}
7372

Core/Sources/HostApp/AccountSettings/AzureView.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,8 @@ struct AzureView: View {
4646
do {
4747
let reply =
4848
try await ChatGPTService(
49-
configuration: OverridingUserPreferenceChatGPTConfiguration(
50-
overriding: .init(featureProvider: .azureOpenAI)
51-
)
49+
configuration: UserPreferenceChatGPTConfiguration()
50+
.overriding(.init(featureProvider: .azureOpenAI))
5251
)
5352
.sendAndWait(content: "Hello", summary: nil)
5453
toast(Text("ChatGPT replied: \(reply ?? "N/A")"), .info)

Core/Sources/HostApp/AccountSettings/OpenAIView.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ struct OpenAIView: View {
5151
do {
5252
let reply =
5353
try await ChatGPTService(
54-
configuration: OverridingUserPreferenceChatGPTConfiguration(
55-
overriding: .init(featureProvider: .openAI)
56-
)
54+
configuration: UserPreferenceChatGPTConfiguration()
55+
.overriding(.init(featureProvider: .openAI))
5756
)
5857
.sendAndWait(content: "Hello", summary: nil)
5958
toast(Text("ChatGPT replied: \(reply ?? "N/A")"), .info)

Tool/Package.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ let package = Package(
1919
.package(url: "https://github.com/alfianlosari/GPTEncoder", from: "1.0.4"),
2020
.package(url: "https://github.com/apple/swift-async-algorithms", from: "0.1.0"),
2121
.package(url: "https://github.com/pointfreeco/swift-parsing", from: "0.12.1"),
22+
.package(url: "https://github.com/ChimeHQ/JSONRPC", from: "0.6.0"),
2223
],
2324
targets: [
2425
// MARK: - Helpers
@@ -61,6 +62,7 @@ let package = Package(
6162
"GPTEncoder",
6263
"Logger",
6364
"Preferences",
65+
.product(name: "JSONRPC", package: "JSONRPC"),
6466
.product(name: "AsyncAlgorithms", package: "swift-async-algorithms"),
6567
]
6668
),

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -95,30 +95,28 @@ public class ChatGPTService: ChatGPTServiceType {
9595
return AsyncThrowingStream<String, Error> { continuation in
9696
Task(priority: .userInitiated) {
9797
do {
98-
let stream = try await sendMemory()
9998
var functionCall: ChatMessage.FunctionCall?
100-
var functionCallMessageID = uuidGenerator()
101-
for try await content in stream {
102-
switch content {
103-
case let .text(text):
104-
continuation.yield(text)
105-
case let .functionCall(call):
106-
functionCall = call
107-
await prepareFunctionCall(call, messageId: functionCallMessageID)
99+
var functionCallMessageID = ""
100+
var isInitialCall = true
101+
while functionCall != nil || isInitialCall {
102+
isInitialCall = false
103+
if let call = functionCall {
104+
functionCall = nil
105+
await runFunctionCall(call, messageId: functionCallMessageID)
108106
}
109-
}
110-
111-
while let call = functionCall {
112-
functionCall = nil
113-
await runFunctionCall(call)
114-
functionCallMessageID = uuidGenerator()
115-
let nextStream = try await sendMemory()
116-
for try await content in nextStream {
107+
let stream = try await sendMemory()
108+
for try await content in stream {
117109
switch content {
118110
case let .text(text):
119111
continuation.yield(text)
120112
case let .functionCall(call):
121-
functionCall = call
113+
if functionCall == nil {
114+
functionCallMessageID = uuidGenerator()
115+
functionCall = call
116+
} else {
117+
functionCall?.name.append(call.name)
118+
functionCall?.arguments.append(call.arguments)
119+
}
122120
await prepareFunctionCall(call, messageId: functionCallMessageID)
123121
}
124122
}
@@ -180,7 +178,14 @@ extension ChatGPTService {
180178
else { throw ChatGPTServiceError.endpointIncorrect }
181179

182180
let messages = await memory.messages.map {
183-
CompletionRequestBody.Message(role: $0.role, content: $0.content)
181+
CompletionRequestBody.Message(
182+
role: $0.role,
183+
content: $0.content ?? "",
184+
name: $0.name,
185+
function_call: $0.functionCall.map {
186+
.init(name: $0.name, arguments: $0.arguments)
187+
}
188+
)
184189
}
185190
let remainingTokens = await memory.remainingTokens
186191

@@ -195,7 +200,13 @@ extension ChatGPTService {
195200
remainingTokens: remainingTokens
196201
),
197202
function_call: nil,
198-
functions: functionProvider.functionSchemas
203+
functions: functionProvider.functions.map {
204+
ChatGPTFunctionSchema(
205+
name: $0.name,
206+
description: $0.description,
207+
parameters: $0.argumentSchema
208+
)
209+
}
199210
)
200211

201212
let api = buildCompletionStreamAPI(
@@ -213,13 +224,15 @@ extension ChatGPTService {
213224
for try await trunk in trunks {
214225
guard let delta = trunk.choices.first?.delta else { continue }
215226

216-
// The api will always return a function call with correct JSON format.
227+
// The api will always return a function call with JSON object.
217228
// The first round will contain the function name and an empty argument.
218229
// e.g. {"name":"weather","arguments":""}
219-
let functionCall: ChatMessage.FunctionCall? = delta.function_call.flatMap {
220-
guard let data = $0.data(using: .utf8) else { return nil }
221-
return try? JSONDecoder()
222-
.decode(ChatMessage.FunctionCall.self, from: data)
230+
// The other rounds will contain part of the arguments.
231+
let functionCall = delta.function_call.map {
232+
ChatMessage.FunctionCall(
233+
name: $0.name ?? "",
234+
arguments: $0.arguments ?? ""
235+
)
223236
}
224237

225238
await memory.streamMessage(
@@ -264,11 +277,10 @@ extension ChatGPTService {
264277
let messages = await memory.messages.map {
265278
CompletionRequestBody.Message(
266279
role: $0.role,
267-
content: $0.content,
280+
content: $0.content ?? "",
268281
name: $0.name,
269282
function_call: $0.functionCall.map {
270-
CompletionRequestBody
271-
.MessageFunctionCall(name: $0.name, arguments: $0.arguments)
283+
.init(name: $0.name, arguments: $0.arguments)
272284
}
273285
)
274286
}
@@ -285,7 +297,13 @@ extension ChatGPTService {
285297
remainingTokens: remainingTokens
286298
),
287299
function_call: nil,
288-
functions: functionProvider.functionSchemas
300+
functions: functionProvider.functions.map {
301+
ChatGPTFunctionSchema(
302+
name: $0.name,
303+
description: $0.description,
304+
parameters: $0.argumentSchema
305+
)
306+
}
289307
)
290308

291309
let api = buildCompletionAPI(
@@ -303,13 +321,13 @@ extension ChatGPTService {
303321
content: choice.message.content,
304322
name: choice.message.name,
305323
functionCall: choice.message.function_call.map {
306-
ChatMessage.FunctionCall(name: $0.name, arguments: $0.arguments)
324+
ChatMessage.FunctionCall(name: $0.name, arguments: $0.arguments ?? "")
307325
}
308326
)
309327
await memory.appendMessage(message)
310328
return message
311329
}
312-
330+
313331
/// When a function call is detected, but arguments are not yet ready, we can call this
314332
/// to insert a message placeholder in memory.
315333
func prepareFunctionCall(_ call: ChatMessage.FunctionCall, messageId: String) async {
@@ -318,6 +336,7 @@ extension ChatGPTService {
318336
id: messageId,
319337
role: .function,
320338
content: nil,
339+
name: call.name,
321340
summary: function.message(at: .detected)
322341
)
323342
await memory.appendMessage(responseMessage)
@@ -348,32 +367,32 @@ extension ChatGPTService {
348367
id: messageId,
349368
role: .function,
350369
content: nil,
351-
summary: function
352-
.message(at: .processing(argumentsJsonString: call.arguments ?? ""))
370+
name: call.name,
371+
summary: function.message(at: .processing(argumentsJsonString: call.arguments))
353372
)
354373
await memory.appendMessage(responseMessage)
355374

356375
do {
357376
// Run the function
358-
let response = try await function
359-
.call(argumentsJsonString: call.arguments ?? "")
377+
let result = try await function
378+
.call(argumentsJsonString: call.arguments)
360379

361380
// Update the message to display the finish state of the function.
362381
await memory.updateMessage(id: messageId) { message in
363-
message.content = response
382+
message.content = result.botReadableContent
364383
message.summary = function.message(at: .ended(
365-
argumentsJsonString: call.arguments ?? "",
366-
result: response
384+
argumentsJsonString: call.arguments,
385+
result: result
367386
))
368387
}
369-
return response
388+
return result.botReadableContent
370389
} catch {
371390
// For errors, use the error message as the result.
372391
let content = "Error: \(error.localizedDescription)"
373392
await memory.updateMessage(id: messageId) { message in
374393
message.content = content
375394
message.summary = function.message(at: .error(
376-
argumentsJsonString: call.arguments ?? "",
395+
argumentsJsonString: call.arguments,
377396
result: error
378397
))
379398
}

Tool/Sources/OpenAIService/CompletionStreamAPI.swift

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct CompletionRequestBody: Encodable, Equatable {
1717
/// The role of the message.
1818
var role: ChatMessage.Role
1919
/// The content of the message.
20-
var content: String?
20+
var content: String
2121
/// When we want to reply to a function call with the result, we have to provide the
2222
/// name of the function call, and include the result in `content`.
2323
///
@@ -86,7 +86,7 @@ struct CompletionRequestBody: Encodable, Equatable {
8686
var user: String?
8787
/// Pass nil to let the bot decide.
8888
var function_call: FunctionCallStrategy?
89-
var functions: [String] = []
89+
var functions: [ChatGPTFunctionSchema]? = nil
9090
}
9191

9292
struct CompletionStreamDataTrunk: Codable {
@@ -102,9 +102,14 @@ struct CompletionStreamDataTrunk: Codable {
102102
var finish_reason: String?
103103

104104
struct Delta: Codable {
105+
struct FunctionCall: Codable {
106+
var name: String?
107+
var arguments: String?
108+
}
109+
105110
var role: ChatMessage.Role?
106111
var content: String?
107-
var function_call: String?
112+
var function_call: FunctionCall?
108113
}
109114
}
110115
}

Tool/Sources/OpenAIService/FucntionCall/ChatGPTFunction.swift

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,53 @@ import Foundation
33
public enum ChatGPTFunctionCallPhase {
44
case detected
55
case processing(argumentsJsonString: String)
6-
case ended(argumentsJsonString: String, result: String)
6+
case ended(argumentsJsonString: String, result: ChatGPTFunctionResult)
77
case error(argumentsJsonString: String, result: Error)
88
}
99

10+
public protocol ChatGPTFunctionResult {
11+
var botReadableContent: String { get }
12+
}
13+
14+
extension String: ChatGPTFunctionResult {
15+
public var botReadableContent: String { self }
16+
}
17+
1018
public protocol ChatGPTFunction {
1119
associatedtype Arguments: Decodable
20+
associatedtype Result: ChatGPTFunctionResult
1221

13-
/// The name of the function.
22+
/// The name of this function.
23+
/// May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
1424
var name: String { get }
25+
/// A short description telling the bot when it should use this function.
26+
var description: String { get }
1527
/// The arguments schema that the function take in [JSON schema](https://json-schema.org).
16-
var argumentsSchema: String { get }
28+
var argumentSchema: JSONSchemaValue { get }
1729
/// Call the function with the given arguments.
18-
func call(arguments: Arguments) async throws -> String
30+
func call(arguments: Arguments) async throws -> Result
1931
/// The message to present in different phases.
2032
func message(at phase: ChatGPTFunctionCallPhase) -> String
2133
}
2234

2335
public extension ChatGPTFunction {
2436
/// Call the function with the given arguments in JSON.
25-
func call(argumentsJsonString: String) async throws -> String {
37+
func call(argumentsJsonString: String) async throws -> Result {
2638
let arguments = try JSONDecoder()
2739
.decode(Arguments.self, from: argumentsJsonString.data(using: .utf8) ?? Data())
2840
return try await call(arguments: arguments)
2941
}
3042
}
3143

44+
struct ChatGPTFunctionSchema: Codable, Equatable {
45+
var name: String
46+
var description: String
47+
var parameters: JSONSchemaValue
48+
49+
init(name: String, description: String, parameters: JSONSchemaValue) {
50+
self.name = name
51+
self.description = description
52+
self.parameters = parameters
53+
}
54+
}
55+
Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import Foundation
22

33
public protocol ChatGPTFunctionProvider {
4-
var functionSchemas: [String] { get }
5-
func function(named: String) -> (any ChatGPTFunction)?
4+
var functions: [any ChatGPTFunction] { get }
5+
}
6+
7+
extension ChatGPTFunctionProvider {
8+
func function(named: String) -> (any ChatGPTFunction)? {
9+
functions.first(where: { $0.name == named })
10+
}
611
}
712

813
public struct NoChatGPTFunctionProvider: ChatGPTFunctionProvider {
14+
public var functions: [any ChatGPTFunction] { [] }
915
public init() {}
10-
11-
public var functionSchemas: [String] { [] }
12-
public func function(named: String) -> (any ChatGPTFunction)? { nil }
1316
}

0 commit comments

Comments
 (0)