@@ -10,14 +10,10 @@ struct QueryWebsiteFunction: ChatGPTFunction {
1010 }
1111
1212 struct Result : ChatGPTFunctionResult {
13- var relevantDocuments : [ Document ]
13+ var answers : [ String ]
1414
1515 var botReadableContent : String {
16- // don't forget to remove overlaps
17- if relevantDocuments. isEmpty {
18- return " No relevant information found "
19- }
20- return relevantDocuments. map ( \. pageContent) . joined ( separator: " \n \n " )
16+ return answers. joined ( separator: " \n " )
2117 }
2218 }
2319
@@ -57,66 +53,71 @@ struct QueryWebsiteFunction: ChatGPTFunction {
5753
5854 func call( arguments: Arguments ) async throws -> Result {
5955 do {
60- let embedding = OpenAIEmbedding (
61- configuration: UserPreferenceEmbeddingConfiguration ( )
62- )
56+ let embedding = OpenAIEmbedding ( configuration: UserPreferenceEmbeddingConfiguration ( ) )
6357
64- let queryEmbeddings = try await embedding. embed ( query: arguments. query)
65- let searchCount = UserDefaults . shared. value ( for: \. chatGPTMaxToken) > 5000 ? 3 : 20
66-
67- let result = try await withThrowingTaskGroup (
68- of: [ ( document: Document , distance: Float ) ] . self
69- ) { group in
58+ let result = try await withThrowingTaskGroup ( of: String . self) { group in
7059 for urlString in arguments. urls {
7160 guard let url = URL ( string: urlString) else { continue }
7261 group. addTask {
73- if let database = await TemporaryUSearch . view ( identifier: urlString) {
74- return try await database. searchWithDistance (
75- embeddings: queryEmbeddings,
76- count: searchCount
77- )
78- }
7962 // 1. grab the website content
8063 await reportProgress ( " Loading \( url) .. " )
81- print ( " == load \( url) " )
64+
65+ if let database = await TemporaryUSearch . view ( identifier: urlString) {
66+ await reportProgress ( " Generating answers.. " )
67+ let qa = RetrievalQAChain ( vectorStore: database, embedding: embedding) {
68+ OpenAIChat (
69+ configuration: UserPreferenceChatGPTConfiguration ( )
70+ . overriding ( . init( temperature: 0 ) ) ,
71+ stream: true
72+ )
73+ }
74+ return try await qa. call ( . init( arguments. query) ) . answer
75+ }
8276 let loader = WebLoader ( urls: [ url] )
8377 let documents = try await loader. load ( )
8478 await reportProgress ( " Processing \( url) .. " )
85- print ( " == loaded \( url) , documents: \( documents. count) " )
8679 // 2. split the content
8780 let splitter = RecursiveCharacterTextSplitter (
8881 chunkSize: 1000 ,
8982 chunkOverlap: 100
9083 )
9184 let splitDocuments = try await splitter. transformDocuments ( documents)
92- print ( " == split \( url) , documents: \( splitDocuments. count) " )
9385 // 3. embedding and store in db
9486 await reportProgress ( " Embedding \( url) .. " )
9587 let embeddedDocuments = try await embedding. embed ( documents: splitDocuments)
96- print ( " == embedded \( url) " )
9788 let database = TemporaryUSearch ( identifier: urlString)
9889 try await database. set ( embeddedDocuments)
99- print ( " == save to database \( url) " )
100- let result = try await database. searchWithDistance (
101- embeddings: queryEmbeddings,
102- count: searchCount
103- )
104- print ( " == result of \( url) : \( result) " )
105- return result
90+ // 4. generate answer
91+ await reportProgress ( " Generating answers.. " )
92+ let qa = RetrievalQAChain ( vectorStore: database, embedding: embedding) {
93+ OpenAIChat (
94+ configuration: UserPreferenceChatGPTConfiguration ( )
95+ . overriding ( . init( temperature: 0 ) ) ,
96+ stream: true
97+ )
98+ }
99+ let result = try await qa. call ( . init( arguments. query) )
100+ return result. answer
106101 }
107102 }
108103
109- var all = [ ( document : Document , distance : Float ) ] ( )
104+ var all = [ String ] ( )
110105 for try await result in group {
111- all. append ( contentsOf : result)
106+ all. append ( result)
112107 }
113- await reportProgress ( " Finish reading websites. " )
108+ await reportProgress ( """
109+ Finish reading websites.
110+ \(
111+ arguments. urls
112+ . map { " - [ \( $0) ]( \( $0) ) " }
113+ . joined ( separator: " \n " )
114+ )
115+ """ )
116+
114117 return all
115- . sorted { $0. distance < $1. distance }
116- . prefix ( searchCount)
117118 }
118119
119- return . init( relevantDocuments : result. map ( \ . document ) )
120+ return . init( answers : result)
120121 } catch {
121122 await reportProgress ( " Failed reading websites. " )
122123 throw error
0 commit comments