Skip to content

Commit a3aff20

Browse files
committed
Merge branch 'feature/fix-gemini-as-utility-model' into develop
2 parents 861e462 + 70ceefc commit a3aff20

12 files changed

Lines changed: 340 additions & 163 deletions

Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift

Lines changed: 122 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ struct GoogleCompletionAPI: CompletionAPI {
77
let apiKey: String
88
let model: ChatModel
99
var requestBody: CompletionRequestBody
10+
let prompt: ChatGPTPrompt
1011

1112
func callAsFunction() async throws -> CompletionResponseBody {
1213
let aiModel = GenerativeModel(
@@ -17,22 +18,22 @@ struct GoogleCompletionAPI: CompletionAPI {
1718
topP: requestBody.top_p.map(Float.init)
1819
))
1920
)
20-
let history = requestBody.messages.map { message in
21+
let history = prompt.googleAICompatible.history.map { message in
2122
ModelContent(
2223
ChatMessage(
2324
role: message.role,
2425
content: message.content,
2526
name: message.name,
26-
functionCall: message.function_call.map {
27-
.init(name: $0.name, arguments: $0.arguments ?? "")
27+
functionCall: message.functionCall.map {
28+
.init(name: $0.name, arguments: $0.arguments)
2829
}
2930
)
3031
)
3132
}
3233

3334
do {
3435
let response = try await aiModel.generateContent(history)
35-
36+
3637
return .init(
3738
object: "chat.completion",
3839
model: model.info.modelName,
@@ -64,7 +65,7 @@ struct GoogleCompletionAPI: CompletionAPI {
6465
return "Internal Error: \(s)"
6566
}
6667
}
67-
68+
6869
switch error {
6970
case let .internalError(underlying):
7071
throw ErrorWrapper(error: underlying)
@@ -79,3 +80,119 @@ struct GoogleCompletionAPI: CompletionAPI {
7980
}
8081
}
8182

83+
extension ChatGPTPrompt {
84+
var googleAICompatible: ChatGPTPrompt {
85+
var history = self.history
86+
var reformattedHistory = [ChatMessage]()
87+
88+
// We don't want to combine the new user message with others.
89+
let newUserMessage: ChatMessage? = if history.last?.role == .user {
90+
history.removeLast()
91+
} else {
92+
nil
93+
}
94+
95+
for message in history {
96+
let lastIndex = reformattedHistory.endIndex - 1
97+
guard lastIndex >= 0 else { // first message
98+
if message.role == .system {
99+
reformattedHistory.append(.init(
100+
id: message.id,
101+
role: .user,
102+
content: ModelContent.convertContent(of: message)
103+
))
104+
reformattedHistory.append(.init(
105+
role: .assistant,
106+
content: "Got it. Let's start our conversation."
107+
))
108+
continue
109+
}
110+
111+
reformattedHistory.append(message)
112+
continue
113+
}
114+
115+
let lastMessage = reformattedHistory[lastIndex]
116+
117+
if ModelContent.convertRole(lastMessage.role) == ModelContent
118+
.convertRole(message.role)
119+
{
120+
let newMessage = ChatMessage(
121+
id: message.id,
122+
role: message.role == .assistant ? .assistant : .user,
123+
content: """
124+
\(ModelContent.convertContent(of: lastMessage))
125+
126+
======
127+
128+
\(ModelContent.convertContent(of: message))
129+
"""
130+
)
131+
reformattedHistory[lastIndex] = newMessage
132+
} else {
133+
reformattedHistory.append(message)
134+
}
135+
}
136+
137+
if let newUserMessage {
138+
if let last = reformattedHistory.last,
139+
ModelContent.convertRole(last.role) == ModelContent
140+
.convertRole(newUserMessage.role)
141+
{
142+
// Add dummy message
143+
let dummyMessage = ChatMessage(
144+
role: .assistant,
145+
content: "OK"
146+
)
147+
reformattedHistory.append(dummyMessage)
148+
}
149+
reformattedHistory.append(newUserMessage)
150+
}
151+
152+
return .init(
153+
history: reformattedHistory,
154+
references: references,
155+
remainingTokenCount: remainingTokenCount
156+
)
157+
}
158+
}
159+
160+
extension ModelContent {
161+
static func convertRole(_ role: ChatMessage.Role) -> String {
162+
switch role {
163+
case .user, .system, .function:
164+
return "user"
165+
case .assistant:
166+
return "model"
167+
}
168+
}
169+
170+
static func convertContent(of message: ChatMessage) -> String {
171+
switch message.role {
172+
case .system:
173+
return "System Prompt:\n\(message.content ?? " ")"
174+
case .user:
175+
return message.content ?? " "
176+
case .function:
177+
return """
178+
Result of \(message.name ?? "function"): \(message.content ?? "N/A")
179+
"""
180+
case .assistant:
181+
if let functionCall = message.functionCall {
182+
return """
183+
Call function: \(functionCall.name)
184+
Arguments: \(functionCall.arguments)
185+
"""
186+
} else {
187+
return message.content ?? " "
188+
}
189+
}
190+
}
191+
192+
init(_ message: ChatMessage) {
193+
let role = Self.convertRole(message.role)
194+
let parts = [ModelContent.Part.text(Self.convertContent(of: message))]
195+
self = .init(role: role, parts: parts)
196+
}
197+
}
198+

Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ struct GoogleCompletionStreamAPI: CompletionStreamAPI {
77
let apiKey: String
88
let model: ChatModel
99
var requestBody: CompletionRequestBody
10+
let prompt: ChatGPTPrompt
1011

1112
func callAsFunction() async throws -> AsyncThrowingStream<CompletionStreamDataChunk, Error> {
1213
let aiModel = GenerativeModel(
@@ -17,14 +18,14 @@ struct GoogleCompletionStreamAPI: CompletionStreamAPI {
1718
topP: requestBody.top_p.map(Float.init)
1819
))
1920
)
20-
let history = requestBody.messages.map { message in
21+
let history = prompt.googleAICompatible.history.map { message in
2122
ModelContent(
2223
ChatMessage(
2324
role: message.role,
2425
content: message.content,
2526
name: message.name,
26-
functionCall: message.function_call.map {
27-
.init(name: $0.name, arguments: $0.arguments ?? "")
27+
functionCall: message.functionCall.map {
28+
.init(name: $0.name, arguments: $0.arguments)
2829
}
2930
)
3031
)

Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import AIModel
22
import Foundation
33
import Preferences
44

5-
typealias CompletionAPIBuilder = (String, ChatModel, URL, CompletionRequestBody)
5+
typealias CompletionAPIBuilder = (String, ChatModel, URL, CompletionRequestBody, ChatGPTPrompt)
66
-> CompletionAPI
77

88
protocol CompletionAPI {

Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@ import AsyncAlgorithms
33
import Foundation
44
import Preferences
55

6-
typealias CompletionStreamAPIBuilder = (String, ChatModel, URL, CompletionRequestBody)
7-
-> any CompletionStreamAPI
6+
typealias CompletionStreamAPIBuilder = (
7+
String,
8+
ChatModel,
9+
URL,
10+
CompletionRequestBody,
11+
ChatGPTPrompt
12+
) -> any CompletionStreamAPI
813

914
protocol CompletionStreamAPI {
1015
func callAsFunction() async throws -> AsyncThrowingStream<CompletionStreamDataChunk, Error>

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,15 @@ public class ChatGPTService: ChatGPTServiceType {
7070

7171
var runningTask: Task<Void, Never>?
7272
var buildCompletionStreamAPI: CompletionStreamAPIBuilder = {
73-
apiKey, model, endpoint, requestBody in
73+
apiKey, model, endpoint, requestBody, prompt in
7474
switch model.format {
7575
case .googleAI:
76-
return GoogleCompletionStreamAPI(apiKey: apiKey, model: model, requestBody: requestBody)
76+
return GoogleCompletionStreamAPI(
77+
apiKey: apiKey,
78+
model: model,
79+
requestBody: requestBody,
80+
prompt: prompt
81+
)
7782
case .openAI, .openAICompatible, .azureOpenAI:
7883
return OpenAICompletionStreamAPI(
7984
apiKey: apiKey,
@@ -85,10 +90,15 @@ public class ChatGPTService: ChatGPTServiceType {
8590
}
8691

8792
var buildCompletionAPI: CompletionAPIBuilder = {
88-
apiKey, model, endpoint, requestBody in
93+
apiKey, model, endpoint, requestBody, prompt in
8994
switch model.format {
9095
case .googleAI:
91-
return GoogleCompletionAPI(apiKey: apiKey, model: model, requestBody: requestBody)
96+
return GoogleCompletionAPI(
97+
apiKey: apiKey,
98+
model: model,
99+
requestBody: requestBody,
100+
prompt: prompt
101+
)
92102
case .openAI, .openAICompatible, .azureOpenAI:
93103
return OpenAICompletionAPI(
94104
apiKey: apiKey,
@@ -305,7 +315,8 @@ extension ChatGPTService {
305315
configuration.apiKey,
306316
model,
307317
url,
308-
requestBody
318+
requestBody,
319+
prompt
309320
)
310321

311322
#if DEBUG
@@ -432,7 +443,8 @@ extension ChatGPTService {
432443
configuration.apiKey,
433444
model,
434445
url,
435-
requestBody
446+
requestBody,
447+
prompt
436448
)
437449

438450
#if DEBUG

Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemory.swift

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ public enum AutoManagedChatGPTMemoryActor: GlobalActor {
1212
protocol AutoManagedChatGPTMemoryStrategy {
1313
func countToken(_ message: ChatMessage) async -> Int
1414
func countToken<F: ChatGPTFunction>(_ function: F) async -> Int
15-
func reformat(_ prompt: ChatGPTPrompt) async -> ChatGPTPrompt
1615
}
1716

1817
/// A memory that automatically manages the history according to max tokens and max message count.
@@ -172,12 +171,10 @@ extension AutoManagedChatGPTMemory {
172171
""")
173172
#endif
174173

175-
let reformattedPrompt = await strategy.reformat(.init(
174+
return .init(
176175
history: allMessages,
177176
references: retrievedContent
178-
))
179-
180-
return reformattedPrompt
177+
)
181178
}
182179

183180
func generateMandatoryMessages(strategy: AutoManagedChatGPTMemoryStrategy) async -> (

0 commit comments

Comments
 (0)