Skip to content

Commit d1d205e

Browse files
committed
Add native implementation of RecursiveCharacterTextSplitter
1 parent 3f4706e commit d1d205e

6 files changed

Lines changed: 616 additions & 28 deletions

File tree

Tool/Sources/LangChain/DocumentTransformer/DocumentTransformer.swift

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ public protocol DocumentTransformer {
66
}
77

88
public protocol TextSplitter: DocumentTransformer {
9+
var chunkSize: Int { get }
10+
var chunkOverlap: Int { get }
11+
var lengthFunction: (String) -> Int { get }
12+
913
/// Split text into multiple components.
1014
func split(text: String) async throws -> [String]
1115
}
@@ -46,5 +50,87 @@ public extension TextSplitter {
4650
}
4751
}
4852

49-
extension TextSplitter {}
53+
public extension TextSplitter {
54+
/// Merge small splits to just fit in the chunk size.
55+
func mergeSplits(_ splits: [String]) -> [String] {
56+
let chunkOverlap = chunkOverlap < chunkSize ? chunkOverlap : 0
57+
58+
var chunks = [String]()
59+
var currentChunk = [String]()
60+
var overlappingChunks = [String]()
61+
var currentChunkSize = 0
62+
63+
func join(_ a: [String], _ b: [String]) -> String {
64+
return (a + b).joined().trimmingCharacters(in: .whitespaces)
65+
}
66+
67+
for text in splits {
68+
let textLength = lengthFunction(text)
69+
if currentChunkSize + textLength > chunkSize {
70+
let currentChunkText = join(overlappingChunks, currentChunk)
71+
chunks.append(currentChunkText)
72+
73+
overlappingChunks = []
74+
var overlappingSize = 0
75+
// use small chunks as overlap if possible
76+
for chunk in currentChunk.reversed() {
77+
let length = lengthFunction(chunk)
78+
if overlappingSize + length > chunkOverlap { break }
79+
if overlappingSize + length + textLength > chunkSize { break }
80+
overlappingSize += length
81+
overlappingChunks.insert(chunk, at: 0)
82+
}
83+
// // fallback to use suffix if no small chunk found
84+
// if overlappingChunks.isEmpty {
85+
// let suffix = String(
86+
// currentChunkText.suffix(min(chunkOverlap, chunkSize - textLength))
87+
// )
88+
// overlappingChunks.append(suffix)
89+
// overlappingSize = lengthFunction(suffix)
90+
// }
91+
92+
currentChunkSize = overlappingSize + textLength
93+
currentChunk = [text]
94+
} else {
95+
currentChunkSize += textLength
96+
currentChunk.append(text)
97+
}
98+
}
99+
100+
if !currentChunk.isEmpty {
101+
chunks.append(join(overlappingChunks, currentChunk))
102+
}
103+
104+
return chunks
105+
}
106+
107+
/// Split the text by separator.
108+
func split(text: String, separator: String) -> [String] {
109+
guard !separator.isEmpty else {
110+
return [text]
111+
}
112+
113+
let pattern = "(\(separator))"
114+
if let regex = try? NSRegularExpression(pattern: pattern) {
115+
let matches = regex.matches(in: text, range: NSRange(text.startIndex..., in: text))
116+
var all = [String]()
117+
var start = text.startIndex
118+
for match in matches {
119+
guard let range = Range(match.range, in: text) else { break }
120+
guard range.lowerBound > start else { break }
121+
let result = text[start..<range.lowerBound]
122+
start = range.lowerBound
123+
if !result.isEmpty {
124+
all.append(String(result))
125+
}
126+
}
127+
if start < text.endIndex {
128+
all.append(String(text[start...]))
129+
}
130+
return all
131+
} else {
132+
return [text]
133+
}
134+
}
135+
}
50136

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,99 @@
11
import Foundation
2-
import PythonHelper
3-
import PythonKit
42

