Skip to content
This repository was archived by the owner on Jun 18, 2025. It is now read-only.

Commit af2e417

Browse files
committed
Update MemoryTemplate
1 parent f0e4b88 commit af2e417

1 file changed

Lines changed: 113 additions & 79 deletions

File tree

Tool/Sources/OpenAIService/Memory/TemplateChatGPTMemory.swift

Lines changed: 113 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -63,36 +63,43 @@ public struct MemoryTemplate {
6363
public enum Content: ExpressibleByStringLiteral {
6464
case text(String)
6565
case list([String], formatter: ([String]) -> String)
66+
case priorityList(
67+
[(content: String, priority: Int)],
68+
formatter: ([String]) -> String
69+
)
6670

6771
public init(stringLiteral value: String) {
6872
self = .text(value)
6973
}
7074
}
7175

7276
public var content: Content
73-
public var truncatePriority: Int = 0
77+
public var priority: Int
7478
public var isEmpty: Bool {
7579
switch content {
7680
case let .text(text):
7781
return text.isEmpty
7882
case let .list(list, _):
7983
return list.isEmpty
84+
case let .priorityList(list, _):
85+
return list.isEmpty
8086
}
8187
}
8288

8389
public init(stringLiteral value: String) {
8490
content = .text(value)
91+
priority = .max
8592
}
8693

87-
public init(_ content: Content, truncatePriority: Int = 0) {
94+
public init(_ content: Content, priority: Int = .max) {
8895
self.content = content
89-
self.truncatePriority = truncatePriority
96+
self.priority = priority
9097
}
9198
}
9299

93100
public var chatMessage: ChatMessage
94101
public var dynamicContent: [DynamicContent] = []
95-
public var truncatePriority: Int = 0
102+
public var priority: Int
96103

97104
public func resolved() -> ChatMessage? {
98105
var baseMessage = chatMessage
@@ -108,6 +115,8 @@ public struct MemoryTemplate {
108115
return text
109116
case let .list(list, formatter):
110117
return formatter(list)
118+
case let .priorityList(list, formatter):
119+
return formatter(list.map { $0.0 })
111120
}
112121
}
113122

@@ -130,26 +139,28 @@ public struct MemoryTemplate {
130139
public init(
131140
chatMessage: ChatMessage,
132141
dynamicContent: [DynamicContent] = [],
133-
truncatePriority: Int = 0
142+
priority: Int = .max
134143
) {
135144
self.chatMessage = chatMessage
136145
self.dynamicContent = dynamicContent
137-
self.truncatePriority = truncatePriority
146+
self.priority = priority
138147
}
139148
}
140149

141150
public var messages: [Message]
142151
public var followUpMessages: [ChatMessage]
143152

144-
let truncateRule: ((
153+
public typealias TruncateRule = (
145154
_ messages: inout [Message],
146155
_ followUpMessages: inout [ChatMessage]
147-
) async throws -> Void)?
156+
) async throws -> Void
157+
158+
let truncateRule: TruncateRule?
148159

149160
public init(
150161
messages: [Message],
151162
followUpMessages: [ChatMessage] = [],
152-
truncateRule: ((inout [Message], inout [ChatMessage]) async throws -> Void)? = nil
163+
truncateRule: TruncateRule? = nil
153164
) {
154165
self.messages = messages
155166
self.truncateRule = truncateRule
@@ -172,84 +183,107 @@ public struct MemoryTemplate {
172183
return
173184
}
174185

175-
try Self.defaultTruncateRule(&messages, &followUpMessages)
186+
try await Self.defaultTruncateRule()(&messages, &followUpMessages)
187+
}
188+
189+
public struct DefaultTruncateRuleOptions {
190+
public var numberOfContentListItemToKeep: (Int) -> Int = { $0 * 2 / 3 }
176191
}
177192

178193
public static func defaultTruncateRule(
179-
_ messages: inout [Message],
180-
_ followUpMessages: inout [ChatMessage]
181-
) throws {
182-
// Remove the oldest followup messages when available.
183-
184-
if followUpMessages.count > 20 {
185-
followUpMessages.removeFirst(followUpMessages.count / 2)
186-
return
187-
}
188-
189-
if followUpMessages.count > 2 {
190-
if followUpMessages.count.isMultiple(of: 2) {
191-
followUpMessages.removeFirst(2)
192-
} else {
193-
followUpMessages.removeFirst(1)
194+
options updateOptions: (inout DefaultTruncateRuleOptions) -> Void = { _ in }
195+
) -> TruncateRule {
196+
var options = DefaultTruncateRuleOptions()
197+
updateOptions(&options)
198+
return { messages, followUpMessages in
199+
200+
// Remove the oldest followup messages when available.
201+
202+
if followUpMessages.count > 20 {
203+
followUpMessages.removeFirst(followUpMessages.count / 2)
204+
return
194205
}
195-
return
196-
}
197-
198-
// Remove according to the priority.
199-
200-
var truncatingMessageIndex: Int?
201-
for (index, message) in messages.enumerated() {
202-
if message.truncatePriority <= 0 { continue }
203-
if let previousIndex = truncatingMessageIndex,
204-
message.truncatePriority > messages[previousIndex].truncatePriority
205-
{
206-
truncatingMessageIndex = index
206+
207+
if followUpMessages.count > 2 {
208+
if followUpMessages.count.isMultiple(of: 2) {
209+
followUpMessages.removeFirst(2)
210+
} else {
211+
followUpMessages.removeFirst(1)
212+
}
213+
return
207214
}
208-
}
209-
210-
guard let truncatingMessageIndex else { throw CancellationError() }
211-
var truncatingMessage: Message {
212-
get { messages[truncatingMessageIndex] }
213-
set { messages[truncatingMessageIndex] = newValue }
214-
}
215-
216-
if truncatingMessage.isEmpty {
217-
messages.remove(at: truncatingMessageIndex)
218-
return
219-
}
220-
221-
truncatingMessage.dynamicContent.removeAll(where: { $0.isEmpty })
222-
223-
var truncatingContentIndex: Int?
224-
for (index, content) in truncatingMessage.dynamicContent.enumerated() {
225-
if content.isEmpty { continue }
226-
if let previousIndex = truncatingContentIndex,
227-
content.truncatePriority > truncatingMessage.dynamicContent[previousIndex]
228-
.truncatePriority
229-
{
230-
truncatingContentIndex = index
215+
216+
// Remove according to the priority.
217+
218+
var truncatingMessageIndex: Int?
219+
for (index, message) in messages.enumerated() {
220+
if message.priority == .max { continue }
221+
if let previousIndex = truncatingMessageIndex,
222+
message.priority < messages[previousIndex].priority
223+
{
224+
truncatingMessageIndex = index
225+
}
231226
}
232-
}
233-
234-
guard let truncatingContentIndex else { throw CancellationError() }
235-
var truncatingContent: Message.DynamicContent {
236-
get { truncatingMessage.dynamicContent[truncatingContentIndex] }
237-
set { truncatingMessage.dynamicContent[truncatingContentIndex] = newValue }
238-
}
239-
240-
switch truncatingContent.content {
241-
case .text:
242-
truncatingMessage.dynamicContent.remove(at: truncatingContentIndex)
243-
case let .list(list, formatter: formatter):
244-
let count = list.count * 2 / 3
245-
if count > 0 {
246-
truncatingContent.content = .list(
247-
Array(list.prefix(count)),
248-
formatter: formatter
249-
)
250-
} else {
227+
228+
guard let truncatingMessageIndex else { throw CancellationError() }
229+
var truncatingMessage: Message {
230+
get { messages[truncatingMessageIndex] }
231+
set { messages[truncatingMessageIndex] = newValue }
232+
}
233+
234+
if truncatingMessage.isEmpty {
235+
messages.remove(at: truncatingMessageIndex)
236+
return
237+
}
238+
239+
truncatingMessage.dynamicContent.removeAll(where: { $0.isEmpty })
240+
241+
var truncatingContentIndex: Int?
242+
for (index, content) in truncatingMessage.dynamicContent.enumerated() {
243+
if content.isEmpty { continue }
244+
if let previousIndex = truncatingContentIndex,
245+
content.priority < truncatingMessage.dynamicContent[previousIndex].priority
246+
{
247+
truncatingContentIndex = index
248+
}
249+
}
250+
251+
guard let truncatingContentIndex else { throw CancellationError() }
252+
var truncatingContent: Message.DynamicContent {
253+
get { truncatingMessage.dynamicContent[truncatingContentIndex] }
254+
set { truncatingMessage.dynamicContent[truncatingContentIndex] = newValue }
255+
}
256+
257+
switch truncatingContent.content {
258+
case .text:
251259
truncatingMessage.dynamicContent.remove(at: truncatingContentIndex)
260+
case let .list(list, formatter):
261+
let count = options.numberOfContentListItemToKeep(list.count)
262+
if count > 0 {
263+
truncatingContent.content = .list(
264+
Array(list.prefix(count)),
265+
formatter: formatter
266+
)
267+
} else {
268+
truncatingMessage.dynamicContent.remove(at: truncatingContentIndex)
269+
}
270+
case let .priorityList(list, formatter):
271+
let count = options.numberOfContentListItemToKeep(list.count)
272+
if count > 0 {
273+
let orderedList = list.enumerated()
274+
let orderedByPriority = orderedList
275+
.sorted { $0.element.priority >= $1.element.priority }
276+
let kept = orderedByPriority.prefix(count)
277+
let reordered = kept.sorted { $0.offset < $1.offset }
278+
truncatingContent.content = .priorityList(
279+
Array(reordered.map { $0.element }),
280+
formatter: formatter
281+
)
282+
} else {
283+
truncatingMessage.dynamicContent.remove(at: truncatingContentIndex)
284+
}
252285
}
253286
}
254287
}
255288
}
289+

0 commit comments

Comments
 (0)