Skip to content

Commit 0864650

Browse files
committed
Update to call the batch embedding method
1 parent a81fd9d commit 0864650

1 file changed

Lines changed: 11 additions & 32 deletions

File tree

Tool/Sources/LangChain/Embedding/OpenAIEmbedding.swift

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -45,33 +45,12 @@ extension OpenAIEmbedding {
4545
func getEmbeddings(
4646
documents: [Document]
4747
) async throws -> [EmbeddedDocument] {
48-
try await withThrowingTaskGroup(
49-
of: (document: Document, embeddings: [Float]).self
50-
) { group in
51-
for document in documents {
52-
group.addTask {
53-
var retryCount = 6
54-
var previousError: Error?
55-
while retryCount > 0 {
56-
do {
57-
let embeddings = try await service.embed(text: document.pageContent)
58-
.data
59-
.map(\.embedding).first ?? []
60-
return (document, embeddings)
61-
} catch {
62-
retryCount -= 1
63-
previousError = error
64-
}
65-
}
66-
throw previousError ?? CancellationError()
67-
}
68-
}
69-
var all = [EmbeddedDocument]()
70-
for try await result in group {
71-
all.append(.init(document: result.document, embeddings: result.embeddings))
48+
try await service.embed(text: documents.map(\.pageContent)).data
49+
.compactMap {
50+
let index = $0.index
51+
guard index >= 0, index < documents.endIndex else { return nil }
52+
return EmbeddedDocument(document: documents[index], embeddings: $0.embedding)
7253
}
73-
return all
74-
}
7554
}
7655

7756
/// OpenAI's embedding API doesn't support embedding inputs longer than the max token.
@@ -112,27 +91,27 @@ extension OpenAIEmbedding {
11291
do {
11392
if text.chunkedTokens.count <= 1 {
11493
// if possible, we should just let OpenAI do the tokenization.
115-
return (
94+
return try (
11695
text.document,
117-
try await service.embed(text: text.document.pageContent)
96+
await service.embed(text: text.document.pageContent)
11897
.data
11998
.map(\.embedding)
12099
)
121100
}
122101

123102
if shouldAverageLongEmbeddings {
124-
return (
103+
return try (
125104
text.document,
126-
try await service.embed(tokens: text.chunkedTokens)
105+
await service.embed(tokens: text.chunkedTokens)
127106
.data
128107
.map(\.embedding)
129108
)
130109
}
131110
// if `shouldAverageLongEmbeddings` is false,
132111
// we only embed the first chunk to save some money.
133-
return (
112+
return try (
134113
text.document,
135-
try await service.embed(tokens: [text.chunkedTokens.first ?? []])
114+
await service.embed(tokens: [text.chunkedTokens.first ?? []])
136115
.data
137116
.map(\.embedding)
138117
)

0 commit comments

Comments
 (0)