|
1 | 1 | import Foundation |
2 | 2 |
|
3 | | -final class RetrievalQA: Chain { |
| 3 | +public final class RetrievalQAChain: Chain { |
4 | 4 | let vectorStore: VectorStore |
5 | 5 | let embedding: Embeddings |
| 6 | + let chatModelFactory: () -> ChatModel |
6 | 7 |
|
7 | | - struct Output { |
| 8 | + public struct Output { |
8 | 9 | var answer: String |
9 | 10 | var sourceDocuments: [Document] |
10 | 11 | } |
11 | 12 |
|
12 | | - init(vectorStore: VectorStore, embedding: Embeddings) { |
| 13 | + public init( |
| 14 | + vectorStore: VectorStore, |
| 15 | + embedding: Embeddings, |
| 16 | + chatModelFactory: @escaping () -> ChatModel |
| 17 | + ) { |
13 | 18 | self.vectorStore = vectorStore |
14 | 19 | self.embedding = embedding |
| 20 | + self.chatModelFactory = chatModelFactory |
15 | 21 | } |
16 | 22 |
|
17 | | - func callLogic( |
| 23 | + public func callLogic( |
18 | 24 | _ input: String, |
19 | | - callbackManagers: [ChainCallbackManager] |
| 25 | + callbackManagers: [CallbackManager] |
20 | 26 | ) async throws -> Output { |
21 | | - let embeddedQuestion = try awa |
22 | | - |
23 | | - return .init(answer: "", sourceDocuments: []) |
| 27 | + let embeddedQuestion = try await embedding.embed(query: input) |
| 28 | + let documents = try await vectorStore.searchWithDistance( |
| 29 | + embeddings: embeddedQuestion, |
| 30 | + count: 10 |
| 31 | + ) |
| 32 | + let refinementChain = RefineDocumentChain(chatModelFactory: chatModelFactory) |
| 33 | + let answer = try await refinementChain.run(.init(question: input, documents: documents)) |
| 34 | + |
| 35 | + return .init(answer: answer, sourceDocuments: documents.map(\.document)) |
24 | 36 | } |
25 | 37 |
|
26 | | - func parseOutput(_ output: Output) -> String { |
| 38 | + public func parseOutput(_ output: Output) -> String { |
27 | 39 | return output.answer |
28 | 40 | } |
29 | 41 | } |
30 | 42 |
|
| 43 | +public final class RefineDocumentChain: Chain { |
| 44 | + public struct Input { |
| 45 | + var question: String |
| 46 | + var documents: [(document: Document, distance: Float)] |
| 47 | + } |
| 48 | + |
| 49 | + struct InitialInput { |
| 50 | + var question: String |
| 51 | + var document: String |
| 52 | + var distance: Float |
| 53 | + } |
| 54 | + |
| 55 | + struct RefinementInput { |
| 56 | + var question: String |
| 57 | + var previousAnswer: String |
| 58 | + var document: String |
| 59 | + var distance: Float |
| 60 | + } |
| 61 | + |
| 62 | + let initialChatModel: ChatModelChain<InitialInput> |
| 63 | + let refinementChatModel: ChatModelChain<RefinementInput> |
| 64 | + |
| 65 | + public init(chatModelFactory: () -> ChatModel) { |
| 66 | + initialChatModel = .init( |
| 67 | + chatModel: chatModelFactory(), |
| 68 | + promptTemplate: { input in [ |
| 69 | + .init(role: .system, content: """ |
| 70 | + The user will send you a question, you must answer it at your best. |
| 71 | + You can use the following document as a reference:### |
| 72 | + \(input.document) |
| 73 | + ### |
| 74 | + """), |
| 75 | + .init(role: .user, content: input.question), |
| 76 | + ] } |
| 77 | + ) |
| 78 | + refinementChatModel = .init( |
| 79 | + chatModel: chatModelFactory(), |
| 80 | + promptTemplate: { input in [ |
| 81 | + .init(role: .system, content: """ |
| 82 | + The user will send you a question, you must update your previous answer to it at your best. |
| 83 | + Previous answer:### |
| 84 | + \(input.previousAnswer) |
| 85 | + ### |
| 86 | + You can use the following document as a reference:### |
| 87 | + \(input.document) |
| 88 | + ### |
| 89 | + """), |
| 90 | + .init(role: .user, content: input.question), |
| 91 | + ] } |
| 92 | + ) |
| 93 | + } |
| 94 | + |
| 95 | + public func callLogic( |
| 96 | + _ input: Input, |
| 97 | + callbackManagers: [CallbackManager] |
| 98 | + ) async throws -> String { |
| 99 | + guard let firstDocument = input.documents.first else { |
| 100 | + return "" |
| 101 | + } |
| 102 | + var output = try await initialChatModel.call( |
| 103 | + .init( |
| 104 | + question: input.question, |
| 105 | + document: firstDocument.document.pageContent, |
| 106 | + distance: firstDocument.distance |
| 107 | + ), |
| 108 | + callbackManagers: callbackManagers |
| 109 | + ) |
| 110 | + for document in input.documents.dropFirst(1) { |
| 111 | + output = try await refinementChatModel.call( |
| 112 | + .init( |
| 113 | + question: input.question, |
| 114 | + previousAnswer: output, |
| 115 | + document: document.document.pageContent, |
| 116 | + distance: document.distance |
| 117 | + ), |
| 118 | + callbackManagers: callbackManagers |
| 119 | + ) |
| 120 | + } |
| 121 | + return output |
| 122 | + } |
| 123 | + |
| 124 | + public func parseOutput(_ output: String) -> String { |
| 125 | + return output |
| 126 | + } |
| 127 | +} |
| 128 | + |
0 commit comments