Skip to content

Commit 1a84d2f

Browse files
committed
Merge branch 'feature/hotfix-for-tabby-chat' into develop
2 parents 06cb361 + 9162154 commit 1a84d2f

4 files changed

Lines changed: 149 additions & 59 deletions

File tree

Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct ChatModelEdit {
2828
var suggestedMaxTokens: Int?
2929
var apiKeySelection: APIKeySelection.State = .init()
3030
var baseURLSelection: BaseURLSelection.State = .init()
31+
var enforceMessageOrder: Bool = false
3132
}
3233

3334
enum Action: Equatable, BindableAction {
@@ -197,11 +198,12 @@ extension ChatModel {
197198
}(),
198199
modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines),
199200
ollamaInfo: .init(keepAlive: state.ollamaKeepAlive),
200-
googleGenerativeAIInfo: .init(apiVersion: state.apiVersion)
201+
googleGenerativeAIInfo: .init(apiVersion: state.apiVersion),
202+
openAICompatibleInfo: .init(enforceMessageOrder: state.enforceMessageOrder)
201203
)
202204
)
203205
}
204-
206+
205207
func toState() -> ChatModelEdit.State {
206208
.init(
207209
id: id,
@@ -216,7 +218,8 @@ extension ChatModel {
216218
apiKeyName: info.apiKeyName,
217219
apiKeyManagement: .init(availableAPIKeyNames: [info.apiKeyName])
218220
),
219-
baseURLSelection: .init(baseURL: info.baseURL, isFullURL: info.isFullURL)
221+
baseURLSelection: .init(baseURL: info.baseURL, isFullURL: info.isFullURL),
222+
enforceMessageOrder: info.openAICompatibleInfo.enforceMessageOrder
220223
)
221224
}
222225
}

Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ struct ChatModelEditView: View {
308308

309309
MaxTokensTextField(store: store)
310310
SupportsFunctionCallingToggle(store: store)
311+
312+
Toggle(isOn: $store.enforceMessageOrder) {
313+
Text("Enforce message order to be user/assistant alternated")
314+
}
311315
}
312316
}
313317
}

Tool/Sources/AIModel/ChatModel.swift

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ public struct ChatModel: Codable, Equatable, Identifiable {
4444
}
4545
}
4646

