|
1 | 1 | import Foundation |
2 | | -import JSONRPC |
3 | 2 |
|
4 | 3 | public protocol DocumentTransformer { |
5 | 4 | func transformDocuments(_ documents: [Document]) async throws -> [Document] |
6 | 5 | } |
7 | | - |
8 | | -public protocol TextSplitter: DocumentTransformer { |
9 | | - var chunkSize: Int { get } |
10 | | - var chunkOverlap: Int { get } |
11 | | - var lengthFunction: (String) -> Int { get } |
12 | | - |
13 | | - /// Split text into multiple components. |
14 | | - func split(text: String) async throws -> [String] |
15 | | -} |
16 | | - |
17 | | -public extension TextSplitter { |
18 | | - /// Create documents from a list of texts. |
19 | | - func createDocuments( |
20 | | - texts: [String], |
21 | | - metadata: [JSONValue] = [] |
22 | | - ) async throws -> [Document] { |
23 | | - var documents = [Document]() |
24 | | - let paddingLength = texts.count - metadata.count |
25 | | - let metadata = metadata + .init(repeating: [:], count: paddingLength) |
26 | | - for (text, metadata) in zip(texts, metadata) { |
27 | | - let trunks = try await split(text: text) |
28 | | - for trunk in trunks { |
29 | | - let document = Document(pageContent: trunk, metadata: metadata) |
30 | | - documents.append(document) |
31 | | - } |
32 | | - } |
33 | | - return documents |
34 | | - } |
35 | | - |
36 | | - /// Split documents. |
37 | | - func splitDocuments(_ documents: [Document]) async throws -> [Document] { |
38 | | - var texts = [String]() |
39 | | - var metadata = [JSONValue]() |
40 | | - for document in documents { |
41 | | - texts.append(document.pageContent) |
42 | | - metadata.append(document.metadata) |
43 | | - } |
44 | | - return try await createDocuments(texts: texts, metadata: metadata) |
45 | | - } |
46 | | - |
47 | | - /// Transform sequence of documents by splitting them. |
48 | | - func transformDocuments(_ documents: [Document]) async throws -> [Document] { |
49 | | - return try await splitDocuments(documents) |
50 | | - } |
51 | | -} |
52 | | - |
53 | | -public extension TextSplitter { |
54 | | - /// Merge small splits to just fit in the chunk size. |
55 | | - func mergeSplits(_ splits: [String]) -> [String] { |
56 | | - let chunkOverlap = chunkOverlap < chunkSize ? chunkOverlap : 0 |
57 | | - |
58 | | - var chunks = [String]() |
59 | | - var currentChunk = [String]() |
60 | | - var overlappingChunks = [String]() |
61 | | - var currentChunkSize = 0 |
62 | | - |
63 | | - func join(_ a: [String], _ b: [String]) -> String { |
64 | | - return (a + b).joined().trimmingCharacters(in: .whitespaces) |
65 | | - } |
66 | | - |
67 | | - for text in splits { |
68 | | - let textLength = lengthFunction(text) |
69 | | - if currentChunkSize + textLength > chunkSize { |
70 | | - let currentChunkText = join(overlappingChunks, currentChunk) |
71 | | - chunks.append(currentChunkText) |
72 | | - |
73 | | - overlappingChunks = [] |
74 | | - var overlappingSize = 0 |
75 | | - // use small chunks as overlap if possible |
76 | | - for chunk in currentChunk.reversed() { |
77 | | - let length = lengthFunction(chunk) |
78 | | - if overlappingSize + length > chunkOverlap { break } |
79 | | - if overlappingSize + length + textLength > chunkSize { break } |
80 | | - overlappingSize += length |
81 | | - overlappingChunks.insert(chunk, at: 0) |
82 | | - } |
83 | | -// // fallback to use suffix if no small chunk found |
84 | | -// if overlappingChunks.isEmpty { |
85 | | -// let suffix = String( |
86 | | -// currentChunkText.suffix(min(chunkOverlap, chunkSize - textLength)) |
87 | | -// ) |
88 | | -// overlappingChunks.append(suffix) |
89 | | -// overlappingSize = lengthFunction(suffix) |
90 | | -// } |
91 | | - |
92 | | - currentChunkSize = overlappingSize + textLength |
93 | | - currentChunk = [text] |
94 | | - } else { |
95 | | - currentChunkSize += textLength |
96 | | - currentChunk.append(text) |
97 | | - } |
98 | | - } |
99 | | - |
100 | | - if !currentChunk.isEmpty { |
101 | | - chunks.append(join(overlappingChunks, currentChunk)) |
102 | | - } |
103 | | - |
104 | | - return chunks |
105 | | - } |
106 | | - |
107 | | - /// Split the text by separator. |
108 | | - func split(text: String, separator: String) -> [String] { |
109 | | - guard !separator.isEmpty else { |
110 | | - return [text] |
111 | | - } |
112 | | - |
113 | | - let pattern = "(\(separator))" |
114 | | - if let regex = try? NSRegularExpression(pattern: pattern) { |
115 | | - let matches = regex.matches(in: text, range: NSRange(text.startIndex..., in: text)) |
116 | | - var all = [String]() |
117 | | - var start = text.startIndex |
118 | | - for match in matches { |
119 | | - guard let range = Range(match.range, in: text) else { break } |
120 | | - guard range.lowerBound > start else { break } |
121 | | - let result = text[start..<range.lowerBound] |
122 | | - start = range.lowerBound |
123 | | - if !result.isEmpty { |
124 | | - all.append(String(result)) |
125 | | - } |
126 | | - } |
127 | | - if start < text.endIndex { |
128 | | - all.append(String(text[start...])) |
129 | | - } |
130 | | - return all |
131 | | - } else { |
132 | | - return [text] |
133 | | - } |
134 | | - } |
135 | | -} |
136 | | - |
0 commit comments