Skip to content

Commit dca1c0d

Browse files
committed
Support joining split documents
1 parent e6dee6b commit dca1c0d

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

Tool/Sources/LangChain/DocumentTransformer/TextSplitter.swift

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ public extension TextSplitter {
2626
for (text, metadata) in zip(texts, metadata) {
2727
let chunks = try await split(text: text)
2828
for chunk in chunks {
29+
var metadata = metadata
30+
metadata["startUTF16Offset"] = .number(Double(chunk.startUTF16Offset))
31+
metadata["endUTF16Offset"] = .number(Double(chunk.endUTF16Offset))
2932
let document = Document(pageContent: chunk.text, metadata: metadata)
3033
documents.append(document)
3134
}
@@ -48,6 +51,32 @@ public extension TextSplitter {
4851
func transformDocuments(_ documents: [Document]) async throws -> [Document] {
4952
return try await splitDocuments(documents)
5053
}
54+
55+
func joinDocuments(_ documents: [Document]) -> Document {
56+
let textChunks: [TextChunk] = documents.compactMap { document in
57+
func extract(_ key: String) -> Int? {
58+
if case let .number(d) = document.metadata[key] {
59+
return Int(d)
60+
}
61+
return nil
62+
}
63+
guard let start = extract("startUTF16Offset"),
64+
let end = extract("endUTF16Offset")
65+
else { return nil }
66+
return TextChunk(
67+
text: document.pageContent,
68+
startUTF16Offset: start,
69+
endUTF16Offset: end
70+
)
71+
}.sorted(by: { $0.startUTF16Offset < $1.startUTF16Offset })
72+
let mergedChunks = mergeSplits(textChunks)
73+
let pageContent = mergedChunks.map(\.text).joined()
74+
var metadata = documents.first?.metadata ?? [String: JSONValue]()
75+
metadata["startUTF16Offset"] = nil
76+
metadata["endUTF16Offset"] = nil
77+
78+
return Document(pageContent: pageContent, metadata: metadata)
79+
}
5180
}
5281

5382
public struct TextChunk: Equatable {
@@ -83,14 +112,14 @@ public extension TextSplitter {
83112
let text = (a + b).map(\.text).joined()
84113
var l = Int.max
85114
var u = 0
86-
115+
87116
for chunk in a + b {
88117
l = min(l, chunk.startUTF16Offset)
89118
u = max(u, chunk.endUTF16Offset)
90119
}
91-
120+
92121
guard l < u else { return nil }
93-
122+
94123
return .init(text: text, startUTF16Offset: l, endUTF16Offset: u)
95124
}
96125

0 commit comments

Comments
 (0)