Skip to content

Commit 3f4706e

Browse files
committed
Add TemporaryUSearch for vector searching
1 parent 2c5df2c commit 3f4706e

File tree

9 files changed

+208
-44
lines changed

9 files changed

+208
-44
lines changed

Copilot for Xcode.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved

Lines changed: 9 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Tool/Package.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ let package = Package(
2424
.package(url: "https://github.com/pointfreeco/swift-parsing", from: "0.12.1"),
2525
.package(url: "https://github.com/ChimeHQ/JSONRPC", from: "0.6.0"),
2626
.package(url: "https://github.com/scinfu/SwiftSoup.git", from: "2.6.0"),
27+
.package(url: "https://github.com/unum-cloud/usearch", from: "0.19.0"),
2728
],
2829
targets: [
2930
// MARK: - Helpers
@@ -60,6 +61,7 @@ let package = Package(
6061
.product(name: "PythonKit", package: "PythonKit"),
6162
.product(name: "Parsing", package: "swift-parsing"),
6263
.product(name: "SwiftSoup", package: "SwiftSoup"),
64+
.product(name: "USearch", package: "usearch"),
6365
]
6466
),
6567

Tool/Sources/LangChain/DocumentLoader/DocumentLoader.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import Foundation
2+
import JSONRPC
23

