Skip to content

Commit c844b1b

Browse files
committed
Update QueryWebsiteFunction to use RetrievalQAChain
1 parent 1b39b67 commit c844b1b

2 files changed

Lines changed: 41 additions & 40 deletions

File tree

Core/Sources/ChatContextCollectors/WebChatContextCollector/QueryWebsiteFunction.swift

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,10 @@ struct QueryWebsiteFunction: ChatGPTFunction {
1010
}
1111

1212
struct Result: ChatGPTFunctionResult {
13-
var relevantDocuments: [Document]
13+
var answers: [String]
1414

1515
var botReadableContent: String {
16-
// don't forget to remove overlaps
17-
if relevantDocuments.isEmpty {
18-
return "No relevant information found"
19-
}
20-
return relevantDocuments.map(\.pageContent).joined(separator: "\n\n")
16+
return answers.joined(separator: "\n")
2117
}
2218
}
2319

@@ -57,66 +53,71 @@ struct QueryWebsiteFunction: ChatGPTFunction {
5753

5854
func call(arguments: Arguments) async throws -> Result {
5955
do {
60-
let embedding = OpenAIEmbedding(
61-
configuration: UserPreferenceEmbeddingConfiguration()
62-
)
56+
let embedding = OpenAIEmbedding(configuration: UserPreferenceEmbeddingConfiguration())
6357

64-
let queryEmbeddings = try await embedding.embed(query: arguments.query)
65-
let searchCount = UserDefaults.shared.value(for: \.chatGPTMaxToken) > 5000 ? 3 : 20
66-
67-
let result = try await withThrowingTaskGroup(
68-
of: [(document: Document, distance: Float)].self
69-
) { group in
58+
let result = try await withThrowingTaskGroup(of: String.self) { group in
7059
for urlString in arguments.urls {
7160
guard let url = URL(string: urlString) else { continue }
7261
group.addTask {
73-
if let database = await TemporaryUSearch.view(identifier: urlString) {
74-
return try await database.searchWithDistance(
75-
embeddings: queryEmbeddings,
76-
count: searchCount
77-
)
78-
}
7962
// 1. grab the website content
8063
await reportProgress("Loading \(url)..")
81-
print("== load \(url)")
64+
65+
if let database = await TemporaryUSearch.view(identifier: urlString) {
66+
await reportProgress("Generating answers..")
67+
let qa = RetrievalQAChain(vectorStore: database, embedding: embedding) {
68+
OpenAIChat(
69+
configuration: UserPreferenceChatGPTConfiguration()
70+
.overriding(.init(temperature: 0)),
71+
stream: true
72+
)
73+
}
74+
return try await qa.call(.init(arguments.query)).answer
75+
}
8276
let loader = WebLoader(urls: [url])
8377
let documents = try await loader.load()
8478
await reportProgress("Processing \(url)..")
85-
print("== loaded \(url), documents: \(documents.count)")
8679
// 2. split the content
8780
let splitter = RecursiveCharacterTextSplitter(
8881
chunkSize: 1000,
8982
chunkOverlap: 100
9083
)
9184
let splitDocuments = try await splitter.transformDocuments(documents)
92-
print("== split \(url), documents: \(splitDocuments.count)")
9385
// 3. embedding and store in db
9486
await reportProgress("Embedding \(url)..")
9587
let embeddedDocuments = try await embedding.embed(documents: splitDocuments)
96-
print("== embedded \(url)")
9788
let database = TemporaryUSearch(identifier: urlString)
9889
try await database.set(embeddedDocuments)
99-
print("== save to database \(url)")
100-
let result = try await database.searchWithDistance(
101-
embeddings: queryEmbeddings,
102-
count: searchCount
103-
)
104-
print("== result of \(url): \(result)")
105-
return result
90+
// 4. generate answer
91+
await reportProgress("Generating answers..")
92+
let qa = RetrievalQAChain(vectorStore: database, embedding: embedding) {
93+
OpenAIChat(
94+
configuration: UserPreferenceChatGPTConfiguration()
95+
.overriding(.init(temperature: 0)),
96+
stream: true
97+
)
98+
}
99+
let result = try await qa.call(.init(arguments.query))
100+
return result.answer
106101
}
107102
}
108103

109-
var all = [(document: Document, distance: Float)]()
104+
var all = [String]()
110105
for try await result in group {
111-
all.append(contentsOf: result)
106+
all.append(result)
112107
}
113-
await reportProgress("Finish reading websites.")
108+
await reportProgress("""
109+
Finish reading websites.
110+
\(
111+
arguments.urls
112+
.map { "- [\($0)](\($0))" }
113+
.joined(separator: "\n")
114+
)
115+
""")
116+
114117
return all
115-
.sorted { $0.distance < $1.distance }
116-
.prefix(searchCount)
117118
}
118119

119-
return .init(relevantDocuments: result.map(\.document))
120+
return .init(answers: result)
120121
} catch {
121122
await reportProgress("Failed reading websites.")
122123
throw error

Tool/Sources/LangChain/Chains/RetrievalQA.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ public final class RetrievalQAChain: Chain {
66
let chatModelFactory: () -> ChatModel
77

88
public struct Output {
9-
var answer: String
10-
var sourceDocuments: [Document]
9+
public var answer: String
10+
public var sourceDocuments: [Document]
1111
}
1212

1313
public init(

0 commit comments

Comments
 (0)