@@ -43,7 +43,7 @@ public final class RetrievalQAChain: Chain {
4343
4444public extension CallbackEvents {
4545 struct RetrievalQADidGenerateIntermediateAnswer : CallbackEvent {
46- public let info : String
46+ public let info : RefineDocumentChain . IntermediateAnswer
4747 }
4848}
4949
@@ -66,16 +66,38 @@ public final class RefineDocumentChain: Chain {
6666 var distance : Float
6767 }
6868
69+ public struct IntermediateAnswer : Decodable {
70+ public var answer : String
71+ public var score : Double
72+ public var more : Bool
73+
74+ public enum CodingKeys : String , CodingKey {
75+ case answer
76+ case score
77+ case more
78+ }
79+
80+ init ( answer: String , score: Double , more: Bool ) {
81+ self . answer = answer
82+ self . score = score
83+ self . more = more
84+ }
85+
86+ public init ( from decoder: Decoder ) throws {
87+ let container = try decoder. container ( keyedBy: CodingKeys . self)
88+ answer = try container. decode ( String . self, forKey: . answer)
89+ score = ( try ? container. decode ( Double . self, forKey: . score) ) ?? 0
90+ more = ( try ? container. decode ( Bool . self, forKey: . more) ) ?? ( score < 6 )
91+ }
92+ }
93+
6994 class FunctionProvider : ChatGPTFunctionProvider {
70- var functions : [ any ChatGPTFunction ] = [ ]
95+ var functionCallStrategy : FunctionCallStrategy ? = . name( " respond " )
96+ var functions : [ any ChatGPTFunction ] = [ RespondFunction ( ) ]
7197 }
7298
7399 struct RespondFunction : ChatGPTFunction {
74- struct Arguments : Codable {
75- var answer : String
76- var score : Double
77- var more : Bool
78- }
100+ typealias Arguments = IntermediateAnswer
79101
80102 struct Result : ChatGPTFunctionResult {
81103 var botReadableContent : String { " " }
@@ -91,17 +113,18 @@ public final class RefineDocumentChain: Chain {
91113 . properties: [
92114 " answer " : [
93115 . type: " string " ,
94- . description: " The answer " ,
116+ . description: " The refined answer " ,
95117 ] ,
96118 " score " : [
97119 . type: " number " ,
98- . description: " The score of the answer, the higher the better " ,
120+ . description: " The score of the answer, the higher the better. 0 to 10. " ,
99121 ] ,
100122 " more " : [
101123 . type: " boolean " ,
102124 . description: " Whether more information is needed to complete the answer " ,
103125 ] ,
104126 ] ,
127+ . required: [ " answer " , " score " , " more " ] ,
105128 ]
106129 }
107130
@@ -114,18 +137,16 @@ public final class RefineDocumentChain: Chain {
114137
115138 let initialChatModel : ChatModelChain < InitialInput >
116139 let refinementChatModel : ChatModelChain < RefinementInput >
117- let initialChatMemory : ChatGPTMemory
118- let refinementChatMemory : ChatGPTMemory
119140
120141 public init ( ) {
121- initialChatMemory = ConversationChatGPTMemory ( systemPrompt: " " )
122- refinementChatMemory = ConversationChatGPTMemory ( systemPrompt: " " )
123-
124142 initialChatModel = . init(
125143 chatModel: OpenAIChat (
126- configuration: UserPreferenceChatGPTConfiguration ( )
127- . overriding ( . init( temperature: 0 ) ) ,
128- memory: initialChatMemory,
144+ configuration: UserPreferenceChatGPTConfiguration ( ) . overriding {
145+ $0. temperature = 0
146+ $0. runFunctionsAutomatically = false
147+ } ,
148+ memory: EmptyChatGPTMemory ( ) ,
149+ functionProvider: FunctionProvider ( ) ,
129150 stream: false
130151 ) ,
131152 promptTemplate: { input in [
@@ -140,9 +161,12 @@ public final class RefineDocumentChain: Chain {
140161 )
141162 refinementChatModel = . init(
142163 chatModel: OpenAIChat (
143- configuration: UserPreferenceChatGPTConfiguration ( )
144- . overriding ( . init( temperature: 0 ) ) ,
145- memory: refinementChatMemory,
164+ configuration: UserPreferenceChatGPTConfiguration ( ) . overriding {
165+ $0. temperature = 0
166+ $0. runFunctionsAutomatically = false
167+ } ,
168+ memory: EmptyChatGPTMemory ( ) ,
169+ functionProvider: FunctionProvider ( ) ,
146170 stream: false
147171 ) ,
148172 promptTemplate: { input in [
@@ -168,6 +192,26 @@ public final class RefineDocumentChain: Chain {
168192 guard let firstDocument = input. documents. first else {
169193 return " "
170194 }
195+
196+ func extractAnswer( _ chatMessage: ChatMessage ) -> IntermediateAnswer {
197+ if let functionCall = chatMessage. functionCall {
198+ do {
199+ let intermediateAnswer = try JSONDecoder ( ) . decode (
200+ IntermediateAnswer . self,
201+ from: functionCall. arguments. data ( using: . utf8) ?? Data ( )
202+ )
203+ return intermediateAnswer
204+ } catch {
205+ let intermediateAnswer = IntermediateAnswer (
206+ answer: functionCall. arguments,
207+ score: 0 ,
208+ more: true
209+ )
210+ return intermediateAnswer
211+ }
212+ }
213+ return . init( answer: chatMessage. content ?? " " , score: 0 , more: true )
214+ }
171215 var output = try await initialChatModel. call (
172216 . init(
173217 question: input. question,
@@ -176,24 +220,27 @@ public final class RefineDocumentChain: Chain {
176220 ) ,
177221 callbackManagers: callbackManagers
178222 )
179- guard var content = output. content else { return " " }
180- callbackManagers
181- . send ( CallbackEvents . RetrievalQADidGenerateIntermediateAnswer ( info: content) )
182- for document in input. documents. dropFirst ( 1 ) {
223+ var intermediateAnswer = extractAnswer ( output)
224+ callbackManagers. send (
225+ CallbackEvents . RetrievalQADidGenerateIntermediateAnswer ( info: intermediateAnswer)
226+ )
227+
228+ for document in input. documents. dropFirst ( 1 ) where intermediateAnswer. more {
183229 output = try await refinementChatModel. call (
184230 . init(
185231 question: input. question,
186- previousAnswer: content ,
232+ previousAnswer: intermediateAnswer . answer ,
187233 document: document. document. pageContent,
188234 distance: document. distance
189235 ) ,
190236 callbackManagers: callbackManagers
191237 )
192- content = output. content ?? " "
193- callbackManagers
194- . send ( CallbackEvents . RetrievalQADidGenerateIntermediateAnswer ( info: content) )
238+ intermediateAnswer = extractAnswer ( output)
239+ callbackManagers. send (
240+ CallbackEvents . RetrievalQADidGenerateIntermediateAnswer ( info: intermediateAnswer)
241+ )
195242 }
196- return content
243+ return intermediateAnswer . answer
197244 }
198245
199246 public func parseOutput( _ output: String ) -> String {
0 commit comments