Skip to content

Commit 7604e87

Browse files
committed
Update AutoManagedChatGPTMemory to support strategies
1 parent f1eeffd commit 7604e87

3 files changed

Lines changed: 315 additions & 85 deletions

File tree

Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemory.swift

Lines changed: 133 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@ import Logger
33
import Preferences
44
import TokenEncoder
55

6+
@globalActor
7+
public enum AutoManagedChatGPTMemoryActor: GlobalActor {
8+
public actor Actor {}
9+
public static let shared = Actor()
10+
}
11+
12+
protocol AutoManagedChatGPTMemoryStrategy {
13+
func countToken(_ message: ChatMessage) async -> Int
14+
func countToken<F: ChatGPTFunction>(_ function: F) async -> Int
15+
func reformat(_ prompt: ChatGPTPrompt) async -> ChatGPTPrompt
16+
}
17+
618
/// A memory that automatically manages the history according to max tokens and max message count.
719
public actor AutoManagedChatGPTMemory: ChatGPTMemory {
820
public struct ComposableMessages {
@@ -27,8 +39,6 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
2739
public var configuration: ChatGPTConfiguration
2840
public var functionProvider: ChatGPTFunctionProvider
2941

30-
static let encoder: TokenEncoder = TiktokenCl100kBaseTokenEncoder()
31-
3242
var onHistoryChange: () -> Void = {}
3343

3444
let composeHistory: HistoryComposer
@@ -60,7 +70,6 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
6070
self.configuration = configuration
6171
self.functionProvider = functionProvider
6272
self.composeHistory = composeHistory
63-
_ = Self.encoder // force pre-initialize
6473
}
6574

6675
public func mutateHistory(_ update: (inout [ChatMessage]) -> Void) {
@@ -87,20 +96,32 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
8796
}
8897

8998
public func generatePrompt() async -> ChatGPTPrompt {
90-
return await generateSendingHistory()
99+
let strategy: AutoManagedChatGPTMemoryStrategy = switch configuration.model?.format {
100+
case .googleAI: GoogleAIStrategy(configuration: configuration)
101+
default: OpenAIStrategy()
102+
}
103+
return await generateSendingHistory(strategy: strategy)
91104
}
92105

106+
func setOnHistoryChangeBlock(_ onChange: @escaping () -> Void) {
107+
onHistoryChange = onChange
108+
}
109+
}
110+
111+
extension AutoManagedChatGPTMemory {
93112
/// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
94113
func generateSendingHistory(
95114
maxNumberOfMessages: Int = UserDefaults.shared.value(for: \.chatGPTMaxMessageCount),
96-
encoder: TokenEncoder = AutoManagedChatGPTMemory.encoder
115+
strategy: AutoManagedChatGPTMemoryStrategy
97116
) async -> ChatGPTPrompt {
117+
// handle no function support models
118+
98119
let (
99120
systemPromptMessage,
100121
contextSystemPromptMessage,
101122
availableTokenCountForMessages,
102123
mandatoryUsage
103-
) = await generateMandatoryMessages(encoder: encoder)
124+
) = await generateMandatoryMessages(strategy: strategy)
104125

105126
let (
106127
historyMessage,
@@ -110,7 +131,7 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
110131
) = await generateMessageHistory(
111132
maxNumberOfMessages: maxNumberOfMessages - 1, // for the new message
112133
maxTokenCount: availableTokenCountForMessages,
113-
encoder: encoder
134+
strategy: strategy
114135
)
115136

116137
let (
@@ -120,7 +141,7 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
120141
retrievedContent
121142
) = await generateRetrievedContentMessage(
122143
maxTokenCount: availableTokenCountForRetrievedContent,
123-
encoder: encoder
144+
strategy: strategy
124145
)
125146

126147
let allMessages = composeHistory(.init(
@@ -151,45 +172,42 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
151172
""")
152173
#endif
153174

154-
return .init(history: allMessages, references: retrievedContent)
155-
}
175+
let reformattedPrompt = await strategy.reformat(.init(
176+
history: allMessages,
177+
references: retrievedContent
178+
))
156179

157-
func setOnHistoryChangeBlock(_ onChange: @escaping () -> Void) {
158-
onHistoryChange = onChange
180+
return reformattedPrompt
159181
}
160-
}
161182

162-
extension AutoManagedChatGPTMemory {
163-
func generateMandatoryMessages(encoder: TokenEncoder) async -> (
183+
func generateMandatoryMessages(strategy: AutoManagedChatGPTMemoryStrategy) async -> (
164184
systemPrompt: ChatMessage,
165185
contextSystemPrompt: ChatMessage,
166186
remainingTokenCount: Int,
167187
usage: (systemPrompt: Int, contextSystemPrompt: Int, functions: Int)
168188
) {
169-
var smallestSystemPromptMessage = ChatMessage(role: .system, content: systemPrompt)
170-
var contextSystemPromptMessage = ChatMessage(role: .user, content: contextSystemPrompt)
171-
let smallestSystemMessageTokenCount = await encoder.countToken(&smallestSystemPromptMessage)
189+
let smallestSystemPromptMessage = ChatMessage(
190+
role: .system,
191+
content: systemPrompt
192+
)
193+
let contextSystemPromptMessage = ChatMessage(
194+
role: .user,
195+
content: contextSystemPrompt
196+
)
197+
let smallestSystemMessageTokenCount = await strategy
198+
.countToken(smallestSystemPromptMessage)
172199
let contextSystemPromptTokenCount = !contextSystemPrompt.isEmpty
173-
? (await encoder.countToken(&contextSystemPromptMessage))
200+
? (await strategy.countToken(contextSystemPromptMessage))
174201
: 0
175202

176203
let functionTokenCount = await {
177204
var totalTokenCount = 0
178-
for function in functionProvider.functions {
179-
async let nameTokenCount = encoder.countToken(text: function.name)
180-
async let descriptionTokenCount = encoder.countToken(text: function.description)
181-
async let schemaTokenCount = {
182-
guard let data = try? JSONEncoder().encode(function.argumentSchema),
183-
let string = String(data: data, encoding: .utf8)
184-
else { return 0 }
185-
return await encoder.countToken(text: string)
186-
}()
187-
188-
await totalTokenCount += nameTokenCount + descriptionTokenCount + schemaTokenCount
205+
for function in self.functionProvider.functions {
206+
totalTokenCount += await strategy.countToken(function)
189207
}
190208
return totalTokenCount
191209
}()
192-
210+
193211
let mandatoryContentTokensCount = smallestSystemMessageTokenCount
194212
+ contextSystemPromptTokenCount
195213
+ functionTokenCount
@@ -217,7 +235,7 @@ extension AutoManagedChatGPTMemory {
217235
func generateMessageHistory(
218236
maxNumberOfMessages: Int,
219237
maxTokenCount: Int,
220-
encoder: TokenEncoder
238+
strategy: AutoManagedChatGPTMemoryStrategy
221239
) async -> (
222240
history: [ChatMessage],
223241
newMessage: ChatMessage,
@@ -231,8 +249,7 @@ extension AutoManagedChatGPTMemory {
231249
for (index, message) in history.enumerated().reversed() {
232250
if maxNumberOfMessages > 0, allMessages.count >= maxNumberOfMessages { break }
233251
if message.isEmpty { continue }
234-
let tokensCount = await encoder.countToken(message)
235-
history[index].tokensCount = tokensCount
252+
let tokensCount = await strategy.countToken(message)
236253
if tokensCount + messageTokenCount > maxTokenCount { break }
237254
messageTokenCount += tokensCount
238255
if index == history.endIndex - 1 {
@@ -252,7 +269,7 @@ extension AutoManagedChatGPTMemory {
252269

253270
func generateRetrievedContentMessage(
254271
maxTokenCount: Int,
255-
encoder: TokenEncoder
272+
strategy: AutoManagedChatGPTMemoryStrategy
256273
) async -> (
257274
retrievedContent: ChatMessage,
258275
remainingTokenCount: Int,
@@ -261,68 +278,99 @@ extension AutoManagedChatGPTMemory {
261278
) {
262279
/// the available tokens count for retrieved content
263280
let thresholdMaxTokenCount = min(maxTokenCount, configuration.maxTokens / 2)
281+
/// A separator that costs only 1 token
282+
let separator = String(repeating: "=", count: 32)
283+
let retrievedContent = retrievedContent.filter { !$0.content.isEmpty }
264284

265-
var retrievedContentTokenCount = 0
266-
let separator = String(repeating: "=", count: 32) // only 1 token
267-
var message = ""
268-
var references = [ChatMessage.Reference]()
269-
270-
func appendToMessage(_ text: String) async -> Bool {
271-
let tokensCount = await encoder.countToken(text: text)
272-
if tokensCount + retrievedContentTokenCount > thresholdMaxTokenCount { return false }
273-
retrievedContentTokenCount += tokensCount
274-
message += text
275-
return true
285+
func buildMessage(retrievedContent: [ChatMessage.Reference]) -> ChatMessage {
286+
var text = ""
287+
for (index, content) in retrievedContent.enumerated() {
288+
if index == 0 {
289+
text += """
290+
Here are the information you know about the system and the project, \
291+
separated by \(separator)
292+
293+
294+
"""
295+
} else {
296+
text += "\n\(separator)\n"
297+
}
298+
299+
text += content.content
300+
}
301+
302+
return .init(role: .user, content: text)
276303
}
277304

278-
for (index, content) in retrievedContent.filter({ !$0.content.isEmpty }).enumerated() {
279-
if index == 0 {
280-
if !(await appendToMessage("""
281-
Here are the information you know about the system and the project, \
282-
separated by \(separator)
305+
func buildMessageThatFits() async
306+
-> (message: ChatMessage, references: [ChatMessage.Reference], tokenCount: Int)
307+
{
308+
var right = retrievedContent.count
309+
var left = 0
310+
var retrievedContent = retrievedContent
311+
var tokenCount: Int?
312+
var proposedMessage = buildMessage(retrievedContent: [])
313+
314+
func checkValid(proposedMessage: ChatMessage) async
315+
-> (isValid: Bool, tokenCount: Int?)
316+
{
317+
// if the size is way below the threshold
318+
let characterCount = proposedMessage.content?.count ?? 0
319+
320+
if characterCount <= thresholdMaxTokenCount {
321+
return (true, nil) // guessing token count.
322+
}
323+
324+
let tokensCount = await strategy.countToken(proposedMessage)
325+
if tokensCount <= thresholdMaxTokenCount {
326+
return (true, tokenCount)
327+
}
328+
return (false, tokenCount)
329+
}
283330

331+
// check if all retrieved content included
332+
let maxMessage = buildMessage(retrievedContent: retrievedContent)
333+
let (isValid, maxTokenCount) = await checkValid(proposedMessage: maxMessage)
334+
if isValid {
335+
let tokenCount = if let maxTokenCount { maxTokenCount }
336+
else { await strategy.countToken(maxMessage) }
337+
return (maxMessage, retrievedContent, tokenCount)
338+
}
284339

285-
""")) { break }
286-
} else {
287-
if !(await appendToMessage("\n\(separator)\n")) { break }
340+
// binary search to reduce countToken calls
341+
while left <= right {
342+
let count = (right + left) / 2
343+
let _retrievedContent = Array(retrievedContent.prefix(count))
344+
let _proposedMessage = buildMessage(retrievedContent: retrievedContent)
345+
let (isValid, _tokenCount) = await checkValid(proposedMessage: _proposedMessage)
346+
if isValid {
347+
proposedMessage = _proposedMessage
348+
retrievedContent = _retrievedContent
349+
tokenCount = _tokenCount
350+
left = count + 1
351+
} else {
352+
right = count - 1
353+
}
288354
}
289355

290-
if !(await appendToMessage(content.content)) { break }
291-
references.append(content)
356+
let finalCount = if let tokenCount {
357+
tokenCount
358+
} else if proposedMessage.content?.isEmpty ?? true {
359+
0
360+
} else {
361+
await strategy.countToken(proposedMessage)
362+
}
363+
return (proposedMessage, retrievedContent, finalCount)
292364
}
293365

366+
let (message, references, tokensCount) = await buildMessageThatFits()
367+
294368
return (
295-
.init(role: .user, content: message),
296-
maxTokenCount - retrievedContentTokenCount,
297-
retrievedContentTokenCount,
369+
message,
370+
maxTokenCount - tokensCount,
371+
tokensCount,
298372
references
299373
)
300374
}
301375
}
302376

303-
public extension TokenEncoder {
304-
/// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
305-
func countToken(_ message: ChatMessage) async -> Int {
306-
var total = 3
307-
if let content = message.content {
308-
total += await encode(text: content).count
309-
}
310-
if let name = message.name {
311-
total += await encode(text: name).count
312-
total += 1
313-
}
314-
if let functionCall = message.functionCall {
315-
total += await encode(text: functionCall.name).count
316-
total += await encode(text: functionCall.arguments).count
317-
}
318-
return total
319-
}
320-
321-
func countToken(_ message: inout ChatMessage) async -> Int {
322-
if let count = message.tokensCount { return count }
323-
let count = await countToken(message)
324-
message.tokensCount = count
325-
return count
326-
}
327-
}
328-

0 commit comments

Comments
 (0)