Skip to content

Commit 59e9a04

Browse files
committed
Make AgentScratchPad generic
1 parent 57925ce commit 59e9a04

5 files changed

Lines changed: 118 additions & 109 deletions

File tree

Tool/Sources/LangChain/Agent.swift

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -66,25 +66,15 @@ public enum AgentNextStep<Output: AgentOutputParsable> {
6666
case finish(AgentFinish<Output>)
6767
}
6868

69-
public enum AgentScratchPad: Equatable {
70-
case text(String)
71-
case messages([String])
72-
73-
var isEmpty: Bool {
74-
switch self {
75-
case let .text(text):
76-
return text.isEmpty
77-
case let .messages(messages):
78-
return messages.isEmpty
79-
}
80-
}
69+
public struct AgentScratchPad<Content: Equatable>: Equatable {
70+
var content: Content
8171
}
8272

83-
public struct AgentInput<T> {
73+
public struct AgentInput<T, ScratchPadContent: Equatable> {
8474
var input: T
85-
var thoughts: AgentScratchPad
75+
var thoughts: AgentScratchPad<ScratchPadContent>
8676

87-
public init(input: T, thoughts: AgentScratchPad) {
77+
public init(input: T, thoughts: AgentScratchPad<ScratchPadContent>) {
8878
self.input = input
8979
self.thoughts = thoughts
9080
}
@@ -100,20 +90,24 @@ public enum AgentEarlyStopHandleType: Equatable {
10090
public protocol Agent {
10191
associatedtype Input
10292
associatedtype Output: AgentOutputParsable
103-
var chatModelChain: ChatModelChain<AgentInput<Input>> { get }
93+
associatedtype ScratchPadContent: Equatable
94+
var chatModelChain: ChatModelChain<AgentInput<Input, ScratchPadContent>> { get }
10495
var observationPrefix: String { get }
10596
var llmPrefix: String { get }
10697

10798
func validateTools(tools: [AgentTool]) throws
108-
func constructScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad
109-
func extraPlan(input: AgentInput<Input>)
110-
func prepareForEarlyStopWithGenerate() -> String
111-
func parseOutput(_ output: ChatModelChain<AgentInput<Input>>.Output) async
99+
func constructScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad<ScratchPadContent>
100+
func constructFinalScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad<ScratchPadContent>
101+
func extraPlan(input: AgentInput<Input, ScratchPadContent>)
102+
func parseOutput(_ output: ChatModelChain<AgentInput<Input, ScratchPadContent>>.Output) async
112103
-> AgentNextStep<Output>
113104
}
114105

115106
public extension Agent {
116-
func getFullInputs(input: Input, intermediateSteps: [AgentAction]) -> AgentInput<Input> {
107+
func getFullInputs(
108+
input: Input,
109+
intermediateSteps: [AgentAction]
110+
) -> AgentInput<Input, ScratchPadContent> {
117111
let thoughts = constructScratchpad(intermediateSteps: intermediateSteps)
118112
return AgentInput(input: input, thoughts: thoughts)
119113
}
@@ -142,14 +136,8 @@ public extension Agent {
142136
log: ""
143137
)
144138
case .generate:
145-
var thoughts = constructBaseScratchpad(intermediateSteps: intermediateSteps)
146-
thoughts += """
147-
148-
\(llmPrefix)I now need to return a final answer based on the previous steps:
149-
\(prepareForEarlyStopWithGenerate())
150-
"""
151-
let input = AgentInput(input: input, thoughts: .text(thoughts))
152-
139+
let thoughts = constructFinalScratchpad(intermediateSteps: intermediateSteps)
140+
let input = AgentInput(input: input, thoughts: thoughts)
153141
let output = try await chatModelChain.call(input, callbackManagers: callbackManagers)
154142
let nextAction = await parseOutput(output)
155143
switch nextAction {
@@ -160,16 +148,4 @@ public extension Agent {
160148
}
161149
}
162150
}
163-
164-
func constructBaseScratchpad(intermediateSteps: [AgentAction]) -> String {
165-
var thoughts = ""
166-
for step in intermediateSteps {
167-
thoughts += """
168-
\(step.log)
169-
\(observationPrefix)\(step.observation ?? "")
170-
"""
171-
}
172-
return thoughts
173-
}
174151
}
175-

Tool/Sources/LangChain/AgentTool.swift

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,37 @@ public struct SimpleAgentTool: AgentTool {
3131
}
3232
}
3333

34-
public struct FunctionCallingAgentTool: AgentTool {
35-
public let function: any ChatGPTFunction
36-
public var name: String { function.name }
37-
public var description: String { function.description }
38-
public let returnDirectly: Bool
39-
40-
public init(function: any ChatGPTFunction, returnDirectly: Bool = false) {
34+
public struct FunctionCallingAgentTool<F: ChatGPTFunction>: AgentTool, ChatGPTFunction {
35+
public func call(arguments: F.Arguments) async throws -> F.Result {
36+
try await function.call(arguments: arguments)
37+
}
38+
39+
public var argumentSchema: OpenAIService.JSONSchemaValue { function.argumentSchema }
40+
41+
public func prepare() async { await function.prepare() }
42+
43+
public var reportProgress: (String) async -> Void {
44+
get { function.reportProgress }
45+
set { function.reportProgress = newValue }
46+
}
47+
48+
public typealias Arguments = F.Arguments
49+
public typealias Result = F.Result
50+
51+
public var function: F
52+
public var name: String
53+
public var description: String
54+
public var returnDirectly: Bool
55+
56+
public init(function: F, returnDirectly: Bool = false) {
4157
self.function = function
58+
name = function.name
59+
description = function.description
4260
self.returnDirectly = returnDirectly
4361
}
44-
62+
4563
public func run(input: String) async throws -> String {
4664
try await function.call(argumentsJsonString: input).botReadableContent
4765
}
4866
}
67+

Tool/Sources/LangChain/Agents/ChatAgent.swift

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ private func formatInstruction(toolsNames: String, preferredLanguage: String) ->
3636
public class ChatAgent: Agent {
3737
public typealias Input = String
3838
public typealias Output = String
39+
public typealias ScratchPadContent = String
3940
public var observationPrefix: String { "Observation: " }
4041
public var llmPrefix: String { "Thought: " }
41-
public let chatModelChain: ChatModelChain<AgentInput<String>>
42+
public let chatModelChain: ChatModelChain<AgentInput<String, String>>
4243
let tools: [AgentTool]
4344

4445
public init(chatModel: ChatModel, tools: [AgentTool], preferredLanguage: String) {
@@ -68,54 +69,61 @@ public class ChatAgent: Agent {
6869
Begin! Reminder to always use the exact characters `Final Answer` when responding.
6970
"""
7071
),
71-
agentInput.thoughts.isEmpty
72+
agentInput.thoughts.content.isEmpty
7273
? .init(role: .user, content: agentInput.input)
7374
: .init(
7475
role: .user,
7576
content: """
7677
\(agentInput.input)
7778
78-
\({
79-
switch agentInput.thoughts {
80-
case let .text(text):
81-
return text
82-
case let .messages(messages):
83-
return messages.map { message in
84-
"""
85-
\(message)
86-
"""
87-
}.joined(separator: "\n")
88-
}
89-
}())
79+
\(agentInput.thoughts.content)
9080
"""
9181
),
9282
]
9383
}
9484
)
9585
}
86+
87+
func constructBaseScratchpad(intermediateSteps: [AgentAction]) -> String {
88+
var thoughts = ""
89+
for step in intermediateSteps {
90+
thoughts += """
91+
\(step.log)
92+
\(observationPrefix)\(step.observation ?? "")
93+
"""
94+
}
95+
return thoughts
96+
}
9697

97-
public func constructScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad {
98+
public func constructScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad<String> {
9899
let baseScratchpad = constructBaseScratchpad(intermediateSteps: intermediateSteps)
99-
if baseScratchpad.isEmpty { return .text("") }
100-
return .text("""
100+
if baseScratchpad.isEmpty { return .init(content: "") }
101+
return .init(content: """
101102
This was your previous work (but I haven't seen any of it! I only see what you return as `Final Answer`):
102103
\(baseScratchpad)
103104
(Please continue with `Thought:` or `Final Answer:`)
104105
""")
105106
}
107+
108+
public func constructFinalScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad<String> {
109+
let baseScratchpad = constructBaseScratchpad(intermediateSteps: intermediateSteps)
110+
if baseScratchpad.isEmpty { return .init(content: "") }
111+
return .init(content: """
112+
This was your previous work (but I haven't seen any of it! I only see what you return as `Final Answer`):
113+
\(baseScratchpad)
114+
\(llmPrefix)I now need to return a final answer based on the previous steps:
115+
"(Please continue with `Final Answer:`)"
116+
""")
117+
}
106118

107119
public func validateTools(tools: [AgentTool]) throws {
108120
// no validation
109121
}
110122

111-
public func extraPlan(input: AgentInput<String>) {
123+
public func extraPlan(input: AgentInput<String, String>) {
112124
// do nothing
113125
}
114126

115-
public func prepareForEarlyStopWithGenerate() -> String {
116-
"(Please continue with `Final Answer:`)"
117-
}
118-
119127
public func parseOutput(_ output: ChatMessage) async -> AgentNextStep<Output> {
120128
let text = output.content ?? ""
121129

Tool/Sources/LangChain/Agents/FunctionCallingChatAgent.swift

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ import Logger
33
import OpenAIService
44

55
public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>: Agent {
6+
public typealias ScratchPadContent = [ChatMessage]
7+
68
public struct EndFunction: ChatGPTFunction {
79
public typealias Argument = Output
810
public typealias Result = String
9-
public var name: String { "sendFinalAnswer" }
10-
public var description: String { "Send the final answer to user" }
11+
public var name: String { "saveFinalAnswer" }
12+
public var description: String { "Save the final answer when it's ready" }
1113
public let argumentSchema: JSONSchemaValue
1214
public var reportProgress: (String) async -> Void = { _ in }
1315
public func prepare() async {}
@@ -71,16 +73,16 @@ public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>:
7173
public typealias Input = String
7274
public var observationPrefix: String { "Observation: " }
7375
public var llmPrefix: String { "Thought: " }
74-
public let chatModelChain: ChatModelChain<AgentInput<String>>
76+
public let chatModelChain: ChatModelChain<AgentInput<String, ScratchPadContent>>
7577
var functionProvider: FunctionProvider
7678

7779
public init(
7880
configuration: ChatGPTConfiguration = UserPreferenceChatGPTConfiguration(),
7981
tools: [AgentTool] = [],
8082
endFunction: EndFunction
8183
) {
82-
let functions = tools.compactMap { $0 as? FunctionCallingAgentTool }.map(\.function)
83-
let otherTools = tools.filter { !($0 is FunctionCallingAgentTool) }
84+
let functions = tools.compactMap { $0 as? (any ChatGPTFunction) }
85+
let otherTools = tools.filter { !($0 is (any ChatGPTFunction)) }
8486
functionProvider = .init(
8587
tools: otherTools,
8688
functionTools: functions,
@@ -102,38 +104,18 @@ public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>:
102104
role: .system,
103105
content: """
104106
Respond to the human as helpfully and accurately as possible. \
105-
Format final answer to be more readable, in a ordered list if possible. \
107+
Save the final answer when it's ready
106108
107109
Begin!
108110
"""
109111
),
110-
agentInput.thoughts.isEmpty
111-
? .init(role: .user, content: agentInput.input)
112-
: .init(
113-
role: .user,
114-
content: """
115-
\(agentInput.input)
116-
117-
\({
118-
switch agentInput.thoughts {
119-
case let .text(text):
120-
return text
121-
case let .messages(messages):
122-
return messages.map { message in
123-
"""
124-
\(message)
125-
"""
126-
}.joined(separator: "\n")
127-
}
128-
}())
129-
"""
130-
),
131-
]
112+
.init(role: .user, content: agentInput.input)
113+
] + agentInput.thoughts.content
132114
}
133115
)
134116
}
135117

136-
public func extraPlan(input: AgentInput<String>) {
118+
public func extraPlan(input: AgentInput<String, ScratchPadContent>) {
137119
// no extra plan
138120
}
139121

@@ -142,14 +124,34 @@ public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>:
142124
return "(call sendFinalAnswer to finish)"
143125
}
144126

145-
public func constructScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad {
146-
let baseScratchpad = constructBaseScratchpad(intermediateSteps: intermediateSteps)
147-
if baseScratchpad.isEmpty { return .text("") }
148-
return .text("""
149-
This was your previous work (but I haven't seen any of it! I only see what you return as `Final Answer`):
150-
\(baseScratchpad)
151-
(Please continue with `Thought:` or call a function)
152-
""")
127+
public func constructScratchpad(
128+
intermediateSteps: [AgentAction]
129+
) -> AgentScratchPad<ScratchPadContent> {
130+
let baseScratchpad = intermediateSteps.flatMap {
131+
[
132+
ChatMessage(
133+
role: .assistant,
134+
content: nil,
135+
functionCall: .init(name: $0.toolName, arguments: $0.toolInput)
136+
),
137+
ChatMessage(role: .function, content: $0.observation),
138+
]
139+
}
140+
return .init(content: baseScratchpad)
141+
}
142+
143+
public func constructFinalScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad<ScratchPadContent> {
144+
let baseScratchpad = intermediateSteps.flatMap {
145+
[
146+
ChatMessage(
147+
role: .assistant,
148+
content: nil,
149+
functionCall: .init(name: $0.toolName, arguments: $0.toolInput)
150+
),
151+
ChatMessage(role: .function, content: $0.observation),
152+
]
153+
}
154+
return .init(content: baseScratchpad)
153155
}
154156

155157
public func validateTools(tools: [AgentTool]) throws {
@@ -185,7 +187,7 @@ public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>:
185187
}
186188
}
187189
}
188-
190+
189191
// fallback to normal agent.
190192

191193
let stringBaseOutput = await ChatAgent(

Tool/Sources/OpenAIService/Models.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ public struct ChatMessage: Equatable, Codable {
1818
public struct FunctionCall: Codable, Equatable {
1919
public var name: String
2020
public var arguments: String
21+
public init(name: String, arguments: String) {
22+
self.name = name
23+
self.arguments = arguments
24+
}
2125
}
2226

2327
/// The role of a message.

0 commit comments

Comments
 (0)