Skip to content

Commit 69442b5

Browse files
committed
Move prompt reformat to ChatGPTService
So that we don't have to implement it for each memory
1 parent 861e462 commit 69442b5

9 files changed

Lines changed: 117 additions & 105 deletions

Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,86 @@ import Foundation
33
import GoogleGenerativeAI
44
import Preferences
55

6+
extension ChatGPTPrompt {
7+
var googleAICompatible: ChatGPTPrompt {
8+
var history = self.history
9+
var reformattedHistory = [ChatMessage]()
10+
11+
// We don't want to combine the new user message with others.
12+
let newUserMessage: ChatMessage? = if history.last?.role == .user {
13+
history.removeLast()
14+
} else {
15+
nil
16+
}
17+
18+
for message in history {
19+
let lastIndex = reformattedHistory.endIndex - 1
20+
guard lastIndex >= 0 else { // first message
21+
if message.role == .system {
22+
reformattedHistory.append(.init(
23+
role: .user,
24+
content: ModelContent.convertContent(of: message)
25+
))
26+
reformattedHistory.append(.init(
27+
role: .assistant,
28+
content: "Got it. Let's start our conversation."
29+
))
30+
continue
31+
}
32+
33+
reformattedHistory.append(message)
34+
continue
35+
}
36+
37+
let lastMessage = reformattedHistory[lastIndex]
38+
39+
if ModelContent.convertRole(lastMessage.role) == ModelContent
40+
.convertRole(message.role)
41+
{
42+
let newMessage = ChatMessage(
43+
role: message.role == .assistant ? .assistant : .user,
44+
content: """
45+
\(ModelContent.convertContent(of: lastMessage))
46+
47+
======
48+
49+
\(ModelContent.convertContent(of: message))
50+
"""
51+
)
52+
reformattedHistory[lastIndex] = newMessage
53+
} else {
54+
reformattedHistory.append(message)
55+
}
56+
}
57+
58+
if let newUserMessage {
59+
if let last = reformattedHistory.last,
60+
ModelContent.convertRole(last.role) == ModelContent
61+
.convertRole(newUserMessage.role)
62+
{
63+
// Add dummy message
64+
let dummyMessage = ChatMessage(
65+
role: .assistant,
66+
content: "OK"
67+
)
68+
reformattedHistory.append(dummyMessage)
69+
}
70+
reformattedHistory.append(newUserMessage)
71+
}
72+
73+
return .init(
74+
history: reformattedHistory,
75+
references: references,
76+
remainingTokenCount: remainingTokenCount
77+
)
78+
}
79+
}
80+
681
struct GoogleCompletionAPI: CompletionAPI {
782
let apiKey: String
883
let model: ChatModel
984
var requestBody: CompletionRequestBody
85+
let prompt: ChatGPTPrompt
1086

1187
func callAsFunction() async throws -> CompletionResponseBody {
1288
let aiModel = GenerativeModel(
@@ -17,22 +93,22 @@ struct GoogleCompletionAPI: CompletionAPI {
1793
topP: requestBody.top_p.map(Float.init)
1894
))
1995
)
20-
let history = requestBody.messages.map { message in
96+
let history = prompt.googleAICompatible.history.map { message in
2197
ModelContent(
2298
ChatMessage(
2399
role: message.role,
24100
content: message.content,
25101
name: message.name,
26-
functionCall: message.function_call.map {
27-
.init(name: $0.name, arguments: $0.arguments ?? "")
102+
functionCall: message.functionCall.map {
103+
.init(name: $0.name, arguments: $0.arguments)
28104
}
29105
)
30106
)
31107
}
32108

33109
do {
34110
let response = try await aiModel.generateContent(history)
35-
111+
36112
return .init(
37113
object: "chat.completion",
38114
model: model.info.modelName,
@@ -64,7 +140,7 @@ struct GoogleCompletionAPI: CompletionAPI {
64140
return "Internal Error: \(s)"
65141
}
66142
}
67-
143+
68144
switch error {
69145
case let .internalError(underlying):
70146
throw ErrorWrapper(error: underlying)

Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ struct GoogleCompletionStreamAPI: CompletionStreamAPI {
77
let apiKey: String
88
let model: ChatModel
99
var requestBody: CompletionRequestBody
10+
let prompt: ChatGPTPrompt
1011

1112
func callAsFunction() async throws -> AsyncThrowingStream<CompletionStreamDataChunk, Error> {
1213
let aiModel = GenerativeModel(
@@ -17,14 +18,14 @@ struct GoogleCompletionStreamAPI: CompletionStreamAPI {
1718
topP: requestBody.top_p.map(Float.init)
1819
))
1920
)
20-
let history = requestBody.messages.map { message in
21+
let history = prompt.googleAICompatible.history.map { message in
2122
ModelContent(
2223
ChatMessage(
2324
role: message.role,
2425
content: message.content,
2526
name: message.name,
26-
functionCall: message.function_call.map {
27-
.init(name: $0.name, arguments: $0.arguments ?? "")
27+
functionCall: message.functionCall.map {
28+
.init(name: $0.name, arguments: $0.arguments)
2829
}
2930
)
3031
)

Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import AIModel
22
import Foundation
33
import Preferences
44

5-
typealias CompletionAPIBuilder = (String, ChatModel, URL, CompletionRequestBody)
5+
typealias CompletionAPIBuilder = (String, ChatModel, URL, CompletionRequestBody, ChatGPTPrompt)
66
-> CompletionAPI
77

88
protocol CompletionAPI {

Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@ import AsyncAlgorithms
33
import Foundation
44
import Preferences
55

6-
typealias CompletionStreamAPIBuilder = (String, ChatModel, URL, CompletionRequestBody)
7-
-> any CompletionStreamAPI
6+
typealias CompletionStreamAPIBuilder = (
7+
String,
8+
ChatModel,
9+
URL,
10+
CompletionRequestBody,
11+
ChatGPTPrompt
12+
) -> any CompletionStreamAPI
813

914
protocol CompletionStreamAPI {
1015
func callAsFunction() async throws -> AsyncThrowingStream<CompletionStreamDataChunk, Error>

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,15 @@ public class ChatGPTService: ChatGPTServiceType {
7070

7171
var runningTask: Task<Void, Never>?
7272
var buildCompletionStreamAPI: CompletionStreamAPIBuilder = {
73-
apiKey, model, endpoint, requestBody in
73+
apiKey, model, endpoint, requestBody, prompt in
7474
switch model.format {
7575
case .googleAI:
76-
return GoogleCompletionStreamAPI(apiKey: apiKey, model: model, requestBody: requestBody)
76+
return GoogleCompletionStreamAPI(
77+
apiKey: apiKey,
78+
model: model,
79+
requestBody: requestBody,
80+
prompt: prompt
81+
)
7782
case .openAI, .openAICompatible, .azureOpenAI:
7883
return OpenAICompletionStreamAPI(
7984
apiKey: apiKey,
@@ -85,10 +90,15 @@ public class ChatGPTService: ChatGPTServiceType {
8590
}
8691

8792
var buildCompletionAPI: CompletionAPIBuilder = {
88-
apiKey, model, endpoint, requestBody in
93+
apiKey, model, endpoint, requestBody, prompt in
8994
switch model.format {
9095
case .googleAI:
91-
return GoogleCompletionAPI(apiKey: apiKey, model: model, requestBody: requestBody)
96+
return GoogleCompletionAPI(
97+
apiKey: apiKey,
98+
model: model,
99+
requestBody: requestBody,
100+
prompt: prompt
101+
)
92102
case .openAI, .openAICompatible, .azureOpenAI:
93103
return OpenAICompletionAPI(
94104
apiKey: apiKey,
@@ -305,7 +315,8 @@ extension ChatGPTService {
305315
configuration.apiKey,
306316
model,
307317
url,
308-
requestBody
318+
requestBody,
319+
prompt
309320
)
310321

311322
#if DEBUG
@@ -432,7 +443,8 @@ extension ChatGPTService {
432443
configuration.apiKey,
433444
model,
434445
url,
435-
requestBody
446+
requestBody,
447+
prompt
436448
)
437449

438450
#if DEBUG

Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemory.swift

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ public enum AutoManagedChatGPTMemoryActor: GlobalActor {
1212
protocol AutoManagedChatGPTMemoryStrategy {
1313
func countToken(_ message: ChatMessage) async -> Int
1414
func countToken<F: ChatGPTFunction>(_ function: F) async -> Int
15-
func reformat(_ prompt: ChatGPTPrompt) async -> ChatGPTPrompt
1615
}
1716

1817
/// A memory that automatically manages the history according to max tokens and max message count.
@@ -172,12 +171,10 @@ extension AutoManagedChatGPTMemory {
172171
""")
173172
#endif
174173

175-
let reformattedPrompt = await strategy.reformat(.init(
174+
return .init(
176175
history: allMessages,
177176
references: retrievedContent
178-
))
179-
180-
return reformattedPrompt
177+
)
181178
}
182179

183180
func generateMandatoryMessages(strategy: AutoManagedChatGPTMemoryStrategy) async -> (

Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryGoogleAIStrategy.swift

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -26,81 +26,6 @@ extension AutoManagedChatGPTMemory {
2626
// function is not supported.
2727
return 0
2828
}
29-
30-
/// Gemini only supports turn-based conversation. A user message must be followed
31-
/// by an model message.
32-
func reformat(_ prompt: ChatGPTPrompt) async -> ChatGPTPrompt {
33-
var history = prompt.history
34-
var reformattedHistory = [ChatMessage]()
35-
36-
// We don't want to combine the new user message with others.
37-
let newUserMessage: ChatMessage? = if history.last?.role == .user {
38-
history.removeLast()
39-
} else {
40-
nil
41-
}
42-
43-
for message in history {
44-
let lastIndex = reformattedHistory.endIndex - 1
45-
guard lastIndex >= 0 else { // first message
46-
if message.role == .system {
47-
reformattedHistory.append(.init(
48-
role: .user,
49-
content: ModelContent.convertContent(of: message)
50-
))
51-
reformattedHistory.append(.init(
52-
role: .assistant,
53-
content: "Got it. Let's start our conversation."
54-
))
55-
continue
56-
}
57-
58-
reformattedHistory.append(message)
59-
continue
60-
}
61-
62-
let lastMessage = reformattedHistory[lastIndex]
63-
64-
if ModelContent.convertRole(lastMessage.role) == ModelContent
65-
.convertRole(message.role)
66-
{
67-
let newMessage = ChatMessage(
68-
role: message.role == .assistant ? .assistant : .user,
69-
content: """
70-
\(ModelContent.convertContent(of: lastMessage))
71-
72-
======
73-
74-
\(ModelContent.convertContent(of: message))
75-
"""
76-
)
77-
reformattedHistory[lastIndex] = newMessage
78-
} else {
79-
reformattedHistory.append(message)
80-
}
81-
}
82-
83-
if let newUserMessage {
84-
if let last = reformattedHistory.last,
85-
ModelContent.convertRole(last.role) == ModelContent
86-
.convertRole(newUserMessage.role)
87-
{
88-
// Add dummy message
89-
let dummyMessage = ChatMessage(
90-
role: .assistant,
91-
content: "OK"
92-
)
93-
reformattedHistory.append(dummyMessage)
94-
}
95-
reformattedHistory.append(newUserMessage)
96-
}
97-
98-
return .init(
99-
history: reformattedHistory,
100-
references: prompt.references,
101-
remainingTokenCount: prompt.remainingTokenCount
102-
)
103-
}
10429
}
10530
}
10631

Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ extension AutoManagedChatGPTMemory {
2222

2323
return await (nameTokenCount + descriptionTokenCount + schemaTokenCount)
2424
}
25-
26-
func reformat(_ prompt: ChatGPTPrompt) async -> ChatGPTPrompt {
27-
prompt
28-
}
2925
}
3026
}
3127

Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ final class ChatGPTStreamTests: XCTestCase {
1515
functionProvider: functionProvider
1616
)
1717
var requestBody: CompletionRequestBody?
18-
service.changeBuildCompletionStreamAPI { _, _, _, _requestBody in
18+
service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in
1919
requestBody = _requestBody
2020
return MockCompletionStreamAPI_Message()
2121
}
@@ -76,7 +76,7 @@ final class ChatGPTStreamTests: XCTestCase {
7676
functionProvider: functionProvider
7777
)
7878
var requestBody: CompletionRequestBody?
79-
service.changeBuildCompletionStreamAPI { _, _, _, _requestBody in
79+
service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in
8080
requestBody = _requestBody
8181
if _requestBody.messages.count <= 2 {
8282
return MockCompletionStreamAPI_Function()
@@ -160,7 +160,7 @@ final class ChatGPTStreamTests: XCTestCase {
160160
)
161161
var requestBody: CompletionRequestBody?
162162

163-
service.changeBuildCompletionStreamAPI { _, _, _, _requestBody in
163+
service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in
164164
requestBody = _requestBody
165165
if _requestBody.messages.count <= 4 {
166166
return MockCompletionStreamAPI_Function()
@@ -266,7 +266,7 @@ final class ChatGPTStreamTests: XCTestCase {
266266
functionProvider: functionProvider
267267
)
268268
var requestBody: CompletionRequestBody?
269-
service.changeBuildCompletionStreamAPI { _, _, _, _requestBody in
269+
service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in
270270
requestBody = _requestBody
271271
if _requestBody.messages.count <= 2 {
272272
return MockCompletionStreamAPI_Function()

0 commit comments

Comments
 (0)