Skip to content

Commit ed811e8

Browse files
committed
Add OpenAIEmbeddingService
1 parent 58483b0 commit ed811e8

8 files changed

Lines changed: 218 additions & 121 deletions

Pro

Submodule Pro updated from 322e945 to de3b6ef
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import AIModel
2+
import Foundation
3+
import Preferences
4+
5+
protocol EmbeddingAPI {
6+
func embed(text: String) async throws -> EmbeddingResponse
7+
func embed(texts: [String]) async throws -> EmbeddingResponse
8+
func embed(tokens: [[Int]]) async throws -> EmbeddingResponse
9+
}
10+
11+
public struct EmbeddingResponse: Decodable {
12+
public struct Object: Decodable {
13+
public var embedding: [Float]
14+
public var index: Int
15+
public var object: String
16+
}
17+
18+
public var data: [Object]
19+
public var model: String
20+
21+
public struct Usage: Decodable {
22+
public var prompt_tokens: Int
23+
public var total_tokens: Int
24+
}
25+
26+
public var usage: Usage
27+
}
28+
29+
30+

Tool/Sources/OpenAIService/APIs/GoogleAIService.swift renamed to Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import Foundation
33
import GoogleGenerativeAI
44
import Preferences
55

6-
actor GoogleAIService: ChatCompletionsAPI, ChatCompletionsStreamAPI {
6+
actor GoogleAIChatCompletionsService: ChatCompletionsAPI, ChatCompletionsStreamAPI {
77
let apiKey: String
88
let model: ChatModel
99
var requestBody: ChatCompletionsRequestBody

Tool/Sources/OpenAIService/APIs/OlamaService.swift renamed to Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import AIModel
22
import Foundation
33
import Preferences
44

5-
public actor OllamaService {
5+
public actor OllamaChatCompletionsService {
66
var apiKey: String
77
var endpoint: URL
88
var requestBody: ChatCompletionsRequestBody
@@ -26,7 +26,7 @@ public actor OllamaService {
2626
}
2727
}
2828

29-
extension OllamaService: ChatCompletionsAPI {
29+
extension OllamaChatCompletionsService: ChatCompletionsAPI {
3030
func callAsFunction() async throws -> ChatCompletionResponseBody {
3131
let requestBody = ChatCompletionRequestBody(
3232
model: model.info.modelName,
@@ -105,7 +105,7 @@ extension OllamaService: ChatCompletionsAPI {
105105
}
106106
}
107107

108-
extension OllamaService: ChatCompletionsStreamAPI {
108+
extension OllamaChatCompletionsService: ChatCompletionsStreamAPI {
109109
func callAsFunction() async throws
110110
-> AsyncThrowingStream<ChatCompletionsStreamDataChunk, Swift.Error>
111111
{
@@ -192,7 +192,7 @@ extension OllamaService: ChatCompletionsStreamAPI {
192192
}
193193
}
194194

195-
extension OllamaService {
195+
extension OllamaChatCompletionsService {
196196
struct Message: Codable, Equatable {
197197
public enum Role: String, Codable {
198198
case user
@@ -224,7 +224,7 @@ extension OllamaService {
224224
// MARK: - Chat Completion API
225225

226226
/// https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-streaming
227-
extension OllamaService {
227+
extension OllamaChatCompletionsService {
228228
struct ChatCompletionRequestBody: Codable {
229229
struct Options: Codable {
230230
var temperature: Double?

Tool/Sources/OpenAIService/APIs/OpenAIService.swift renamed to Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import AsyncAlgorithms
33
import Foundation
44
import Preferences
55

6-
actor OpenAIService: ChatCompletionsStreamAPI, ChatCompletionsAPI {
6+
actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI {
77
struct CompletionAPIError: Error, Codable, LocalizedError {
88
struct E: Codable {
99
var message: String
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import AIModel
2+
import Foundation
3+
import Logger
4+
5+
struct OpenAIEmbeddingService: EmbeddingAPI {
6+
struct EmbeddingRequestBody: Encodable {
7+
var input: [String]
8+
var model: String
9+
}
10+
11+
struct EmbeddingFromTokensRequestBody: Encodable {
12+
var input: [[Int]]
13+
var model: String
14+
}
15+
16+
let apiKey: String
17+
let model: EmbeddingModel
18+
let endpoint: String
19+
20+
public func embed(text: String) async throws -> EmbeddingResponse {
21+
return try await embed(texts: [text])
22+
}
23+
24+
public func embed(texts text: [String]) async throws -> EmbeddingResponse {
25+
guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect }
26+
var request = URLRequest(url: url)
27+
request.httpMethod = "POST"
28+
let encoder = JSONEncoder()
29+
request.httpBody = try encoder.encode(EmbeddingRequestBody(
30+
input: text,
31+
model: model.info.modelName
32+
))
33+
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
34+
if !apiKey.isEmpty {
35+
switch model.format {
36+
case .openAI, .openAICompatible:
37+
request.setValue(
38+
"Bearer \(apiKey)",
39+
forHTTPHeaderField: "Authorization"
40+
)
41+
case .azureOpenAI:
42+
request.setValue(apiKey, forHTTPHeaderField: "api-key")
43+
case .ollama:
44+
assertionFailure("Unsupported")
45+
}
46+
}
47+
48+
let (result, response) = try await URLSession.shared.data(for: request)
49+
guard let response = response as? HTTPURLResponse else {
50+
throw ChatGPTServiceError.responseInvalid
51+
}
52+
53+
guard response.statusCode == 200 else {
54+
let error = try? JSONDecoder().decode(
55+
OpenAIChatCompletionsService.CompletionAPIError.self,
56+
from: result
57+
)
58+
throw error ?? ChatGPTServiceError
59+
.otherError(String(data: result, encoding: .utf8) ?? "Unknown Error")
60+
}
61+
62+
let embeddingResponse = try JSONDecoder().decode(EmbeddingResponse.self, from: result)
63+
#if DEBUG
64+
Logger.service.info("""
65+
Embedding usage
66+
- number of strings: \(text.count)
67+
- prompt tokens: \(embeddingResponse.usage.prompt_tokens)
68+
- total tokens: \(embeddingResponse.usage.total_tokens)
69+
70+
""")
71+
#endif
72+
return embeddingResponse
73+
}
74+
75+
public func embed(tokens: [[Int]]) async throws -> EmbeddingResponse {
76+
guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect }
77+
var request = URLRequest(url: url)
78+
request.httpMethod = "POST"
79+
let encoder = JSONEncoder()
80+
request.httpBody = try encoder.encode(EmbeddingFromTokensRequestBody(
81+
input: tokens,
82+
model: model.info.modelName
83+
))
84+
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
85+
if !apiKey.isEmpty {
86+
switch model.format {
87+
case .openAI, .openAICompatible:
88+
request.setValue(
89+
"Bearer \(apiKey)",
90+
forHTTPHeaderField: "Authorization"
91+
)
92+
case .azureOpenAI:
93+
request.setValue(apiKey, forHTTPHeaderField: "api-key")
94+
case .ollama:
95+
assertionFailure("Unsupported")
96+
}
97+
}
98+
99+
let (result, response) = try await URLSession.shared.data(for: request)
100+
guard let response = response as? HTTPURLResponse else {
101+
throw ChatGPTServiceError.responseInvalid
102+
}
103+
104+
guard response.statusCode == 200 else {
105+
let error = try? JSONDecoder().decode(
106+
OpenAIChatCompletionsService.CompletionAPIError.self,
107+
from: result
108+
)
109+
throw error ?? ChatGPTServiceError
110+
.otherError(String(data: result, encoding: .utf8) ?? "Unknown Error")
111+
}
112+
113+
let embeddingResponse = try JSONDecoder().decode(EmbeddingResponse.self, from: result)
114+
#if DEBUG
115+
Logger.service.info("""
116+
Embedding usage
117+
- number of strings: \(tokens.count)
118+
- prompt tokens: \(embeddingResponse.usage.prompt_tokens)
119+
- total tokens: \(embeddingResponse.usage.total_tokens)
120+
121+
""")
122+
#endif
123+
return embeddingResponse
124+
}
125+
}
126+

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,21 @@ public class ChatGPTService: ChatGPTServiceType {
9090
apiKey, model, endpoint, requestBody, prompt in
9191
switch model.format {
9292
case .googleAI:
93-
return GoogleAIService(
93+
return GoogleAIChatCompletionsService(
9494
apiKey: apiKey,
9595
model: model,
9696
requestBody: requestBody,
9797
prompt: prompt
9898
)
9999
case .openAI, .openAICompatible, .azureOpenAI:
100-
return OpenAIService(
100+
return OpenAIChatCompletionsService(
101101
apiKey: apiKey,
102102
model: model,
103103
endpoint: endpoint,
104104
requestBody: requestBody
105105
)
106106
case .ollama:
107-
return OllamaService(
107+
return OllamaChatCompletionsService(
108108
apiKey: apiKey,
109109
model: model,
110110
endpoint: endpoint,
@@ -117,21 +117,21 @@ public class ChatGPTService: ChatGPTServiceType {
117117
apiKey, model, endpoint, requestBody, prompt in
118118
switch model.format {
119119
case .googleAI:
120-
return GoogleAIService(
120+
return GoogleAIChatCompletionsService(
121121
apiKey: apiKey,
122122
model: model,
123123
requestBody: requestBody,
124124
prompt: prompt
125125
)
126126
case .openAI, .openAICompatible, .azureOpenAI:
127-
return OpenAIService(
127+
return OpenAIChatCompletionsService(
128128
apiKey: apiKey,
129129
model: model,
130130
endpoint: endpoint,
131131
requestBody: requestBody
132132
)
133133
case .ollama:
134-
return OllamaService(
134+
return OllamaChatCompletionsService(
135135
apiKey: apiKey,
136136
model: model,
137137
endpoint: endpoint,

0 commit comments

Comments
 (0)