@@ -21,37 +21,45 @@ public struct OpenAIEmbedding: Embeddings {
2121 self . safe = safe
2222 }
2323
24- public func embed( documents: [ String ] ) async throws -> [ [ Float ] ] {
24+ public func embed( documents: [ Document ] ) async throws -> [ EmbeddedDocument ] {
2525 if safe {
26- return try await getLenSafeEmbeddings ( texts : documents) . map ( \ . embeddings )
26+ return try await getLenSafeEmbeddings ( documents : documents)
2727 }
28- return try await getEmbeddings ( texts : documents) . map ( \ . embeddings )
28+ return try await getEmbeddings ( documents : documents)
2929 }
3030
3131 public func embed( query: String ) async throws -> [ Float ] {
3232 if safe {
33- return try await getLenSafeEmbeddings ( texts: [ query] ) . first? . embeddings ?? [ ]
33+ return try await getLenSafeEmbeddings ( documents: [ . init(
34+ pageContent: query,
35+ metadata: . null
36+ ) ] )
37+ . first?
38+ . embeddings ?? [ ]
3439 }
35- return try await getEmbeddings ( texts: [ query] ) . first? . embeddings ?? [ ]
40+ return try await getEmbeddings ( documents: [ . init( pageContent: query, metadata: . null) ] )
41+ . first?
42+ . embeddings ?? [ ]
3643 }
3744}
3845
3946extension OpenAIEmbedding {
4047 func getEmbeddings(
41- texts : [ String ]
42- ) async throws -> [ ( originalText : String , embeddings : [ Float ] ) ] {
48+ documents : [ Document ]
49+ ) async throws -> [ EmbeddedDocument ] {
4350 try await withThrowingTaskGroup (
44- of: ( originalText : String , embeddings: [ Float] ) . self
51+ of: ( document : Document , embeddings: [ Float] ) . self
4552 ) { group in
46- for text in texts {
53+ for document in documents {
4754 group. addTask {
4855 var retryCount = 6
4956 var previousError : Error ?
5057 while retryCount > 0 {
5158 do {
52- let embeddings = try await service. embed ( text: text) . data
59+ let embeddings = try await service. embed ( text: document. pageContent)
60+ . data
5361 . map ( \. embeddings) . first ?? [ ]
54- return ( text , embeddings)
62+ return ( document , embeddings)
5563 } catch {
5664 retryCount -= 1
5765 previousError = error
@@ -60,27 +68,27 @@ extension OpenAIEmbedding {
6068 throw previousError ?? CancellationError ( )
6169 }
6270 }
63- var all = [ ( originalText : String , embeddings : [ Float ] ) ] ( )
71+ var all = [ EmbeddedDocument ] ( )
6472 for try await result in group {
65- all. append ( result)
73+ all. append ( . init ( document : result. document , embeddings : result . embeddings ) )
6674 }
6775 return all
6876 }
6977 }
7078
7179 func getLenSafeEmbeddings(
72- texts : [ String ]
73- ) async throws -> [ ( originalText : String , embeddings : [ Float ] ) ] {
80+ documents : [ Document ]
81+ ) async throws -> [ EmbeddedDocument ] {
7482 struct Text {
75- var rawText : String
83+ var document : Document
7684 var chunkedTokens : [ [ Int ] ]
7785 }
7886
79- var texts = texts . map { Text ( rawText : $0, chunkedTokens: [ ] ) }
87+ var texts = documents . map { Text ( document : $0, chunkedTokens: [ ] ) }
8088 let encoding = TiktokenCl100kBaseTokenEncoder ( )
8189
8290 for (index, text) in texts. enumerated ( ) {
83- let token = encoding. encode ( text: text. rawText )
91+ let token = encoding. encode ( text: text. document . pageContent )
8492 // just incase the calculation is incorrect
8593 let maxToken = max ( 10 , service. configuration. maxToken - 10 )
8694
@@ -92,27 +100,28 @@ extension OpenAIEmbedding {
92100 }
93101
94102 let batchedEmbeddings = try await withThrowingTaskGroup (
95- of: ( String , [ [ Float] ] ) . self
103+ of: ( Document , [ [ Float] ] ) . self
96104 ) { group in
97105 for text in texts {
98106 group. addTask {
99107 var retryCount = 6
100108 var previousError : Error ?
101- guard !text. chunkedTokens. isEmpty else { return ( text. rawText, [ ] ) }
109+ guard !text. chunkedTokens. isEmpty
110+ else { return ( text. document, [ ] ) }
102111 while retryCount > 0 {
103112 do {
104113 if text. chunkedTokens. count <= 1 {
105114 // if possible, we should just let OpenAI do the tokenization.
106115 return (
107- text. rawText ,
108- try await service. embed ( text: text. rawText )
116+ text. document ,
117+ try await service. embed ( text: text. document . pageContent )
109118 . data
110119 . map ( \. embeddings)
111120 )
112121 }
113122 if shouldAverageLongEmbeddings {
114123 return (
115- text. rawText ,
124+ text. document ,
116125 try await service. embed ( tokens: text. chunkedTokens)
117126 . data
118127 . map ( \. embeddings)
@@ -121,7 +130,7 @@ extension OpenAIEmbedding {
121130 // if `shouldAverageLongEmbeddings` is false,
122131 // we only embed the first chunk to save some money.
123132 return (
124- text. rawText ,
133+ text. document ,
125134 try await service. embed ( tokens: [ text. chunkedTokens. first ?? [ ] ] )
126135 . data
127136 . map ( \. embeddings)
@@ -134,21 +143,21 @@ extension OpenAIEmbedding {
134143 throw previousError ?? CancellationError ( )
135144 }
136145 }
137- var result = [ ( originalText : String , embeddings: [ [ Float] ] ) ] ( )
146+ var result = [ ( document : Document , embeddings: [ [ Float] ] ) ] ( )
138147 for try await response in group {
139148 try Task . checkCancellation ( )
140149 result. append ( ( response. 0 , response. 1 ) )
141150 }
142151 return result
143152 }
144153
145- var results = [ ( originalText : String , embeddings : [ Float ] ) ] ( )
154+ var results = [ EmbeddedDocument ] ( )
146155
147- for (text , embeddings) in batchedEmbeddings {
156+ for (document , embeddings) in batchedEmbeddings {
148157 if embeddings. count == 1 , let first = embeddings. first {
149- results. append ( ( text , first) )
158+ results. append ( . init ( document : document , embeddings : first) )
150159 } else if embeddings. isEmpty {
151- results. append ( ( text , [ ] ) )
160+ results. append ( . init ( document : document , embeddings : [ ] ) )
152161 } else if shouldAverageLongEmbeddings {
153162 // untested
154163 do {
@@ -162,14 +171,14 @@ extension OpenAIEmbedding {
162171 let normalized = average / numpy. linalg. norm ( average)
163172 return [ Float] ( normalized. tolist ( ) )
164173 } ) else { throw CancellationError ( ) }
165- results. append ( ( text , averagedEmbeddings) )
174+ results. append ( . init ( document : document , embeddings : averagedEmbeddings) )
166175 } catch {
167176 if let first = embeddings. first {
168- results. append ( ( text , first) )
177+ results. append ( . init ( document : document , embeddings : first) )
169178 }
170179 }
171180 } else if let first = embeddings. first {
172- results. append ( ( text , first) )
181+ results. append ( . init ( document : document , embeddings : first) )
173182 }
174183 }
175184
0 commit comments