Skip to content

Commit 970bc84

Browse files
committed
Support async truncation
1 parent bc34a0f commit 970bc84

1 file changed

Lines changed: 19 additions & 20 deletions

File tree

Tool/Sources/OpenAIService/Memory/TemplateChatGPTMemory.swift

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public actor TemplateChatGPTMemory: ChatGPTMemory {
4646

4747
while !(await checkTokenCount()) {
4848
do {
49-
try memoryTemplate.truncate()
49+
try await memoryTemplate.truncate()
5050
} catch {
5151
Logger.service.error("Failed to truncate prompt template: \(error)")
5252
break
@@ -84,7 +84,7 @@ public struct MemoryTemplate {
8484
content = .text(value)
8585
}
8686

87-
public init(content: Content, truncatePriority: Int = 0) {
87+
public init(_ content: Content, truncatePriority: Int = 0) {
8888
self.content = content
8989
self.truncatePriority = truncatePriority
9090
}
@@ -144,21 +144,31 @@ public struct MemoryTemplate {
144144
let truncateRule: ((
145145
_ messages: inout [Message],
146146
_ followUpMessages: inout [ChatMessage]
147-
) throws -> Void)?
147+
) async throws -> Void)?
148+
149+
public init(
150+
messages: [Message],
151+
followUpMessages: [ChatMessage] = [],
152+
truncateRule: ((inout [Message], inout [ChatMessage]) async throws -> Void)? = nil
153+
) {
154+
self.messages = messages
155+
self.truncateRule = truncateRule
156+
self.followUpMessages = followUpMessages
157+
}
148158

149159
func resolved() -> [ChatMessage] {
150160
messages.compactMap { message in message.resolved() } + followUpMessages
151161
}
152162

153-
func truncated() throws -> MemoryTemplate {
163+
func truncated() async throws -> MemoryTemplate {
154164
var copy = self
155-
try copy.truncate()
165+
try await copy.truncate()
156166
return copy
157167
}
158168

159-
mutating func truncate() throws {
169+
mutating func truncate() async throws {
160170
if let truncateRule = truncateRule {
161-
try truncateRule(&messages, &followUpMessages)
171+
try await truncateRule(&messages, &followUpMessages)
162172
return
163173
}
164174

@@ -170,7 +180,7 @@ public struct MemoryTemplate {
170180
_ followUpMessages: inout [ChatMessage]
171181
) throws {
172182
// Remove the oldest followup messages when available.
173-
183+
174184
if followUpMessages.count > 20 {
175185
followUpMessages.removeFirst(followUpMessages.count / 2)
176186
return
@@ -186,7 +196,7 @@ public struct MemoryTemplate {
186196
}
187197

188198
// Remove according to the priority.
189-
199+
190200
var truncatingMessageIndex: Int?
191201
for (index, message) in messages.enumerated() {
192202
if message.truncatePriority <= 0 { continue }
@@ -242,15 +252,4 @@ public struct MemoryTemplate {
242252
}
243253
}
244254
}
245-
246-
public init(
247-
messages: [Message],
248-
followUpMessages: [ChatMessage] = [],
249-
truncateRule: ((inout [Message], inout [ChatMessage]) -> Void)? = nil
250-
) {
251-
self.messages = messages
252-
self.truncateRule = truncateRule
253-
self.followUpMessages = followUpMessages
254-
}
255255
}
256-

0 commit comments

Comments
 (0)