@@ -3,6 +3,83 @@ import Foundation
33import GoogleGenerativeAI
44import Preferences
55
6+ struct GoogleCompletionAPI : CompletionAPI {
7+ let apiKey : String
8+ let model : ChatModel
9+ var requestBody : CompletionRequestBody
10+ let prompt : ChatGPTPrompt
11+
12+ func callAsFunction( ) async throws -> CompletionResponseBody {
13+ let aiModel = GenerativeModel (
14+ name: model. info. modelName,
15+ apiKey: apiKey,
16+ generationConfig: . init( GenerationConfig (
17+ temperature: requestBody. temperature. map ( Float . init) ,
18+ topP: requestBody. top_p. map ( Float . init)
19+ ) )
20+ )
21+ let history = prompt. googleAICompatible. history. map { message in
22+ ModelContent (
23+ ChatMessage (
24+ role: message. role,
25+ content: message. content,
26+ name: message. name,
27+ functionCall: message. functionCall. map {
28+ . init( name: $0. name, arguments: $0. arguments)
29+ }
30+ )
31+ )
32+ }
33+
34+ do {
35+ let response = try await aiModel. generateContent ( history)
36+
37+ return . init(
38+ object: " chat.completion " ,
39+ model: model. info. modelName,
40+ usage: . init( prompt_tokens: 0 , completion_tokens: 0 , total_tokens: 0 ) ,
41+ choices: response. candidates. enumerated ( ) . map {
42+ let ( index, candidate) = $0
43+ return . init(
44+ message: . init(
45+ role: . assistant,
46+ content: candidate. content. parts. first ( where: { part in
47+ if let text = part. text {
48+ return !text. isEmpty
49+ } else {
50+ return false
51+ }
52+ } ) ? . text ?? " "
53+ ) ,
54+ index: index,
55+ finish_reason: candidate. finishReason? . rawValue ?? " "
56+ )
57+ }
58+ )
59+ } catch let error as GenerateContentError {
60+ struct ErrorWrapper : Error , LocalizedError {
61+ let error : Error
62+ var errorDescription : String ? {
63+ var s = " "
64+ dump ( error, to: & s)
65+ return " Internal Error: \( s) "
66+ }
67+ }
68+
69+ switch error {
70+ case let . internalError( underlying) :
71+ throw ErrorWrapper ( error: underlying)
72+ case . promptBlocked:
73+ throw error
74+ case . responseStoppedEarly:
75+ throw error
76+ }
77+ } catch {
78+ throw error
79+ }
80+ }
81+ }
82+
683extension ChatGPTPrompt {
784 var googleAICompatible : ChatGPTPrompt {
885 var history = self . history
@@ -20,6 +97,7 @@ extension ChatGPTPrompt {
2097 guard lastIndex >= 0 else { // first message
2198 if message. role == . system {
2299 reformattedHistory. append ( . init(
100+ id: message. id,
23101 role: . user,
24102 content: ModelContent . convertContent ( of: message)
25103 ) )
@@ -40,6 +118,7 @@ extension ChatGPTPrompt {
40118 . convertRole ( message. role)
41119 {
42120 let newMessage = ChatMessage (
121+ id: message. id,
43122 role: message. role == . assistant ? . assistant : . user,
44123 content: """
45124 \( ModelContent . convertContent ( of: lastMessage) )
@@ -78,80 +157,42 @@ extension ChatGPTPrompt {
78157 }
79158}
80159
81- struct GoogleCompletionAPI : CompletionAPI {
82- let apiKey : String
83- let model : ChatModel
84- var requestBody : CompletionRequestBody
85- let prompt : ChatGPTPrompt
86-
87- func callAsFunction( ) async throws -> CompletionResponseBody {
88- let aiModel = GenerativeModel (
89- name: model. info. modelName,
90- apiKey: apiKey,
91- generationConfig: . init( GenerationConfig (
92- temperature: requestBody. temperature. map ( Float . init) ,
93- topP: requestBody. top_p. map ( Float . init)
94- ) )
95- )
96- let history = prompt. googleAICompatible. history. map { message in
97- ModelContent (
98- ChatMessage (
99- role: message. role,
100- content: message. content,
101- name: message. name,
102- functionCall: message. functionCall. map {
103- . init( name: $0. name, arguments: $0. arguments)
104- }
105- )
106- )
160+ extension ModelContent {
161+ static func convertRole( _ role: ChatMessage . Role ) -> String {
162+ switch role {
163+ case . user, . system, . function:
164+ return " user "
165+ case . assistant:
166+ return " model "
107167 }
168+ }
108169
109- do {
110- let response = try await aiModel. generateContent ( history)
111-
112- return . init(
113- object: " chat.completion " ,
114- model: model. info. modelName,
115- usage: . init( prompt_tokens: 0 , completion_tokens: 0 , total_tokens: 0 ) ,
116- choices: response. candidates. enumerated ( ) . map {
117- let ( index, candidate) = $0
118- return . init(
119- message: . init(
120- role: . assistant,
121- content: candidate. content. parts. first ( where: { part in
122- if let text = part. text {
123- return !text. isEmpty
124- } else {
125- return false
126- }
127- } ) ? . text ?? " "
128- ) ,
129- index: index,
130- finish_reason: candidate. finishReason? . rawValue ?? " "
131- )
132- }
133- )
134- } catch let error as GenerateContentError {
135- struct ErrorWrapper : Error , LocalizedError {
136- let error : Error
137- var errorDescription : String ? {
138- var s = " "
139- dump ( error, to: & s)
140- return " Internal Error: \( s) "
141- }
142- }
143-
144- switch error {
145- case let . internalError( underlying) :
146- throw ErrorWrapper ( error: underlying)
147- case . promptBlocked:
148- throw error
149- case . responseStoppedEarly:
150- throw error
170+ static func convertContent( of message: ChatMessage ) -> String {
171+ switch message. role {
172+ case . system:
173+ return " System Prompt: \n \( message. content ?? " " ) "
174+ case . user:
175+ return message. content ?? " "
176+ case . function:
177+ return """
178+ Result of \( message. name ?? " function " ) : \( message. content ?? " N/A " )
179+ """
180+ case . assistant:
181+ if let functionCall = message. functionCall {
182+ return """
183+ Call function: \( functionCall. name)
184+ Arguments: \( functionCall. arguments)
185+ """
186+ } else {
187+ return message. content ?? " "
151188 }
152- } catch {
153- throw error
154189 }
155190 }
191+
192+ init ( _ message: ChatMessage ) {
193+ let role = Self . convertRole ( message. role)
194+ let parts = [ ModelContent . Part. text ( Self . convertContent ( of: message) ) ]
195+ self = . init( role: role, parts: parts)
196+ }
156197}
157198
0 commit comments