@@ -6,25 +6,68 @@ import TokenEncoder
66
77public struct OpenAIEmbedding : Embeddings {
88 public var service : EmbeddingService
9+ public var shouldAverageLongEmbeddings : Bool
910 /// Usually we won't hit the limit because the max token is 8191 and we will do text splitting
1011 /// before embedding.
11- public var shouldAverageLongEmbeddings : Bool
12+ public var safe : Bool
1213
13- public init ( configuration: EmbeddingConfiguration , shouldAverageLongEmbeddings: Bool = false ) {
14+ public init (
15+ configuration: EmbeddingConfiguration ,
16+ shouldAverageLongEmbeddings: Bool = false ,
17+ safe: Bool = false
18+ ) {
1419 service = EmbeddingService ( configuration: configuration)
1520 self . shouldAverageLongEmbeddings = shouldAverageLongEmbeddings
21+ self . safe = safe
1622 }
1723
1824 public func embed( documents: [ String ] ) async throws -> [ [ Float ] ] {
19- [ ]
25+ if safe {
26+ return try await getLenSafeEmbeddings ( texts: documents) . map ( \. embeddings)
27+ }
28+ return try await getEmbeddings ( texts: documents) . map ( \. embeddings)
2029 }
2130
2231 public func embed( query: String ) async throws -> [ Float ] {
23- return try await getLenSafeEmbeddings ( texts: [ query] ) . first? . embeddings ?? [ ]
32+ if safe {
33+ return try await getLenSafeEmbeddings ( texts: [ query] ) . first? . embeddings ?? [ ]
34+ }
35+ return try await getEmbeddings ( texts: [ query] ) . first? . embeddings ?? [ ]
2436 }
2537}
2638
2739extension OpenAIEmbedding {
40+ func getEmbeddings(
41+ texts: [ String ]
42+ ) async throws -> [ ( originalText: String , embeddings: [ Float ] ) ] {
43+ try await withThrowingTaskGroup (
44+ of: ( originalText: String, embeddings: [ Float] ) . self
45+ ) { group in
46+ for text in texts {
47+ group. addTask {
48+ var retryCount = 6
49+ var previousError : Error ?
50+ while retryCount > 0 {
51+ do {
52+ let embeddings = try await service. embed ( text: text) . data
53+ . map ( \. embeddings) . first ?? [ ]
54+ return ( text, embeddings)
55+ } catch {
56+ retryCount -= 1
57+ previousError = error
58+ }
59+ }
60+ throw previousError ?? CancellationError ( )
61+ }
62+ }
63+ var all = [ ( originalText: String, embeddings: [ Float] ) ] ( )
64+ for try await result in group {
65+ all. append ( result)
66+ }
67+ return all
68+ }
69+ }
70+
2871 func getLenSafeEmbeddings(
2972 texts: [ String ]
3073 ) async throws -> [ ( originalText: String , embeddings: [ Float ] ) ] {
@@ -116,7 +159,7 @@ extension OpenAIEmbedding {
116159 axis: 0 ,
117160 weights: embeddings. map ( \. count)
118161 )
119- let normalized = average / numpy. linalg. norm ( embeddings )
162+ let normalized = average / numpy. linalg. norm ( average )
120163 return [ Float] ( normalized. tolist ( ) )
121164 } ) else { throw CancellationError ( ) }
122165 results. append ( ( text, averagedEmbeddings) )
0 commit comments