Skip to content

Commit 477c6fc

Browse files
committed
Update function call base agent
1 parent e187a35 commit 477c6fc

4 files changed

Lines changed: 169 additions & 83 deletions

File tree

Tool/Sources/LangChain/Agent.swift

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,44 +21,49 @@ public struct AgentAction: Equatable {
2121
}
2222

2323
public extension CallbackEvents {
24-
struct AgentDidFinish: CallbackEvent {
25-
public let info: AgentFinish
24+
struct AgentDidFinish<Output: AgentOutputParsable>: CallbackEvent {
25+
public let info: AgentFinish<Output>
2626
}
27-
28-
var agentDidFinish: AgentDidFinish.Type {
29-
AgentDidFinish.self
27+
28+
func agentDidFinish<Output: AgentOutputParsable>() -> AgentDidFinish<Output>.Type {
29+
AgentDidFinish<Output>.self
3030
}
3131

3232
struct AgentActionDidStart: CallbackEvent {
3333
public let info: AgentAction
3434
}
35-
35+
3636
var agentActionDidStart: AgentActionDidStart.Type {
3737
AgentActionDidStart.self
3838
}
3939

4040
struct AgentActionDidEnd: CallbackEvent {
4141
public let info: AgentAction
4242
}
43-
43+
4444
var agentActionDidEnd: AgentActionDidEnd.Type {
4545
AgentActionDidEnd.self
4646
}
4747
}
4848

49-
public struct AgentFinish: Equatable {
50-
public var returnValue: String
49+
public struct AgentFinish<Output: AgentOutputParsable> {
50+
public enum ReturnValue {
51+
case success(Output)
52+
case failure(String)
53+
}
54+
55+
public var returnValue: ReturnValue
5156
public var log: String
5257

53-
public init(returnValue: String, log: String) {
58+
public init(returnValue: ReturnValue, log: String) {
5459
self.returnValue = returnValue
5560
self.log = log
5661
}
5762
}
5863

59-
public enum AgentNextStep: Equatable {
64+
public enum AgentNextStep<Output: AgentOutputParsable> {
6065
case actions([AgentAction])
61-
case finish(AgentFinish)
66+
case finish(AgentFinish<Output>)
6267
}
6368

6469
public enum AgentScratchPad: Equatable {
@@ -94,15 +99,17 @@ public enum AgentEarlyStopHandleType: Equatable {
9499

95100
public protocol Agent {
96101
associatedtype Input
102+
associatedtype Output: AgentOutputParsable
97103
var chatModelChain: ChatModelChain<AgentInput<Input>> { get }
98104
var observationPrefix: String { get }
99105
var llmPrefix: String { get }
100106

101107
func validateTools(tools: [AgentTool]) throws
102108
func constructScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad
103109
func extraPlan(input: AgentInput<Input>)
104-
func prepareForEarlyStopWithGenerate()
105-
func parseOutput(_ output: ChatModelChain<AgentInput<Input>>.Output) async -> AgentNextStep
110+
func prepareForEarlyStopWithGenerate() -> String
111+
func parseOutput(_ output: ChatModelChain<AgentInput<Input>>.Output) async
112+
-> AgentNextStep<Output>
106113
}
107114

108115
public extension Agent {
@@ -115,7 +122,7 @@ public extension Agent {
115122
input: Input,
116123
intermediateSteps: [AgentAction],
117124
callbackManagers: [CallbackManager]
118-
) async throws -> AgentNextStep {
125+
) async throws -> AgentNextStep<Output> {
119126
let input = getFullInputs(input: input, intermediateSteps: intermediateSteps)
120127
extraPlan(input: input)
121128
let output = try await chatModelChain.call(input, callbackManagers: callbackManagers)
@@ -127,29 +134,29 @@ public extension Agent {
127134
earlyStoppedHandleType: AgentEarlyStopHandleType,
128135
intermediateSteps: [AgentAction],
129136
callbackManagers: [CallbackManager]
130-
) async throws -> AgentFinish {
137+
) async throws -> AgentFinish<Output> {
131138
switch earlyStoppedHandleType {
132139
case .force:
133140
return AgentFinish(
134-
returnValue: "Agent stopped due to iteration limit or time limit.",
141+
returnValue: .failure("Agent stopped due to iteration limit or time limit."),
135142
log: ""
136143
)
137144
case .generate:
138145
var thoughts = constructBaseScratchpad(intermediateSteps: intermediateSteps)
139146
thoughts += """
140147
141148
\(llmPrefix)I now need to return a final answer based on the previous steps:
142-
(Please continue with `Final Answer:`)
149+
\(prepareForEarlyStopWithGenerate())
143150
"""
144151
let input = AgentInput(input: input, thoughts: .text(thoughts))
145-
prepareForEarlyStopWithGenerate()
152+
146153
let output = try await chatModelChain.call(input, callbackManagers: callbackManagers)
147154
let nextAction = await parseOutput(output)
148155
switch nextAction {
149156
case let .finish(finish):
150157
return finish
151158
case .actions:
152-
return AgentFinish(returnValue: output.content ?? "", log: output.content ?? "")
159+
return .init(returnValue: .failure(output.content ?? ""), log: output.content ?? "")
153160
}
154161
}
155162
}

Tool/Sources/LangChain/AgentExecutor.swift

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
11
import Foundation
22

3-
public actor AgentExecutor<InnerAgent: Agent>: Chain where InnerAgent.Input == String {
3+
public protocol AgentOutputParsable {
4+
static func parse(_ string: String) throws -> Self
5+
var botReadableContent: String { get }
6+
}
7+
8+
extension String: AgentOutputParsable {
9+
public static func parse(_ string: String) throws -> String { string }
10+
public var botReadableContent: String { self }
11+
}
12+
13+
public actor AgentExecutor<InnerAgent: Agent>: Chain
14+
where InnerAgent.Input == String, InnerAgent.Output: AgentOutputParsable
15+
{
416
public typealias Input = String
517
public struct Output {
6-
let finalOutput: String
18+
typealias FinalOutput = AgentFinish<InnerAgent.Output>.ReturnValue
19+
20+
let finalOutput: FinalOutput
721
let intermediateSteps: [AgentAction]
822
}
923

@@ -96,7 +110,10 @@ public actor AgentExecutor<InnerAgent: Agent>: Chain where InnerAgent.Input == S
96110
}
97111

98112
public nonisolated func parseOutput(_ output: Output) -> String {
99-
output.finalOutput
113+
switch output.finalOutput {
114+
case let .failure(error): return error
115+
case let .success(output): return output.botReadableContent
116+
}
100117
}
101118

102119
public func cancel() {
@@ -109,7 +126,7 @@ struct InvalidToolError: Error {}
109126

110127
extension AgentExecutor {
111128
func end(
112-
output: AgentFinish,
129+
output: AgentFinish<InnerAgent.Output>,
113130
intermediateSteps: [AgentAction],
114131
callbackManagers: [CallbackManager]
115132
) -> Output {
@@ -120,18 +137,21 @@ extension AgentExecutor {
120137
return .init(finalOutput: finalOutput, intermediateSteps: intermediateSteps)
121138
}
122139

140+
/// Plan the scratch pad and let the agent decide what to do next
123141
func takeNextStep(
124142
input: Input,
125143
intermediateSteps: [AgentAction],
126144
callbackManagers: [CallbackManager]
127-
) async throws -> AgentNextStep {
145+
) async throws -> AgentNextStep<InnerAgent.Output> {
128146
let output = try await agent.plan(
129147
input: input,
130148
intermediateSteps: intermediateSteps,
131149
callbackManagers: callbackManagers
132150
)
133151
switch output {
152+
// If the output says finish, then return the output immediately.
134153
case .finish: return output
154+
// If the output contains actions, run them, and append the results to the scratch pad.
135155
case let .actions(actions):
136156
let completedActions = try await withThrowingTaskGroup(of: AgentAction.self) {
137157
taskGroup in
@@ -157,10 +177,19 @@ extension AgentExecutor {
157177
}
158178
}
159179

160-
func getToolFinish(action: AgentAction) -> AgentFinish? {
180+
func getToolFinish(action: AgentAction) -> AgentFinish<InnerAgent.Output>? {
161181
guard let tool = tools[action.toolName] else { return nil }
162182
guard tool.returnDirectly else { return nil }
163-
return .init(returnValue: action.observation ?? "", log: "")
183+
184+
do {
185+
let result = try InnerAgent.Output.parse(action.observation ?? "")
186+
return .init(returnValue: .success(result), log: action.observation ?? "")
187+
} catch {
188+
return .init(
189+
returnValue: .failure(action.observation ?? "no observation"),
190+
log: action.observation ?? ""
191+
)
192+
}
164193
}
165194
}
166195

Tool/Sources/LangChain/Agents/ChatAgent.swift

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ private func formatInstruction(toolsNames: String, preferredLanguage: String) ->
3535

3636
public class ChatAgent: Agent {
3737
public typealias Input = String
38+
public typealias Output = String
3839
public var observationPrefix: String { "Observation: " }
3940
public var llmPrefix: String { "Thought: " }
4041
public let chatModelChain: ChatModelChain<AgentInput<String>>
@@ -111,28 +112,28 @@ public class ChatAgent: Agent {
111112
// do nothing
112113
}
113114

114-
public func prepareForEarlyStopWithGenerate() {
115-
// do nothing
115+
public func prepareForEarlyStopWithGenerate() -> String {
116+
"(Please continue with `Final Answer:`)"
116117
}
117118

118-
public func parseOutput(_ output: ChatMessage) async -> AgentNextStep {
119+
public func parseOutput(_ output: ChatMessage) async -> AgentNextStep<Output> {
119120
let text = output.content ?? ""
120121

121-
func parseFinalAnswerIfPossible() -> AgentNextStep? {
122+
func parseFinalAnswerIfPossible() -> AgentNextStep<Output>? {
122123
let throughAnswerParser = PrefixThrough("Final Answer:")
123124
var parsableContent = text[...]
124125
do {
125126
_ = try throughAnswerParser.parse(&parsableContent)
126127
let answer = String(parsableContent)
127128
let output = answer.trimmingCharacters(in: .whitespacesAndNewlines)
128-
return .finish(AgentFinish(returnValue: output, log: text))
129+
return .finish(AgentFinish(returnValue: .success(output), log: text))
129130
} catch {
130131
Logger.langchain.info("Could not parse LLM output final answer: \(error)")
131132
return nil
132133
}
133134
}
134135

135-
func parseNextActionIfPossible() -> AgentNextStep? {
136+
func parseNextActionIfPossible() -> AgentNextStep<Output>? {
136137
let throughActionBlockParser = PrefixThrough("""
137138
Action:
138139
```
@@ -179,7 +180,7 @@ public class ChatAgent: Agent {
179180
answer = "Sorry, I don't know."
180181
}
181182

182-
return .finish(AgentFinish(returnValue: String(answer), log: text))
183+
return .finish(AgentFinish(returnValue: .success(String(answer)), log: text))
183184
}
184185
}
185186

0 commit comments

Comments
 (0)