Skip to content

Commit 6d6fc49

Browse files
committed
Convert API types to GoogleAIService
1 parent d3c02a1 commit 6d6fc49

2 files changed

Lines changed: 87 additions & 85 deletions

File tree

Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift

Lines changed: 0 additions & 84 deletions
This file was deleted.

Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift renamed to Tool/Sources/OpenAIService/APIs/GoogleAIService.swift

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,24 @@ import Foundation
33
import GoogleGenerativeAI
44
import Preferences
55

6-
struct GoogleCompletionAPI: ChatCompletionsAPI {
6+
actor GoogleAIService: ChatCompletionsAPI, ChatCompletionsStreamAPI {
77
let apiKey: String
88
let model: ChatModel
99
var requestBody: ChatCompletionsRequestBody
1010
let prompt: ChatGPTPrompt
1111

12+
init(
13+
apiKey: String,
14+
model: ChatModel,
15+
requestBody: ChatCompletionsRequestBody,
16+
prompt: ChatGPTPrompt
17+
) {
18+
self.apiKey = apiKey
19+
self.model = model
20+
self.requestBody = requestBody
21+
self.prompt = prompt
22+
}
23+
1224
func callAsFunction() async throws -> ChatCompletionResponseBody {
1325
let aiModel = GenerativeModel(
1426
name: model.info.modelName,
@@ -78,6 +90,80 @@ struct GoogleCompletionAPI: ChatCompletionsAPI {
7890
throw error
7991
}
8092
}
93+
94+
func callAsFunction() async throws
95+
-> AsyncThrowingStream<ChatCompletionsStreamDataChunk, Error>
96+
{
97+
let aiModel = GenerativeModel(
98+
name: model.info.modelName,
99+
apiKey: apiKey,
100+
generationConfig: .init(GenerationConfig(
101+
temperature: requestBody.temperature.map(Float.init),
102+
topP: requestBody.top_p.map(Float.init)
103+
))
104+
)
105+
let history = prompt.googleAICompatible.history.map { message in
106+
ModelContent(
107+
ChatMessage(
108+
role: message.role,
109+
content: message.content,
110+
name: message.name,
111+
functionCall: message.functionCall.map {
112+
.init(name: $0.name, arguments: $0.arguments)
113+
}
114+
)
115+
)
116+
}
117+
118+
let stream = AsyncThrowingStream<ChatCompletionsStreamDataChunk, Error> { continuation in
119+
let stream = aiModel.generateContentStream(history)
120+
let task = Task {
121+
do {
122+
for try await response in stream {
123+
if Task.isCancelled { break }
124+
let chunk = ChatCompletionsStreamDataChunk(
125+
object: "",
126+
model: model.info.modelName,
127+
choices: response.candidates.map { candidate in
128+
.init(delta: .init(
129+
role: .assistant,
130+
content: candidate.content.parts
131+
.first(where: { $0.text != nil })?.text ?? ""
132+
))
133+
}
134+
)
135+
continuation.yield(chunk)
136+
}
137+
continuation.finish()
138+
} catch let error as GenerateContentError {
139+
struct ErrorWrapper: Error, LocalizedError {
140+
let error: Error
141+
var errorDescription: String? {
142+
var s = ""
143+
dump(error, to: &s)
144+
return "Internal Error: \(s)"
145+
}
146+
}
147+
148+
switch error {
149+
case let .internalError(underlying):
150+
continuation.finish(throwing: ErrorWrapper(error: underlying))
151+
case .promptBlocked:
152+
continuation.finish(throwing: error)
153+
case .responseStoppedEarly:
154+
continuation.finish(throwing: error)
155+
}
156+
} catch {
157+
continuation.finish(throwing: error)
158+
}
159+
}
160+
continuation.onTermination = { _ in
161+
task.cancel()
162+
}
163+
}
164+
165+
return stream
166+
}
81167
}
82168

83169
extension ChatGPTPrompt {

0 commit comments

Comments
 (0)