Skip to content

Commit e73363f

Browse files
committed
Update TextSplitter to generate TextChunk
1 parent 815ac67 commit e73363f

3 files changed

Lines changed: 74 additions & 38 deletions

File tree

Tool/Sources/LangChain/DocumentTransformer/RecursiveCharacterTextSplitter.swift

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public class RecursiveCharacterTextSplitter: TextSplitter {
1919
/// - chunkOverlap: The maximum overlap between chunks.
2020
/// - lengthFunction: A function to compute the length of text.
2121
public init(
22-
separators: [String] = ["\n\n", "\r\n", "\n", "\r", " ", ""],
22+
separators: [String],
2323
chunkSize: Int = 4000,
2424
chunkOverlap: Int = 200,
2525
lengthFunction: @escaping (String) -> Int = { $0.count }
@@ -39,7 +39,7 @@ public class RecursiveCharacterTextSplitter: TextSplitter {
3939
/// - chunkOverlap: The maximum overlap between chunks.
4040
/// - lengthFunction: A function to compute the length of text.
4141
public init(
42-
separatorSet: TextSplitterSeparatorSet,
42+
separatorSet: TextSplitterSeparatorSet = .default,
4343
chunkSize: Int = 4000,
4444
chunkOverlap: Int = 200,
4545
lengthFunction: @escaping (String) -> Int = { $0.count }
@@ -51,12 +51,12 @@ public class RecursiveCharacterTextSplitter: TextSplitter {
5151
separators = separatorSet.separators
5252
}
5353

54-
public func split(text: String) async throws -> [String] {
55-
return split(text: text, separators: separators)
54+
public func split(text: String) async throws -> [TextChunk] {
55+
return split(text: text, separators: separators, startIndex: 0)
5656
}
5757

58-
private func split(text: String, separators: [String]) -> [String] {
59-
var finalChunks = [String]()
58+
private func split(text: String, separators: [String], startIndex: Int) -> [TextChunk] {
59+
var finalChunks = [TextChunk]()
6060

6161
// Get appropriate separator to use
6262
let firstSeparatorIndex = separators.firstIndex {
@@ -83,12 +83,12 @@ public class RecursiveCharacterTextSplitter: TextSplitter {
8383
nextSeparators = []
8484
}
8585

86-
let splits = split(text: text, separator: separator)
86+
let splits = split(text: text, separator: separator, startIndex: startIndex)
8787

8888
// Now go merging things, recursively splitting longer texts.
89-
var goodSplits = [String]()
89+
var goodSplits = [TextChunk]()
9090
for s in splits {
91-
if lengthFunction(s) < chunkSize {
91+
if lengthFunction(s.text) < chunkSize {
9292
goodSplits.append(s)
9393
} else {
9494
if !goodSplits.isEmpty {
@@ -99,7 +99,11 @@ public class RecursiveCharacterTextSplitter: TextSplitter {
9999
if nextSeparators.isEmpty {
100100
finalChunks.append(s)
101101
} else {
102-
let other_info = split(text: s, separators: nextSeparators)
102+
let other_info = split(
103+
text: s.text,
104+
separators: nextSeparators,
105+
startIndex: s.startUTF16Offset
106+
)
103107
finalChunks.append(contentsOf: other_info)
104108
}
105109
}

Tool/Sources/LangChain/DocumentTransformer/TextSplitter.swift

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public protocol TextSplitter: DocumentTransformer {
1111
var lengthFunction: (String) -> Int { get }
1212

1313
/// Split text into multiple components.
14-
func split(text: String) async throws -> [String]
14+
func split(text: String) async throws -> [TextChunk]
1515
}
1616

1717
public extension TextSplitter {
@@ -26,7 +26,7 @@ public extension TextSplitter {
2626
for (text, metadata) in zip(texts, metadata) {
2727
let chunks = try await split(text: text)
2828
for chunk in chunks {
29-
let document = Document(pageContent: chunk, metadata: metadata)
29+
let document = Document(pageContent: chunk.text, metadata: metadata)
3030
documents.append(document)
3131
}
3232
}
@@ -50,31 +50,48 @@ public extension TextSplitter {
5050
}
5151
}
5252

53+
public struct TextChunk {
54+
public var text: String
55+
public var startUTF16Offset: Int
56+
public var endUTF16Offset: Int
57+
}
58+
5359
public extension TextSplitter {
5460
/// Merge small splits to just fit in the chunk size.
55-
func mergeSplits(_ splits: [String]) -> [String] {
61+
func mergeSplits(_ splits: [TextChunk]) -> [TextChunk] {
5662
let chunkOverlap = chunkOverlap < chunkSize ? chunkOverlap : 0
5763

58-
var chunks = [String]()
59-
var currentChunk = [String]()
60-
var overlappingChunks = [String]()
64+
var chunks = [TextChunk]()
65+
var currentChunk = [TextChunk]()
66+
var overlappingChunks = [TextChunk]()
6167
var currentChunkSize = 0
62-
63-
func join(_ a: [String], _ b: [String]) -> String {
64-
return (a + b).joined().trimmingCharacters(in: .whitespaces)
68+
69+
func join(_ a: [TextChunk], _ b: [TextChunk]) -> TextChunk? {
70+
let text = (a + b).map(\.text).joined()
71+
var l = Int.max
72+
var u = 0
73+
74+
for chunk in a + b {
75+
l = min(l, chunk.startUTF16Offset)
76+
u = max(u, chunk.endUTF16Offset)
77+
}
78+
79+
guard l < u else { return nil }
80+
81+
return .init(text: text, startUTF16Offset: l, endUTF16Offset: u)
6582
}
6683

67-
for text in splits {
68-
let textLength = lengthFunction(text)
84+
for chunk in splits {
85+
let textLength = lengthFunction(chunk.text)
6986
if currentChunkSize + textLength > chunkSize {
70-
let currentChunkText = join(overlappingChunks, currentChunk)
87+
guard let currentChunkText = join(overlappingChunks, currentChunk) else { continue }
7188
chunks.append(currentChunkText)
7289

7390
overlappingChunks = []
7491
var overlappingSize = 0
7592
// use small chunks as overlap if possible
7693
for chunk in currentChunk.reversed() {
77-
let length = lengthFunction(chunk)
94+
let length = lengthFunction(chunk.text)
7895
if overlappingSize + length > chunkOverlap { break }
7996
if overlappingSize + length + textLength > chunkSize { break }
8097
overlappingSize += length
@@ -90,46 +107,57 @@ public extension TextSplitter {
90107
// }
91108

92109
currentChunkSize = overlappingSize + textLength
93-
currentChunk = [text]
110+
currentChunk = [chunk]
94111
} else {
95112
currentChunkSize += textLength
96-
currentChunk.append(text)
113+
currentChunk.append(chunk)
97114
}
98115
}
99116

100-
if !currentChunk.isEmpty {
101-
chunks.append(join(overlappingChunks, currentChunk))
117+
if !currentChunk.isEmpty, let joinedChunks = join(overlappingChunks, currentChunk) {
118+
chunks.append(joinedChunks)
119+
} else {
120+
chunks.append(contentsOf: overlappingChunks)
121+
chunks.append(contentsOf: currentChunk)
102122
}
103123

104124
return chunks
105125
}
106126

107127
/// Split the text by separator.
108-
func split(text: String, separator: String) -> [String] {
109-
guard !separator.isEmpty else {
110-
return [text]
111-
}
112-
128+
func split(text: String, separator: String, startIndex: Int = 0) -> [TextChunk] {
113129
let pattern = "(\(separator))"
114-
if let regex = try? NSRegularExpression(pattern: pattern) {
130+
if !separator.isEmpty, let regex = try? NSRegularExpression(pattern: pattern) {
115131
let matches = regex.matches(in: text, range: NSRange(text.startIndex..., in: text))
116-
var all = [String]()
132+
var all = [TextChunk]()
117133
var start = text.startIndex
118134
for match in matches {
119135
guard let range = Range(match.range, in: text) else { break }
120136
guard range.lowerBound > start else { break }
121137
let result = text[start..<range.lowerBound]
122-
start = range.lowerBound
123138
if !result.isEmpty {
124-
all.append(String(result))
139+
all.append(.init(
140+
text: String(result),
141+
startUTF16Offset: start.utf16Offset(in: text) + startIndex,
142+
endUTF16Offset: range.lowerBound.utf16Offset(in: text) + startIndex
143+
))
125144
}
145+
start = range.lowerBound
126146
}
127147
if start < text.endIndex {
128-
all.append(String(text[start...]))
148+
all.append(.init(
149+
text: String(text[start...]),
150+
startUTF16Offset: start.utf16Offset(in: text) + startIndex,
151+
endUTF16Offset: text.endIndex.utf16Offset(in: text) + startIndex
152+
))
129153
}
130154
return all
131155
} else {
132-
return [text]
156+
return [.init(
157+
text: text,
158+
startUTF16Offset: text.startIndex.utf16Offset(in: text) + startIndex,
159+
endUTF16Offset: text.endIndex.utf16Offset(in: text) + startIndex
160+
)]
133161
}
134162
}
135163
}

Tool/Sources/LangChain/DocumentTransformer/TextSplitterSeparatorSet.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,5 +347,9 @@ public struct TextSplitterSeparatorSet: ExpressibleByArrayLiteral {
347347
"",
348348
]
349349
}
350+
351+
public static var `default`: TextSplitterSeparatorSet {
352+
["\n\n", "\r\n", "\n", "\r", " ", ""]
353+
}
350354
}
351355

0 commit comments

Comments
 (0)