Skip to content

Commit 81f723c

Browse files
committed
Update AutoManagedChatGPTMemory to handle retrieved content
1 parent 68575c2 commit 81f723c

2 files changed

Lines changed: 81 additions & 26 deletions

File tree

Core/Sources/ChatService/DynamicContextController.swift

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ final class DynamicContextController {
6666
return contexts
6767
}
6868

69-
let separator = String(repeating: "=", count: 32) // only 1 token
70-
7169
let contextPrompts = contexts
7270
.flatMap(\.systemPrompt)
7371
.filter { !$0.content.isEmpty }
@@ -76,12 +74,9 @@ final class DynamicContextController {
7674
let contextualSystemPrompt = """
7775
\(language.isEmpty ? "" : "You must always reply in \(language)")
7876
\(systemPrompt)
79-
80-
Below are information related to the conversation, separated by \(separator)
81-
82-
\(contextPrompts.map(\.content).joined(separator: "\n\(separator)\n"))
8377
"""
8478
await memory.mutateSystemPrompt(contextualSystemPrompt)
79+
await memory.mutateRetrievedContent(contextPrompts.map(\.content))
8580
functionProvider.append(functions: contexts.flatMap(\.functions))
8681
}
8782
}

Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemory.swift

Lines changed: 80 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
88
public private(set) var messages: [ChatMessage] = []
99
public private(set) var remainingTokens: Int?
1010

11-
public var systemPrompt: ChatMessage
11+
public var systemPrompt: String
12+
public var retrievedContent: [String] = []
1213
public var history: [ChatMessage] = [] {
1314
didSet { onHistoryChange() }
1415
}
@@ -25,7 +26,7 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
2526
configuration: ChatGPTConfiguration,
2627
functionProvider: ChatGPTFunctionProvider
2728
) {
28-
self.systemPrompt = .init(role: .system, content: systemPrompt)
29+
self.systemPrompt = systemPrompt
2930
self.configuration = configuration
3031
self.functionProvider = functionProvider
3132
_ = Self.encoder // force pre-initialize
@@ -36,7 +37,11 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
3637
}
3738

3839
public func mutateSystemPrompt(_ newPrompt: String) {
39-
systemPrompt.content = newPrompt
40+
systemPrompt = newPrompt
41+
}
42+
43+
public func mutateRetrievedContent(_ newContent: [String]) {
44+
retrievedContent = newContent
4045
}
4146

4247
public nonisolated
@@ -52,6 +57,17 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
5257
}
5358

5459
/// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
60+
///
61+
/// Format:
62+
/// ```
63+
/// [System Prompt] priority: high
64+
/// [Retrieved Content] priority: low
65+
/// [Retrieved Content A]
66+
/// <separator>
67+
/// [Retrieved Content B]
68+
/// [Functions] priority: high
69+
/// [Message History] priority: medium
70+
/// ```
5571
func generateSendingHistory(
5672
maxNumberOfMessages: Int = UserDefaults.shared.value(for: \.chatGPTMaxMessageCount),
5773
encoder: TokenEncoder = AutoManagedChatGPTMemory.encoder
@@ -63,8 +79,8 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
6379
return count
6480
}
6581

66-
var all: [ChatMessage] = []
67-
let systemMessageTokenCount = countToken(&systemPrompt)
82+
var smallestSystemPromptMessage = ChatMessage(role: .system, content: systemPrompt)
83+
let smallestSystemMessageTokenCount = countToken(&smallestSystemPromptMessage)
6884
let functionTokenCount = functionProvider.functions.reduce(into: 0) { partial, function in
6985
var count = encoder.countToken(text: function.name)
7086
+ encoder.countToken(text: function.description)
@@ -75,38 +91,82 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
7591
}
7692
partial += count
7793
}
78-
var allTokensCount = functionTokenCount +
79-
3 // every reply is primed with <|start|>assistant<|message|>
80-
allTokensCount += systemPrompt.isEmpty ? 0 : systemMessageTokenCount
94+
let mandatoryContentTokensCount = smallestSystemMessageTokenCount
95+
+ functionTokenCount
96+
+ 3 // every reply is primed with <|start|>assistant<|message|>
97+
98+
/// the available tokens count for other messages and retrieved content
99+
let availableTokenCountForMessages = configuration.maxTokens
100+
- configuration.minimumReplyTokens
101+
- mandatoryContentTokensCount
102+
103+
var messageTokenCount = 0
104+
var allMessages: [ChatMessage] = []
81105

82106
for (index, message) in history.enumerated().reversed() {
83-
if maxNumberOfMessages > 0, all.count >= maxNumberOfMessages { break }
107+
if maxNumberOfMessages > 0, allMessages.count >= maxNumberOfMessages { break }
84108
if message.isEmpty { continue }
85109
let tokensCount = countToken(&history[index])
86-
if tokensCount + allTokensCount >
87-
configuration.maxTokens - configuration.minimumReplyTokens
88-
{
89-
break
110+
if tokensCount + messageTokenCount > availableTokenCountForMessages { break }
111+
messageTokenCount += tokensCount
112+
allMessages.append(message)
113+
}
114+
115+
/// the available tokens count for retrieved content
116+
let availableTokenCountForRetrievedContent = availableTokenCountForMessages
117+
- messageTokenCount
118+
var retrievedContentTokenCount = 0
119+
120+
let separator = String(repeating: "=", count: 32) // only 1 token
121+
122+
var systemPrompt = systemPrompt
123+
124+
func appendToSystemPrompt(_ text: String) -> Bool {
125+
let tokensCount = encoder.countToken(text: text)
126+
if tokensCount + retrievedContentTokenCount >
127+
availableTokenCountForRetrievedContent { return false }
128+
retrievedContentTokenCount += tokensCount
129+
systemPrompt += text
130+
return true
131+
}
132+
133+
for (index, content) in retrievedContent.filter({ !$0.isEmpty }).enumerated() {
134+
if index == 0 {
135+
if !appendToSystemPrompt("""
136+
137+
Below are information related to the conversation, separated by \(separator)
138+
139+
""") { break }
140+
} else {
141+
if !appendToSystemPrompt(separator) { break }
90142
}
91-
allTokensCount += tokensCount
92-
all.append(message)
143+
144+
if !appendToSystemPrompt(content) { break }
93145
}
94146

95147
if !systemPrompt.isEmpty {
96-
all.append(systemPrompt)
148+
let message = ChatMessage(role: .system, content: systemPrompt)
149+
allMessages.append(message)
97150
}
98151

99152
#if DEBUG
100153
Logger.service.info("""
101154
Sending tokens count
102-
- system prompt: \(systemMessageTokenCount)
155+
- system prompt: \(smallestSystemPromptMessage)
103156
- functions: \(functionTokenCount)
104-
- total: \(allTokensCount)
105-
157+
- messages: \(messageTokenCount)
158+
- retrieved content: \(retrievedContentTokenCount)
159+
- total: \(
160+
smallestSystemMessageTokenCount
161+
+ functionTokenCount
162+
+ messageTokenCount
163+
+ retrievedContentTokenCount
164+
)
165+
106166
""")
107167
#endif
108168

109-
return all.reversed()
169+
return allMessages.reversed()
110170
}
111171

112172
func generateRemainingTokens(

0 commit comments

Comments
 (0)