Skip to content

Commit 2500225

Browse files
committed
Merge branch 'feature/google-gemini-support' into develop
2 parents a3858b7 + aff7de2 commit 2500225

23 files changed

Lines changed: 675 additions & 149 deletions

Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,26 @@ struct ChatModelEdit: ReducerProtocol {
114114
return .none
115115

116116
case .checkSuggestedMaxTokens:
117-
guard state.format == .openAI,
118-
let knownModel = ChatGPTModel(rawValue: state.modelName)
119-
else {
117+
switch state.format {
118+
case .openAI:
119+
if let knownModel = ChatGPTModel(rawValue: state.modelName) {
120+
state.suggestedMaxTokens = knownModel.maxToken
121+
} else {
122+
state.suggestedMaxTokens = nil
123+
}
124+
return .none
125+
case .googleAI:
126+
if let knownModel = GoogleGenerativeAIModel(rawValue: state.modelName) {
127+
state.suggestedMaxTokens = knownModel.maxToken
128+
} else {
129+
state.suggestedMaxTokens = nil
130+
}
131+
return .none
132+
default:
120133
state.suggestedMaxTokens = nil
121134
return .none
122135
}
123-
state.suggestedMaxTokens = knownModel.maxToken
124-
return .none
125-
136+
126137
case .apiKeySelection:
127138
return .none
128139

Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ struct ChatModelEditView: View {
2222
azureOpenAI
2323
case .openAICompatible:
2424
openAICompatible
25+
case .googleAI:
26+
googleAI
2527
}
2628
}
2729
}
@@ -88,6 +90,8 @@ struct ChatModelEditView: View {
8890
Text("Azure OpenAI").tag(format)
8991
case .openAICompatible:
9092
Text("OpenAI Compatible").tag(format)
93+
case .googleAI:
94+
Text("Google Generative AI").tag(format)
9195
}
9296
}
9397
},
@@ -269,6 +273,35 @@ struct ChatModelEditView: View {
269273
maxTokensTextField
270274
supportsFunctionCallingToggle
271275
}
276+
277+
@ViewBuilder
278+
var googleAI: some View {
279+
apiKeyNamePicker
280+
281+
WithViewStore(
282+
store,
283+
removeDuplicates: { $0.modelName == $1.modelName }
284+
) { viewStore in
285+
TextField("Model Name", text: viewStore.$modelName)
286+
.overlay(alignment: .trailing) {
287+
Picker(
288+
"",
289+
selection: viewStore.$modelName,
290+
content: {
291+
if GoogleGenerativeAIModel(rawValue: viewStore.state.modelName) == nil {
292+
Text("Custom Model").tag(viewStore.state.modelName)
293+
}
294+
ForEach(GoogleGenerativeAIModel.allCases, id: \.self) { model in
295+
Text(model.rawValue).tag(model.rawValue)
296+
}
297+
}
298+
)
299+
.frame(width: 20)
300+
}
301+
}
302+
303+
maxTokensTextField
304+
}
272305
}
273306

274307
#Preview("OpenAI") {

Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelManagement.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ extension ChatModel: ManageableAIModel {
1010
case .openAI: return "OpenAI"
1111
case .azureOpenAI: return "Azure OpenAI"
1212
case .openAICompatible: return "OpenAI Compatible"
13+
case .googleAI: return "Google Generative AI"
1314
}
1415
}
1516

