Skip to content

Commit 4f00c12

Browse files
committed
Add gemini support
1 parent 93e2592 commit 4f00c12

File tree

8 files changed

+164
-51
lines changed

8 files changed

+164
-51
lines changed

Core/Sources/SuggestionWidget/FeatureReducers/PromptToCode.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ public struct PromptToCode: ReducerProtocol {
128128
case revertButtonTapped
129129
case stopRespondingButtonTapped
130130
case modifyCodeFinished
131-
case modifyCodeTrunkReceived(code: String, description: String)
131+
case modifyCodeChunkReceived(code: String, description: String)
132132
case modifyCodeFailed(error: String)
133133
case modifyCodeCancelled
134134
case cancelButtonTapped
@@ -189,7 +189,7 @@ public struct PromptToCode: ReducerProtocol {
189189
)
190190
for try await fragment in stream {
191191
try Task.checkCancellation()
192-
await send(.modifyCodeTrunkReceived(
192+
await send(.modifyCodeChunkReceived(
193193
code: fragment.code,
194194
description: fragment.description
195195
))
@@ -221,7 +221,7 @@ public struct PromptToCode: ReducerProtocol {
221221
promptToCodeService.stopResponding()
222222
return .cancel(id: CancellationKey.modifyCode(state.id))
223223

224-
case let .modifyCodeTrunkReceived(code, description):
224+
case let .modifyCodeChunkReceived(code, description):
225225
state.code = code
226226
state.description = description
227227
return .none

Tool/Sources/LangChain/ChatModel/OpenAIChat.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ public struct OpenAIChat: ChatModel {
3838
if stream {
3939
let stream = try await service.send(content: "")
4040
var message = ""
41-
for try await trunk in stream {
42-
message.append(trunk)
43-
callbackManagers.send(CallbackEvents.LLMDidProduceNewToken(info: trunk))
41+
for try await chunk in stream {
42+
message.append(chunk)
43+
callbackManagers.send(CallbackEvents.LLMDidProduceNewToken(info: chunk))
4444
}
4545
return await memory.history.last ?? .init(role: .assistant, content: "")
4646
} else {

Tool/Sources/LangChain/DocumentTransformer/TextSplitter.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ public extension TextSplitter {
2424
let paddingLength = texts.count - metadata.count
2525
let metadata = metadata + .init(repeating: [:], count: paddingLength)
2626
for (text, metadata) in zip(texts, metadata) {
27-
let trunks = try await split(text: text)
28-
for trunk in trunks {
29-
let document = Document(pageContent: trunk, metadata: metadata)
27+
let chunks = try await split(text: text)
28+
for chunk in chunks {
29+
let document = Document(pageContent: chunk, metadata: metadata)
3030
documents.append(document)
3131
}
3232
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import AIModel
2+
import Foundation
3+
import GoogleGenerativeAI
4+
import Preferences
5+
6+
struct GoogleCompletionAPI: CompletionAPI {
7+
let apiKey: String
8+
let model: ChatModel
9+
var requestBody: CompletionRequestBody
10+
11+
func callAsFunction() async throws -> CompletionResponseBody {
12+
let aiModel = GenerativeModel(
13+
name: model.name,
14+
apiKey: apiKey,
15+
generationConfig: .init(GenerationConfig(
16+
temperature: requestBody.temperature.map(Float.init),
17+
topP: requestBody.top_p.map(Float.init)
18+
))
19+
)
20+
let history = requestBody.messages.map { message in
21+
ModelContent(
22+
ChatMessage(
23+
role: message.role,
24+
content: message.content,
25+
name: message.name,
26+
functionCall: message.function_call.map {
27+
.init(name: $0.name, arguments: $0.arguments ?? "")
28+
}
29+
)
30+
)
31+
}
32+
33+
let response = try await aiModel.generateContent(history)
34+
35+
return .init(
36+
object: "chat.completion",
37+
model: model.name,
38+
usage: .init(prompt_tokens: 0, completion_tokens: 0, total_tokens: 0),
39+
choices: response.candidates.enumerated().map {
40+
let (index, candidate) = $0
41+
return .init(
42+
message: .init(
43+
role: .assistant,
44+
content: candidate.content.parts.first(where: { part in
45+
if let text = part.text {
46+
return !text.isEmpty
47+
} else {
48+
return false
49+
}
50+
})?.text ?? ""
51+
),
52+
index: index,
53+
finish_reason: candidate.finishReason?.rawValue ?? ""
54+
)
55+
}
56+
)
57+
}
58+
}
59+
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import AIModel
2+
import Foundation
3+
import GoogleGenerativeAI
4+
import Preferences
5+
6+
struct GoogleCompletionStreamAPI: CompletionStreamAPI {
7+
let apiKey: String
8+
let model: ChatModel
9+
var requestBody: CompletionRequestBody
10+
11+
func callAsFunction() async throws -> AsyncThrowingStream<CompletionStreamDataChunk, Error> {
12+
let aiModel = GenerativeModel(
13+
name: model.name,
14+
apiKey: apiKey,
15+
generationConfig: .init(GenerationConfig(
16+
temperature: requestBody.temperature.map(Float.init),
17+
topP: requestBody.top_p.map(Float.init)
18+
))
19+
)
20+
let history = requestBody.messages.map { message in
21+
ModelContent(
22+
ChatMessage(
23+
role: message.role,
24+
content: message.content,
25+
name: message.name,
26+
functionCall: message.function_call.map {
27+
.init(name: $0.name, arguments: $0.arguments ?? "")
28+
}
29+
)
30+
)
31+
}
32+
33+
let stream = AsyncThrowingStream<CompletionStreamDataChunk, Error> { continuation in
34+
let stream = aiModel.generateContentStream(history)
35+
let task = Task {
36+
do {
37+
for try await response in stream {
38+
if Task.isCancelled { break }
39+
let chunk = CompletionStreamDataChunk(
40+
object: "",
41+
model: model.name,
42+
choices: response.candidates.map { candidate in
43+
.init(delta: .init(
44+
role: .assistant,
45+
content: candidate.content.parts
46+
.first(where: { $0.text != nil })?.text ?? ""
47+
))
48+
}
49+
)
50+
continuation.yield(chunk)
51+
}
52+
continuation.finish()
53+
} catch {
54+
continuation.finish(throwing: error)
55+
}
56+
}
57+
continuation.onTermination = { _ in
58+
task.cancel()
59+
}
60+
}
61+
62+
return stream
63+
}
64+
}
65+

Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1+
import AIModel
12
import AsyncAlgorithms
23
import Foundation
34
import Preferences
4-
import AIModel
55

66
typealias CompletionStreamAPIBuilder = (String, ChatModel, URL, CompletionRequestBody)
7-
-> CompletionStreamAPI
7+
-> any CompletionStreamAPI
88

99
protocol CompletionStreamAPI {
10-
func callAsFunction() async throws -> (
11-
trunkStream: AsyncThrowingStream<CompletionStreamDataTrunk, Error>,
12-
cancel: Cancellable
13-
)
10+
func callAsFunction() async throws -> AsyncThrowingStream<CompletionStreamDataChunk, Error>
1411
}
1512

1613
public enum FunctionCallStrategy: Codable, Equatable {
@@ -128,7 +125,7 @@ struct CompletionRequestBody: Codable, Equatable {
128125
}
129126
}
130127

131-
struct CompletionStreamDataTrunk: Codable {
128+
struct CompletionStreamDataChunk: Codable {
132129
var id: String?
133130
var object: String?
134131
var model: String?
@@ -171,10 +168,7 @@ struct OpenAICompletionStreamAPI: CompletionStreamAPI {
171168
self.model = model
172169
}
173170

174-
func callAsFunction() async throws -> (
175-
trunkStream: AsyncThrowingStream<CompletionStreamDataTrunk, Error>,
176-
cancel: Cancellable
177-
) {
171+
func callAsFunction() async throws -> AsyncThrowingStream<CompletionStreamDataChunk, Error> {
178172
var request = URLRequest(url: endpoint)
179173
request.httpMethod = "POST"
180174
let encoder = JSONEncoder()
@@ -187,7 +181,7 @@ struct OpenAICompletionStreamAPI: CompletionStreamAPI {
187181
case .azureOpenAI:
188182
request.setValue(apiKey, forHTTPHeaderField: "api-key")
189183
case .googleAI:
190-
assert(false, "Unsupported")
184+
assertionFailure("Unsupported")
191185
}
192186
}
193187

@@ -207,35 +201,31 @@ struct OpenAICompletionStreamAPI: CompletionStreamAPI {
207201
throw error ?? ChatGPTServiceError.responseInvalid
208202
}
209203

210-
var receivingDataTask: Task<Void, Error>?
211-
212-
let stream = AsyncThrowingStream<CompletionStreamDataTrunk, Error> { continuation in
213-
receivingDataTask = Task {
204+
let stream = AsyncThrowingStream<CompletionStreamDataChunk, Error> { continuation in
205+
let task = Task {
214206
do {
215207
for try await line in result.lines {
216208
if Task.isCancelled { break }
217209
let prefix = "data: "
218210
guard line.hasPrefix(prefix),
219211
let content = line.dropFirst(prefix.count).data(using: .utf8),
220-
let trunk = try? JSONDecoder()
221-
.decode(CompletionStreamDataTrunk.self, from: content)
212+
let chunk = try? JSONDecoder()
213+
.decode(CompletionStreamDataChunk.self, from: content)
222214
else { continue }
223-
continuation.yield(trunk)
215+
continuation.yield(chunk)
224216
}
225217
continuation.finish()
226218
} catch {
227219
continuation.finish(throwing: error)
228220
}
229221
}
230-
}
231-
232-
return (
233-
stream,
234-
Cancellable {
222+
continuation.onTermination = { _ in
223+
task.cancel()
235224
result.task.cancel()
236-
receivingDataTask?.cancel()
237225
}
238-
)
226+
}
227+
228+
return stream
239229
}
240230
}
241231

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public class ChatGPTService: ChatGPTServiceType {
7373
apiKey, model, endpoint, requestBody in
7474
switch model.format {
7575
case .googleAI:
76-
fatalError()
76+
return GoogleCompletionStreamAPI(apiKey: apiKey, model: model, requestBody: requestBody)
7777
case .openAI, .openAICompatible, .azureOpenAI:
7878
return OpenAICompletionStreamAPI(
7979
apiKey: apiKey,
@@ -88,7 +88,7 @@ public class ChatGPTService: ChatGPTServiceType {
8888
apiKey, model, endpoint, requestBody in
8989
switch model.format {
9090
case .googleAI:
91-
fatalError()
91+
return GoogleCompletionAPI(apiKey: apiKey, model: model, requestBody: requestBody)
9292
case .openAI, .openAICompatible, .azureOpenAI:
9393
return OpenAICompletionAPI(
9494
apiKey: apiKey,
@@ -321,13 +321,12 @@ extension ChatGPTService {
321321
id: proposedId,
322322
references: prompt.references
323323
)
324-
let (trunks, cancel) = try await api()
325-
for try await trunk in trunks {
324+
let chunks = try await api()
325+
for try await chunk in chunks {
326326
if Task.isCancelled {
327-
cancel()
328327
throw CancellationError()
329328
}
330-
guard let delta = trunk.choices?.first?.delta else { continue }
329+
guard let delta = chunk.choices?.first?.delta else { continue }
331330

332331
// The api will always return a function call with JSON object.
333332
// The first round will contain the function name and an empty argument.

Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,13 @@ extension ChatGPTStreamTests {
240240
struct MockCompletionStreamAPI_Message: CompletionStreamAPI {
241241
@Dependency(\.uuid) var uuid
242242
func callAsFunction() async throws -> (
243-
trunkStream: AsyncThrowingStream<CompletionStreamDataTrunk, Error>,
243+
chunkStream: AsyncThrowingStream<CompletionStreamDataChunk, Error>,
244244
cancel: OpenAIService.Cancellable
245245
) {
246246
let id = uuid().uuidString
247247
return (
248-
AsyncThrowingStream<CompletionStreamDataTrunk, Error> { continuation in
249-
let trunks: [CompletionStreamDataTrunk] = [
248+
AsyncThrowingStream<CompletionStreamDataChunk, Error> { continuation in
249+
let chunks: [CompletionStreamDataChunk] = [
250250
.init(id: id, object: "", model: "", choices: [
251251
.init(delta: .init(role: .assistant), index: 0, finish_reason: ""),
252252
]),
@@ -260,8 +260,8 @@ extension ChatGPTStreamTests {
260260
.init(delta: .init(content: "friends"), index: 0, finish_reason: ""),
261261
]),
262262
]
263-
for trunk in trunks {
264-
continuation.yield(trunk)
263+
for chunk in chunks {
264+
continuation.yield(chunk)
265265
}
266266
continuation.finish()
267267
},
@@ -273,13 +273,13 @@ extension ChatGPTStreamTests {
273273
struct MockCompletionStreamAPI_Function: CompletionStreamAPI {
274274
@Dependency(\.uuid) var uuid
275275
func callAsFunction() async throws -> (
276-
trunkStream: AsyncThrowingStream<CompletionStreamDataTrunk, Error>,
276+
chunkStream: AsyncThrowingStream<CompletionStreamDataChunk, Error>,
277277
cancel: OpenAIService.Cancellable
278278
) {
279279
let id = uuid().uuidString
280280
return (
281-
AsyncThrowingStream<CompletionStreamDataTrunk, Error> { continuation in
282-
let trunks: [CompletionStreamDataTrunk] = [
281+
AsyncThrowingStream<CompletionStreamDataChunk, Error> { continuation in
282+
let chunks: [CompletionStreamDataChunk] = [
283283
.init(id: id, object: "", model: "", choices: [
284284
.init(
285285
delta: .init(
@@ -317,8 +317,8 @@ extension ChatGPTStreamTests {
317317
finish_reason: ""
318318
)]),
319319
]
320-
for trunk in trunks {
321-
continuation.yield(trunk)
320+
for chunk in chunks {
321+
continuation.yield(chunk)
322322
}
323323
continuation.finish()
324324
},

0 commit comments

Comments
 (0)