Skip to content

Commit 47086c6

Browse files
committed
Support setting Google AI api version
1 parent 6249d12 commit 47086c6

4 files changed

Lines changed: 44 additions & 26 deletions

File tree

Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import AIModel
2-
import Toast
32
import ComposableArchitecture
43
import Dependencies
54
import Keychain
65
import OpenAIService
76
import Preferences
87
import SwiftUI
8+
import Toast
99

1010
struct ChatModelEdit: ReducerProtocol {
1111
struct State: Equatable, Identifiable {
@@ -16,6 +16,7 @@ struct ChatModelEdit: ReducerProtocol {
1616
@BindingState var supportsFunctionCalling: Bool = true
1717
@BindingState var modelName: String = ""
1818
@BindingState var ollamaKeepAlive: String = ""
19+
@BindingState var apiVersion: String = ""
1920
var apiKeyName: String { apiKeySelection.apiKeyName }
2021
var baseURL: String { baseURLSelection.baseURL }
2122
var isFullURL: Bool { baseURLSelection.isFullURL }
@@ -47,6 +48,7 @@ struct ChatModelEdit: ReducerProtocol {
4748
toast($0, $1, "ChatModelEdit")
4849
}
4950
}
51+
5052
@Dependency(\.apiKeyKeychain) var keychain
5153

5254
var body: some ReducerProtocol<State, Action> {
@@ -77,19 +79,7 @@ struct ChatModelEdit: ReducerProtocol {
7779
case .testButtonClicked:
7880
guard !state.isTesting else { return .none }
7981
state.isTesting = true
80-
let model = ChatModel(
81-
id: state.id,
82-
name: state.name,
83-
format: state.format,
84-
info: .init(
85-
apiKeyName: state.apiKeyName,
86-
baseURL: state.baseURL,
87-
isFullURL: state.isFullURL,
88-
maxTokens: state.maxTokens,
89-
supportsFunctionCalling: state.supportsFunctionCalling,
90-
modelName: state.modelName
91-
)
92-
)
82+
let model = ChatModel(state: state)
9383
return .run { send in
9484
do {
9585
let service = ChatGPTService(
@@ -194,6 +184,7 @@ extension ChatModelEdit.State {
194184
supportsFunctionCalling: model.info.supportsFunctionCalling,
195185
modelName: model.info.modelName,
196186
ollamaKeepAlive: model.info.ollamaInfo.keepAlive,
187+
apiVersion: model.info.googleGenerativeAIInfo.apiVersion,
197188
apiKeySelection: .init(
198189
apiKeyName: model.info.apiKeyName,
199190
apiKeyManagement: .init(availableAPIKeyNames: [model.info.apiKeyName])
@@ -223,7 +214,8 @@ extension ChatModel {
223214
}
224215
}(),
225216
modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines),
226-
ollamaInfo: .init(keepAlive: state.ollamaKeepAlive)
217+
ollamaInfo: .init(keepAlive: state.ollamaKeepAlive),
218+
googleGenerativeAIInfo: .init(apiVersion: state.apiVersion)
227219
)
228220
)
229221
}

Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ struct ChatModelEditView: View {
331331
baseURLTextField(prompt: Text("https://generativelanguage.googleapis.com")) {
332332
Text("/v1")
333333
}
334-
334+
335335
apiKeyNamePicker
336336

337337
WithViewStore(
@@ -357,6 +357,10 @@ struct ChatModelEditView: View {
357357
}
358358

359359
maxTokensTextField
360+
361+
WithViewStore(store, removeDuplicates: { $0.apiVersion == $1.apiVersion }) { viewStore in
362+
TextField("API Version", text: viewStore.$apiVersion, prompt: Text("v1"))
363+
}
360364
}
361365

362366
@ViewBuilder
@@ -396,7 +400,7 @@ struct ChatModelEditView: View {
396400
baseURLTextField(prompt: Text("https://api.anthropic.com")) {
397401
Text("/v1/messages")
398402
}
399-
403+
400404
apiKeyNamePicker
401405

402406
WithViewStore(
@@ -425,7 +429,7 @@ struct ChatModelEditView: View {
425429
.frame(width: 20)
426430
}
427431
}
428-
432+
429433
maxTokensTextField
430434

431435
VStack(alignment: .leading, spacing: 8) {

Tool/Sources/AIModel/ChatModel.swift

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ public struct ChatModel: Codable, Equatable, Identifiable {
4343
self.organizationID = organizationID
4444
}
4545
}
46+
47+
public struct GoogleGenerativeAIInfo: Codable, Equatable {
48+
@FallbackDecoding<EmptyString>
49+
public var apiVersion: String
50+
51+
public init(apiVersion: String = "") {
52+
self.apiVersion = apiVersion
53+
}
54+
}
4655

4756
@FallbackDecoding<EmptyString>
4857
public var apiKeyName: String
@@ -61,6 +70,8 @@ public struct ChatModel: Codable, Equatable, Identifiable {
6170
public var openAIInfo: OpenAIInfo
6271
@FallbackDecoding<EmptyChatModelOllamaInfo>
6372
public var ollamaInfo: OllamaInfo
73+
@FallbackDecoding<EmptyChatModelGoogleGenerativeAIInfo>
74+
public var googleGenerativeAIInfo: GoogleGenerativeAIInfo
6475

6576
public init(
6677
apiKeyName: String = "",
@@ -70,7 +81,8 @@ public struct ChatModel: Codable, Equatable, Identifiable {
7081
supportsFunctionCalling: Bool = true,
7182
modelName: String = "",
7283
openAIInfo: OpenAIInfo = OpenAIInfo(),
73-
ollamaInfo: OllamaInfo = OllamaInfo()
84+
ollamaInfo: OllamaInfo = OllamaInfo(),
85+
googleGenerativeAIInfo: GoogleGenerativeAIInfo = GoogleGenerativeAIInfo()
7486
) {
7587
self.apiKeyName = apiKeyName
7688
self.baseURL = baseURL
@@ -80,6 +92,7 @@ public struct ChatModel: Codable, Equatable, Identifiable {
8092
self.modelName = modelName
8193
self.openAIInfo = openAIInfo
8294
self.ollamaInfo = ollamaInfo
95+
self.googleGenerativeAIInfo = googleGenerativeAIInfo
8396
}
8497
}
8598

@@ -132,3 +145,6 @@ public struct EmptyChatModelOpenAIInfo: FallbackValueProvider {
132145
public static var defaultValue: ChatModel.Info.OpenAIInfo { .init() }
133146
}
134147

148+
public struct EmptyChatModelGoogleGenerativeAIInfo: FallbackValueProvider {
149+
public static var defaultValue: ChatModel.Info.GoogleGenerativeAIInfo { .init() }
150+
}

Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@ actor GoogleAIChatCompletionsService: ChatCompletionsAPI, ChatCompletionsStreamA
3030
apiKey: apiKey,
3131
generationConfig: .init(GenerationConfig(
3232
temperature: requestBody.temperature.map(Float.init)
33-
)),
34-
baseURL: baseURL
33+
)),
34+
baseURL: baseURL,
35+
requestOptions: model.info.googleGenerativeAIInfo.apiVersion.isEmpty
36+
? .init()
37+
: .init(apiVersion: model.info.googleGenerativeAIInfo.apiVersion)
3538
)
3639
let history = prompt.googleAICompatible.history.map { message in
3740
ModelContent(message)
@@ -59,7 +62,7 @@ actor GoogleAIChatCompletionsService: ChatCompletionsAPI, ChatCompletionsStreamA
5962
throw error
6063
case .promptImageContentError:
6164
throw error
62-
case let .invalidAPIKey(message: message):
65+
case .invalidAPIKey:
6366
throw error
6467
case .unsupportedUserLocation:
6568
throw error
@@ -77,8 +80,11 @@ actor GoogleAIChatCompletionsService: ChatCompletionsAPI, ChatCompletionsStreamA
7780
apiKey: apiKey,
7881
generationConfig: .init(GenerationConfig(
7982
temperature: requestBody.temperature.map(Float.init)
80-
)),
81-
baseURL: baseURL
83+
)),
84+
baseURL: baseURL,
85+
requestOptions: model.info.googleGenerativeAIInfo.apiVersion.isEmpty
86+
? .init()
87+
: .init(apiVersion: model.info.googleGenerativeAIInfo.apiVersion)
8288
)
8389
let history = prompt.googleAICompatible.history.map { message in
8490
ModelContent(message)
@@ -111,9 +117,9 @@ actor GoogleAIChatCompletionsService: ChatCompletionsAPI, ChatCompletionsStreamA
111117
continuation.finish(throwing: error)
112118
case .responseStoppedEarly:
113119
continuation.finish(throwing: error)
114-
case let .promptImageContentError(underlying: underlying):
120+
case .promptImageContentError:
115121
continuation.finish(throwing: error)
116-
case let .invalidAPIKey(message: message):
122+
case .invalidAPIKey:
117123
continuation.finish(throwing: error)
118124
case .unsupportedUserLocation:
119125
continuation.finish(throwing: error)

0 commit comments

Comments
 (0)