Core/Sources/PromptToCodeService/OpenAIPromptToCodeService.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ extension OpenAIPromptToCodeService {
222222
{
223223
func extractCodeFromMarkdown(_ markdown: String) -> (code: String, endIndex: Int)? {
224224
let codeBlockRegex = try! NSRegularExpression(
225-
pattern: #"```(?:\w+)?[\n]([\s\S]+?)[\n]```"#,
225+
pattern: #"```(?:\w+)?\R([\s\S]+?)\R```"#,
226226
options: .dotMatchesLineSeparators
227227
)
228228
let range = NSRange(markdown.startIndex..<markdown.endIndex, in: markdown)
@@ -232,7 +232,7 @@ extension OpenAIPromptToCodeService {
232232
}
233233

234234
let incompleteCodeBlockRegex = try! NSRegularExpression(
235-
pattern: #"```(?:\w+)?[\n]([\s\S]+?)$"#,
235+
pattern: #"```(?:\w+)?\R([\s\S]+?)$"#,
236236
options: .dotMatchesLineSeparators
237237
)
238238
let range2 = NSRange(markdown.startIndex..<markdown.endIndex, in: markdown)

Core/Sources/SuggestionWidget/FeatureReducers/PromptToCode.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ public struct PromptToCode: ReducerProtocol {
128128
case revertButtonTapped
129129
case stopRespondingButtonTapped
130130
case modifyCodeFinished
131-
case modifyCodeTrunkReceived(code: String, description: String)
131+
case modifyCodeChunkReceived(code: String, description: String)
132132
case modifyCodeFailed(error: String)
133133
case modifyCodeCancelled
134134
case cancelButtonTapped
@@ -189,7 +189,7 @@ public struct PromptToCode: ReducerProtocol {
189189
)
190190
for try await fragment in stream {
191191
try Task.checkCancellation()
192-
await send(.modifyCodeTrunkReceived(
192+
await send(.modifyCodeChunkReceived(
193193
code: fragment.code,
194194
description: fragment.description
195195
))
@@ -221,7 +221,7 @@ public struct PromptToCode: ReducerProtocol {
221221
promptToCodeService.stopResponding()
222222
return .cancel(id: CancellationKey.modifyCode(state.id))
223223

224-
case let .modifyCodeTrunkReceived(code, description):
224+
case let .modifyCodeChunkReceived(code, description):
225225
state.code = code
226226
state.description = description
227227
return .none

Pro

Submodule Pro updated from 57f7523 to eeae9b4

Tool/Package.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ let package = Package(
6363
.package(url: "https://github.com/apple/swift-syntax.git", branch: "main"),
6464
.package(url: "https://github.com/GottaGetSwifty/CodableWrappers", from: "2.0.7"),
6565
.package(url: "https://github.com/krzyzanowskim/STTextView", from: "0.8.21"),
66+
.package(url: "https://github.com/google/generative-ai-swift", from: "0.4.4"),
6667

6768
// TreeSitter
6869
.package(url: "https://github.com/intitni/SwiftTreeSitter.git", branch: "main"),
@@ -130,6 +131,7 @@ let package = Package(
130131
name: "TokenEncoder",
131132
dependencies: [
132133
.product(name: "Tiktoken", package: "Tiktoken"),
134+
.product(name: "GoogleGenerativeAI", package: "generative-ai-swift"),
133135
],
134136
resources: [
135137
.copy("Resources/cl100k_base.tiktoken"),
@@ -313,6 +315,7 @@ let package = Package(
313315
"Keychain",
314316
.product(name: "JSONRPC", package: "JSONRPC"),
315317
.product(name: "AsyncAlgorithms", package: "swift-async-algorithms"),
318+
.product(name: "GoogleGenerativeAI", package: "generative-ai-swift"),
316319
.product(
317320
name: "ComposableArchitecture",
318321
package: "swift-composable-architecture"

Tool/Sources/AIModel/ChatModel.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ public struct ChatModel: Codable, Equatable, Identifiable {
2020
case openAI
2121
case azureOpenAI
2222
case openAICompatible
23+
case googleAI
2324
}
2425

2526
public struct Info: Codable, Equatable {
@@ -69,6 +70,10 @@ public struct ChatModel: Codable, Equatable, Identifiable {
6970
let version = "2023-07-01-preview"
7071
if baseURL.isEmpty { return "" }
7172
return "\(baseURL)/openai/deployments/\(deployment)/chat/completions?api-version=\(version)"
73+
case .googleAI:
74+
let baseURL = info.baseURL
75+
if baseURL.isEmpty { return "https://generativelanguage.googleapis.com/v1" }
76+
return "\(baseURL)/v1/chat/completions"
7277
}
7378
}
7479
}

Tool/Sources/LangChain/ChatModel/OpenAIChat.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ public struct OpenAIChat: ChatModel {
3838
if stream {
3939
let stream = try await service.send(content: "")
4040
var message = ""
41-
for try await trunk in stream {
42-
message.append(trunk)
43-
callbackManagers.send(CallbackEvents.LLMDidProduceNewToken(info: trunk))
41+
for try await chunk in stream {
42+
message.append(chunk)
43+
callbackManagers.send(CallbackEvents.LLMDidProduceNewToken(info: chunk))
4444
}
4545
return await memory.history.last ?? .init(role: .assistant, content: "")
4646
} else {

Tool/Sources/LangChain/DocumentTransformer/TextSplitter.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ public extension TextSplitter {
2424
let paddingLength = texts.count - metadata.count
2525
let metadata = metadata + .init(repeating: [:], count: paddingLength)
2626
for (text, metadata) in zip(texts, metadata) {
27-
let trunks = try await split(text: text)
28-
for trunk in trunks {
29-
let document = Document(pageContent: trunk, metadata: metadata)
27+
let chunks = try await split(text: text)
28+
for chunk in chunks {
29+
let document = Document(pageContent: chunk, metadata: metadata)
3030
documents.append(document)
3131
}
3232
}

0 commit comments

Comments
 (0)