11import Foundation
2+ import OpenAIService
23
34public final class RetrievalQAChain : Chain {
45 let vectorStore : VectorStore
56 let embedding : Embeddings
6- let chatModelFactory : ( ) -> ChatModel
77
88 public struct Output {
99 public var answer : String
@@ -12,12 +12,10 @@ public final class RetrievalQAChain: Chain {
1212
1313 public init (
1414 vectorStore: VectorStore ,
15- embedding: Embeddings ,
16- chatModelFactory: @escaping ( ) -> ChatModel
15+ embedding: Embeddings
1716 ) {
1817 self . vectorStore = vectorStore
1918 self . embedding = embedding
20- self . chatModelFactory = chatModelFactory
2119 }
2220
2321 public func callLogic(
@@ -29,7 +27,7 @@ public final class RetrievalQAChain: Chain {
2927 embeddings: embeddedQuestion,
3028 count: 5
3129 )
32- let refinementChain = RefineDocumentChain ( chatModelFactory : chatModelFactory )
30+ let refinementChain = RefineDocumentChain ( )
3331 let answer = try await refinementChain. run (
3432 . init( question: input, documents: documents) ,
3533 callbackManagers: callbackManagers
@@ -68,12 +66,68 @@ public final class RefineDocumentChain: Chain {
6866 var distance : Float
6967 }
7068
69+ class FunctionProvider : ChatGPTFunctionProvider {
70+ var functions : [ any ChatGPTFunction ] = [ ]
71+ }
72+
73+ struct RespondFunction : ChatGPTFunction {
74+ struct Arguments : Codable {
75+ var answer : String
76+ var score : Double
77+ var more : Bool
78+ }
79+
80+ struct Result : ChatGPTFunctionResult {
81+ var botReadableContent : String { " " }
82+ }
83+
84+ var reportProgress : ( String ) async -> Void = { _ in }
85+
86+ var name : String = " respond "
87+ var description : String = " Respond with the refined answer "
88+ var argumentSchema : JSONSchemaValue {
89+ return [
90+ . type: " object " ,
91+ . properties: [
92+ " answer " : [
93+ . type: " string " ,
94+ . description: " The answer " ,
95+ ] ,
96+ " score " : [
97+ . type: " number " ,
98+ . description: " The score of the answer, the higher the better " ,
99+ ] ,
100+ " more " : [
101+ . type: " boolean " ,
102+ . description: " Whether more information is needed to complete the answer " ,
103+ ] ,
104+ ] ,
105+ ]
106+ }
107+
108+ func prepare( ) async { }
109+
110+ func call( arguments: Arguments ) async throws -> Result {
111+ return Result ( )
112+ }
113+ }
114+
71115 let initialChatModel : ChatModelChain < InitialInput >
72116 let refinementChatModel : ChatModelChain < RefinementInput >
117+ let initialChatMemory : ChatGPTMemory
118+ let refinementChatMemory : ChatGPTMemory
119+
120+ public init ( ) {
121+ initialChatMemory = ConversationChatGPTMemory ( systemPrompt: " " )
122+ refinementChatMemory = ConversationChatGPTMemory ( systemPrompt: " " )
73123
74- public init ( chatModelFactory: ( ) -> ChatModel ) {
75124 initialChatModel = . init(
76- chatModel: chatModelFactory ( ) ,
125+ chatModel: OpenAIChat (
126+ configuration: UserPreferenceChatGPTConfiguration ( )
127+ . overriding ( . init( temperature: 0 ) ) ,
128+ memory: initialChatMemory,
129+ stream: false
130+ ) ,
77131 promptTemplate: { input in [
78132 . init( role: . system, content: """
79133 The user will send you a question, you must answer it at your best.
@@ -85,7 +139,12 @@ public final class RefineDocumentChain: Chain {
85139 ] }
86140 )
87141 refinementChatModel = . init(
88- chatModel: chatModelFactory ( ) ,
142+ chatModel: OpenAIChat (
143+ configuration: UserPreferenceChatGPTConfiguration ( )
144+ . overriding ( . init( temperature: 0 ) ) ,
145+ memory: refinementChatMemory,
146+ stream: false
147+ ) ,
89148 promptTemplate: { input in [
90149 . init( role: . system, content: """
91150 The user will send you a question, you must refine your previous answer to it at your best.
@@ -117,7 +176,9 @@ public final class RefineDocumentChain: Chain {
117176 ) ,
118177 callbackManagers: callbackManagers
119178 )
120- callbackManagers. send ( CallbackEvents . RetrievalQADidGenerateIntermediateAnswer ( info: output) )
179+ guard var content = output. content else { return " " }
180+ callbackManagers
181+ . send ( CallbackEvents . RetrievalQADidGenerateIntermediateAnswer ( info: content) )
121182 for document in input. documents. dropFirst ( 1 ) {
122183 output = try await refinementChatModel. call (
123184 . init(
0 commit comments