Skip to content

Commit 70ceefc

Browse files
committed
Add tests for prompt reformat for Google Gemini
1 parent 20f5de9 commit 70ceefc

4 files changed

Lines changed: 270 additions & 105 deletions

File tree

Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift

Lines changed: 111 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,83 @@ import Foundation
33
import GoogleGenerativeAI
44
import Preferences
55

6+
struct GoogleCompletionAPI: CompletionAPI {
7+
let apiKey: String
8+
let model: ChatModel
9+
var requestBody: CompletionRequestBody
10+
let prompt: ChatGPTPrompt
11+
12+
func callAsFunction() async throws -> CompletionResponseBody {
13+
let aiModel = GenerativeModel(
14+
name: model.info.modelName,
15+
apiKey: apiKey,
16+
generationConfig: .init(GenerationConfig(
17+
temperature: requestBody.temperature.map(Float.init),
18+
topP: requestBody.top_p.map(Float.init)
19+
))
20+
)
21+
let history = prompt.googleAICompatible.history.map { message in
22+
ModelContent(
23+
ChatMessage(
24+
role: message.role,
25+
content: message.content,
26+
name: message.name,
27+
functionCall: message.functionCall.map {
28+
.init(name: $0.name, arguments: $0.arguments)
29+
}
30+
)
31+
)
32+
}
33+
34+
do {
35+
let response = try await aiModel.generateContent(history)
36+
37+
return .init(
38+
object: "chat.completion",
39+
model: model.info.modelName,
40+
usage: .init(prompt_tokens: 0, completion_tokens: 0, total_tokens: 0),
41+
choices: response.candidates.enumerated().map {
42+
let (index, candidate) = $0
43+
return .init(
44+
message: .init(
45+
role: .assistant,
46+
content: candidate.content.parts.first(where: { part in
47+
if let text = part.text {
48+
return !text.isEmpty
49+
} else {
50+
return false
51+
}
52+
})?.text ?? ""
53+
),
54+
index: index,
55+
finish_reason: candidate.finishReason?.rawValue ?? ""
56+
)
57+
}
58+
)
59+
} catch let error as GenerateContentError {
60+
struct ErrorWrapper: Error, LocalizedError {
61+
let error: Error
62+
var errorDescription: String? {
63+
var s = ""
64+
dump(error, to: &s)
65+
return "Internal Error: \(s)"
66+
}
67+
}
68+
69+
switch error {
70+
case let .internalError(underlying):
71+
throw ErrorWrapper(error: underlying)
72+
case .promptBlocked:
73+
throw error
74+
case .responseStoppedEarly:
75+
throw error
76+
}
77+
} catch {
78+
throw error
79+
}
80+
}
81+
}
82+
683
extension ChatGPTPrompt {
784
var googleAICompatible: ChatGPTPrompt {
885
var history = self.history
@@ -20,6 +97,7 @@ extension ChatGPTPrompt {
2097
guard lastIndex >= 0 else { // first message
2198
if message.role == .system {
2299
reformattedHistory.append(.init(
100+
id: message.id,
23101
role: .user,
24102
content: ModelContent.convertContent(of: message)
25103
))
@@ -40,6 +118,7 @@ extension ChatGPTPrompt {
40118
.convertRole(message.role)
41119
{
42120
let newMessage = ChatMessage(
121+
id: message.id,
43122
role: message.role == .assistant ? .assistant : .user,
44123
content: """
45124
\(ModelContent.convertContent(of: lastMessage))
@@ -78,80 +157,42 @@ extension ChatGPTPrompt {
78157
}
79158
}
80159

81-
struct GoogleCompletionAPI: CompletionAPI {
82-
let apiKey: String
83-
let model: ChatModel
84-
var requestBody: CompletionRequestBody
85-
let prompt: ChatGPTPrompt
86-
87-
func callAsFunction() async throws -> CompletionResponseBody {
88-
let aiModel = GenerativeModel(
89-
name: model.info.modelName,
90-
apiKey: apiKey,
91-
generationConfig: .init(GenerationConfig(
92-
temperature: requestBody.temperature.map(Float.init),
93-
topP: requestBody.top_p.map(Float.init)
94-
))
95-
)
96-
let history = prompt.googleAICompatible.history.map { message in
97-
ModelContent(
98-
ChatMessage(
99-
role: message.role,
100-
content: message.content,
101-
name: message.name,
102-
functionCall: message.functionCall.map {
103-
.init(name: $0.name, arguments: $0.arguments)
104-
}
105-
)
106-
)
160+
extension ModelContent {
161+
static func convertRole(_ role: ChatMessage.Role) -> String {
162+
switch role {
163+
case .user, .system, .function:
164+
return "user"
165+
case .assistant:
166+
return "model"
107167
}
168+
}
108169

109-
do {
110-
let response = try await aiModel.generateContent(history)
111-
112-
return .init(
113-
object: "chat.completion",
114-
model: model.info.modelName,
115-
usage: .init(prompt_tokens: 0, completion_tokens: 0, total_tokens: 0),
116-
choices: response.candidates.enumerated().map {
117-
let (index, candidate) = $0
118-
return .init(
119-
message: .init(
120-
role: .assistant,
121-
content: candidate.content.parts.first(where: { part in
122-
if let text = part.text {
123-
return !text.isEmpty
124-
} else {
125-
return false
126-
}
127-
})?.text ?? ""
128-
),
129-
index: index,
130-
finish_reason: candidate.finishReason?.rawValue ?? ""
131-
)
132-
}
133-
)
134-
} catch let error as GenerateContentError {
135-
struct ErrorWrapper: Error, LocalizedError {
136-
let error: Error
137-
var errorDescription: String? {
138-
var s = ""
139-
dump(error, to: &s)
140-
return "Internal Error: \(s)"
141-
}
142-
}
143-
144-
switch error {
145-
case let .internalError(underlying):
146-
throw ErrorWrapper(error: underlying)
147-
case .promptBlocked:
148-
throw error
149-
case .responseStoppedEarly:
150-
throw error
170+
static func convertContent(of message: ChatMessage) -> String {
171+
switch message.role {
172+
case .system:
173+
return "System Prompt:\n\(message.content ?? " ")"
174+
case .user:
175+
return message.content ?? " "
176+
case .function:
177+
return """
178+
Result of \(message.name ?? "function"): \(message.content ?? "N/A")
179+
"""
180+
case .assistant:
181+
if let functionCall = message.functionCall {
182+
return """
183+
Call function: \(functionCall.name)
184+
Arguments: \(functionCall.arguments)
185+
"""
186+
} else {
187+
return message.content ?? " "
151188
}
152-
} catch {
153-
throw error
154189
}
155190
}
191+
192+
init(_ message: ChatMessage) {
193+
let role = Self.convertRole(message.role)
194+
let parts = [ModelContent.Part.text(Self.convertContent(of: message))]
195+
self = .init(role: role, parts: parts)
196+
}
156197
}
157198

Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryGoogleAIStrategy.swift

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,38 +29,4 @@ extension AutoManagedChatGPTMemory {
2929
}
3030
}
3131

32-
extension ModelContent {
33-
static func convertRole(_ role: ChatMessage.Role) -> String {
34-
switch role {
35-
case .user, .system, .function:
36-
return "user"
37-
case .assistant:
38-
return "model"
39-
}
40-
}
41-
42-
static func convertContent(of message: ChatMessage) -> String {
43-
switch message.role {
44-
case .system:
45-
return "System Prompt: \n\(message.content ?? " ")"
46-
case .user, .function:
47-
return message.content ?? " "
48-
case .assistant:
49-
if let functionCall = message.functionCall {
50-
return """
51-
call function: \(functionCall.name)
52-
arguments: \(functionCall.arguments)
53-
"""
54-
} else {
55-
return message.content ?? " "
56-
}
57-
}
58-
}
59-
60-
init(_ message: ChatMessage) {
61-
let role = Self.convertRole(message.role)
62-
let parts = [ModelContent.Part.text(Self.convertContent(of: message))]
63-
self = .init(role: role, parts: parts)
64-
}
65-
}
6632

Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import Foundation
22

3-
public struct ChatGPTPrompt {
3+
public struct ChatGPTPrompt: Equatable {
44
public var history: [ChatMessage]
55
public var references: [ChatMessage.Reference]
66
public var remainingTokenCount: Int?

0 commit comments

Comments
 (0)