Skip to content

Commit 5543c78

Browse files
committed
Add TemplateChatGPTMemory
1 parent caaf4a3 commit 5543c78

1 file changed

Lines changed: 256 additions & 0 deletions

File tree

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
import ChatBasic
2+
import Foundation
3+
import Logger
4+
import Preferences
5+
import TokenEncoder
6+
7+
/// A memory that automatically manages the history according to max tokens and the template rules.
8+
public actor TemplateChatGPTMemory: ChatGPTMemory {
9+
public private(set) var memoryTemplate: MemoryTemplate
10+
public var history: [ChatMessage] { memoryTemplate.resolved() }
11+
public var configuration: ChatGPTConfiguration
12+
public var functionProvider: ChatGPTFunctionProvider
13+
14+
public init(
15+
memoryTemplate: MemoryTemplate,
16+
configuration: ChatGPTConfiguration,
17+
functionProvider: ChatGPTFunctionProvider
18+
) {
19+
self.memoryTemplate = memoryTemplate
20+
self.configuration = configuration
21+
self.functionProvider = functionProvider
22+
}
23+
24+
public func mutateHistory(_ update: (inout [ChatMessage]) -> Void) async {
25+
update(&memoryTemplate.followUpMessages)
26+
}
27+
28+
public func generatePrompt() async -> ChatGPTPrompt {
29+
let strategy: AutoManagedChatGPTMemoryStrategy = switch configuration.model?.format {
30+
case .googleAI: AutoManagedChatGPTMemory.GoogleAIStrategy(configuration: configuration)
31+
default: AutoManagedChatGPTMemory.OpenAIStrategy()
32+
}
33+
34+
var memoryTemplate = self.memoryTemplate
35+
func checkTokenCount() async -> Bool {
36+
let history = self.history
37+
var tokenCount = 0
38+
for message in history {
39+
tokenCount += await strategy.countToken(message)
40+
}
41+
for function in functionProvider.functions {
42+
tokenCount += await strategy.countToken(function)
43+
}
44+
return tokenCount <= configuration.maxTokens - configuration.minimumReplyTokens
45+
}
46+
47+
while !(await checkTokenCount()) {
48+
do {
49+
try memoryTemplate.truncate()
50+
} catch {
51+
Logger.service.error("Failed to truncate prompt template: \(error)")
52+
break
53+
}
54+
}
55+
56+
return ChatGPTPrompt(history: memoryTemplate.resolved())
57+
}
58+
}
59+
60+
public struct MemoryTemplate {
61+
public struct Message {
62+
public struct DynamicContent: ExpressibleByStringLiteral {
63+
public enum Content: ExpressibleByStringLiteral {
64+
case text(String)
65+
case list([String], formatter: ([String]) -> String)
66+
67+
public init(stringLiteral value: String) {
68+
self = .text(value)
69+
}
70+
}
71+
72+
public var content: Content
73+
public var truncatePriority: Int = 0
74+
public var isEmpty: Bool {
75+
switch content {
76+
case let .text(text):
77+
return text.isEmpty
78+
case let .list(list, _):
79+
return list.isEmpty
80+
}
81+
}
82+
83+
public init(stringLiteral value: String) {
84+
content = .text(value)
85+
}
86+
87+
public init(content: Content, truncatePriority: Int = 0) {
88+
self.content = content
89+
self.truncatePriority = truncatePriority
90+
}
91+
}
92+
93+
public var chatMessage: ChatMessage
94+
public var dynamicContent: [DynamicContent] = []
95+
public var truncatePriority: Int = 0
96+
97+
public func resolved() -> ChatMessage? {
98+
var baseMessage = chatMessage
99+
guard !dynamicContent.isEmpty else {
100+
if baseMessage.isEmpty { return nil }
101+
return baseMessage
102+
}
103+
104+
let contents: [String] = dynamicContent.compactMap { content in
105+
if content.isEmpty { return nil }
106+
switch content.content {
107+
case let .text(text):
108+
return text
109+
case let .list(list, formatter):
110+
return formatter(list)
111+
}
112+
}
113+
114+
baseMessage.content = contents.joined(separator: "\n\n")
115+
116+
return baseMessage
117+
}
118+
119+
public var isEmpty: Bool {
120+
if !dynamicContent.isEmpty { return dynamicContent.allSatisfy { $0.isEmpty } }
121+
if let toolCalls = chatMessage.toolCalls, !toolCalls.isEmpty {
122+
return false
123+
}
124+
if let content = chatMessage.content, !content.isEmpty {
125+
return false
126+
}
127+
return true
128+
}
129+
130+
public init(
131+
chatMessage: ChatMessage,
132+
dynamicContent: [DynamicContent] = [],
133+
truncatePriority: Int = 0
134+
) {
135+
self.chatMessage = chatMessage
136+
self.dynamicContent = dynamicContent
137+
self.truncatePriority = truncatePriority
138+
}
139+
}
140+
141+
public var messages: [Message]
142+
public var followUpMessages: [ChatMessage]
143+
144+
let truncateRule: ((
145+
_ messages: inout [Message],
146+
_ followUpMessages: inout [ChatMessage]
147+
) throws -> Void)?
148+
149+
func resolved() -> [ChatMessage] {
150+
messages.compactMap { message in message.resolved() } + followUpMessages
151+
}
152+
153+
func truncated() throws -> MemoryTemplate {
154+
var copy = self
155+
try copy.truncate()
156+
return copy
157+
}
158+
159+
mutating func truncate() throws {
160+
if let truncateRule = truncateRule {
161+
try truncateRule(&messages, &followUpMessages)
162+
return
163+
}
164+
165+
try Self.defaultTruncateRule(&messages, &followUpMessages)
166+
}
167+
168+
public static func defaultTruncateRule(
169+
_ messages: inout [Message],
170+
_ followUpMessages: inout [ChatMessage]
171+
) throws {
172+
// Remove the oldest followup messages when available.
173+
174+
if followUpMessages.count > 20 {
175+
followUpMessages.removeFirst(followUpMessages.count / 2)
176+
return
177+
}
178+
179+
if followUpMessages.count > 2 {
180+
if followUpMessages.count.isMultiple(of: 2) {
181+
followUpMessages.removeFirst(2)
182+
} else {
183+
followUpMessages.removeFirst(1)
184+
}
185+
return
186+
}
187+
188+
// Remove according to the priority.
189+
190+
var truncatingMessageIndex: Int?
191+
for (index, message) in messages.enumerated() {
192+
if message.truncatePriority <= 0 { continue }
193+
if let previousIndex = truncatingMessageIndex,
194+
message.truncatePriority > messages[previousIndex].truncatePriority
195+
{
196+
truncatingMessageIndex = index
197+
}
198+
}
199+
200+
guard let truncatingMessageIndex else { throw CancellationError() }
201+
var truncatingMessage: Message {
202+
get { messages[truncatingMessageIndex] }
203+
set { messages[truncatingMessageIndex] = newValue }
204+
}
205+
206+
if truncatingMessage.isEmpty {
207+
messages.remove(at: truncatingMessageIndex)
208+
return
209+
}
210+
211+
truncatingMessage.dynamicContent.removeAll(where: { $0.isEmpty })
212+
213+
var truncatingContentIndex: Int?
214+
for (index, content) in truncatingMessage.dynamicContent.enumerated() {
215+
if content.isEmpty { continue }
216+
if let previousIndex = truncatingContentIndex,
217+
content.truncatePriority > truncatingMessage.dynamicContent[previousIndex]
218+
.truncatePriority
219+
{
220+
truncatingContentIndex = index
221+
}
222+
}
223+
224+
guard let truncatingContentIndex else { throw CancellationError() }
225+
var truncatingContent: Message.DynamicContent {
226+
get { truncatingMessage.dynamicContent[truncatingContentIndex] }
227+
set { truncatingMessage.dynamicContent[truncatingContentIndex] = newValue }
228+
}
229+
230+
switch truncatingContent.content {
231+
case .text:
232+
truncatingMessage.dynamicContent.remove(at: truncatingContentIndex)
233+
case let .list(list, formatter: formatter):
234+
let count = list.count * 2 / 3
235+
if count > 0 {
236+
truncatingContent.content = .list(
237+
Array(list.prefix(count)),
238+
formatter: formatter
239+
)
240+
} else {
241+
truncatingMessage.dynamicContent.remove(at: truncatingContentIndex)
242+
}
243+
}
244+
}
245+
246+
public init(
247+
messages: [Message],
248+
followUpMessages: [ChatMessage] = [],
249+
truncateRule: ((inout [Message], inout [ChatMessage]) -> Void)? = nil
250+
) {
251+
self.messages = messages
252+
self.truncateRule = truncateRule
253+
self.followUpMessages = followUpMessages
254+
}
255+
}
256+

0 commit comments

Comments
 (0)