@@ -4,7 +4,7 @@ import OpenAIService
44
55public class FunctionCallingChatAgent < Output: AgentOutputParsable & Decodable > : Agent {
66 public typealias ScratchPadContent = [ ChatMessage ]
7-
7+
88 public struct EndFunction : ChatGPTFunction {
99 public typealias Argument = Output
1010 public typealias Result = String
@@ -79,7 +79,8 @@ public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>:
7979 public init (
8080 configuration: ChatGPTConfiguration = UserPreferenceChatGPTConfiguration ( ) ,
8181 tools: [ AgentTool ] = [ ] ,
82- endFunction: EndFunction
82+ endFunction: EndFunction ,
83+ extraSystemPrompt: String = " "
8384 ) {
8485 let functions = tools. compactMap { $0 as? ( any ChatGPTFunction ) }
8586 let otherTools = tools. filter { !( $0 is ( any ChatGPTFunction ) ) }
@@ -103,13 +104,13 @@ public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>:
103104 . init(
104105 role: . system,
105106 content: """
106- Respond to the human as helpfully and accurately as possible. \
107- Save the final answer when it's ready
108-
109- Begin!
107+ Gather information using functions, and generate a final answer to my query as helpfully and accurately as possible.
108+ You don't ask me for additional information.
109+ \( extraSystemPrompt )
110+ When you have the final answer, you MUST call ` \( endFunction . name ) ` to save it.
110111 """
111112 ) ,
112- . init( role: . user, content: agentInput. input)
113+ . init( role: . user, content: agentInput. input) ,
113114 ] + agentInput. thoughts. content
114115 }
115116 )
@@ -118,49 +119,57 @@ public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>:
118119 public func extraPlan( input: AgentInput < String , ScratchPadContent > ) {
119120 // no extra plan
120121 }
121-
122- public func prepareForEarlyStopWithGenerate( ) -> String {
123- functionProvider. shouldFinish = true
124- return " (call sendFinalAnswer to finish) "
125- }
126-
127- public func constructScratchpad(
122+
123+ func constructBaseScratchpad(
128124 intermediateSteps: [ AgentAction ]
129- ) -> AgentScratchPad < ScratchPadContent > {
130- let baseScratchpad = intermediateSteps. flatMap {
131- [
125+ ) -> ScratchPadContent {
126+ return intermediateSteps. flatMap {
127+ if let observation = $0. observation {
128+ return [
129+ ChatMessage (
130+ role: . assistant,
131+ content: nil ,
132+ functionCall: . init( name: $0. toolName, arguments: $0. toolInput)
133+ ) ,
134+ ChatMessage ( role: . function, content: observation, name: $0. toolName) ,
135+ ]
136+ }
137+ return [
138+ ChatMessage ( role: . assistant, content: $0. toolInput) ,
132139 ChatMessage (
133- role: . assistant,
134- content: nil ,
135- functionCall: . init( name: $0. toolName, arguments: $0. toolInput)
140+ role: . user,
141+ content: " Please continue, call \( functionProvider. endFunction. name) when you are done. "
136142 ) ,
137- ChatMessage ( role: . function, content: $0. observation) ,
138143 ]
139144 }
140- return . init( content: baseScratchpad)
141145 }
142-
143- public func constructFinalScratchpad( intermediateSteps: [ AgentAction ] ) -> AgentScratchPad < ScratchPadContent > {
144- let baseScratchpad = intermediateSteps. flatMap {
145- [
146- ChatMessage (
147- role: . assistant,
148- content: nil ,
149- functionCall: . init( name: $0. toolName, arguments: $0. toolInput)
150- ) ,
151- ChatMessage ( role: . function, content: $0. observation) ,
152- ]
153- }
146+
147+ public func constructScratchpad(
148+ intermediateSteps: [ AgentAction ]
149+ ) -> AgentScratchPad < ScratchPadContent > {
150+ functionProvider. shouldFinish = false
151+ let baseScratchpad = constructBaseScratchpad ( intermediateSteps: intermediateSteps)
154152 return . init( content: baseScratchpad)
155153 }
156154
155+ public func constructFinalScratchpad(
156+ intermediateSteps: [ AgentAction ]
157+ ) -> AgentScratchPad < ScratchPadContent > {
158+ functionProvider. shouldFinish = true
159+ let baseScratchpad = constructBaseScratchpad ( intermediateSteps: intermediateSteps)
160+ return . init( content: baseScratchpad + [
161+ ChatMessage ( role: . assistant, content: " Now I need to save the final answer " ) ,
162+ ChatMessage ( role: . user, content: " Please continue " ) ,
163+ ] )
164+ }
165+
157166 public func validateTools( tools: [ AgentTool ] ) throws {
158167 // no validation
159168 }
160169
161170 public func parseOutput( _ message: ChatMessage ) async -> AgentNextStep < Output > {
162- if message . role == . assistant , let functionCall = message. functionCall {
163- if let function = functionProvider. functionTools . first ( where: {
171+ if let functionCall = message. functionCall {
172+ if let function = functionProvider. functions . first ( where: {
164173 $0. name == functionCall. name
165174 } ) {
166175 if function. name == functionProvider. endFunction. name {
@@ -204,6 +213,8 @@ public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>:
204213 case let . unstructured( x) , let . structured( x) :
205214 return . finish( . init( returnValue: . unstructured( x) , log: finish. log) )
206215 }
216+ case let . thought( content) :
217+ return . finish( . init( returnValue: . unstructured( content) , log: content) )
207218 }
208219 }
209220}
0 commit comments