|
1 | 1 | import Foundation |
2 | | -import PythonHelper |
3 | | -import PythonKit |
4 | 2 |
|
5 | | -public struct RecursiveCharacterTextSplitter: TextSplitter { |
| 3 | +class RecursiveCharacterTextSplitter: TextSplitter { |
| 4 | + /** |
| 5 | + Implementation of splitting text that looks at characters. |
| 6 | + Recursively tries to split by different characters to find one that works. |
| 7 | + */ |
6 | 8 | public var separators: [String] |
7 | 9 | public var chunkSize: Int |
8 | 10 | public var chunkOverlap: Int |
9 | | - |
10 | | - public init( |
| 11 | + public var lengthFunction: (String) -> Int |
| 12 | + |
| 13 | + init( |
11 | 14 | separators: [String] = ["\n\n", "\n", " ", ""], |
12 | 15 | chunkSize: Int = 4000, |
13 | | - chunkOverlap: Int = 200 |
| 16 | + chunkOverlap: Int = 200, |
| 17 | + lengthFunction: @escaping (String) -> Int = { $0.count } |
14 | 18 | ) { |
| 19 | + assert(chunkOverlap <= chunkSize) |
| 20 | + self.chunkSize = chunkSize |
| 21 | + self.chunkOverlap = chunkOverlap |
| 22 | + self.lengthFunction = lengthFunction |
15 | 23 | self.separators = separators |
| 24 | + } |
| 25 | + |
| 26 | + init( |
| 27 | + separatorSet: TextSplitterSeparatorSet, |
| 28 | + chunkSize: Int = 4000, |
| 29 | + chunkOverlap: Int = 200, |
| 30 | + lengthFunction: @escaping (String) -> Int = { $0.count } |
| 31 | + ) { |
| 32 | + assert(chunkOverlap <= chunkSize) |
16 | 33 | self.chunkSize = chunkSize |
17 | 34 | self.chunkOverlap = chunkOverlap |
| 35 | + self.lengthFunction = lengthFunction |
| 36 | + separators = separatorSet.separators |
18 | 37 | } |
19 | 38 |
|
20 | 39 | public func split(text: String) async throws -> [String] { |
21 | | - try await runPython { |
22 | | - let text_splitter = try Python.attemptImportOnPythonThread("langchain.text_splitter") |
23 | | - let PythonRecursiveCharacterTextSplitter = text_splitter.RecursiveCharacterTextSplitter |
24 | | - let splitter = PythonRecursiveCharacterTextSplitter( |
25 | | - separators: separators, |
26 | | - chunk_size: chunkSize, |
27 | | - chunk_overlap: chunkOverlap |
28 | | -// length_function: PythonFunction({ object in |
29 | | -// if let string = String(object) { return string.count } |
30 | | -// return 0 |
31 | | -// }) |
32 | | - ) |
33 | | - let result = splitter.split_text(text) |
34 | | - guard let array = [String](result) else { return [] } |
35 | | - return array |
| 40 | + return split(text: text, separators: separators) |
| 41 | + } |
| 42 | + |
| 43 | + private func split(text: String, separators: [String]) -> [String] { |
| 44 | + var finalChunks = [String]() |
| 45 | + |
| 46 | + // Get appropriate separator to use |
| 47 | + let firstSeparatorIndex = separators.firstIndex { |
| 48 | + let pattern = "(\($0))" |
| 49 | + guard let regex = try? NSRegularExpression(pattern: pattern) else { return false } |
| 50 | + return regex.firstMatch( |
| 51 | + in: text, |
| 52 | + options: [], |
| 53 | + range: NSRange(text.startIndex..., in: text) |
| 54 | + ) != nil |
| 55 | + } |
| 56 | + var separator: String |
| 57 | + var nextSeparators: [String] |
| 58 | + |
| 59 | + if let index = firstSeparatorIndex { |
| 60 | + separator = separators[index] |
| 61 | + if index < separators.endIndex - 1 { |
| 62 | + nextSeparators = Array(separators[(index + 1)...]) |
| 63 | + } else { |
| 64 | + nextSeparators = [] |
| 65 | + } |
| 66 | + } else { |
| 67 | + separator = "" |
| 68 | + nextSeparators = [] |
| 69 | + } |
| 70 | + |
| 71 | + let splits = split(text: text, separator: separator) |
| 72 | + |
| 73 | + // Now go merging things, recursively splitting longer texts. |
| 74 | + var goodSplits = [String]() |
| 75 | + for s in splits { |
| 76 | + if lengthFunction(s) < chunkSize { |
| 77 | + goodSplits.append(s) |
| 78 | + } else { |
| 79 | + if !goodSplits.isEmpty { |
| 80 | + let mergedText = mergeSplits(goodSplits) |
| 81 | + finalChunks.append(contentsOf: mergedText) |
| 82 | + goodSplits.removeAll() |
| 83 | + } |
| 84 | + if nextSeparators.isEmpty { |
| 85 | + finalChunks.append(s) |
| 86 | + } else { |
| 87 | + let other_info = split(text: s, separators: nextSeparators) |
| 88 | + finalChunks.append(contentsOf: other_info) |
| 89 | + } |
| 90 | + } |
| 91 | + } |
| 92 | + if !goodSplits.isEmpty { |
| 93 | + let merged_text = mergeSplits(goodSplits) |
| 94 | + finalChunks.append(contentsOf: merged_text) |
36 | 95 | } |
| 96 | + return finalChunks |
37 | 97 | } |
38 | 98 | } |
39 | 99 |
|
0 commit comments