Skip to content

Commit 9162154

Browse files
committed
Support enforcing message order
1 parent 9e49591 commit 9162154

4 files changed

Lines changed: 118 additions & 32 deletions

File tree

Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct ChatModelEdit {
2828
var suggestedMaxTokens: Int?
2929
var apiKeySelection: APIKeySelection.State = .init()
3030
var baseURLSelection: BaseURLSelection.State = .init()
31+
var enforceMessageOrder: Bool = false
3132
}
3233

3334
enum Action: Equatable, BindableAction {
@@ -197,11 +198,12 @@ extension ChatModel {
197198
}(),
198199
modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines),
199200
ollamaInfo: .init(keepAlive: state.ollamaKeepAlive),
200-
googleGenerativeAIInfo: .init(apiVersion: state.apiVersion)
201+
googleGenerativeAIInfo: .init(apiVersion: state.apiVersion),
202+
openAICompatibleInfo: .init(enforceMessageOrder: state.enforceMessageOrder)
201203
)
202204
)
203205
}
204-
206+
205207
func toState() -> ChatModelEdit.State {
206208
.init(
207209
id: id,
@@ -216,7 +218,8 @@ extension ChatModel {
216218
apiKeyName: info.apiKeyName,
217219
apiKeyManagement: .init(availableAPIKeyNames: [info.apiKeyName])
218220
),
219-
baseURLSelection: .init(baseURL: info.baseURL, isFullURL: info.isFullURL)
221+
baseURLSelection: .init(baseURL: info.baseURL, isFullURL: info.isFullURL),
222+
enforceMessageOrder: info.openAICompatibleInfo.enforceMessageOrder
220223
)
221224
}
222225
}

Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ struct ChatModelEditView: View {
308308

309309
MaxTokensTextField(store: store)
310310
SupportsFunctionCallingToggle(store: store)
311+
312+
Toggle(isOn: $store.enforceMessageOrder) {
313+
Text("Enforce message order to be user/assistant alternated")
314+
}
311315
}
312316
}
313317
}

Tool/Sources/AIModel/ChatModel.swift

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ public struct ChatModel: Codable, Equatable, Identifiable {
4444
}
4545
}
4646

