@@ -220,7 +220,11 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI
220220 ) {
221221 self . apiKey = apiKey
222222 self . endpoint = endpoint
223- self . requestBody = . init( requestBody)
223+ self . requestBody = . init(
224+ requestBody,
225+ enforceMessageOrder: model. info. openAICompatibleInfo. enforceMessageOrder,
226+ canUseTool: model. info. supportsFunctionCalling
227+ )
224228 self . model = model
225229 }
226230
@@ -468,36 +472,94 @@ extension OpenAIChatCompletionsService.StreamDataChunk {
468472}
469473
470474extension OpenAIChatCompletionsService . RequestBody {
471- init ( _ body: ChatCompletionsRequestBody ) {
475+ init ( _ body: ChatCompletionsRequestBody , enforceMessageOrder : Bool , canUseTool : Bool ) {
472476 model = body. model
473- messages = body. messages. map { message in
474- . init(
475- role: {
476- switch message. role {
477- case . user:
478- return . user
479- case . assistant:
480- return . assistant
481- case . system:
482- return . system
483- case . tool:
484- return . tool
477+ if enforceMessageOrder {
478+ var systemPrompts = [ String] ( )
479+ var nonSystemMessages = [ Message] ( )
480+
481+ for message in body. messages {
482+ switch ( message. role, canUseTool) {
483+ case ( . system, _) :
484+ systemPrompts. append ( message. content)
485+ case ( . tool, true ) :
486+ if let last = nonSystemMessages. last, last. role == . tool {
487+ nonSystemMessages [ nonSystemMessages. endIndex - 1 ] . content
488+ += " \n \n \( message. content) "
489+ } else {
490+ nonSystemMessages. append ( . init(
491+ role: . tool,
492+ content: message. content,
493+ tool_calls: message. toolCalls? . map { tool in
494+ MessageToolCall (
495+ id: tool. id,
496+ type: tool. type,
497+ function: MessageFunctionCall (
498+ name: tool. function. name,
499+ arguments: tool. function. arguments
500+ )
501+ )
502+ }
503+ ) )
504+ }
505+ case ( . assistant, _) , ( . tool, false ) :
506+ if let last = nonSystemMessages. last, last. role == . assistant {
507+ nonSystemMessages [ nonSystemMessages. endIndex - 1 ] . content
508+ += " \n \n \( message. content) "
509+ } else {
510+ nonSystemMessages. append ( . init( role: . assistant, content: message. content) )
511+ }
512+ case ( . user, _) :
513+ if let last = nonSystemMessages. last, last. role == . user {
514+ nonSystemMessages [ nonSystemMessages. endIndex - 1 ] . content
515+ += " \n \n \( message. content) "
516+ } else {
517+ nonSystemMessages. append ( . init(
518+ role: . user,
519+ content: message. content,
520+ name: message. name,
521+ tool_call_id: message. toolCallId
522+ ) )
485523 }
486- } ( ) ,
487- content: message. content,
488- name: message. name,
489- tool_calls: message. toolCalls? . map { tool in
490- MessageToolCall (
491- id: tool. id,
492- type: tool. type,
493- function: MessageFunctionCall (
494- name: tool. function. name,
495- arguments: tool. function. arguments
524+ }
525+ }
526+ messages = [
527+ . init(
528+ role: . system,
529+ content: systemPrompts. joined ( separator: " \n \n " )
530+ . trimmingCharacters ( in: . whitespacesAndNewlines)
531+ ) ,
532+ ] + nonSystemMessages
533+ } else {
534+ messages = body. messages. map { message in
535+ . init(
536+ role: {
537+ switch message. role {
538+ case . user:
539+ return . user
540+ case . assistant:
541+ return . assistant
542+ case . system:
543+ return . system
544+ case . tool:
545+ return . tool
546+ }
547+ } ( ) ,
548+ content: message. content,
549+ name: message. name,
550+ tool_calls: message. toolCalls? . map { tool in
551+ MessageToolCall (
552+ id: tool. id,
553+ type: tool. type,
554+ function: MessageFunctionCall (
555+ name: tool. function. name,
556+ arguments: tool. function. arguments
557+ )
496558 )
497- )
498- } ,
499- tool_call_id : message . toolCallId
500- )
559+ } ,
560+ tool_call_id : message . toolCallId
561+ )
562+ }
501563 }
502564 temperature = body. temperature
503565 stream = body. stream
0 commit comments