Skip to content

Commit 6349fa4

Browse files
committed
Add function call support
1 parent 99ce70f commit 6349fa4

7 files changed

Lines changed: 284 additions & 69 deletions

File tree

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 199 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ public struct ChatGPTError: Error, Codable, LocalizedError {
5555
public class ChatGPTService: ChatGPTServiceType {
5656
public var memory: ChatGPTMemory
5757
public var configuration: ChatGPTConfiguration
58+
public var functionProvider: ChatGPTFunctionProvider
5859

5960
var uuidGenerator: () -> String = { UUID().uuidString }
6061
var cancelTask: Cancellable?
@@ -66,19 +67,19 @@ public class ChatGPTService: ChatGPTServiceType {
6667
systemPrompt: "",
6768
configuration: UserPreferenceChatGPTConfiguration()
6869
),
69-
configuration: ChatGPTConfiguration = UserPreferenceChatGPTConfiguration()
70+
configuration: ChatGPTConfiguration = UserPreferenceChatGPTConfiguration(),
71+
functionProvider: ChatGPTFunctionProvider = NoChatGPTFunctionProvider()
7072
) {
7173
self.memory = memory
7274
self.configuration = configuration
75+
self.functionProvider = functionProvider
7376
}
7477

78+
/// Send a message and stream the reply.
7579
public func send(
7680
content: String,
7781
summary: String? = nil
7882
) async throws -> AsyncThrowingStream<String, Error> {
79-
guard let url = URL(string: configuration.endpoint)
80-
else { throw ChatGPTServiceError.endpointIncorrect }
81-
8283
if !content.isEmpty || summary != nil {
8384
let newMessage = ChatMessage(
8485
id: uuidGenerator(),
@@ -91,6 +92,93 @@ public class ChatGPTService: ChatGPTServiceType {
9192
await memory.appendMessage(newMessage)
9293
}
9394

95+
return AsyncThrowingStream<String, Error> { continuation in
96+
Task(priority: .userInitiated) {
97+
do {
98+
let stream = try await sendMemory()
99+
var functionCall: ChatMessage.FunctionCall?
100+
var functionCallMessageID = uuidGenerator()
101+
for try await content in stream {
102+
switch content {
103+
case let .text(text):
104+
continuation.yield(text)
105+
case let .functionCall(call):
106+
functionCall = call
107+
await prepareFunctionCall(call, messageId: functionCallMessageID)
108+
}
109+
}
110+
111+
while let call = functionCall {
112+
functionCall = nil
113+
await runFunctionCall(call)
114+
functionCallMessageID = uuidGenerator()
115+
let nextStream = try await sendMemory()
116+
for try await content in nextStream {
117+
switch content {
118+
case let .text(text):
119+
continuation.yield(text)
120+
case let .functionCall(call):
121+
functionCall = call
122+
await prepareFunctionCall(call, messageId: functionCallMessageID)
123+
}
124+
}
125+
}
126+
continuation.finish()
127+
} catch {
128+
continuation.finish(throwing: error)
129+
}
130+
}
131+
}
132+
}
133+
134+
/// Send a message and get the reply in return.
135+
public func sendAndWait(
136+
content: String,
137+
summary: String? = nil
138+
) async throws -> String? {
139+
if !content.isEmpty || summary != nil {
140+
let newMessage = ChatMessage(
141+
id: uuidGenerator(),
142+
role: .user,
143+
content: content,
144+
summary: summary
145+
)
146+
await memory.appendMessage(newMessage)
147+
}
148+
149+
let message = try await sendMemoryAndWait()
150+
var finalResult = message?.content
151+
var functionCall = message?.functionCall
152+
while let call = functionCall {
153+
functionCall = nil
154+
await runFunctionCall(call)
155+
guard let nextMessage = try await sendMemoryAndWait() else { break }
156+
finalResult = nextMessage.content
157+
functionCall = nextMessage.functionCall
158+
}
159+
160+
return finalResult
161+
}
162+
163+
public func stopReceivingMessage() {
164+
cancelTask?()
165+
cancelTask = nil
166+
}
167+
}
168+
169+
// - MARK: Internal
170+
171+
extension ChatGPTService {
172+
enum StreamContent {
173+
case text(String)
174+
case functionCall(ChatMessage.FunctionCall)
175+
}
176+
177+
/// Send the memory as prompt to ChatGPT, with stream enabled.
178+
func sendMemory() async throws -> AsyncThrowingStream<StreamContent, Error> {
179+
guard let url = URL(string: configuration.endpoint)
180+
else { throw ChatGPTServiceError.endpointIncorrect }
181+
94182
let messages = await memory.messages.map {
95183
CompletionRequestBody.Message(role: $0.role, content: $0.content)
96184
}
@@ -107,7 +195,7 @@ public class ChatGPTService: ChatGPTServiceType {
107195
remainingTokens: remainingTokens
108196
),
109197
function_call: nil,
110-
functions: []
198+
functions: functionProvider.functionSchemas
111199
)
112200

113201
let api = buildCompletionStreamAPI(
@@ -117,51 +205,41 @@ public class ChatGPTService: ChatGPTServiceType {
117205
requestBody
118206
)
119207

120-
return AsyncThrowingStream<String, Error> { continuation in
208+
return AsyncThrowingStream<StreamContent, Error> { continuation in
121209
Task {
122210
do {
123211
let (trunks, cancel) = try await api()
124212
cancelTask = cancel
125-
var id = ""
126-
var functionCallRawString = ""
127213
for try await trunk in trunks {
128-
id = trunk.id
129-
130214
guard let delta = trunk.choices.first?.delta else { continue }
131215

216+
// The api will always return a function call with correct JSON format.
217+
// The first round will contain the function name and an empty argument.
218+
// e.g. {"name":"weather","arguments":""}
219+
let functionCall: ChatMessage.FunctionCall? = delta.function_call.flatMap {
220+
guard let data = $0.data(using: .utf8) else { return nil }
221+
return try? JSONDecoder()
222+
.decode(ChatMessage.FunctionCall.self, from: data)
223+
}
224+
132225
await memory.streamMessage(
133226
id: trunk.id,
134227
role: delta.role,
135228
content: delta.content,
136-
functionCall: nil
229+
functionCall: functionCall
137230
)
138231

139-
if let call = delta.function_call {
140-
functionCallRawString.append(call)
232+
if let functionCall {
233+
continuation.yield(.functionCall(functionCall))
141234
}
142235

143236
if let content = delta.content {
144-
continuation.yield(content)
237+
continuation.yield(.text(content))
145238
}
146239

147240
try await Task.sleep(nanoseconds: 3_000_000)
148241
}
149242

150-
if !functionCallRawString.isEmpty,
151-
let data = functionCallRawString.data(using: .utf8)
152-
{
153-
let function = try JSONDecoder().decode(
154-
ChatMessage.FunctionCall.self,
155-
from: data
156-
)
157-
await memory.streamMessage(
158-
id: id,
159-
role: nil,
160-
content: nil,
161-
functionCall: function
162-
)
163-
}
164-
165243
continuation.finish()
166244
} catch let error as CancellationError {
167245
continuation.finish(throwing: error)
@@ -178,25 +256,21 @@ public class ChatGPTService: ChatGPTServiceType {
178256
}
179257
}
180258

181-
public func sendAndWait(
182-
content: String,
183-
summary: String? = nil
184-
) async throws -> String? {
259+
/// Send the memory as prompt to ChatGPT, with stream disabled.
260+
func sendMemoryAndWait() async throws -> ChatMessage? {
185261
guard let url = URL(string: configuration.endpoint)
186262
else { throw ChatGPTServiceError.endpointIncorrect }
187263

188-
if !content.isEmpty || summary != nil {
189-
let newMessage = ChatMessage(
190-
id: uuidGenerator(),
191-
role: .user,
192-
content: content,
193-
summary: summary
194-
)
195-
await memory.appendMessage(newMessage)
196-
}
197-
198264
let messages = await memory.messages.map {
199-
CompletionRequestBody.Message(role: $0.role, content: $0.content)
265+
CompletionRequestBody.Message(
266+
role: $0.role,
267+
content: $0.content,
268+
name: $0.name,
269+
function_call: $0.functionCall.map {
270+
CompletionRequestBody
271+
.MessageFunctionCall(name: $0.name, arguments: $0.arguments)
272+
}
273+
)
200274
}
201275
let remainingTokens = await memory.remainingTokens
202276

@@ -211,7 +285,7 @@ public class ChatGPTService: ChatGPTServiceType {
211285
remainingTokens: remainingTokens
212286
),
213287
function_call: nil,
214-
functions: []
288+
functions: functionProvider.functionSchemas
215289
)
216290

217291
let api = buildCompletionAPI(
@@ -222,22 +296,89 @@ public class ChatGPTService: ChatGPTServiceType {
222296
)
223297
let response = try await api()
224298

225-
if let choice = response.choices.first {
226-
await memory.appendMessage(.init(
227-
id: response.id,
228-
role: choice.message.role,
229-
content: choice.message.content
230-
))
299+
guard let choice = response.choices.first else { return nil }
300+
let message = ChatMessage(
301+
id: response.id,
302+
role: choice.message.role,
303+
content: choice.message.content,
304+
name: choice.message.name,
305+
functionCall: choice.message.function_call.map {
306+
ChatMessage.FunctionCall(name: $0.name, arguments: $0.arguments)
307+
}
308+
)
309+
await memory.appendMessage(message)
310+
return message
311+
}
312+
313+
/// When a function call is detected, but arguments are not yet ready, we can call this
314+
/// to insert a message placeholder in memory.
315+
func prepareFunctionCall(_ call: ChatMessage.FunctionCall, messageId: String) async {
316+
guard let function = functionProvider.function(named: call.name) else { return }
317+
let responseMessage = ChatMessage(
318+
id: messageId,
319+
role: .function,
320+
content: nil,
321+
summary: function.message(at: .detected)
322+
)
323+
await memory.appendMessage(responseMessage)
324+
}
231325

232-
return choice.message.content
326+
/// Run a function call from the bot, and insert the result in memory.
327+
@discardableResult
328+
func runFunctionCall(
329+
_ call: ChatMessage.FunctionCall,
330+
messageId: String? = nil
331+
) async -> String {
332+
let messageId = messageId ?? uuidGenerator()
333+
334+
guard let function = functionProvider.function(named: call.name) else {
335+
let content = "Error: function not found"
336+
let responseMessage = ChatMessage(
337+
id: messageId,
338+
role: .function,
339+
content: content,
340+
summary: "Function `\(call.name)` not found."
341+
)
342+
await memory.appendMessage(responseMessage)
343+
return content
233344
}
234345

235-
return nil
236-
}
237-
238-
public func stopReceivingMessage() {
239-
cancelTask?()
240-
cancelTask = nil
346+
// Insert the chat message into memory to indicate the start of the function.
347+
let responseMessage = ChatMessage(
348+
id: messageId,
349+
role: .function,
350+
content: nil,
351+
summary: function
352+
.message(at: .processing(argumentsJsonString: call.arguments ?? ""))
353+
)
354+
await memory.appendMessage(responseMessage)
355+
356+
do {
357+
// Run the function
358+
let response = try await function
359+
.call(argumentsJsonString: call.arguments ?? "")
360+
361+
// Update the message to display the finish state of the function.
362+
await memory.updateMessage(id: messageId) { message in
363+
message.content = response
364+
message.summary = function.message(at: .ended(
365+
argumentsJsonString: call.arguments ?? "",
366+
result: response
367+
))
368+
}
369+
return response
370+
} catch {
371+
// For errors, use the error message as the result.
372+
let content = "Error: \(error.localizedDescription)"
373+
await memory.updateMessage(id: messageId) { message in
374+
message.content = content
375+
message.summary = function.message(at: .error(
376+
argumentsJsonString: call.arguments ?? "",
377+
result: error
378+
))
379+
}
380+
return content
381+
}
241382
}
242383
}
243384

Tool/Sources/OpenAIService/CompletionStreamAPI.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import AsyncAlgorithms
22
import Foundation
33
import Preferences
4-
import JSONRPC
54

65
typealias CompletionStreamAPIBuilder = (String, ChatFeatureProvider, URL, CompletionRequestBody) -> CompletionStreamAPI
76

@@ -38,7 +37,7 @@ struct CompletionRequestBody: Encodable, Equatable {
3837
/// The name of the
3938
var name: String
4039
/// A JSON string.
41-
var arguments: String
40+
var arguments: String?
4241
}
4342

4443
enum FunctionCallStrategy: Encodable, Equatable {
@@ -85,8 +84,9 @@ struct CompletionRequestBody: Encodable, Equatable {
8584
var frequency_penalty: Double?
8685
var logit_bias: [String: Double]?
8786
var user: String?
87+
/// Pass nil to let the bot decide.
8888
var function_call: FunctionCallStrategy?
89-
var functions: [Int] = []
89+
var functions: [String] = []
9090
}
9191

9292
struct CompletionStreamDataTrunk: Codable {

0 commit comments

Comments
 (0)