Skip to content

Commit 268c818

Browse files
committed
Add settings keys for prompt to code chat model and embedding model
1 parent ef1e9c9 commit 268c818

File tree

7 files changed

+80
-30
lines changed

7 files changed

+80
-30
lines changed

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,18 @@ public protocol ChatGPTServiceType {
1010
}
1111

1212
public enum ChatGPTServiceError: Error, LocalizedError {
13+
case chatModelNotAvailable
14+
case embeddingModelNotAvailable
1315
case endpointIncorrect
1416
case responseInvalid
1517
case otherError(String)
1618

1719
public var errorDescription: String? {
1820
switch self {
21+
case .chatModelNotAvailable:
22+
return "Chat model is not available, please add a model in the settings."
23+
case .embeddingModelNotAvailable:
24+
return "Embedding model is not available, please add a model in the settings."
1925
case .endpointIncorrect:
2026
return "ChatGPT endpoint is incorrect"
2127
case .responseInvalid:
@@ -180,8 +186,12 @@ extension ChatGPTService {
180186

181187
/// Send the memory as prompt to ChatGPT, with stream enabled.
182188
func sendMemory() async throws -> AsyncThrowingStream<StreamContent, Error> {
183-
guard let url = URL(string: configuration.endpoint)
184-
else { throw ChatGPTServiceError.endpointIncorrect }
189+
guard let model = configuration.model else {
190+
throw ChatGPTServiceError.chatModelNotAvailable
191+
}
192+
guard let url = URL(string: configuration.endpoint) else {
193+
throw ChatGPTServiceError.endpointIncorrect
194+
}
185195

186196
await memory.refresh()
187197

@@ -197,8 +207,6 @@ extension ChatGPTService {
197207
}
198208
let remainingTokens = await memory.remainingTokens
199209

200-
let model = configuration.model
201-
202210
let requestBody = CompletionRequestBody(
203211
model: model.info.modelName,
204212
messages: messages,
@@ -287,8 +295,12 @@ extension ChatGPTService {
287295

288296
/// Send the memory as prompt to ChatGPT, with stream disabled.
289297
func sendMemoryAndWait() async throws -> ChatMessage? {
290-
guard let url = URL(string: configuration.endpoint)
291-
else { throw ChatGPTServiceError.endpointIncorrect }
298+
guard let model = configuration.model else {
299+
throw ChatGPTServiceError.chatModelNotAvailable
300+
}
301+
guard let url = URL(string: configuration.endpoint) else {
302+
throw ChatGPTServiceError.endpointIncorrect
303+
}
292304

293305
await memory.refresh()
294306

@@ -304,8 +316,6 @@ extension ChatGPTService {
304316
}
305317
let remainingTokens = await memory.remainingTokens
306318

307-
let model = configuration.model
308-
309319
let requestBody = CompletionRequestBody(
310320
model: model.info.modelName,
311321
messages: messages,
@@ -357,7 +367,7 @@ extension ChatGPTService {
357367
/// When a function call is detected, but arguments are not yet ready, we can call this
358368
/// to insert a message placeholder in memory.
359369
func prepareFunctionCall(_ call: ChatMessage.FunctionCall, messageId: String) async {
360-
guard var function = functionProvider.function(named: call.name) else { return }
370+
guard let function = functionProvider.function(named: call.name) else { return }
361371
let responseMessage = ChatMessage(
362372
id: messageId,
363373
role: .function,
@@ -380,7 +390,7 @@ extension ChatGPTService {
380390
) async -> String {
381391
let messageId = messageId ?? uuidGenerator()
382392

383-
guard var function = functionProvider.function(named: call.name) else {
393+
guard let function = functionProvider.function(named: call.name) else {
384394
return await fallbackFunctionCall(call, messageId: messageId)
385395
}
386396

Tool/Sources/OpenAIService/Configuration/ChatGPTConfiguration.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import Preferences
44
import Keychain
55

66
public protocol ChatGPTConfiguration {
7-
var model: ChatModel { get }
7+
var model: ChatModel? { get }
88
var temperature: Double { get }
99
var apiKey: String { get }
1010
var stop: [String] { get }
@@ -15,11 +15,12 @@ public protocol ChatGPTConfiguration {
1515

1616
public extension ChatGPTConfiguration {
1717
var endpoint: String {
18-
model.endpoint
18+
model?.endpoint ?? ""
1919
}
2020

2121
var apiKey: String {
22-
(try? Keychain.apiKey.get(model.info.apiKeyName)) ?? ""
22+
guard let name = model?.info.apiKeyName else { return "" }
23+
return (try? Keychain.apiKey.get(name)) ?? ""
2324
}
2425

2526
func overriding(

Tool/Sources/OpenAIService/Configuration/EmbeddingConfiguration.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,20 @@ import Keychain
44
import Preferences
55

66
public protocol EmbeddingConfiguration {
7-
var model: EmbeddingModel { get }
7+
var model: EmbeddingModel? { get }
88
var apiKey: String { get }
99
var maxToken: Int { get }
1010
var dimensions: Int { get }
1111
}
1212

1313
public extension EmbeddingConfiguration {
1414
var endpoint: String {
15-
model.endpoint
15+
model?.endpoint ?? ""
1616
}
1717

1818
var apiKey: String {
19-
(try? Keychain.apiKey.get(model.info.apiKeyName)) ?? ""
19+
guard let name = model?.info.apiKeyName else { return "" }
20+
return (try? Keychain.apiKey.get(name)) ?? ""
2021
}
2122

2223
func overriding(

Tool/Sources/OpenAIService/Configuration/UserPreferenceChatGPTConfiguration.swift

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,29 @@ import Foundation
33
import Preferences
44

55
public struct UserPreferenceChatGPTConfiguration: ChatGPTConfiguration {
6+
public var chatModelKey: KeyPath<UserDefaultPreferenceKeys, PreferenceKey<String>>?
7+
68
public var temperature: Double {
79
min(max(0, UserDefaults.shared.value(for: \.chatGPTTemperature)), 2)
810
}
911

10-
public var model: ChatModel {
12+
public var model: ChatModel? {
1113
let models = UserDefaults.shared.value(for: \.chatModels)
14+
15+
if let chatModelKey {
16+
let id = UserDefaults.shared.value(for: chatModelKey)
17+
if let model = models.first(where: { $0.id == id }) {
18+
return model
19+
}
20+
}
21+
1222
let id = UserDefaults.shared.value(for: \.defaultChatFeatureChatModelId)
1323
return models.first { $0.id == id }
14-
?? models.first ?? .init(id: "", name: "", format: .openAI, info: .init())
24+
?? models.first
1525
}
1626

1727
public var maxTokens: Int {
18-
model.info.maxTokens
28+
model?.info.maxTokens ?? 0
1929
}
2030

2131
public var stop: [String] {
@@ -30,7 +40,9 @@ public struct UserPreferenceChatGPTConfiguration: ChatGPTConfiguration {
3040
true
3141
}
3242

33-
public init() {}
43+
public init(chatModelKey: KeyPath<UserDefaultPreferenceKeys, PreferenceKey<String>>? = nil) {
44+
self.chatModelKey = chatModelKey
45+
}
3446
}
3547

3648
public class OverridingChatGPTConfiguration: ChatGPTConfiguration {
@@ -77,7 +89,7 @@ public class OverridingChatGPTConfiguration: ChatGPTConfiguration {
7789
overriding.temperature ?? configuration.temperature
7890
}
7991

80-
public var model: ChatModel {
92+
public var model: ChatModel? {
8193
if let model = overriding.model { return model }
8294
let models = UserDefaults.shared.value(for: \.chatModels)
8395
guard let id = overriding.modelId,

Tool/Sources/OpenAIService/Configuration/UserPreferenceEmbeddingConfiguration.swift

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,40 @@ import Foundation
33
import Preferences
44

55
public struct UserPreferenceEmbeddingConfiguration: EmbeddingConfiguration {
6-
public var model: EmbeddingModel {
6+
public var embeddingModelKey: KeyPath<UserDefaultPreferenceKeys, PreferenceKey<String>>?
7+
8+
public var model: EmbeddingModel? {
79
let models = UserDefaults.shared.value(for: \.embeddingModels)
10+
11+
if let embeddingModelKey {
12+
let id = UserDefaults.shared.value(for: embeddingModelKey)
13+
if let model = models.first(where: { $0.id == id }) {
14+
return model
15+
}
16+
}
17+
818
let id = UserDefaults.shared.value(for: \.defaultChatFeatureEmbeddingModelId)
919
return models.first { $0.id == id }
10-
?? models.first ?? .init(id: "", name: "", format: .openAI, info: .init())
20+
?? models.first
1121
}
1222

1323
public var maxToken: Int {
14-
model.info.maxTokens
24+
model?.info.maxTokens ?? 0
1525
}
1626

1727
public var dimensions: Int {
18-
let dimensions = model.info.dimensions
28+
let dimensions = model?.info.dimensions ?? 0
1929
if dimensions <= 0 {
2030
return 1536
2131
}
2232
return dimensions
2333
}
2434

25-
public init() {}
35+
public init(
36+
embeddingModelKey: KeyPath<UserDefaultPreferenceKeys, PreferenceKey<String>>? = nil
37+
) {
38+
self.embeddingModelKey = embeddingModelKey
39+
}
2640
}
2741

2842
public class OverridingEmbeddingConfiguration<
@@ -55,7 +69,7 @@ public class OverridingEmbeddingConfiguration<
5569
self.configuration = configuration
5670
}
5771

58-
public var model: EmbeddingModel {
72+
public var model: EmbeddingModel? {
5973
if let model = overriding.model { return model }
6074
let models = UserDefaults.shared.value(for: \.embeddingModels)
6175
guard let id = overriding.modelId,
@@ -67,7 +81,7 @@ public class OverridingEmbeddingConfiguration<
6781
public var maxToken: Int {
6882
overriding.maxTokens ?? configuration.maxToken
6983
}
70-
84+
7185
public var dimensions: Int {
7286
overriding.dimensions ?? configuration.dimensions
7387
}

Tool/Sources/OpenAIService/EmbeddingService.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@ public struct EmbeddingService {
4141
}
4242

4343
public func embed(text: [String]) async throws -> EmbeddingResponse {
44+
guard let model = configuration.model else {
45+
throw ChatGPTServiceError.embeddingModelNotAvailable
46+
}
4447
guard let url = URL(string: configuration.endpoint) else {
4548
throw ChatGPTServiceError.endpointIncorrect
4649
}
47-
let model = configuration.model
4850
var request = URLRequest(url: url)
4951
request.httpMethod = "POST"
5052
let encoder = JSONEncoder()
@@ -90,10 +92,12 @@ public struct EmbeddingService {
9092
}
9193

9294
public func embed(tokens: [[Int]]) async throws -> EmbeddingResponse {
95+
guard let model = configuration.model else {
96+
throw ChatGPTServiceError.embeddingModelNotAvailable
97+
}
9398
guard let url = URL(string: configuration.endpoint) else {
9499
throw ChatGPTServiceError.endpointIncorrect
95100
}
96-
let model = configuration.model
97101
var request = URLRequest(url: url)
98102
request.httpMethod = "POST"
99103
let encoder = JSONEncoder()

Tool/Sources/Preferences/Keys.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,14 @@ public extension UserDefaultPreferenceKeys {
264264
var promptToCodeGenerateDescriptionInUserPreferredLanguage: PreferenceKey<Bool> {
265265
.init(defaultValue: true, key: "PromptToCodeGenerateDescriptionInUserPreferredLanguage")
266266
}
267+
268+
var promptToCodeChatModelId: PreferenceKey<String> {
269+
.init(defaultValue: "", key: "PromptToCodeChatModelId")
270+
}
271+
272+
var promptToCodeEmbeddingModelId: PreferenceKey<String> {
273+
.init(defaultValue: "", key: "PromptToCodeEmbeddingModelId")
274+
}
267275
}
268276

269277
// MARK: - Suggestion

0 commit comments

Comments
 (0)