Skip to content

Commit 94aaade

Browse files
committed
Change the definition of a memory
1 parent 9b22c29 commit 94aaade

File tree

13 files changed

+108
-84
lines changed

13 files changed

+108
-84
lines changed

Core/Sources/ChatGPTChatTab/Chat.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public struct DisplayedChatMessage: Equatable {
1717
public var text: String
1818
public var references: [ChatMessage.Reference] = []
1919

20-
public init(id: String, role: Role, text: String, references: [ChatMessage.Reference]) {
20+
public init(id: String, role: Role, text: String, references: [ChatMessage.Reference] = []) {
2121
self.id = id
2222
self.role = role
2323
self.text = text

Core/Sources/ChatPlugins/ShortcutChatPlugin/ShortcutChatPlugin.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public actor ShortcutChatPlugin: ChatPlugin {
4949
var input = String(content).trimmingCharacters(in: .whitespacesAndNewlines)
5050
if input.isEmpty {
5151
// if no input detected, use the previous message as input
52-
input = await chatGPTService.memory.messages.last?.content ?? ""
52+
input = await chatGPTService.memory.history.last?.content ?? ""
5353
await chatGPTService.memory.appendMessage(.init(role: .user, content: originalMessage))
5454
} else {
5555
await chatGPTService.memory.appendMessage(.init(role: .user, content: originalMessage))

Core/Sources/ChatPlugins/ShortcutChatPlugin/ShortcutInputChatPlugin.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public actor ShortcutInputChatPlugin: ChatPlugin {
5252
var input = String(content).trimmingCharacters(in: .whitespacesAndNewlines)
5353
if input.isEmpty {
5454
// if no input detected, use the previous message as input
55-
input = await chatGPTService.memory.messages.last?.content ?? ""
55+
input = await chatGPTService.memory.history.last?.content ?? ""
5656
}
5757

5858
do {

Core/Sources/ChatService/ChatService.swift

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,13 @@ public final class ChatService: ObservableObject {
100100
guard !isReceivingMessage else { throw CancellationError() }
101101
let handledInPlugin = try await pluginController.handleContent(content)
102102
if handledInPlugin { return }
103+
isReceivingMessage = true
104+
defer { isReceivingMessage = false }
103105

104106
let stream = try await chatGPTService.send(content: content, summary: nil)
105-
isReceivingMessage = true
106107
do {
107108
for try await _ in stream {}
108-
isReceivingMessage = false
109-
} catch {
110-
isReceivingMessage = false
111-
}
109+
} catch {}
112110
}
113111

114112
public func sendAndWait(content: String) async throws -> String {

Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,6 @@ public final class ContextAwareAutoManagedChatGPTMemory: ChatGPTMemory {
77
let functionProvider: ChatFunctionProvider
88
weak var chatService: ChatService?
99

10-
public var messages: [ChatMessage] {
11-
get async { await memory.messages }
12-
}
13-
14-
public var remainingTokens: Int? {
15-
get async { await memory.remainingTokens }
16-
}
17-
1810
public var history: [ChatMessage] {
1911
get async { await memory.history }
2012
}
@@ -45,7 +37,7 @@ public final class ContextAwareAutoManagedChatGPTMemory: ChatGPTMemory {
4537
await memory.mutateHistory(update)
4638
}
4739

48-
public func refresh() async {
40+
public func generatePrompt() async -> ChatGPTPrompt {
4941
let content = (await memory.history)
5042
.last(where: { $0.role == .user || $0.role == .function })?.content
5143
try? await contextController.collectContextInformation(
@@ -55,7 +47,7 @@ public final class ContextAwareAutoManagedChatGPTMemory: ChatGPTMemory {
5547
""",
5648
content: content ?? ""
5749
)
58-
await memory.refresh()
50+
return await memory.generatePrompt()
5951
}
6052
}
6153

Core/Sources/ChatService/DynamicContextController.swift

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,21 @@ final class DynamicContextController {
4141
var content = content
4242
var scopes = Self.parseScopes(&content)
4343
scopes.formUnion(defaultScopes)
44-
44+
4545
let overridingChatModelId = {
4646
var ids = [String]()
4747
if scopes.contains(.sense) {
4848
ids.append(UserDefaults.shared.value(for: \.preferredChatModelIdForSenseScope))
4949
}
50-
50+
5151
if scopes.contains(.project) {
5252
ids.append(UserDefaults.shared.value(for: \.preferredChatModelIdForProjectScope))
5353
}
54-
54+
5555
if scopes.contains(.web) {
5656
ids.append(UserDefaults.shared.value(for: \.preferredChatModelIdForWebScope))
5757
}
58-
58+
5959
let chatModels = UserDefaults.shared.value(for: \.chatModels)
6060
let idIndexMap = chatModels.enumerated().reduce(into: [String: Int]()) {
6161
$0[$1.element.id] = $1.offset
@@ -66,7 +66,7 @@ final class DynamicContextController {
6666
return lhs < rhs
6767
}).first
6868
}()
69-
69+
7070
configuration.overriding.modelId = overridingChatModelId
7171

7272
functionProvider.removeAll()
@@ -108,7 +108,17 @@ final class DynamicContextController {
108108
"""
109109
await memory.mutateSystemPrompt(contextualSystemPrompt)
110110
await memory.mutateContextSystemPrompt(contextSystemPrompt)
111-
await memory.mutateRetrievedContent(contextPrompts.map(\.content))
111+
await memory.mutateRetrievedContent(contextPrompts.map {
112+
.init(
113+
title: "",
114+
subTitle: "",
115+
uri: "",
116+
content: $0.content,
117+
startLine: nil,
118+
endLine: nil,
119+
metadata: [:]
120+
)
121+
})
112122
functionProvider.append(functions: contexts.flatMap(\.functions))
113123
}
114124
}

Tool/Sources/LangChain/ChatModel/OpenAIChat.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ public struct OpenAIChat: ChatModel {
4242
message.append(trunk)
4343
callbackManagers.send(CallbackEvents.LLMDidProduceNewToken(info: trunk))
4444
}
45-
return await memory.messages.last ?? .init(role: .assistant, content: "")
45+
return await memory.history.last ?? .init(role: .assistant, content: "")
4646
} else {
4747
let _ = try await service.sendAndWait(content: "")
48-
return await memory.messages.last ?? .init(role: .assistant, content: "")
48+
return await memory.history.last ?? .init(role: .assistant, content: "")
4949
}
5050
}
5151
}

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ public class ChatGPTService: ChatGPTServiceType {
9393
content: content,
9494
name: nil,
9595
functionCall: nil,
96-
summary: summary
96+
summary: summary,
97+
references: []
9798
)
9899
await memory.appendMessage(newMessage)
99100
}
@@ -218,7 +219,7 @@ extension ChatGPTService {
218219

219220
/// Send the memory as prompt to ChatGPT, with stream enabled.
220221
func sendMemory() async throws -> AsyncThrowingStream<StreamContent, Error> {
221-
await memory.refresh()
222+
let prompt = await memory.generatePrompt()
222223

223224
guard let model = configuration.model else {
224225
throw ChatGPTServiceError.chatModelNotAvailable
@@ -227,7 +228,7 @@ extension ChatGPTService {
227228
throw ChatGPTServiceError.endpointIncorrect
228229
}
229230

230-
let messages = await memory.messages.map {
231+
let messages = prompt.history.map {
231232
CompletionRequestBody.Message(
232233
role: $0.role,
233234
content: $0.content ?? "",
@@ -237,7 +238,7 @@ extension ChatGPTService {
237238
}
238239
)
239240
}
240-
let remainingTokens = await memory.remainingTokens
241+
let remainingTokens = prompt.remainingTokenCount
241242

242243
let requestBody = CompletionRequestBody(
243244
model: model.info.modelName,
@@ -278,9 +279,13 @@ extension ChatGPTService {
278279
return AsyncThrowingStream<StreamContent, Error> { continuation in
279280
let task = Task {
280281
do {
282+
let proposedId = UUID().uuidString + String(Date().timeIntervalSince1970)
283+
await memory.streamMessage(
284+
id: proposedId,
285+
references: prompt.references
286+
)
281287
let (trunks, cancel) = try await api()
282288
cancelTask = cancel
283-
let proposedId = UUID().uuidString + String(Date().timeIntervalSince1970)
284289
for try await trunk in trunks {
285290
try Task.checkCancellation()
286291
guard let delta = trunk.choices?.first?.delta else { continue }
@@ -336,7 +341,8 @@ extension ChatGPTService {
336341

337342
/// Send the memory as prompt to ChatGPT, with stream disabled.
338343
func sendMemoryAndWait() async throws -> ChatMessage? {
339-
await memory.refresh()
344+
let proposedId = UUID().uuidString + String(Date().timeIntervalSince1970)
345+
let prompt = await memory.generatePrompt()
340346

341347
guard let model = configuration.model else {
342348
throw ChatGPTServiceError.chatModelNotAvailable
@@ -345,7 +351,7 @@ extension ChatGPTService {
345351
throw ChatGPTServiceError.endpointIncorrect
346352
}
347353

348-
let messages = await memory.messages.map {
354+
let messages = prompt.history.map {
349355
CompletionRequestBody.Message(
350356
role: $0.role,
351357
content: $0.content ?? "",
@@ -355,7 +361,7 @@ extension ChatGPTService {
355361
}
356362
)
357363
}
358-
let remainingTokens = await memory.remainingTokens
364+
let remainingTokens = prompt.remainingTokenCount
359365

360366
let requestBody = CompletionRequestBody(
361367
model: model.info.modelName,
@@ -397,13 +403,14 @@ extension ChatGPTService {
397403

398404
guard let choice = response.choices.first else { return nil }
399405
let message = ChatMessage(
400-
id: response.id ?? UUID().uuidString,
406+
id: proposedId,
401407
role: choice.message.role,
402408
content: choice.message.content,
403409
name: choice.message.name,
404410
functionCall: choice.message.function_call.map {
405411
ChatMessage.FunctionCall(name: $0.name, arguments: $0.arguments ?? "")
406-
}
412+
},
413+
references: prompt.references
407414
)
408415
await memory.appendMessage(message)
409416
return message

Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemory.swift

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@ import TokenEncoder
55

66
/// A memory that automatically manages the history according to max tokens and max message count.
77
public actor AutoManagedChatGPTMemory: ChatGPTMemory {
8-
public private(set) var messages: [ChatMessage] = []
8+
public private(set) var history: [ChatMessage] = [] {
9+
didSet { onHistoryChange() }
10+
}
911
public private(set) var remainingTokens: Int?
1012

1113
public var systemPrompt: String
1214
public var contextSystemPrompt: String
13-
public var retrievedContent: [String] = []
14-
public var history: [ChatMessage] = [] {
15-
didSet { onHistoryChange() }
16-
}
17-
15+
public var retrievedContent: [ChatMessage.Reference] = []
1816
public var configuration: ChatGPTConfiguration
1917
public var functionProvider: ChatGPTFunctionProvider
2018

@@ -46,7 +44,7 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
4644
contextSystemPrompt = newPrompt
4745
}
4846

49-
public func mutateRetrievedContent(_ newContent: [String]) {
47+
public func mutateRetrievedContent(_ newContent: [ChatMessage.Reference]) {
5048
retrievedContent = newContent
5149
}
5250

@@ -57,9 +55,8 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
5755
}
5856
}
5957

60-
public func refresh() async {
61-
messages = generateSendingHistory()
62-
remainingTokens = generateRemainingTokens()
58+
public func generatePrompt() async -> ChatGPTPrompt {
59+
return generateSendingHistory()
6360
}
6461

6562
/// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
@@ -79,7 +76,7 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
7976
func generateSendingHistory(
8077
maxNumberOfMessages: Int = UserDefaults.shared.value(for: \.chatGPTMaxMessageCount),
8178
encoder: TokenEncoder = AutoManagedChatGPTMemory.encoder
82-
) -> [ChatMessage] {
79+
) -> ChatGPTPrompt {
8380
let (
8481
systemPromptMessage,
8582
contextSystemPromptMessage,
@@ -102,7 +99,7 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
10299
retrievedContentMessage,
103100
_,
104101
retrievedContentUsage,
105-
_
102+
retrievedContent
106103
) = generateRetrievedContentMessage(
107104
maxTokenCount: availableTokenCountForRetrievedContent,
108105
encoder: encoder
@@ -134,15 +131,7 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
134131
""")
135132
#endif
136133

137-
return allMessages
138-
}
139-
140-
func generateRemainingTokens(
141-
maxNumberOfMessages: Int = UserDefaults.shared.value(for: \.chatGPTMaxMessageCount),
142-
encoder: TokenEncoder = AutoManagedChatGPTMemory.encoder
143-
) -> Int? {
144-
// It should be fine to just let OpenAI decide.
145-
return nil
134+
return .init(history: allMessages, references: retrievedContent)
146135
}
147136

148137
func setOnHistoryChangeBlock(_ onChange: @escaping () -> Void) {
@@ -240,41 +229,42 @@ extension AutoManagedChatGPTMemory {
240229
retrievedContent: ChatMessage,
241230
remainingTokenCount: Int,
242231
usage: Int,
243-
includedRetrievedContent: [String]
232+
references: [ChatMessage.Reference]
244233
) {
245234
var retrievedContentTokenCount = 0
246235
let separator = String(repeating: "=", count: 32) // only 1 token
247236
var message = ""
248-
var includedRetrievedContent = [String]()
237+
var references = [ChatMessage.Reference]()
249238

250239
func appendToMessage(_ text: String) -> Bool {
251240
let tokensCount = encoder.countToken(text: text)
252241
if tokensCount + retrievedContentTokenCount > maxTokenCount { return false }
253242
retrievedContentTokenCount += tokensCount
254243
message += text
255-
includedRetrievedContent.append(text)
256244
return true
257245
}
258246

259-
for (index, content) in retrievedContent.filter({ !$0.isEmpty }).enumerated() {
247+
for (index, content) in retrievedContent.filter({ !$0.content.isEmpty }).enumerated() {
260248
if index == 0 {
261249
if !appendToMessage("""
262-
Here are the information you know about the system and the project, separated by \(separator)
250+
Here are the information you know about the system and the project, \
251+
separated by \(separator)
263252
264253
265254
""") { break }
266255
} else {
267256
if !appendToMessage("\n\(separator)\n") { break }
268257
}
269258

270-
if !appendToMessage(content) { break }
259+
if !appendToMessage(content.content) { break }
260+
references.append(content)
271261
}
272262

273263
return (
274264
.init(role: .user, content: message),
275265
maxTokenCount - retrievedContentTokenCount,
276266
retrievedContentTokenCount,
277-
includedRetrievedContent
267+
references
278268
)
279269
}
280270
}

0 commit comments

Comments
 (0)