47+
public struct OpenAICompatibleInfo: Codable, Equatable {
48+
@FallbackDecoding<EmptyBool>
49+
public var enforceMessageOrder: Bool
50+
51+
public init(enforceMessageOrder: Bool = false) {
52+
self.enforceMessageOrder = enforceMessageOrder
53+
}
54+
}
55+
4756
public struct GoogleGenerativeAIInfo: Codable, Equatable {
4857
@FallbackDecoding<EmptyString>
4958
public var apiVersion: String
@@ -72,6 +81,8 @@ public struct ChatModel: Codable, Equatable, Identifiable {
7281
public var ollamaInfo: OllamaInfo
7382
@FallbackDecoding<EmptyChatModelGoogleGenerativeAIInfo>
7483
public var googleGenerativeAIInfo: GoogleGenerativeAIInfo
84+
@FallbackDecoding<EmptyChatModelOpenAICompatibleInfo>
85+
public var openAICompatibleInfo: OpenAICompatibleInfo
7586

7687
public init(
7788
apiKeyName: String = "",
@@ -82,7 +93,8 @@ public struct ChatModel: Codable, Equatable, Identifiable {
8293
modelName: String = "",
8394
openAIInfo: OpenAIInfo = OpenAIInfo(),
8495
ollamaInfo: OllamaInfo = OllamaInfo(),
85-
googleGenerativeAIInfo: GoogleGenerativeAIInfo = GoogleGenerativeAIInfo()
96+
googleGenerativeAIInfo: GoogleGenerativeAIInfo = GoogleGenerativeAIInfo(),
97+
openAICompatibleInfo: OpenAICompatibleInfo = OpenAICompatibleInfo()
8698
) {
8799
self.apiKeyName = apiKeyName
88100
self.baseURL = baseURL
@@ -93,6 +105,7 @@ public struct ChatModel: Codable, Equatable, Identifiable {
93105
self.openAIInfo = openAIInfo
94106
self.ollamaInfo = ollamaInfo
95107
self.googleGenerativeAIInfo = googleGenerativeAIInfo
108+
self.openAICompatibleInfo = openAICompatibleInfo
96109
}
97110
}
98111

@@ -148,3 +161,7 @@ public struct EmptyChatModelOpenAIInfo: FallbackValueProvider {
148161
public struct EmptyChatModelGoogleGenerativeAIInfo: FallbackValueProvider {
149162
public static var defaultValue: ChatModel.Info.GoogleGenerativeAIInfo { .init() }
150163
}
164+
165+
public struct EmptyChatModelOpenAICompatibleInfo: FallbackValueProvider {
166+
public static var defaultValue: ChatModel.Info.OpenAICompatibleInfo { .init() }
167+
}

Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift

Lines changed: 121 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,11 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI
220220
) {
221221
self.apiKey = apiKey
222222
self.endpoint = endpoint
223-
self.requestBody = .init(requestBody)
223+
self.requestBody = .init(
224+
requestBody,
225+
enforceMessageOrder: model.info.openAICompatibleInfo.enforceMessageOrder,
226+
canUseTool: model.info.supportsFunctionCalling
227+
)
224228
self.model = model
225229
}
226230

@@ -278,34 +282,38 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI
278282
}
279283

280284
func callAsFunction() async throws -> ChatCompletionResponseBody {
281-
requestBody.stream = false
282-
var request = URLRequest(url: endpoint)
283-
request.httpMethod = "POST"
284-
let encoder = JSONEncoder()
285-
request.httpBody = try encoder.encode(requestBody)
286-
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
287-
288-
Self.setupAppInformation(&request)
289-
Self.setupAPIKey(&request, model: model, apiKey: apiKey)
290-
291-
let (result, response) = try await URLSession.shared.data(for: request)
292-
guard let response = response as? HTTPURLResponse else {
293-
throw ChatGPTServiceError.responseInvalid
294-
}
295-
296-
guard response.statusCode == 200 else {
297-
let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result)
298-
throw error ?? ChatGPTServiceError
299-
.otherError(String(data: result, encoding: .utf8) ?? "Unknown Error")
300-
}
301-
302-
do {
303-
let body = try JSONDecoder().decode(ResponseBody.self, from: result)
304-
return body.formalized()
305-
} catch {
306-
dump(error)
307-
throw error
285+
let stream: AsyncThrowingStream<ChatCompletionsStreamDataChunk, Error> =
286+
try await callAsFunction()
287+
288+
var body = ChatCompletionResponseBody(
289+
id: nil,
290+
object: "",
291+
model: "",
292+
message: .init(role: .assistant, content: ""),
293+
otherChoices: [],
294+
finishReason: ""
295+
)
296+
for try await chunk in stream {
297+
if let id = chunk.id {
298+
body.id = id
299+
}
300+
if let finishReason = chunk.finishReason {
301+
body.finishReason = finishReason
302+
}
303+
if let model = chunk.model {
304+
body.model = model
305+
}
306+
if let object = chunk.object {
307+
body.object = object
308+
}
309+
if let role = chunk.message?.role {
310+
body.message.role = role
311+
}
312+
if let text = chunk.message?.content {
313+
body.message.content += text
314+
}
308315
}
316+
return body
309317
}
310318

311319
static func setupAppInformation(_ request: inout URLRequest) {
@@ -464,36 +472,94 @@ extension OpenAIChatCompletionsService.StreamDataChunk {
464472
}
465473

466474
extension OpenAIChatCompletionsService.RequestBody {
467-
init(_ body: ChatCompletionsRequestBody) {
475+
init(_ body: ChatCompletionsRequestBody, enforceMessageOrder: Bool, canUseTool: Bool) {
468476
model = body.model
469-
messages = body.messages.map { message in
470-
.init(
471-
role: {
472-
switch message.role {
473-
case .user:
474-
return .user
475-
case .assistant:
476-
return .assistant
477-
case .system:
478-
return .system
479-
case .tool:
480-
return .tool
477+
if enforceMessageOrder {
478+
var systemPrompts = [String]()
479+
var nonSystemMessages = [Message]()
480+
481+
for message in body.messages {
482+
switch (message.role, canUseTool) {
483+
case (.system, _):
484+
systemPrompts.append(message.content)
485+
case (.tool, true):
486+
if let last = nonSystemMessages.last, last.role == .tool {
487+
nonSystemMessages[nonSystemMessages.endIndex - 1].content
488+
+= "\n\n\(message.content)"
489+
} else {
490+
nonSystemMessages.append(.init(
491+
role: .tool,
492+
content: message.content,
493+
tool_calls: message.toolCalls?.map { tool in
494+
MessageToolCall(
495+
id: tool.id,
496+
type: tool.type,
497+
function: MessageFunctionCall(
498+
name: tool.function.name,
499+
arguments: tool.function.arguments
500+
)
501+
)
502+
}
503+
))
481504
}
482-
}(),
483-
content: message.content,
484-
name: message.name,
485-
tool_calls: message.toolCalls?.map { tool in
486-
MessageToolCall(
487-
id: tool.id,
488-
type: tool.type,
489-
function: MessageFunctionCall(
490-
name: tool.function.name,
491-
arguments: tool.function.arguments
505+
case (.assistant, _), (.tool, false):
506+
if let last = nonSystemMessages.last, last.role == .assistant {
507+
nonSystemMessages[nonSystemMessages.endIndex - 1].content
508+
+= "\n\n\(message.content)"
509+
} else {
510+
nonSystemMessages.append(.init(role: .assistant, content: message.content))
511+
}
512+
case (.user, _):
513+
if let last = nonSystemMessages.last, last.role == .user {
514+
nonSystemMessages[nonSystemMessages.endIndex - 1].content
515+
+= "\n\n\(message.content)"
516+
} else {
517+
nonSystemMessages.append(.init(
518+
role: .user,
519+
content: message.content,
520+
name: message.name,
521+
tool_call_id: message.toolCallId
522+
))
523+
}
524+
}
525+
}
526+
messages = [
527+
.init(
528+
role: .system,
529+
content: systemPrompts.joined(separator: "\n\n")
530+
.trimmingCharacters(in: .whitespacesAndNewlines)
531+
),
532+
] + nonSystemMessages
533+
} else {
534+
messages = body.messages.map { message in
535+
.init(
536+
role: {
537+
switch message.role {
538+
case .user:
539+
return .user
540+
case .assistant:
541+
return .assistant
542+
case .system:
543+
return .system
544+
case .tool:
545+
return .tool
546+
}
547+
}(),
548+
content: message.content,
549+
name: message.name,
550+
tool_calls: message.toolCalls?.map { tool in
551+
MessageToolCall(
552+
id: tool.id,
553+
type: tool.type,
554+
function: MessageFunctionCall(
555+
name: tool.function.name,
556+
arguments: tool.function.arguments
557+
)
492558
)
493-
)
494-
},
495-
tool_call_id: message.toolCallId
496-
)
559+
},
560+
tool_call_id: message.toolCallId
561+
)
562+
}
497563
}
498564
temperature = body.temperature
499565
stream = body.stream

0 commit comments

Comments
 (0)