3-
public struct Document {
4+
public struct Document: Codable {
45
public var pageContent: String
5-
public var metadata: [String: Any]
6-
public init(pageContent: String, metadata: [String: Any]) {
6+
public var metadata: JSONValue
7+
public init(pageContent: String, metadata: JSONValue) {
78
self.pageContent = pageContent
89
self.metadata = metadata
910
}

Tool/Sources/LangChain/DocumentLoader/TextLoader.swift

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ public struct TextLoader: DocumentLoader {
2626
let modificationDate = try? url.resourceValues(forKeys: [.contentModificationDateKey])
2727
.contentModificationDate
2828
return [Document(pageContent: attributedString.string, metadata: [
29-
"filename": url.lastPathComponent,
30-
"extension": url.pathExtension,
31-
"contentModificationDate": modificationDate ?? Date(),
29+
"filename": .string(url.lastPathComponent),
30+
"extension": .string(url.pathExtension),
31+
"contentModificationDate": .number(
32+
(modificationDate ?? Date()).timeIntervalSince1970
33+
),
3234
])]
3335
}
3436
}

Tool/Sources/LangChain/DocumentLoader/WebLoader.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ public struct WebLoader: DocumentLoader {
3737

3838
if let body = body {
3939
let doc = Document(pageContent: body, metadata: [
40-
"title": title,
41-
"url": result.url,
42-
"date": Date(),
40+
"title": .string(title),
41+
"url": .string(result.url.absoluteString),
42+
"date": .number(Date().timeIntervalSince1970),
4343
])
4444
documents.append(doc)
4545
}

Tool/Sources/LangChain/DocumentTransformer/DocumentTransformer.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import Foundation
2+
import JSONRPC
23

34
public protocol DocumentTransformer {
45
func transformDocuments(_ documents: [Document]) async throws -> [Document]
@@ -13,7 +14,7 @@ public extension TextSplitter {
1314
/// Create documents from a list of texts.
1415
func createDocuments(
1516
texts: [String],
16-
metadata: [[String: Any]] = []
17+
metadata: [JSONValue] = []
1718
) async throws -> [Document] {
1819
var documents = [Document]()
1920
let paddingLength = texts.count - metadata.count
@@ -31,7 +32,7 @@ public extension TextSplitter {
3132
/// Split documents.
3233
func splitDocuments(_ documents: [Document]) async throws -> [Document] {
3334
var texts = [String]()
34-
var metadata = [[String: Any]]()
35+
var metadata = [JSONValue]()
3536
for document in documents {
3637
texts.append(document.pageContent)
3738
metadata.append(document.metadata)

Tool/Sources/LangChain/Embedding/Embedding.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@ import Foundation
22

33
public protocol Embeddings {
44
/// Embed search docs.
5-
func embed(documents: [String]) async throws -> [[Float]]
5+
func embed(documents: [Document]) async throws -> [EmbeddedDocument]
66
/// Embed query text.
77
func embed(query: String) async throws -> [Float]
88
}
9+
10+
public struct EmbeddedDocument: Codable {
11+
var document: Document
12+
var embeddings: [Float]
13+
}

Tool/Sources/LangChain/Embedding/OpenAIEmbedding.swift

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,37 +21,45 @@ public struct OpenAIEmbedding: Embeddings {
2121
self.safe = safe
2222
}
2323

24-
public func embed(documents: [String]) async throws -> [[Float]] {
24+
public func embed(documents: [Document]) async throws -> [EmbeddedDocument] {
2525
if safe {
26-
return try await getLenSafeEmbeddings(texts: documents).map(\.embeddings)
26+
return try await getLenSafeEmbeddings(documents: documents)
2727
}
28-
return try await getEmbeddings(texts: documents).map(\.embeddings)
28+
return try await getEmbeddings(documents: documents)
2929
}
3030

3131
public func embed(query: String) async throws -> [Float] {
3232
if safe {
33-
return try await getLenSafeEmbeddings(texts: [query]).first?.embeddings ?? []
33+
return try await getLenSafeEmbeddings(documents: [.init(
34+
pageContent: query,
35+
metadata: .null
36+
)])
37+
.first?
38+
.embeddings ?? []
3439
}
35-
return try await getEmbeddings(texts: [query]).first?.embeddings ?? []
40+
return try await getEmbeddings(documents: [.init(pageContent: query, metadata: .null)])
41+
.first?
42+
.embeddings ?? []
3643
}
3744
}
3845

3946
extension OpenAIEmbedding {
4047
func getEmbeddings(
41-
texts: [String]
42-
) async throws -> [(originalText: String, embeddings: [Float])] {
48+
documents: [Document]
49+
) async throws -> [EmbeddedDocument] {
4350
try await withThrowingTaskGroup(
44-
of: (originalText: String, embeddings: [Float]).self
51+
of: (document: Document, embeddings: [Float]).self
4552
) { group in
46-
for text in texts {
53+
for document in documents {
4754
group.addTask {
4855
var retryCount = 6
4956
var previousError: Error?
5057
while retryCount > 0 {
5158
do {
52-
let embeddings = try await service.embed(text: text).data
59+
let embeddings = try await service.embed(text: document.pageContent)
60+
.data
5361
.map(\.embeddings).first ?? []
54-
return (text, embeddings)
62+
return (document, embeddings)
5563
} catch {
5664
retryCount -= 1
5765
previousError = error
@@ -60,27 +68,27 @@ extension OpenAIEmbedding {
6068
throw previousError ?? CancellationError()
6169
}
6270
}
63-
var all = [(originalText: String, embeddings: [Float])]()
71+
var all = [EmbeddedDocument]()
6472
for try await result in group {
65-
all.append(result)
73+
all.append(.init(document: result.document, embeddings: result.embeddings))
6674
}
6775
return all
6876
}
6977
}
7078

7179
func getLenSafeEmbeddings(
72-
texts: [String]
73-
) async throws -> [(originalText: String, embeddings: [Float])] {
80+
documents: [Document]
81+
) async throws -> [EmbeddedDocument] {
7482
struct Text {
75-
var rawText: String
83+
var document: Document
7684
var chunkedTokens: [[Int]]
7785
}
7886

79-
var texts = texts.map { Text(rawText: $0, chunkedTokens: []) }
87+
var texts = documents.map { Text(document: $0, chunkedTokens: []) }
8088
let encoding = TiktokenCl100kBaseTokenEncoder()
8189

8290
for (index, text) in texts.enumerated() {
83-
let token = encoding.encode(text: text.rawText)
91+
let token = encoding.encode(text: text.document.pageContent)
8492
// just incase the calculation is incorrect
8593
let maxToken = max(10, service.configuration.maxToken - 10)
8694

@@ -92,27 +100,28 @@ extension OpenAIEmbedding {
92100
}
93101

94102
let batchedEmbeddings = try await withThrowingTaskGroup(
95-
of: (String, [[Float]]).self
103+
of: (Document, [[Float]]).self
96104
) { group in
97105
for text in texts {
98106
group.addTask {
99107
var retryCount = 6
100108
var previousError: Error?
101-
guard !text.chunkedTokens.isEmpty else { return (text.rawText, []) }
109+
guard !text.chunkedTokens.isEmpty
110+
else { return (text.document, []) }
102111
while retryCount > 0 {
103112
do {
104113
if text.chunkedTokens.count <= 1 {
105114
// if possible, we should just let OpenAI do the tokenization.
106115
return (
107-
text.rawText,
108-
try await service.embed(text: text.rawText)
116+
text.document,
117+
try await service.embed(text: text.document.pageContent)
109118
.data
110119
.map(\.embeddings)
111120
)
112121
}
113122
if shouldAverageLongEmbeddings {
114123
return (
115-
text.rawText,
124+
text.document,
116125
try await service.embed(tokens: text.chunkedTokens)
117126
.data
118127
.map(\.embeddings)
@@ -121,7 +130,7 @@ extension OpenAIEmbedding {
121130
// if `shouldAverageLongEmbeddings` is false,
122131
// we only embed the first chunk to save some money.
123132
return (
124-
text.rawText,
133+
text.document,
125134
try await service.embed(tokens: [text.chunkedTokens.first ?? []])
126135
.data
127136
.map(\.embeddings)
@@ -134,21 +143,21 @@ extension OpenAIEmbedding {
134143
throw previousError ?? CancellationError()
135144
}
136145
}
137-
var result = [(originalText: String, embeddings: [[Float]])]()
146+
var result = [(document: Document, embeddings: [[Float]])]()
138147
for try await response in group {
139148
try Task.checkCancellation()
140149
result.append((response.0, response.1))
141150
}
142151
return result
143152
}
144153

145-
var results = [(originalText: String, embeddings: [Float])]()
154+
var results = [EmbeddedDocument]()
146155

147-
for (text, embeddings) in batchedEmbeddings {
156+
for (document, embeddings) in batchedEmbeddings {
148157
if embeddings.count == 1, let first = embeddings.first {
149-
results.append((text, first))
158+
results.append(.init(document: document, embeddings: first))
150159
} else if embeddings.isEmpty {
151-
results.append((text, []))
160+
results.append(.init(document: document, embeddings: []))
152161
} else if shouldAverageLongEmbeddings {
153162
// untested
154163
do {
@@ -162,14 +171,14 @@ extension OpenAIEmbedding {
162171
let normalized = average / numpy.linalg.norm(average)
163172
return [Float](normalized.tolist())
164173
}) else { throw CancellationError() }
165-
results.append((text, averagedEmbeddings))
174+
results.append(.init(document: document, embeddings: averagedEmbeddings))
166175
} catch {
167176
if let first = embeddings.first {
168-
results.append((text, first))
177+
results.append(.init(document: document, embeddings: first))
169178
}
170179
}
171180
} else if let first = embeddings.first {
172-
results.append((text, first))
181+
results.append(.init(document: document, embeddings: first))
173182
}
174183
}
175184

0 commit comments

Comments
 (0)