Skip to content

Commit bc39122

Browse files
committed
Adjust the prompt in chat
1 parent 03882d7 commit bc39122

1 file changed

Lines changed: 148 additions & 65 deletions

File tree

Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemory.swift

Lines changed: 148 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
6767
/// Format:
6868
/// ```
6969
/// [System Prompt] priority: high
70-
/// [Retrieved Content] priority: low
70+
/// [Functions] priority: high
71+
/// [Retrieved Content] priority: low
7172
/// [Retrieved Content A]
7273
/// <separator>
7374
/// [Retrieved Content B]
74-
/// [Functions] priority: high
7575
/// [Message History] priority: medium
7676
/// [Context System Prompt] priority: high
7777
/// [Latest Message] priority: high
@@ -80,18 +80,88 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
8080
maxNumberOfMessages: Int = UserDefaults.shared.value(for: \.chatGPTMaxMessageCount),
8181
encoder: TokenEncoder = AutoManagedChatGPTMemory.encoder
8282
) -> [ChatMessage] {
83-
func countToken(_ message: inout ChatMessage) -> Int {
84-
if let count = message.tokensCount { return count }
85-
let count = encoder.countToken(message: message)
86-
message.tokensCount = count
87-
return count
83+
let (
84+
systemPromptMessage,
85+
contextSystemPromptMessage,
86+
availableTokenCountForMessages,
87+
mandatoryUsage
88+
) = generateMandatoryMessages(encoder: encoder)
89+
90+
let (
91+
historyMessage,
92+
newMessage,
93+
availableTokenCountForRetrievedContent,
94+
messageUsage
95+
) = generateMessageHistory(
96+
maxNumberOfMessages: maxNumberOfMessages,
97+
maxTokenCount: availableTokenCountForMessages,
98+
encoder: encoder
99+
)
100+
101+
let (
102+
retrievedContentMessage,
103+
_,
104+
retrievedContentUsage,
105+
_
106+
) = generateRetrievedContentMessage(
107+
maxTokenCount: availableTokenCountForRetrievedContent,
108+
encoder: encoder
109+
)
110+
111+
let allMessages: [ChatMessage] = (
112+
[systemPromptMessage] +
113+
historyMessage +
114+
[retrievedContentMessage, contextSystemPromptMessage, newMessage]
115+
).filter {
116+
!($0.content?.isEmpty ?? false)
88117
}
89118

119+
#if DEBUG
120+
Logger.service.info("""
121+
Sending tokens count
122+
- system prompt: \(mandatoryUsage.systemPrompt)
123+
- context system prompt: \(mandatoryUsage.contextSystemPrompt)
124+
- functions: \(mandatoryUsage.functions)
125+
- messages: \(messageUsage)
126+
- retrieved content: \(retrievedContentUsage)
127+
- total: \(
128+
mandatoryUsage.systemPrompt
129+
+ mandatoryUsage.contextSystemPrompt
130+
+ mandatoryUsage.functions
131+
+ messageUsage
132+
+ retrievedContentUsage
133+
)
134+
""")
135+
#endif
136+
137+
return allMessages
138+
}
139+
140+
func generateRemainingTokens(
141+
maxNumberOfMessages: Int = UserDefaults.shared.value(for: \.chatGPTMaxMessageCount),
142+
encoder: TokenEncoder = AutoManagedChatGPTMemory.encoder
143+
) -> Int? {
144+
// It should be fine to just let OpenAI decide.
145+
return nil
146+
}
147+
148+
func setOnHistoryChangeBlock(_ onChange: @escaping () -> Void) {
149+
onHistoryChange = onChange
150+
}
151+
}
152+
153+
extension AutoManagedChatGPTMemory {
154+
func generateMandatoryMessages(encoder: TokenEncoder) -> (
155+
systemPrompt: ChatMessage,
156+
contextSystemPrompt: ChatMessage,
157+
remainingTokenCount: Int,
158+
usage: (systemPrompt: Int, contextSystemPrompt: Int, functions: Int)
159+
) {
90160
var smallestSystemPromptMessage = ChatMessage(role: .system, content: systemPrompt)
91161
var contextSystemPromptMessage = ChatMessage(role: .user, content: contextSystemPrompt)
92-
let smallestSystemMessageTokenCount = countToken(&smallestSystemPromptMessage)
162+
let smallestSystemMessageTokenCount = encoder.countToken(&smallestSystemPromptMessage)
93163
let contextSystemPromptTokenCount = !contextSystemPrompt.isEmpty
94-
? countToken(&contextSystemPromptMessage)
164+
? encoder.countToken(&contextSystemPromptMessage)
95165
: 0
96166

97167
let functionTokenCount = functionProvider.functions.reduce(into: 0) { partial, function in
@@ -109,46 +179,86 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
109179
+ functionTokenCount
110180
+ 3 // every reply is primed with <|start|>assistant<|message|>
111181

182+
// build messages
183+
112184
/// the available tokens count for other messages and retrieved content
113185
let availableTokenCountForMessages = configuration.maxTokens
114186
- configuration.minimumReplyTokens
115187
- mandatoryContentTokensCount
116188

189+
return (
190+
smallestSystemPromptMessage,
191+
contextSystemPromptMessage,
192+
availableTokenCountForMessages,
193+
(
194+
smallestSystemMessageTokenCount,
195+
contextSystemPromptTokenCount,
196+
functionTokenCount
197+
)
198+
)
199+
}
200+
201+
func generateMessageHistory(
202+
maxNumberOfMessages: Int,
203+
maxTokenCount: Int,
204+
encoder: TokenEncoder
205+
) -> (
206+
history: [ChatMessage],
207+
newMessage: ChatMessage,
208+
remainingTokenCount: Int,
209+
usage: Int
210+
) {
117211
var messageTokenCount = 0
118212
var allMessages: [ChatMessage] = []
213+
var newMessage: ChatMessage?
119214

120215
for (index, message) in history.enumerated().reversed() {
121216
if maxNumberOfMessages > 0, allMessages.count >= maxNumberOfMessages { break }
122217
if message.isEmpty { continue }
123-
let tokensCount = countToken(&history[index])
124-
if tokensCount + messageTokenCount > availableTokenCountForMessages { break }
218+
let tokensCount = encoder.countToken(&history[index])
219+
if tokensCount + messageTokenCount > maxTokenCount { break }
125220
messageTokenCount += tokensCount
126-
allMessages.append(message)
221+
if index == history.endIndex - 1 {
222+
newMessage = message
223+
} else {
224+
allMessages.append(message)
225+
}
127226
}
128227

129-
/// the available tokens count for retrieved content
130-
let availableTokenCountForRetrievedContent = min(
131-
availableTokenCountForMessages - messageTokenCount,
132-
configuration.maxTokens / 2
228+
return (
229+
allMessages.reversed(),
230+
newMessage ?? .init(role: .user, content: ""),
231+
maxTokenCount - messageTokenCount,
232+
messageTokenCount
133233
)
134-
var retrievedContentTokenCount = 0
234+
}
135235

236+
func generateRetrievedContentMessage(
237+
maxTokenCount: Int,
238+
encoder: TokenEncoder
239+
) -> (
240+
retrievedContent: ChatMessage,
241+
remainingTokenCount: Int,
242+
usage: Int,
243+
includedRetrievedContent: [String]
244+
) {
245+
var retrievedContentTokenCount = 0
136246
let separator = String(repeating: "=", count: 32) // only 1 token
247+
var message = ""
248+
var includedRetrievedContent = [String]()
137249

138-
var systemPrompt = systemPrompt
139-
140-
func appendToSystemPrompt(_ text: String) -> Bool {
250+
func appendToMessage(_ text: String) -> Bool {
141251
let tokensCount = encoder.countToken(text: text)
142-
if tokensCount + retrievedContentTokenCount >
143-
availableTokenCountForRetrievedContent { return false }
252+
if tokensCount + retrievedContentTokenCount > maxTokenCount { return false }
144253
retrievedContentTokenCount += tokensCount
145-
systemPrompt += text
254+
message += text
255+
includedRetrievedContent.append(text)
146256
return true
147257
}
148258

149259
for (index, content) in retrievedContent.filter({ !$0.isEmpty }).enumerated() {
150260
if index == 0 {
151-
if !appendToSystemPrompt("""
261+
if !appendToMessage("""
152262
153263
154264
## Relevant Content
@@ -158,52 +268,18 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
158268
159269
""") { break }
160270
} else {
161-
if !appendToSystemPrompt("\n\(separator)\n") { break }
271+
if !appendToMessage("\n\(separator)\n") { break }
162272
}
163273

164-
if !appendToSystemPrompt(content) { break }
274+
if !appendToMessage(content) { break }
165275
}
166276

167-
if !systemPrompt.isEmpty {
168-
let message = ChatMessage(role: .system, content: systemPrompt)
169-
allMessages.append(message)
170-
}
171-
172-
if !contextSystemPrompt.isEmpty {
173-
allMessages.insert(contextSystemPromptMessage, at: 1)
174-
}
175-
176-
#if DEBUG
177-
Logger.service.info("""
178-
Sending tokens count
179-
- system prompt: \(smallestSystemMessageTokenCount)
180-
- context system prompt: \(contextSystemPromptTokenCount)
181-
- functions: \(functionTokenCount)
182-
- messages: \(messageTokenCount)
183-
- retrieved content: \(retrievedContentTokenCount)
184-
- total: \(
185-
smallestSystemMessageTokenCount
186-
+ contextSystemPromptTokenCount
187-
+ functionTokenCount
188-
+ messageTokenCount
189-
+ retrievedContentTokenCount
277+
return (
278+
.init(role: .user, content: message),
279+
maxTokenCount - retrievedContentTokenCount,
280+
retrievedContentTokenCount,
281+
includedRetrievedContent
190282
)
191-
""")
192-
#endif
193-
194-
return allMessages.reversed()
195-
}
196-
197-
func generateRemainingTokens(
198-
maxNumberOfMessages: Int = UserDefaults.shared.value(for: \.chatGPTMaxMessageCount),
199-
encoder: TokenEncoder = AutoManagedChatGPTMemory.encoder
200-
) -> Int? {
201-
// It should be fine to just let OpenAI decide.
202-
return nil
203-
}
204-
205-
func setOnHistoryChangeBlock(_ onChange: @escaping () -> Void) {
206-
onHistoryChange = onChange
207283
}
208284
}
209285

@@ -224,5 +300,12 @@ extension TokenEncoder {
224300
}
225301
return total
226302
}
303+
304+
func countToken(_ message: inout ChatMessage) -> Int {
305+
if let count = message.tokensCount { return count }
306+
let count = countToken(message: message)
307+
message.tokensCount = count
308+
return count
309+
}
227310
}
228311

0 commit comments

Comments
 (0)