@@ -55,6 +55,7 @@ public struct ChatGPTError: Error, Codable, LocalizedError {
5555public class ChatGPTService : ChatGPTServiceType {
5656 public var memory : ChatGPTMemory
5757 public var configuration : ChatGPTConfiguration
58+ public var functionProvider : ChatGPTFunctionProvider
5859
5960 var uuidGenerator : ( ) -> String = { UUID ( ) . uuidString }
6061 var cancelTask : Cancellable ?
@@ -66,19 +67,19 @@ public class ChatGPTService: ChatGPTServiceType {
6667 systemPrompt: " " ,
6768 configuration: UserPreferenceChatGPTConfiguration ( )
6869 ) ,
69- configuration: ChatGPTConfiguration = UserPreferenceChatGPTConfiguration ( )
70+ configuration: ChatGPTConfiguration = UserPreferenceChatGPTConfiguration ( ) ,
71+ functionProvider: ChatGPTFunctionProvider = NoChatGPTFunctionProvider ( )
7072 ) {
7173 self . memory = memory
7274 self . configuration = configuration
75+ self . functionProvider = functionProvider
7376 }
7477
78+ /// Send a message and stream the reply.
7579 public func send(
7680 content: String ,
7781 summary: String ? = nil
7882 ) async throws -> AsyncThrowingStream < String , Error > {
79- guard let url = URL ( string: configuration. endpoint)
80- else { throw ChatGPTServiceError . endpointIncorrect }
81-
8283 if !content. isEmpty || summary != nil {
8384 let newMessage = ChatMessage (
8485 id: uuidGenerator ( ) ,
@@ -91,6 +92,93 @@ public class ChatGPTService: ChatGPTServiceType {
9192 await memory. appendMessage ( newMessage)
9293 }
9394
95+ return AsyncThrowingStream < String , Error > { continuation in
96+ Task ( priority: . userInitiated) {
97+ do {
98+ let stream = try await sendMemory ( )
99+ var functionCall : ChatMessage . FunctionCall ?
100+ var functionCallMessageID = uuidGenerator ( )
101+ for try await content in stream {
102+ switch content {
103+ case let . text( text) :
104+ continuation. yield ( text)
105+ case let . functionCall( call) :
106+ functionCall = call
107+ await prepareFunctionCall ( call, messageId: functionCallMessageID)
108+ }
109+ }
110+
111+ while let call = functionCall {
112+ functionCall = nil
113+ await runFunctionCall ( call)
114+ functionCallMessageID = uuidGenerator ( )
115+ let nextStream = try await sendMemory ( )
116+ for try await content in nextStream {
117+ switch content {
118+ case let . text( text) :
119+ continuation. yield ( text)
120+ case let . functionCall( call) :
121+ functionCall = call
122+ await prepareFunctionCall ( call, messageId: functionCallMessageID)
123+ }
124+ }
125+ }
126+ continuation. finish ( )
127+ } catch {
128+ continuation. finish ( throwing: error)
129+ }
130+ }
131+ }
132+ }
133+
134+ /// Send a message and get the reply in return.
135+ public func sendAndWait(
136+ content: String ,
137+ summary: String ? = nil
138+ ) async throws -> String ? {
139+ if !content. isEmpty || summary != nil {
140+ let newMessage = ChatMessage (
141+ id: uuidGenerator ( ) ,
142+ role: . user,
143+ content: content,
144+ summary: summary
145+ )
146+ await memory. appendMessage ( newMessage)
147+ }
148+
149+ let message = try await sendMemoryAndWait ( )
150+ var finalResult = message? . content
151+ var functionCall = message? . functionCall
152+ while let call = functionCall {
153+ functionCall = nil
154+ await runFunctionCall ( call)
155+ guard let nextMessage = try await sendMemoryAndWait ( ) else { break }
156+ finalResult = nextMessage. content
157+ functionCall = nextMessage. functionCall
158+ }
159+
160+ return finalResult
161+ }
162+
163+ public func stopReceivingMessage( ) {
164+ cancelTask ? ( )
165+ cancelTask = nil
166+ }
167+ }
168+
169+ // - MARK: Internal
170+
171+ extension ChatGPTService {
172+ enum StreamContent {
173+ case text( String )
174+ case functionCall( ChatMessage . FunctionCall )
175+ }
176+
177+ /// Send the memory as prompt to ChatGPT, with stream enabled.
178+ func sendMemory( ) async throws -> AsyncThrowingStream < StreamContent , Error > {
179+ guard let url = URL ( string: configuration. endpoint)
180+ else { throw ChatGPTServiceError . endpointIncorrect }
181+
94182 let messages = await memory. messages. map {
95183 CompletionRequestBody . Message ( role: $0. role, content: $0. content)
96184 }
@@ -107,7 +195,7 @@ public class ChatGPTService: ChatGPTServiceType {
107195 remainingTokens: remainingTokens
108196 ) ,
109197 function_call: nil ,
110- functions: [ ]
198+ functions: functionProvider . functionSchemas
111199 )
112200
113201 let api = buildCompletionStreamAPI (
@@ -117,51 +205,41 @@ public class ChatGPTService: ChatGPTServiceType {
117205 requestBody
118206 )
119207
120- return AsyncThrowingStream < String , Error > { continuation in
208+ return AsyncThrowingStream < StreamContent , Error > { continuation in
121209 Task {
122210 do {
123211 let ( trunks, cancel) = try await api ( )
124212 cancelTask = cancel
125- var id = " "
126- var functionCallRawString = " "
127213 for try await trunk in trunks {
128- id = trunk. id
129-
130214 guard let delta = trunk. choices. first? . delta else { continue }
131215
216+ // The api will always return a function call with correct JSON format.
217+ // The first round will contain the function name and an empty argument.
218+ // e.g. {"name":"weather","arguments":""}
219+ let functionCall : ChatMessage . FunctionCall ? = delta. function_call. flatMap {
220+ guard let data = $0. data ( using: . utf8) else { return nil }
221+ return try ? JSONDecoder ( )
222+ . decode ( ChatMessage . FunctionCall. self, from: data)
223+ }
224+
132225 await memory. streamMessage (
133226 id: trunk. id,
134227 role: delta. role,
135228 content: delta. content,
136- functionCall: nil
229+ functionCall: functionCall
137230 )
138231
139- if let call = delta . function_call {
140- functionCallRawString . append ( call )
232+ if let functionCall {
233+ continuation . yield ( . functionCall ( functionCall ) )
141234 }
142235
143236 if let content = delta. content {
144- continuation. yield ( content)
237+ continuation. yield ( . text ( content) )
145238 }
146239
147240 try await Task . sleep ( nanoseconds: 3_000_000 )
148241 }
149242
150- if !functionCallRawString. isEmpty,
151- let data = functionCallRawString. data ( using: . utf8)
152- {
153- let function = try JSONDecoder ( ) . decode (
154- ChatMessage . FunctionCall. self,
155- from: data
156- )
157- await memory. streamMessage (
158- id: id,
159- role: nil ,
160- content: nil ,
161- functionCall: function
162- )
163- }
164-
165243 continuation. finish ( )
166244 } catch let error as CancellationError {
167245 continuation. finish ( throwing: error)
@@ -178,25 +256,21 @@ public class ChatGPTService: ChatGPTServiceType {
178256 }
179257 }
180258
181- public func sendAndWait(
182- content: String ,
183- summary: String ? = nil
184- ) async throws -> String ? {
259+ /// Send the memory as prompt to ChatGPT, with stream disabled.
260+ func sendMemoryAndWait( ) async throws -> ChatMessage ? {
185261 guard let url = URL ( string: configuration. endpoint)
186262 else { throw ChatGPTServiceError . endpointIncorrect }
187263
188- if !content. isEmpty || summary != nil {
189- let newMessage = ChatMessage (
190- id: uuidGenerator ( ) ,
191- role: . user,
192- content: content,
193- summary: summary
194- )
195- await memory. appendMessage ( newMessage)
196- }
197-
198264 let messages = await memory. messages. map {
199- CompletionRequestBody . Message ( role: $0. role, content: $0. content)
265+ CompletionRequestBody . Message (
266+ role: $0. role,
267+ content: $0. content,
268+ name: $0. name,
269+ function_call: $0. functionCall. map {
270+ CompletionRequestBody
271+ . MessageFunctionCall ( name: $0. name, arguments: $0. arguments)
272+ }
273+ )
200274 }
201275 let remainingTokens = await memory. remainingTokens
202276
@@ -211,7 +285,7 @@ public class ChatGPTService: ChatGPTServiceType {
211285 remainingTokens: remainingTokens
212286 ) ,
213287 function_call: nil ,
214- functions: [ ]
288+ functions: functionProvider . functionSchemas
215289 )
216290
217291 let api = buildCompletionAPI (
@@ -222,22 +296,89 @@ public class ChatGPTService: ChatGPTServiceType {
222296 )
223297 let response = try await api ( )
224298
225- if let choice = response. choices. first {
226- await memory. appendMessage ( . init(
227- id: response. id,
228- role: choice. message. role,
229- content: choice. message. content
230- ) )
299+ guard let choice = response. choices. first else { return nil }
300+ let message = ChatMessage (
301+ id: response. id,
302+ role: choice. message. role,
303+ content: choice. message. content,
304+ name: choice. message. name,
305+ functionCall: choice. message. function_call. map {
306+ ChatMessage . FunctionCall ( name: $0. name, arguments: $0. arguments)
307+ }
308+ )
309+ await memory. appendMessage ( message)
310+ return message
311+ }
312+
313+ /// When a function call is detected, but arguments are not yet ready, we can call this
314+ /// to insert a message placeholder in memory.
315+ func prepareFunctionCall( _ call: ChatMessage . FunctionCall , messageId: String ) async {
316+ guard let function = functionProvider. function ( named: call. name) else { return }
317+ let responseMessage = ChatMessage (
318+ id: messageId,
319+ role: . function,
320+ content: nil ,
321+ summary: function. message ( at: . detected)
322+ )
323+ await memory. appendMessage ( responseMessage)
324+ }
231325
232- return choice. message. content
326+ /// Run a function call from the bot, and insert the result in memory.
327+ @discardableResult
328+ func runFunctionCall(
329+ _ call: ChatMessage . FunctionCall ,
330+ messageId: String ? = nil
331+ ) async -> String {
332+ let messageId = messageId ?? uuidGenerator ( )
333+
334+ guard let function = functionProvider. function ( named: call. name) else {
335+ let content = " Error: function not found "
336+ let responseMessage = ChatMessage (
337+ id: messageId,
338+ role: . function,
339+ content: content,
340+ summary: " Function ` \( call. name) ` not found. "
341+ )
342+ await memory. appendMessage ( responseMessage)
343+ return content
233344 }
234345
235- return nil
236- }
237-
238- public func stopReceivingMessage( ) {
239- cancelTask ? ( )
240- cancelTask = nil
346+ // Insert the chat message into memory to indicate the start of the function.
347+ let responseMessage = ChatMessage (
348+ id: messageId,
349+ role: . function,
350+ content: nil ,
351+ summary: function
352+ . message ( at: . processing( argumentsJsonString: call. arguments ?? " " ) )
353+ )
354+ await memory. appendMessage ( responseMessage)
355+
356+ do {
357+ // Run the function
358+ let response = try await function
359+ . call ( argumentsJsonString: call. arguments ?? " " )
360+
361+ // Update the message to display the finish state of the function.
362+ await memory. updateMessage ( id: messageId) { message in
363+ message. content = response
364+ message. summary = function. message ( at: . ended(
365+ argumentsJsonString: call. arguments ?? " " ,
366+ result: response
367+ ) )
368+ }
369+ return response
370+ } catch {
371+ // For errors, use the error message as the result.
372+ let content = " Error: \( error. localizedDescription) "
373+ await memory. updateMessage ( id: messageId) { message in
374+ message. content = content
375+ message. summary = function. message ( at: . error(
376+ argumentsJsonString: call. arguments ?? " " ,
377+ result: error
378+ ) )
379+ }
380+ return content
381+ }
241382 }
242383}
243384
0 commit comments