47+
public struct OpenAICompatibleInfo: Codable, Equatable {
48+
@FallbackDecoding<EmptyBool>
49+
public var enforceMessageOrder: Bool
50+
51+
public init(enforceMessageOrder: Bool = false) {
52+
self.enforceMessageOrder = enforceMessageOrder
53+
}
54+
}
55+
4756
public struct GoogleGenerativeAIInfo: Codable, Equatable {
4857
@FallbackDecoding<EmptyString>
4958
public var apiVersion: String
@@ -72,6 +81,8 @@ public struct ChatModel: Codable, Equatable, Identifiable {
7281
public var ollamaInfo: OllamaInfo
7382
@FallbackDecoding<EmptyChatModelGoogleGenerativeAIInfo>
7483
public var googleGenerativeAIInfo: GoogleGenerativeAIInfo
84+
@FallbackDecoding<EmptyChatModelOpenAICompatibleInfo>
85+
public var openAICompatibleInfo: OpenAICompatibleInfo
7586

7687
public init(
7788
apiKeyName: String = "",
@@ -82,7 +93,8 @@ public struct ChatModel: Codable, Equatable, Identifiable {
8293
modelName: String = "",
8394
openAIInfo: OpenAIInfo = OpenAIInfo(),
8495
ollamaInfo: OllamaInfo = OllamaInfo(),
85-
googleGenerativeAIInfo: GoogleGenerativeAIInfo = GoogleGenerativeAIInfo()
96+
googleGenerativeAIInfo: GoogleGenerativeAIInfo = GoogleGenerativeAIInfo(),
97+
openAICompatibleInfo: OpenAICompatibleInfo = OpenAICompatibleInfo()
8698
) {
8799
self.apiKeyName = apiKeyName
88100
self.baseURL = baseURL
@@ -93,6 +105,7 @@ public struct ChatModel: Codable, Equatable, Identifiable {
93105
self.openAIInfo = openAIInfo
94106
self.ollamaInfo = ollamaInfo
95107
self.googleGenerativeAIInfo = googleGenerativeAIInfo
108+
self.openAICompatibleInfo = openAICompatibleInfo
96109
}
97110
}
98111

@@ -148,3 +161,7 @@ public struct EmptyChatModelOpenAIInfo: FallbackValueProvider {
148161
public struct EmptyChatModelGoogleGenerativeAIInfo: FallbackValueProvider {
149162
public static var defaultValue: ChatModel.Info.GoogleGenerativeAIInfo { .init() }
150163
}
164+
165+
public struct EmptyChatModelOpenAICompatibleInfo: FallbackValueProvider {
166+
public static var defaultValue: ChatModel.Info.OpenAICompatibleInfo { .init() }
167+
}

Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift

Lines changed: 90 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,11 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI
220220
) {
221221
self.apiKey = apiKey
222222
self.endpoint = endpoint
223-
self.requestBody = .init(requestBody)
223+
self.requestBody = .init(
224+
requestBody,
225+
enforceMessageOrder: model.info.openAICompatibleInfo.enforceMessageOrder,
226+
canUseTool: model.info.supportsFunctionCalling
227+
)
224228
self.model = model
225229
}
226230

@@ -468,36 +472,94 @@ extension OpenAIChatCompletionsService.StreamDataChunk {
468472
}
469473

470474
extension OpenAIChatCompletionsService.RequestBody {
471-
init(_ body: ChatCompletionsRequestBody) {
475+
init(_ body: ChatCompletionsRequestBody, enforceMessageOrder: Bool, canUseTool: Bool) {
472476
model = body.model
473-
messages = body.messages.map { message in
474-
.init(
475-
role: {
476-
switch message.role {
477-
case .user:
478-
return .user
479-
case .assistant:
480-
return .assistant
481-
case .system:
482-
return .system
483-
case .tool:
484-
return .tool
477+
if enforceMessageOrder {
478+
var systemPrompts = [String]()
479+
var nonSystemMessages = [Message]()
480+
481+
for message in body.messages {
482+
switch (message.role, canUseTool) {
483+
case (.system, _):
484+
systemPrompts.append(message.content)
485+
case (.tool, true):
486+
if let last = nonSystemMessages.last, last.role == .tool {
487+
nonSystemMessages[nonSystemMessages.endIndex - 1].content
488+
+= "\n\n\(message.content)"
489+
} else {
490+
nonSystemMessages.append(.init(
491+
role: .tool,
492+
content: message.content,
493+
tool_calls: message.toolCalls?.map { tool in
494+
MessageToolCall(
495+
id: tool.id,
496+
type: tool.type,
497+
function: MessageFunctionCall(
498+
name: tool.function.name,
499+
arguments: tool.function.arguments
500+
)
501+
)
502+
}
503+
))
504+
}
505+
case (.assistant, _), (.tool, false):
506+
if let last = nonSystemMessages.last, last.role == .assistant {
507+
nonSystemMessages[nonSystemMessages.endIndex - 1].content
508+
+= "\n\n\(message.content)"
509+
} else {
510+
nonSystemMessages.append(.init(role: .assistant, content: message.content))
511+
}
512+
case (.user, _):
513+
if let last = nonSystemMessages.last, last.role == .user {
514+
nonSystemMessages[nonSystemMessages.endIndex - 1].content
515+
+= "\n\n\(message.content)"
516+
} else {
517+
nonSystemMessages.append(.init(
518+
role: .user,
519+
content: message.content,
520+
name: message.name,
521+
tool_call_id: message.toolCallId
522+
))
485523
}
486-
}(),
487-
content: message.content,
488-
name: message.name,
489-
tool_calls: message.toolCalls?.map { tool in
490-
MessageToolCall(
491-
id: tool.id,
492-
type: tool.type,
493-
function: MessageFunctionCall(
494-
name: tool.function.name,
495-
arguments: tool.function.arguments
524+
}
525+
}
526+
messages = [
527+
.init(
528+
role: .system,
529+
content: systemPrompts.joined(separator: "\n\n")
530+
.trimmingCharacters(in: .whitespacesAndNewlines)
531+
),
532+
] + nonSystemMessages
533+
} else {
534+
messages = body.messages.map { message in
535+
.init(
536+
role: {
537+
switch message.role {
538+
case .user:
539+
return .user
540+
case .assistant:
541+
return .assistant
542+
case .system:
543+
return .system
544+
case .tool:
545+
return .tool
546+
}
547+
}(),
548+
content: message.content,
549+
name: message.name,
550+
tool_calls: message.toolCalls?.map { tool in
551+
MessageToolCall(
552+
id: tool.id,
553+
type: tool.type,
554+
function: MessageFunctionCall(
555+
name: tool.function.name,
556+
arguments: tool.function.arguments
557+
)
496558
)
497-
)
498-
},
499-
tool_call_id: message.toolCallId
500-
)
559+
},
560+
tool_call_id: message.toolCallId
561+
)
562+
}
501563
}
502564
temperature = body.temperature
503565
stream = body.stream

0 commit comments

Comments
 (0)