Skip to content

Commit 2c5df2c

Browse files
committed
Update embedding to support unsafe embedding
1 parent 6856dbc commit 2c5df2c

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

Tool/Sources/LangChain/Embedding/OpenAIEmbedding.swift

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,68 @@ import TokenEncoder
66

77
public struct OpenAIEmbedding: Embeddings {
88
public var service: EmbeddingService
9+
public var shouldAverageLongEmbeddings: Bool
910
/// Usually we won't hit the limit because the max token is 8191 and we will do text splitting
1011
/// before embedding.
11-
public var shouldAverageLongEmbeddings: Bool
12+
public var safe: Bool
1213

13-
public init(configuration: EmbeddingConfiguration, shouldAverageLongEmbeddings: Bool = false) {
14+
public init(
15+
configuration: EmbeddingConfiguration,
16+
shouldAverageLongEmbeddings: Bool = false,
17+
safe: Bool = false
18+
) {
1419
service = EmbeddingService(configuration: configuration)
1520
self.shouldAverageLongEmbeddings = shouldAverageLongEmbeddings
21+
self.safe = safe
1622
}
1723

1824
public func embed(documents: [String]) async throws -> [[Float]] {
19-
[]
25+
if safe {
26+
return try await getLenSafeEmbeddings(texts: documents).map(\.embeddings)
27+
}
28+
return try await getEmbeddings(texts: documents).map(\.embeddings)
2029
}
2130

2231
public func embed(query: String) async throws -> [Float] {
23-
return try await getLenSafeEmbeddings(texts: [query]).first?.embeddings ?? []
32+
if safe {
33+
return try await getLenSafeEmbeddings(texts: [query]).first?.embeddings ?? []
34+
}
35+
return try await getEmbeddings(texts: [query]).first?.embeddings ?? []
2436
}
2537
}
2638

2739
extension OpenAIEmbedding {
40+
func getEmbeddings(
41+
texts: [String]
42+
) async throws -> [(originalText: String, embeddings: [Float])] {
43+
try await withThrowingTaskGroup(
44+
of: (originalText: String, embeddings: [Float]).self
45+
) { group in
46+
for text in texts {
47+
group.addTask {
48+
var retryCount = 6
49+
var previousError: Error?
50+
while retryCount > 0 {
51+
do {
52+
let embeddings = try await service.embed(text: text).data
53+
.map(\.embeddings).first ?? []
54+
return (text, embeddings)
55+
} catch {
56+
retryCount -= 1
57+
previousError = error
58+
}
59+
}
60+
throw previousError ?? CancellationError()
61+
}
62+
}
63+
var all = [(originalText: String, embeddings: [Float])]()
64+
for try await result in group {
65+
all.append(result)
66+
}
67+
return all
68+
}
69+
}
70+
2871
func getLenSafeEmbeddings(
2972
texts: [String]
3073
) async throws -> [(originalText: String, embeddings: [Float])] {
@@ -116,7 +159,7 @@ extension OpenAIEmbedding {
116159
axis: 0,
117160
weights: embeddings.map(\.count)
118161
)
119-
let normalized = average / numpy.linalg.norm(embeddings)
162+
let normalized = average / numpy.linalg.norm(average)
120163
return [Float](normalized.tolist())
121164
}) else { throw CancellationError() }
122165
results.append((text, averagedEmbeddings))

Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemory.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
4545
}
4646
}
4747

48+
/// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
4849
func generateSendingHistory(
4950
maxNumberOfMessages: Int = UserDefaults.shared.value(for: \.chatGPTMaxMessageCount),
5051
encoder: TokenEncoder = AutoManagedChatGPTMemory.encoder
@@ -68,7 +69,7 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
6869
}
6970
partial += count
7071
}
71-
var allTokensCount = functionTokenCount
72+
var allTokensCount = functionTokenCount + 3 // every reply is primed with <|start|>assistant<|message|>
7273
allTokensCount += systemPrompt.isEmpty ? 0 : systemMessageTokenCount
7374

7475
for (index, message) in history.enumerated().reversed() {
@@ -110,13 +111,15 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
110111
}
111112

112113
extension TokenEncoder {
114+
/// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
113115
func countToken(message: ChatMessage) -> Int {
114116
var total = 3
115117
if let content = message.content {
116118
total += encode(text: content).count
117119
}
118120
if let name = message.name {
119121
total += encode(text: name).count
122+
total += 1
120123
}
121124
if let functionCall = message.functionCall {
122125
total += encode(text: functionCall.name).count

0 commit comments

Comments
 (0)