@@ -21,44 +21,49 @@ public struct AgentAction: Equatable {
2121}
2222
2323public extension CallbackEvents {
24- struct AgentDidFinish : CallbackEvent {
25- public let info : AgentFinish
24+ struct AgentDidFinish < Output : AgentOutputParsable > : CallbackEvent {
25+ public let info : AgentFinish < Output >
2626 }
27-
28- var agentDidFinish : AgentDidFinish . Type {
29- AgentDidFinish . self
27+
28+ func agentDidFinish< Output : AgentOutputParsable > ( ) -> AgentDidFinish < Output > . Type {
29+ AgentDidFinish< Output > . self
3030 }
3131
3232 struct AgentActionDidStart : CallbackEvent {
3333 public let info : AgentAction
3434 }
35-
35+
3636 var agentActionDidStart : AgentActionDidStart . Type {
3737 AgentActionDidStart . self
3838 }
3939
4040 struct AgentActionDidEnd : CallbackEvent {
4141 public let info : AgentAction
4242 }
43-
43+
4444 var agentActionDidEnd : AgentActionDidEnd . Type {
4545 AgentActionDidEnd . self
4646 }
4747}
4848
49- public struct AgentFinish : Equatable {
50- public var returnValue : String
49+ public struct AgentFinish < Output: AgentOutputParsable > {
50+ public enum ReturnValue {
51+ case success( Output )
52+ case failure( String )
53+ }
54+
55+ public var returnValue : ReturnValue
5156 public var log : String
5257
53- public init ( returnValue: String , log: String ) {
58+ public init ( returnValue: ReturnValue , log: String ) {
5459 self . returnValue = returnValue
5560 self . log = log
5661 }
5762}
5863
59- public enum AgentNextStep : Equatable {
64+ public enum AgentNextStep < Output : AgentOutputParsable > {
6065 case actions( [ AgentAction ] )
61- case finish( AgentFinish )
66+ case finish( AgentFinish < Output > )
6267}
6368
6469public enum AgentScratchPad : Equatable {
@@ -94,15 +99,17 @@ public enum AgentEarlyStopHandleType: Equatable {
9499
95100public protocol Agent {
96101 associatedtype Input
102+ associatedtype Output : AgentOutputParsable
97103 var chatModelChain : ChatModelChain < AgentInput < Input > > { get }
98104 var observationPrefix : String { get }
99105 var llmPrefix : String { get }
100106
101107 func validateTools( tools: [ AgentTool ] ) throws
102108 func constructScratchpad( intermediateSteps: [ AgentAction ] ) -> AgentScratchPad
103109 func extraPlan( input: AgentInput < Input > )
104- func prepareForEarlyStopWithGenerate( )
105- func parseOutput( _ output: ChatModelChain < AgentInput < Input > > . Output ) async -> AgentNextStep
110+ func prepareForEarlyStopWithGenerate( ) -> String
111+ func parseOutput( _ output: ChatModelChain < AgentInput < Input > > . Output ) async
112+ -> AgentNextStep < Output >
106113}
107114
108115public extension Agent {
@@ -115,7 +122,7 @@ public extension Agent {
115122 input: Input ,
116123 intermediateSteps: [ AgentAction ] ,
117124 callbackManagers: [ CallbackManager ]
118- ) async throws -> AgentNextStep {
125+ ) async throws -> AgentNextStep < Output > {
119126 let input = getFullInputs ( input: input, intermediateSteps: intermediateSteps)
120127 extraPlan ( input: input)
121128 let output = try await chatModelChain. call ( input, callbackManagers: callbackManagers)
@@ -127,29 +134,29 @@ public extension Agent {
127134 earlyStoppedHandleType: AgentEarlyStopHandleType ,
128135 intermediateSteps: [ AgentAction ] ,
129136 callbackManagers: [ CallbackManager ]
130- ) async throws -> AgentFinish {
137+ ) async throws -> AgentFinish < Output > {
131138 switch earlyStoppedHandleType {
132139 case . force:
133140 return AgentFinish (
134- returnValue: " Agent stopped due to iteration limit or time limit. " ,
141+ returnValue: . failure ( " Agent stopped due to iteration limit or time limit. " ) ,
135142 log: " "
136143 )
137144 case . generate:
138145 var thoughts = constructBaseScratchpad ( intermediateSteps: intermediateSteps)
139146 thoughts += """
140147
141148 \( llmPrefix) I now need to return a final answer based on the previous steps:
142- (Please continue with `Final Answer:` )
149+ \( prepareForEarlyStopWithGenerate ( ) )
143150 """
144151 let input = AgentInput ( input: input, thoughts: . text( thoughts) )
145- prepareForEarlyStopWithGenerate ( )
152+
146153 let output = try await chatModelChain. call ( input, callbackManagers: callbackManagers)
147154 let nextAction = await parseOutput ( output)
148155 switch nextAction {
149156 case let . finish( finish) :
150157 return finish
151158 case . actions:
152- return AgentFinish ( returnValue: output. content ?? " " , log: output. content ?? " " )
159+ return . init ( returnValue: . failure ( output. content ?? " " ) , log: output. content ?? " " )
153160 }
154161 }
155162 }
0 commit comments