@@ -3,6 +3,18 @@ import Logger
33import Preferences
44import TokenEncoder
55
6+ @globalActor
7+ public enum AutoManagedChatGPTMemoryActor : GlobalActor {
8+ public actor Actor { }
9+ public static let shared = Actor ( )
10+ }
11+
12+ protocol AutoManagedChatGPTMemoryStrategy {
13+ func countToken( _ message: ChatMessage ) async -> Int
14+ func countToken< F: ChatGPTFunction > ( _ function: F ) async -> Int
15+ func reformat( _ prompt: ChatGPTPrompt ) async -> ChatGPTPrompt
16+ }
17+
618/// A memory that automatically manages the history according to max tokens and max message count.
719public actor AutoManagedChatGPTMemory : ChatGPTMemory {
820 public struct ComposableMessages {
@@ -27,8 +39,6 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
2739 public var configuration : ChatGPTConfiguration
2840 public var functionProvider : ChatGPTFunctionProvider
2941
30- static let encoder : TokenEncoder = TiktokenCl100kBaseTokenEncoder ( )
31-
3242 var onHistoryChange : ( ) -> Void = { }
3343
3444 let composeHistory : HistoryComposer
@@ -60,7 +70,6 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
6070 self . configuration = configuration
6171 self . functionProvider = functionProvider
6272 self . composeHistory = composeHistory
63- _ = Self . encoder // force pre-initialize
6473 }
6574
6675 public func mutateHistory( _ update: ( inout [ ChatMessage ] ) -> Void ) {
@@ -87,20 +96,32 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
8796 }
8897
8998 public func generatePrompt( ) async -> ChatGPTPrompt {
90- return await generateSendingHistory ( )
99+ let strategy : AutoManagedChatGPTMemoryStrategy = switch configuration. model? . format {
100+ case . googleAI: GoogleAIStrategy ( configuration: configuration)
101+ default : OpenAIStrategy ( )
102+ }
103+ return await generateSendingHistory ( strategy: strategy)
91104 }
92105
106+ func setOnHistoryChangeBlock( _ onChange: @escaping ( ) -> Void ) {
107+ onHistoryChange = onChange
108+ }
109+ }
110+
111+ extension AutoManagedChatGPTMemory {
93112 /// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
94113 func generateSendingHistory(
95114 maxNumberOfMessages: Int = UserDefaults . shared. value ( for: \. chatGPTMaxMessageCount) ,
96- encoder : TokenEncoder = AutoManagedChatGPTMemory . encoder
115+ strategy : AutoManagedChatGPTMemoryStrategy
97116 ) async -> ChatGPTPrompt {
117+ // handle no function support models
118+
98119 let (
99120 systemPromptMessage,
100121 contextSystemPromptMessage,
101122 availableTokenCountForMessages,
102123 mandatoryUsage
103- ) = await generateMandatoryMessages ( encoder : encoder )
124+ ) = await generateMandatoryMessages ( strategy : strategy )
104125
105126 let (
106127 historyMessage,
@@ -110,7 +131,7 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
110131 ) = await generateMessageHistory (
111132 maxNumberOfMessages: maxNumberOfMessages - 1 , // for the new message
112133 maxTokenCount: availableTokenCountForMessages,
113- encoder : encoder
134+ strategy : strategy
114135 )
115136
116137 let (
@@ -120,7 +141,7 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
120141 retrievedContent
121142 ) = await generateRetrievedContentMessage (
122143 maxTokenCount: availableTokenCountForRetrievedContent,
123- encoder : encoder
144+ strategy : strategy
124145 )
125146
126147 let allMessages = composeHistory ( . init(
@@ -151,45 +172,42 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
151172 """ )
152173 #endif
153174
154- return . init( history: allMessages, references: retrievedContent)
155- }
175+ let reformattedPrompt = await strategy. reformat ( . init(
176+ history: allMessages,
177+ references: retrievedContent
178+ ) )
156179
157- func setOnHistoryChangeBlock( _ onChange: @escaping ( ) -> Void ) {
158- onHistoryChange = onChange
180+ return reformattedPrompt
159181 }
160- }
161182
162- extension AutoManagedChatGPTMemory {
163- func generateMandatoryMessages( encoder: TokenEncoder ) async -> (
183+ func generateMandatoryMessages( strategy: AutoManagedChatGPTMemoryStrategy ) async -> (
164184 systemPrompt: ChatMessage ,
165185 contextSystemPrompt: ChatMessage ,
166186 remainingTokenCount: Int ,
167187 usage: ( systemPrompt: Int , contextSystemPrompt: Int , functions: Int )
168188 ) {
169- var smallestSystemPromptMessage = ChatMessage ( role: . system, content: systemPrompt)
170- var contextSystemPromptMessage = ChatMessage ( role: . user, content: contextSystemPrompt)
171- let smallestSystemMessageTokenCount = await encoder. countToken ( & smallestSystemPromptMessage)
189+ let smallestSystemPromptMessage = ChatMessage (
190+ role: . system,
191+ content: systemPrompt
192+ )
193+ let contextSystemPromptMessage = ChatMessage (
194+ role: . user,
195+ content: contextSystemPrompt
196+ )
197+ let smallestSystemMessageTokenCount = await strategy
198+ . countToken ( smallestSystemPromptMessage)
172199 let contextSystemPromptTokenCount = !contextSystemPrompt. isEmpty
173- ? ( await encoder . countToken ( & contextSystemPromptMessage) )
200+ ? ( await strategy . countToken ( contextSystemPromptMessage) )
174201 : 0
175202
176203 let functionTokenCount = await {
177204 var totalTokenCount = 0
178- for function in functionProvider. functions {
179- async let nameTokenCount = encoder. countToken ( text: function. name)
180- async let descriptionTokenCount = encoder. countToken ( text: function. description)
181- async let schemaTokenCount = {
182- guard let data = try ? JSONEncoder ( ) . encode ( function. argumentSchema) ,
183- let string = String ( data: data, encoding: . utf8)
184- else { return 0 }
185- return await encoder. countToken ( text: string)
186- } ( )
187-
188- await totalTokenCount += nameTokenCount + descriptionTokenCount + schemaTokenCount
205+ for function in self . functionProvider. functions {
206+ totalTokenCount += await strategy. countToken ( function)
189207 }
190208 return totalTokenCount
191209 } ( )
192-
210+
193211 let mandatoryContentTokensCount = smallestSystemMessageTokenCount
194212 + contextSystemPromptTokenCount
195213 + functionTokenCount
@@ -217,7 +235,7 @@ extension AutoManagedChatGPTMemory {
217235 func generateMessageHistory(
218236 maxNumberOfMessages: Int ,
219237 maxTokenCount: Int ,
220- encoder : TokenEncoder
238+ strategy : AutoManagedChatGPTMemoryStrategy
221239 ) async -> (
222240 history: [ ChatMessage ] ,
223241 newMessage: ChatMessage ,
@@ -231,8 +249,7 @@ extension AutoManagedChatGPTMemory {
231249 for (index, message) in history. enumerated ( ) . reversed ( ) {
232250 if maxNumberOfMessages > 0 , allMessages. count >= maxNumberOfMessages { break }
233251 if message. isEmpty { continue }
234- let tokensCount = await encoder. countToken ( message)
235- history [ index] . tokensCount = tokensCount
252+ let tokensCount = await strategy. countToken ( message)
236253 if tokensCount + messageTokenCount > maxTokenCount { break }
237254 messageTokenCount += tokensCount
238255 if index == history. endIndex - 1 {
@@ -252,7 +269,7 @@ extension AutoManagedChatGPTMemory {
252269
253270 func generateRetrievedContentMessage(
254271 maxTokenCount: Int ,
255- encoder : TokenEncoder
272+ strategy : AutoManagedChatGPTMemoryStrategy
256273 ) async -> (
257274 retrievedContent: ChatMessage ,
258275 remainingTokenCount: Int ,
@@ -261,68 +278,99 @@ extension AutoManagedChatGPTMemory {
261278 ) {
262279 /// the available tokens count for retrieved content
263280 let thresholdMaxTokenCount = min ( maxTokenCount, configuration. maxTokens / 2 )
281+ /// A separator that costs only 1 token
282+ let separator = String ( repeating: " = " , count: 32 )
283+ let retrievedContent = retrievedContent. filter { !$0. content. isEmpty }
264284
265- var retrievedContentTokenCount = 0
266- let separator = String ( repeating: " = " , count: 32 ) // only 1 token
267- var message = " "
268- var references = [ ChatMessage . Reference] ( )
269-
270- func appendToMessage( _ text: String ) async -> Bool {
271- let tokensCount = await encoder. countToken ( text: text)
272- if tokensCount + retrievedContentTokenCount > thresholdMaxTokenCount { return false }
273- retrievedContentTokenCount += tokensCount
274- message += text
275- return true
285+ func buildMessage( retrievedContent: [ ChatMessage . Reference ] ) -> ChatMessage {
286+ var text = " "
287+ for (index, content) in retrievedContent. enumerated ( ) {
288+ if index == 0 {
289+ text += """
290+ Here are the information you know about the system and the project, \
291+ separated by \( separator)
292+
293+
294+ """
295+ } else {
296+ text += " \n \( separator) \n "
297+ }
298+
299+ text += content. content
300+ }
301+
302+ return . init( role: . user, content: text)
276303 }
277304
278- for (index, content) in retrievedContent. filter ( { !$0. content. isEmpty } ) . enumerated ( ) {
279- if index == 0 {
280- if !( await appendToMessage ( """
281- Here are the information you know about the system and the project, \
282- separated by \( separator)
305+ func buildMessageThatFits( ) async
306+ -> ( message: ChatMessage , references: [ ChatMessage . Reference ] , tokenCount: Int )
307+ {
308+ var right = retrievedContent. count
309+ var left = 0
310+ var retrievedContent = retrievedContent
311+ var tokenCount : Int ?
312+ var proposedMessage = buildMessage ( retrievedContent: [ ] )
313+
314+ func checkValid( proposedMessage: ChatMessage ) async
315+ -> ( isValid: Bool , tokenCount: Int ? )
316+ {
317+ // if the size is way below the threshold
318+ let characterCount = proposedMessage. content? . count ?? 0
319+
320+ if characterCount <= thresholdMaxTokenCount {
321+ return ( true , nil ) // guessing token count.
322+ }
323+
324+ let tokensCount = await strategy. countToken ( proposedMessage)
325+ if tokensCount <= thresholdMaxTokenCount {
326+ return ( true , tokenCount)
327+ }
328+ return ( false , tokenCount)
329+ }
283330
331+ // check if all retrieved content included
332+ let maxMessage = buildMessage ( retrievedContent: retrievedContent)
333+ let ( isValid, maxTokenCount) = await checkValid ( proposedMessage: maxMessage)
334+ if isValid {
335+ let tokenCount = if let maxTokenCount { maxTokenCount }
336+ else { await strategy. countToken ( maxMessage) }
337+ return ( maxMessage, retrievedContent, tokenCount)
338+ }
284339
285- """ ) ) { break }
286- } else {
287- if !( await appendToMessage ( " \n \( separator) \n " ) ) { break }
340+ // binary search to reduce countToken calls
341+ while left <= right {
342+ let count = ( right + left) / 2
343+ let _retrievedContent = Array ( retrievedContent. prefix ( count) )
344+ let _proposedMessage = buildMessage ( retrievedContent: retrievedContent)
345+ let ( isValid, _tokenCount) = await checkValid ( proposedMessage: _proposedMessage)
346+ if isValid {
347+ proposedMessage = _proposedMessage
348+ retrievedContent = _retrievedContent
349+ tokenCount = _tokenCount
350+ left = count + 1
351+ } else {
352+ right = count - 1
353+ }
288354 }
289355
290- if !( await appendToMessage ( content. content) ) { break }
291- references. append ( content)
356+ let finalCount = if let tokenCount {
357+ tokenCount
358+ } else if proposedMessage. content? . isEmpty ?? true {
359+ 0
360+ } else {
361+ await strategy. countToken ( proposedMessage)
362+ }
363+ return ( proposedMessage, retrievedContent, finalCount)
292364 }
293365
366+ let ( message, references, tokensCount) = await buildMessageThatFits ( )
367+
294368 return (
295- . init ( role : . user , content : message) ,
296- maxTokenCount - retrievedContentTokenCount ,
297- retrievedContentTokenCount ,
369+ message,
370+ maxTokenCount - tokensCount ,
371+ tokensCount ,
298372 references
299373 )
300374 }
301375}
302376
303- public extension TokenEncoder {
304- /// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
305- func countToken( _ message: ChatMessage ) async -> Int {
306- var total = 3
307- if let content = message. content {
308- total += await encode ( text: content) . count
309- }
310- if let name = message. name {
311- total += await encode ( text: name) . count
312- total += 1
313- }
314- if let functionCall = message. functionCall {
315- total += await encode ( text: functionCall. name) . count
316- total += await encode ( text: functionCall. arguments) . count
317- }
318- return total
319- }
320-
321- func countToken( _ message: inout ChatMessage ) async -> Int {
322- if let count = message. tokensCount { return count }
323- let count = await countToken ( message)
324- message. tokensCount = count
325- return count
326- }
327- }
328-
0 commit comments