Skip to content

Commit 39b387d

Browse files
committed
Support changing model according to scopes
1 parent 956f712 commit 39b387d

5 files changed

Lines changed: 53 additions & 18 deletions

File tree

Core/Sources/ChatService/ChatService.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,18 @@ public final class ChatService: ObservableObject {
4040

4141
public convenience init() {
4242
let configuration = UserPreferenceChatGPTConfiguration().overriding()
43+
/// Used by context collector
44+
let extraConfiguration = configuration.overriding()
4345
let memory = ContextAwareAutoManagedChatGPTMemory(
44-
configuration: configuration,
46+
configuration: extraConfiguration,
4547
functionProvider: ChatFunctionProvider()
4648
)
4749
self.init(
4850
memory: memory,
4951
configuration: configuration,
5052
chatGPTService: ChatGPTService(
5153
memory: memory,
52-
configuration: configuration,
54+
configuration: extraConfiguration,
5355
functionProvider: memory.functionProvider
5456
)
5557
)

Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ public final class ContextAwareAutoManagedChatGPTMemory: ChatGPTMemory {
1414
public var remainingTokens: Int? {
1515
get async { await memory.remainingTokens }
1616
}
17-
17+
1818
public var history: [ChatMessage] {
1919
get async { await memory.history }
2020
}
21-
21+
2222
func observeHistoryChange(_ observer: @escaping () -> Void) {
2323
memory.observeHistoryChange(observer)
2424
}
@@ -48,10 +48,13 @@ public final class ContextAwareAutoManagedChatGPTMemory: ChatGPTMemory {
4848
public func refresh() async {
4949
let content = (await memory.history)
5050
.last(where: { $0.role == .user || $0.role == .function })?.content
51-
try? await contextController.updatePromptToMatchContent(systemPrompt: """
52-
\(chatService?.systemPrompt ?? "")
53-
\(chatService?.extraSystemPrompt ?? "")
54-
""", content: content ?? "")
51+
try? await contextController.collectContextInformation(
52+
systemPrompt: """
53+
\(chatService?.systemPrompt ?? "")
54+
\(chatService?.extraSystemPrompt ?? "")
55+
""",
56+
content: content ?? ""
57+
)
5558
await memory.refresh()
5659
}
5760
}

Core/Sources/ChatService/DynamicContextController.swift

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,37 @@ final class DynamicContextController {
3737
self.contextCollectors = contextCollectors
3838
}
3939

40-
func updatePromptToMatchContent(systemPrompt: String, content: String) async throws {
40+
func collectContextInformation(systemPrompt: String, content: String) async throws {
4141
var content = content
4242
var scopes = Self.parseScopes(&content)
4343
scopes.formUnion(defaultScopes)
44+
45+
let overridingChatModelId = {
46+
var ids = [String]()
47+
if scopes.contains(.sense) {
48+
ids.append(UserDefaults.shared.value(for: \.preferredChatModelIdForSenseScope))
49+
}
50+
51+
if scopes.contains(.project) {
52+
ids.append(UserDefaults.shared.value(for: \.preferredChatModelIdForProjectScope))
53+
}
54+
55+
if scopes.contains(.web) {
56+
ids.append(UserDefaults.shared.value(for: \.preferredChatModelIdForWebScope))
57+
}
58+
59+
let chatModels = UserDefaults.shared.value(for: \.chatModels)
60+
let idIndexMap = chatModels.enumerated().reduce(into: [String: Int]()) {
61+
$0[$1.element.id] = $1.offset
62+
}
63+
return ids.sorted(by: {
64+
let lhs = idIndexMap[$0] ?? Int.max
65+
let rhs = idIndexMap[$1] ?? Int.max
66+
return lhs < rhs
67+
}).first
68+
}()
69+
70+
configuration.overriding.modelId = overridingChatModelId
4471

4572
functionProvider.removeAll()
4673
let language = UserDefaults.shared.value(for: \.chatGPTLanguage)

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,15 @@ extension ChatGPTService {
212212

213213
/// Send the memory as prompt to ChatGPT, with stream enabled.
214214
func sendMemory() async throws -> AsyncThrowingStream<StreamContent, Error> {
215+
await memory.refresh()
216+
215217
guard let model = configuration.model else {
216218
throw ChatGPTServiceError.chatModelNotAvailable
217219
}
218220
guard let url = URL(string: configuration.endpoint) else {
219221
throw ChatGPTServiceError.endpointIncorrect
220222
}
221223

222-
await memory.refresh()
223-
224224
let messages = await memory.messages.map {
225225
CompletionRequestBody.Message(
226226
role: $0.role,
@@ -325,15 +325,15 @@ extension ChatGPTService {
325325

326326
/// Send the memory as prompt to ChatGPT, with stream disabled.
327327
func sendMemoryAndWait() async throws -> ChatMessage? {
328+
await memory.refresh()
329+
328330
guard let model = configuration.model else {
329331
throw ChatGPTServiceError.chatModelNotAvailable
330332
}
331333
guard let url = URL(string: configuration.endpoint) else {
332334
throw ChatGPTServiceError.endpointIncorrect
333335
}
334336

335-
await memory.refresh()
336-
337337
let messages = await memory.messages.map {
338338
CompletionRequestBody.Message(
339339
role: $0.role,

Tool/Sources/OpenAIService/Configuration/UserPreferenceChatGPTConfiguration.swift

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
import AIModel
22
import Foundation
3+
import Keychain
34
import Preferences
45

56
public struct UserPreferenceChatGPTConfiguration: ChatGPTConfiguration {
67
public var chatModelKey: KeyPath<UserDefaultPreferenceKeys, PreferenceKey<String>>?
7-
8+
89
public var temperature: Double {
910
min(max(0, UserDefaults.shared.value(for: \.chatGPTTemperature)), 2)
1011
}
1112

1213
public var model: ChatModel? {
1314
let models = UserDefaults.shared.value(for: \.chatModels)
14-
15+
1516
if let chatModelKey {
1617
let id = UserDefaults.shared.value(for: chatModelKey)
1718
if let model = models.first(where: { $0.id == id }) {
1819
return model
1920
}
2021
}
21-
22+
2223
let id = UserDefaults.shared.value(for: \.defaultChatFeatureChatModelId)
2324
return models.first { $0.id == id }
2425
?? models.first
@@ -116,9 +117,11 @@ public class OverridingChatGPTConfiguration: ChatGPTConfiguration {
116117
public var runFunctionsAutomatically: Bool {
117118
overriding.runFunctionsAutomatically ?? configuration.runFunctionsAutomatically
118119
}
119-
120+
120121
public var apiKey: String {
121-
overriding.apiKey ?? configuration.apiKey
122+
if let apiKey = overriding.apiKey { return apiKey }
123+
guard let name = model?.info.apiKeyName else { return configuration.apiKey }
124+
return (try? Keychain.apiKey.get(name)) ?? configuration.apiKey
122125
}
123126
}
124127

0 commit comments

Comments
 (0)