-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Expand file tree
/
Copy pathConversationStorage.swift
More file actions
245 lines (215 loc) · 9.32 KB
/
ConversationStorage.swift
File metadata and controls
245 lines (215 loc) · 9.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import Foundation
import SQLite
public protocol ConversationStorageProtocol {
func fetchTurnItems(for conversationID: String) throws -> [TurnItem]
func fetchConversationItems(_ type: ConversationFetchType) throws -> [ConversationItem]
func operate(_ request: OperationRequest) throws
}
public final class ConversationStorage: ConversationStorageProtocol {
static let BusyTimeout: Double = 5 // error after 5 seconds
private var path: String
private var db: Connection?
let conversationTable = ConversationTable()
let turnTable = TurnTable()
public init(_ path: String) throws {
guard !path.isEmpty else { throw DatabaseError.invalidPath(path) }
self.path = path
do {
let db = try Connection(path)
db.busyTimeout = ConversationStorage.BusyTimeout
self.db = db
} catch {
throw DatabaseError.connectionFailed(error.localizedDescription)
}
}
deinit { db = nil }
private func withDB<T>(_ operation: (Connection) throws -> T) throws -> T {
guard let db = self.db else {
throw DatabaseError.connectionLost
}
return try operation(db)
}
private func withDBTransaction(_ operation: (Connection) throws -> Void) throws {
guard let db = self.db else {
throw DatabaseError.connectionLost
}
try db.transaction {
try operation(db)
}
}
public func createTableIfNeeded() throws {
try withDB { db in
try db.execute("""
BEGIN TRANSACTION;
CREATE TABLE IF NOT EXISTS Conversation (
id TEXT NOT NULL PRIMARY KEY,
title TEXT,
isSelected INTEGER NOT NULL,
CLSConversationID TEXT,
data BLOB NOT NULL,
createdAt REAL DEFAULT (strftime('%s','now')),
updatedAt REAL DEFAULT (strftime('%s','now'))
);
CREATE TABLE IF NOT EXISTS Turn (
rowID INTEGER PRIMARY KEY AUTOINCREMENT,
id TEXT NOT NULL UNIQUE,
conversationID TEXT NOT NULL,
CLSTurnID TEXT,
role TEXT NOT NULL,
data BLOB NOT NULL,
createdAt REAL DEFAULT (strftime('%s','now')),
updatedAt REAL DEFAULT (strftime('%s','now')),
UNIQUE (conversationID, id)
);
COMMIT TRANSACTION;
""")
}
}
public func operate(_ request: OperationRequest) throws {
guard request.operations.count > 0 else { return }
try withDBTransaction { db in
let now = Date().timeIntervalSince1970
for operation in request.operations {
switch operation {
case .upsertConversation(let conversationItems):
for conversationItems in conversationItems {
try db.run(
conversationTable.table.upsert(
conversationTable.column.id <- conversationItems.id,
conversationTable.column.title <- conversationItems.title,
conversationTable.column.isSelected <- conversationItems.isSelected,
conversationTable.column.CLSConversationID <- conversationItems.CLSConversationID ?? "",
conversationTable.column.data <- conversationItems.data.toBlob(),
conversationTable.column.createdAt <- conversationItems.createdAt.timeIntervalSince1970,
conversationTable.column.updatedAt <- conversationItems.updatedAt.timeIntervalSince1970,
onConflictOf: conversationTable.column.id
)
)
}
case .upsertTurn(let turnItems):
for turnItem in turnItems {
try db.run(
turnTable.table.upsert(
turnTable.column.conversationID <- turnItem.conversationID,
turnTable.column.id <- turnItem.id,
turnTable.column.CLSTurnID <- turnItem.CLSTurnID ?? "",
turnTable.column.role <- turnItem.role,
turnTable.column.data <- turnItem.data.toBlob(),
turnTable.column.createdAt <- turnItem.createdAt.timeIntervalSince1970,
turnTable.column.updatedAt <- turnItem.updatedAt.timeIntervalSince1970,
onConflictOf: SQLite.Expression<Void>(literal: "\"conversationID\", \"id\"")
)
)
}
case .delete(let deleteItems):
for deleteItem in deleteItems {
switch deleteItem {
case let .conversation(id):
try db.run(conversationTable.table.filter(conversationTable.column.id == id).delete())
case .turn(let id):
try db.run(turnTable.table.filter(conversationTable.column.id == id).delete())
case .turnByConversationID(let conversationID):
try db.run(turnTable.table.filter(turnTable.column.conversationID == conversationID).delete())
}
}
}
}
}
}
public func fetchTurnItems(for conversationID: String) throws -> [TurnItem] {
var items: [TurnItem] = []
try withDB { db in
let table = turnTable.table
let column = turnTable.column
var query = table
.filter(column.conversationID == conversationID)
.order(column.rowID.asc)
let rowIterator = try db.prepareRowIterator(query)
items = try rowIterator.map { row in
TurnItem(
id: row[column.id],
conversationID: row[column.conversationID],
CLSTurnID: row[column.CLSTurnID],
role: row[column.role],
data: row[column.data].toString(),
createdAt: row[column.createdAt].toDate(),
updatedAt: row[column.updatedAt].toDate()
)
}
}
return items
}
public func fetchConversationItems(_ type: ConversationFetchType) throws -> [ConversationItem] {
var items: [ConversationItem] = []
try withDB { db in
let table = conversationTable.table
let column = conversationTable.column
var query = table
switch type {
case .all:
query = query.order(column.updatedAt.desc)
case .selected:
query = query
.filter(column.isSelected == true)
.limit(1)
case .latest:
query = query
.order(column.updatedAt.desc)
.limit(1)
case .id(let id):
query = query
.filter(conversationTable.column.id == id)
.limit(1)
}
let rowIterator = try db.prepareRowIterator(query)
items = try rowIterator.map { row in
ConversationItem(
id: row[column.id],
title: row[column.title],
isSelected: row[column.isSelected],
CLSConversationID: row[column.CLSConversationID],
data: row[column.data].toString(),
createdAt: row[column.createdAt].toDate(),
updatedAt: row[column.updatedAt].toDate()
)
}
}
return items
}
public func fetchConversationPreviewItems() throws -> [ConversationPreviewItem] {
var items: [ConversationPreviewItem] = []
try withDB { db in
let table = conversationTable.table
let column = conversationTable.column
let query = table
.select(column.id, column.title, column.isSelected, column.updatedAt)
.order(column.updatedAt.desc)
let rowIterator = try db.prepareRowIterator(query)
items = try rowIterator.map { row in
ConversationPreviewItem(
id: row[column.id],
title: row[column.title],
isSelected: row[column.isSelected],
updatedAt: row[column.updatedAt].toDate()
)
}
}
return items
}
}
extension String {
func toBlob() -> Blob {
let data = self.data(using: .utf8) ?? Data() // TODO: handle exception
return Blob(bytes: [UInt8](data))
}
}
extension Blob {
func toString() -> String {
return String(data: Data(bytes), encoding: .utf8) ?? ""
}
}
extension Double {
func toDate() -> Date {
return Date(timeIntervalSince1970: self)
}
}