Skip to content

Commit 9f137ee

Browse files
committed
Add OllamaEmbeddingService
1 parent ed811e8 commit 9f137ee

4 files changed

Lines changed: 105 additions & 9 deletions

File tree

Pro

Submodule Pro updated from de3b6ef to 13a9fde

Tool/Sources/OpenAIService/APIs/EmbeddingAPIDefinitions.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,3 @@ public struct EmbeddingResponse: Decodable {
2626
public var usage: Usage
2727
}
2828

29-
30-
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import AIModel
2+
import Foundation
3+
import Logger
4+
5+
struct OllamaEmbeddingService: EmbeddingAPI {
6+
struct EmbeddingRequestBody: Encodable {
7+
var prompt: String
8+
var model: String
9+
}
10+
11+
struct ResponseBody: Decodable {
12+
var embedding: [Float]
13+
}
14+
15+
let model: EmbeddingModel
16+
let endpoint: String
17+
18+
public func embed(text: String) async throws -> EmbeddingResponse {
19+
guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect }
20+
var request = URLRequest(url: url)
21+
request.httpMethod = "POST"
22+
let encoder = JSONEncoder()
23+
request.httpBody = try encoder.encode(EmbeddingRequestBody(
24+
prompt: text,
25+
model: model.info.modelName
26+
))
27+
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
28+
29+
let (result, response) = try await URLSession.shared.data(for: request)
30+
guard let response = response as? HTTPURLResponse else {
31+
throw ChatGPTServiceError.responseInvalid
32+
}
33+
34+
guard response.statusCode == 200 else {
35+
let error = try? JSONDecoder().decode(
36+
OpenAIChatCompletionsService.CompletionAPIError.self,
37+
from: result
38+
)
39+
throw error ?? ChatGPTServiceError
40+
.otherError(String(data: result, encoding: .utf8) ?? "Unknown Error")
41+
}
42+
43+
let embeddingResponse = try JSONDecoder().decode(ResponseBody.self, from: result)
44+
#if DEBUG
45+
Logger.service.info("""
46+
Embedding usage
47+
- number of strings: \(text.count)
48+
- prompt tokens: N/A
49+
- total tokens: \(embeddingResponse.embedding.count)
50+
51+
""")
52+
#endif
53+
return .init(
54+
data: [.init(
55+
embedding: embeddingResponse.embedding,
56+
index: 0,
57+
object: model.info.modelName
58+
)],
59+
model: model.info.modelName,
60+
usage: .init(prompt_tokens: 0, total_tokens: embeddingResponse.embedding.count)
61+
)
62+
}
63+
64+
public func embed(texts: [String]) async throws -> EmbeddingResponse {
65+
try await withThrowingTaskGroup(of: EmbeddingResponse.self) { group in
66+
for text in texts {
67+
_ = group.addTaskUnlessCancelled {
68+
try await self.embed(text: text)
69+
}
70+
}
71+
72+
var result = EmbeddingResponse(
73+
data: [],
74+
model: model.info.modelName,
75+
usage: .init(prompt_tokens: 0, total_tokens: 0)
76+
)
77+
78+
for try await response in group {
79+
result.data.append(contentsOf: response.data)
80+
result.usage.prompt_tokens += response.usage.prompt_tokens
81+
result.usage.total_tokens += response.usage.total_tokens
82+
}
83+
84+
return result
85+
}
86+
}
87+
88+
public func embed(tokens: [[Int]]) async throws -> EmbeddingResponse {
89+
throw CancellationError()
90+
}
91+
}
92+

Tool/Sources/OpenAIService/EmbeddingService.swift

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ public struct EmbeddingService {
2121
endpoint: configuration.endpoint
2222
).embed(text: text)
2323
case .ollama:
24-
#warning("MUSTDO:")
25-
fatalError()
24+
embeddingResponse = try await OllamaEmbeddingService(
25+
model: model,
26+
endpoint: configuration.endpoint
27+
).embed(text: text)
2628
}
2729

2830
#if DEBUG
@@ -50,8 +52,10 @@ public struct EmbeddingService {
5052
endpoint: configuration.endpoint
5153
).embed(texts: text)
5254
case .ollama:
53-
#warning("MUSTDO:")
54-
fatalError()
55+
embeddingResponse = try await OllamaEmbeddingService(
56+
model: model,
57+
endpoint: configuration.endpoint
58+
).embed(texts: text)
5559
}
5660

5761
#if DEBUG
@@ -79,8 +83,10 @@ public struct EmbeddingService {
7983
endpoint: configuration.endpoint
8084
).embed(tokens: tokens)
8185
case .ollama:
82-
#warning("MUSTDO:")
83-
fatalError()
86+
embeddingResponse = try await OllamaEmbeddingService(
87+
model: model,
88+
endpoint: configuration.endpoint
89+
).embed(tokens: tokens)
8490
}
8591

8692
#if DEBUG

0 commit comments

Comments
 (0)