Skip to content

Commit 6856dbc

Browse files
committed
Support embedding
1 parent 55c35a9 commit 6856dbc

7 files changed

Lines changed: 408 additions & 4 deletions

File tree

TestPlan.xctestplan

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@
8484
"identifier" : "ChatServiceTests",
8585
"name" : "ChatServiceTests"
8686
}
87+
},
88+
{
89+
"target" : {
90+
"containerPath" : "container:",
91+
"identifier" : "TokenEncoderTests",
92+
"name" : "TokenEncoderTests"
93+
}
8794
}
8895
],
8996
"version" : 1

Tool/Sources/LangChain/Embedding/Embedding.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import Foundation
22

33
public protocol Embeddings {
44
/// Embed search docs.
5-
func embedDocuments(texts: [String]) -> [[Float]]
5+
func embed(documents: [String]) async throws -> [[Float]]
66
/// Embed query text.
7-
func embedQuery(text: String) -> [Float]
7+
func embed(query: String) async throws -> [Float]
88
}
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import Foundation
2+
import OpenAIService
3+
import PythonHelper
4+
import PythonKit
5+
import TokenEncoder
6+
7+
public struct OpenAIEmbedding: Embeddings {
8+
public var service: EmbeddingService
9+
/// Usually we won't hit the limit because the max token is 8191 and we will do text splitting
10+
/// before embedding.
11+
public var shouldAverageLongEmbeddings: Bool
12+
13+
public init(configuration: EmbeddingConfiguration, shouldAverageLongEmbeddings: Bool = false) {
14+
service = EmbeddingService(configuration: configuration)
15+
self.shouldAverageLongEmbeddings = shouldAverageLongEmbeddings
16+
}
17+
18+
public func embed(documents: [String]) async throws -> [[Float]] {
19+
[]
20+
}
21+
22+
public func embed(query: String) async throws -> [Float] {
23+
return try await getLenSafeEmbeddings(texts: [query]).first?.embeddings ?? []
24+
}
25+
}
26+
27+
extension OpenAIEmbedding {
28+
func getLenSafeEmbeddings(
29+
texts: [String]
30+
) async throws -> [(originalText: String, embeddings: [Float])] {
31+
struct Text {
32+
var rawText: String
33+
var chunkedTokens: [[Int]]
34+
}
35+
36+
var texts = texts.map { Text(rawText: $0, chunkedTokens: []) }
37+
let encoding = TiktokenCl100kBaseTokenEncoder()
38+
39+
for (index, text) in texts.enumerated() {
40+
let token = encoding.encode(text: text.rawText)
41+
// just incase the calculation is incorrect
42+
let maxToken = max(10, service.configuration.maxToken - 10)
43+
44+
for j in stride(from: 0, to: token.count, by: maxToken) {
45+
texts[index].chunkedTokens.append(
46+
Array(token[j..<min(j + maxToken, token.count)])
47+
)
48+
}
49+
}
50+
51+
let batchedEmbeddings = try await withThrowingTaskGroup(
52+
of: (String, [[Float]]).self
53+
) { group in
54+
for text in texts {
55+
group.addTask {
56+
var retryCount = 6
57+
var previousError: Error?
58+
guard !text.chunkedTokens.isEmpty else { return (text.rawText, []) }
59+
while retryCount > 0 {
60+
do {
61+
if text.chunkedTokens.count <= 1 {
62+
// if possible, we should just let OpenAI do the tokenization.
63+
return (
64+
text.rawText,
65+
try await service.embed(text: text.rawText)
66+
.data
67+
.map(\.embeddings)
68+
)
69+
}
70+
if shouldAverageLongEmbeddings {
71+
return (
72+
text.rawText,
73+
try await service.embed(tokens: text.chunkedTokens)
74+
.data
75+
.map(\.embeddings)
76+
)
77+
}
78+
// if `shouldAverageLongEmbeddings` is false,
79+
// we only embed the first chunk to save some money.
80+
return (
81+
text.rawText,
82+
try await service.embed(tokens: [text.chunkedTokens.first ?? []])
83+
.data
84+
.map(\.embeddings)
85+
)
86+
} catch {
87+
retryCount -= 1
88+
previousError = error
89+
}
90+
}
91+
throw previousError ?? CancellationError()
92+
}
93+
}
94+
var result = [(originalText: String, embeddings: [[Float]])]()
95+
for try await response in group {
96+
try Task.checkCancellation()
97+
result.append((response.0, response.1))
98+
}
99+
return result
100+
}
101+
102+
var results = [(originalText: String, embeddings: [Float])]()
103+
104+
for (text, embeddings) in batchedEmbeddings {
105+
if embeddings.count == 1, let first = embeddings.first {
106+
results.append((text, first))
107+
} else if embeddings.isEmpty {
108+
results.append((text, []))
109+
} else if shouldAverageLongEmbeddings {
110+
// untested
111+
do {
112+
guard let averagedEmbeddings = try await runPython({
113+
let numpy = try Python.attemptImportOnPythonThread("numpy")
114+
let average = numpy.average(
115+
embeddings,
116+
axis: 0,
117+
weights: embeddings.map(\.count)
118+
)
119+
let normalized = average / numpy.linalg.norm(embeddings)
120+
return [Float](normalized.tolist())
121+
}) else { throw CancellationError() }
122+
results.append((text, averagedEmbeddings))
123+
} catch {
124+
if let first = embeddings.first {
125+
results.append((text, first))
126+
}
127+
}
128+
} else if let first = embeddings.first {
129+
results.append((text, first))
130+
}
131+
}
132+
133+
return results
134+
}
135+
}
136+
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import Foundation
2+
import Preferences
3+
4+
public typealias EmbeddingFeatureProvider = ChatFeatureProvider
5+
6+
public protocol EmbeddingConfiguration {
7+
var featureProvider: EmbeddingFeatureProvider { get }
8+
var endpoint: String { get }
9+
var apiKey: String { get }
10+
var maxToken: Int { get }
11+
var model: String { get }
12+
}
13+
14+
extension EmbeddingConfiguration {
15+
func endpoint(for provider: EmbeddingFeatureProvider) -> String {
16+
switch provider {
17+
case .openAI:
18+
let baseURL = UserDefaults.shared.value(for: \.openAIBaseURL)
19+
if baseURL.isEmpty { return "https://api.openai.com/v1/embeddings" }
20+
return "\(baseURL)/v1/chat/completions"
21+
case .azureOpenAI:
22+
let baseURL = UserDefaults.shared.value(for: \.azureOpenAIBaseURL)
23+
let deployment = UserDefaults.shared.value(for: \.azureChatGPTDeployment)
24+
let version = "2023-05-15"
25+
if baseURL.isEmpty { return "" }
26+
return "\(baseURL)/openai/deployments/\(deployment)/embeddings?api-version=\(version)"
27+
}
28+
}
29+
30+
func apiKey(for provider: ChatFeatureProvider) -> String {
31+
switch provider {
32+
case .openAI:
33+
return UserDefaults.shared.value(for: \.openAIAPIKey)
34+
case .azureOpenAI:
35+
return UserDefaults.shared.value(for: \.azureOpenAIAPIKey)
36+
}
37+
}
38+
39+
func overriding(
40+
_ overrides: OverridingEmbeddingConfiguration<Self>.Overriding = .init()
41+
) -> OverridingEmbeddingConfiguration<Self> {
42+
.init(overriding: self, with: overrides)
43+
}
44+
}
45+
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import Foundation
2+
import Preferences
3+
4+
public struct UserPreferenceEmbeddingConfiguration: EmbeddingConfiguration {
5+
public var featureProvider: ChatFeatureProvider {
6+
UserDefaults.shared.value(for: \.chatFeatureProvider)
7+
}
8+
9+
public var model: String {
10+
let value = UserDefaults.shared.value(for: \.chatGPTModel)
11+
if value.isEmpty { return "text-embedding-ada-002" }
12+
return value
13+
}
14+
15+
public var endpoint: String {
16+
endpoint(for: featureProvider)
17+
}
18+
19+
public var apiKey: String {
20+
apiKey(for: featureProvider)
21+
}
22+
23+
public var maxToken: Int {
24+
8191
25+
}
26+
27+
public init() {}
28+
}
29+
30+
public class OverridingEmbeddingConfiguration<
31+
Configuration: EmbeddingConfiguration
32+
>: EmbeddingConfiguration {
33+
public struct Overriding {
34+
var featureProvider: ChatFeatureProvider?
35+
var model: String?
36+
var endPoint: String?
37+
var apiKey: String?
38+
var maxTokens: Int?
39+
40+
public init(
41+
model: String? = nil,
42+
featureProvider: ChatFeatureProvider? = nil,
43+
endPoint: String? = nil,
44+
apiKey: String? = nil,
45+
maxTokens: Int? = nil
46+
) {
47+
self.model = model
48+
self.featureProvider = featureProvider
49+
self.endPoint = endPoint
50+
self.apiKey = apiKey
51+
self.maxTokens = maxTokens
52+
}
53+
}
54+
55+
private let configuration: Configuration
56+
public var overriding = Overriding()
57+
58+
public init(overriding configuration: Configuration, with overrides: Overriding = .init()) {
59+
self.overriding = overrides
60+
self.configuration = configuration
61+
}
62+
63+
public var featureProvider: ChatFeatureProvider {
64+
overriding.featureProvider ?? configuration.featureProvider
65+
}
66+
67+
public var model: String {
68+
overriding.model ?? configuration.model
69+
}
70+
71+
public var endpoint: String {
72+
overriding.endPoint
73+
?? overriding.featureProvider.map(endpoint(for:))
74+
?? configuration.endpoint
75+
}
76+
77+
public var apiKey: String {
78+
overriding.apiKey
79+
?? overriding.featureProvider.map(apiKey(for:))
80+
?? configuration.apiKey
81+
}
82+
83+
public var maxToken: Int {
84+
overriding.maxTokens ?? configuration.maxToken
85+
}
86+
}
87+

0 commit comments

Comments
 (0)