5-
public struct RecursiveCharacterTextSplitter: TextSplitter {
3+
class RecursiveCharacterTextSplitter: TextSplitter {
4+
/**
5+
Implementation of splitting text that looks at characters.
6+
Recursively tries to split by different characters to find one that works.
7+
*/
68
public var separators: [String]
79
public var chunkSize: Int
810
public var chunkOverlap: Int
9-
10-
public init(
11+
public var lengthFunction: (String) -> Int
12+
13+
init(
1114
separators: [String] = ["\n\n", "\n", " ", ""],
1215
chunkSize: Int = 4000,
13-
chunkOverlap: Int = 200
16+
chunkOverlap: Int = 200,
17+
lengthFunction: @escaping (String) -> Int = { $0.count }
1418
) {
19+
assert(chunkOverlap <= chunkSize)
20+
self.chunkSize = chunkSize
21+
self.chunkOverlap = chunkOverlap
22+
self.lengthFunction = lengthFunction
1523
self.separators = separators
24+
}
25+
26+
init(
27+
separatorSet: TextSplitterSeparatorSet,
28+
chunkSize: Int = 4000,
29+
chunkOverlap: Int = 200,
30+
lengthFunction: @escaping (String) -> Int = { $0.count }
31+
) {
32+
assert(chunkOverlap <= chunkSize)
1633
self.chunkSize = chunkSize
1734
self.chunkOverlap = chunkOverlap
35+
self.lengthFunction = lengthFunction
36+
separators = separatorSet.separators
1837
}
1938

2039
public func split(text: String) async throws -> [String] {
21-
try await runPython {
22-
let text_splitter = try Python.attemptImportOnPythonThread("langchain.text_splitter")
23-
let PythonRecursiveCharacterTextSplitter = text_splitter.RecursiveCharacterTextSplitter
24-
let splitter = PythonRecursiveCharacterTextSplitter(
25-
separators: separators,
26-
chunk_size: chunkSize,
27-
chunk_overlap: chunkOverlap
28-
// length_function: PythonFunction({ object in
29-
// if let string = String(object) { return string.count }
30-
// return 0
31-
// })
32-
)
33-
let result = splitter.split_text(text)
34-
guard let array = [String](result) else { return [] }
35-
return array
40+
return split(text: text, separators: separators)
41+
}
42+
43+
private func split(text: String, separators: [String]) -> [String] {
44+
var finalChunks = [String]()
45+
46+
// Get appropriate separator to use
47+
let firstSeparatorIndex = separators.firstIndex {
48+
let pattern = "(\($0))"
49+
guard let regex = try? NSRegularExpression(pattern: pattern) else { return false }
50+
return regex.firstMatch(
51+
in: text,
52+
options: [],
53+
range: NSRange(text.startIndex..., in: text)
54+
) != nil
55+
}
56+
var separator: String
57+
var nextSeparators: [String]
58+
59+
if let index = firstSeparatorIndex {
60+
separator = separators[index]
61+
if index < separators.endIndex - 1 {
62+
nextSeparators = Array(separators[(index + 1)...])
63+
} else {
64+
nextSeparators = []
65+
}
66+
} else {
67+
separator = ""
68+
nextSeparators = []
69+
}
70+
71+
let splits = split(text: text, separator: separator)
72+
73+
// Now go merging things, recursively splitting longer texts.
74+
var goodSplits = [String]()
75+
for s in splits {
76+
if lengthFunction(s) < chunkSize {
77+
goodSplits.append(s)
78+
} else {
79+
if !goodSplits.isEmpty {
80+
let mergedText = mergeSplits(goodSplits)
81+
finalChunks.append(contentsOf: mergedText)
82+
goodSplits.removeAll()
83+
}
84+
if nextSeparators.isEmpty {
85+
finalChunks.append(s)
86+
} else {
87+
let other_info = split(text: s, separators: nextSeparators)
88+
finalChunks.append(contentsOf: other_info)
89+
}
90+
}
91+
}
92+
if !goodSplits.isEmpty {
93+
let merged_text = mergeSplits(goodSplits)
94+
finalChunks.append(contentsOf: merged_text)
3695
}
96+
return finalChunks
3797
}
3898
}
3999

0 commit comments

Comments
 (0)