@@ -220,7 +220,11 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI
220220 ) {
221221 self . apiKey = apiKey
222222 self . endpoint = endpoint
223- self . requestBody = . init( requestBody)
223+ self . requestBody = . init(
224+ requestBody,
225+ enforceMessageOrder: model. info. openAICompatibleInfo. enforceMessageOrder,
226+ canUseTool: model. info. supportsFunctionCalling
227+ )
224228 self . model = model
225229 }
226230
@@ -278,34 +282,38 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI
278282 }
279283
280284 func callAsFunction( ) async throws -> ChatCompletionResponseBody {
281- requestBody. stream = false
282- var request = URLRequest ( url: endpoint)
283- request. httpMethod = " POST "
284- let encoder = JSONEncoder ( )
285- request. httpBody = try encoder. encode ( requestBody)
286- request. setValue ( " application/json " , forHTTPHeaderField: " Content-Type " )
287-
288- Self . setupAppInformation ( & request)
289- Self . setupAPIKey ( & request, model: model, apiKey: apiKey)
290-
291- let ( result, response) = try await URLSession . shared. data ( for: request)
292- guard let response = response as? HTTPURLResponse else {
293- throw ChatGPTServiceError . responseInvalid
294- }
295-
296- guard response. statusCode == 200 else {
297- let error = try ? JSONDecoder ( ) . decode ( CompletionAPIError . self, from: result)
298- throw error ?? ChatGPTServiceError
299- . otherError ( String ( data: result, encoding: . utf8) ?? " Unknown Error " )
300- }
301-
302- do {
303- let body = try JSONDecoder ( ) . decode ( ResponseBody . self, from: result)
304- return body. formalized ( )
305- } catch {
306- dump ( error)
307- throw error
285+ let stream : AsyncThrowingStream < ChatCompletionsStreamDataChunk , Error > =
286+ try await callAsFunction ( )
287+
288+ var body = ChatCompletionResponseBody (
289+ id: nil ,
290+ object: " " ,
291+ model: " " ,
292+ message: . init( role: . assistant, content: " " ) ,
293+ otherChoices: [ ] ,
294+ finishReason: " "
295+ )
296+ for try await chunk in stream {
297+ if let id = chunk. id {
298+ body. id = id
299+ }
300+ if let finishReason = chunk. finishReason {
301+ body. finishReason = finishReason
302+ }
303+ if let model = chunk. model {
304+ body. model = model
305+ }
306+ if let object = chunk. object {
307+ body. object = object
308+ }
309+ if let role = chunk. message? . role {
310+ body. message. role = role
311+ }
312+ if let text = chunk. message? . content {
313+ body. message. content += text
314+ }
308315 }
316+ return body
309317 }
310318
311319 static func setupAppInformation( _ request: inout URLRequest ) {
@@ -464,36 +472,94 @@ extension OpenAIChatCompletionsService.StreamDataChunk {
464472}
465473
466474extension OpenAIChatCompletionsService . RequestBody {
467- init ( _ body: ChatCompletionsRequestBody ) {
475+ init ( _ body: ChatCompletionsRequestBody , enforceMessageOrder : Bool , canUseTool : Bool ) {
468476 model = body. model
469- messages = body. messages. map { message in
470- . init(
471- role: {
472- switch message. role {
473- case . user:
474- return . user
475- case . assistant:
476- return . assistant
477- case . system:
478- return . system
479- case . tool:
480- return . tool
477+ if enforceMessageOrder {
478+ var systemPrompts = [ String] ( )
479+ var nonSystemMessages = [ Message] ( )
480+
481+ for message in body. messages {
482+ switch ( message. role, canUseTool) {
483+ case ( . system, _) :
484+ systemPrompts. append ( message. content)
485+ case ( . tool, true ) :
486+ if let last = nonSystemMessages. last, last. role == . tool {
487+ nonSystemMessages [ nonSystemMessages. endIndex - 1 ] . content
488+ += " \n \n \( message. content) "
489+ } else {
490+ nonSystemMessages. append ( . init(
491+ role: . tool,
492+ content: message. content,
493+ tool_calls: message. toolCalls? . map { tool in
494+ MessageToolCall (
495+ id: tool. id,
496+ type: tool. type,
497+ function: MessageFunctionCall (
498+ name: tool. function. name,
499+ arguments: tool. function. arguments
500+ )
501+ )
502+ }
503+ ) )
481504 }
482- } ( ) ,
483- content: message. content,
484- name: message. name,
485- tool_calls: message. toolCalls? . map { tool in
486- MessageToolCall (
487- id: tool. id,
488- type: tool. type,
489- function: MessageFunctionCall (
490- name: tool. function. name,
491- arguments: tool. function. arguments
505+ case ( . assistant, _) , ( . tool, false ) :
506+ if let last = nonSystemMessages. last, last. role == . assistant {
507+ nonSystemMessages [ nonSystemMessages. endIndex - 1 ] . content
508+ += " \n \n \( message. content) "
509+ } else {
510+ nonSystemMessages. append ( . init( role: . assistant, content: message. content) )
511+ }
512+ case ( . user, _) :
513+ if let last = nonSystemMessages. last, last. role == . user {
514+ nonSystemMessages [ nonSystemMessages. endIndex - 1 ] . content
515+ += " \n \n \( message. content) "
516+ } else {
517+ nonSystemMessages. append ( . init(
518+ role: . user,
519+ content: message. content,
520+ name: message. name,
521+ tool_call_id: message. toolCallId
522+ ) )
523+ }
524+ }
525+ }
526+ messages = [
527+ . init(
528+ role: . system,
529+ content: systemPrompts. joined ( separator: " \n \n " )
530+ . trimmingCharacters ( in: . whitespacesAndNewlines)
531+ ) ,
532+ ] + nonSystemMessages
533+ } else {
534+ messages = body. messages. map { message in
535+ . init(
536+ role: {
537+ switch message. role {
538+ case . user:
539+ return . user
540+ case . assistant:
541+ return . assistant
542+ case . system:
543+ return . system
544+ case . tool:
545+ return . tool
546+ }
547+ } ( ) ,
548+ content: message. content,
549+ name: message. name,
550+ tool_calls: message. toolCalls? . map { tool in
551+ MessageToolCall (
552+ id: tool. id,
553+ type: tool. type,
554+ function: MessageFunctionCall (
555+ name: tool. function. name,
556+ arguments: tool. function. arguments
557+ )
492558 )
493- )
494- } ,
495- tool_call_id : message . toolCallId
496- )
559+ } ,
560+ tool_call_id : message . toolCallId
561+ )
562+ }
497563 }
498564 temperature = body. temperature
499565 stream = body. stream
0 commit comments