Skip to content

Commit e2c469f

Browse files
committed
Adjust agent
1 parent 5e8596c commit e2c469f

3 files changed

Lines changed: 65 additions & 36 deletions

File tree

Tool/Sources/LangChain/Agent.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ public struct AgentFinish<Output: AgentOutputParsable> {
6464
public enum AgentNextStep<Output: AgentOutputParsable> {
6565
case actions([AgentAction])
6666
case finish(AgentFinish<Output>)
67+
case thought(String)
6768
}
6869

6970
public struct AgentScratchPad<Content: Equatable>: Equatable {
@@ -145,6 +146,8 @@ public extension Agent {
145146
return finish
146147
case .actions:
147148
return .init(returnValue: .unstructured(output.content ?? ""), log: output.content ?? "")
149+
case let .thought(content):
150+
return .init(returnValue: .unstructured(content), log: content)
148151
}
149152
}
150153
}

Tool/Sources/LangChain/AgentExecutor.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,14 @@ public actor AgentExecutor<InnerAgent: Agent>: Chain
5757
}
5858

5959
while shouldContinue() {
60+
try Task.checkCancellation()
6061
let nextStepOutput = try await takeNextStep(
6162
input: input,
6263
intermediateSteps: intermediateSteps,
6364
callbackManagers: callbackManagers
6465
)
6566

67+
try Task.checkCancellation()
6668
switch nextStepOutput {
6769
case let .finish(finish):
6870
return end(
@@ -82,6 +84,8 @@ public actor AgentExecutor<InnerAgent: Agent>: Chain
8284
callbackManagers: callbackManagers
8385
)
8486
}
87+
case .thought:
88+
break
8589
}
8690
iterations += 1
8791
}
@@ -156,6 +160,7 @@ extension AgentExecutor {
156160
}
157161
var completedActions = [AgentAction]()
158162
for try await action in taskGroup {
163+
try Task.checkCancellation()
159164
completedActions.append(action)
160165
callbackManagers
161166
.forEach { $0.send(CallbackEvents.AgentActionDidEnd(info: action)) }
@@ -164,6 +169,15 @@ extension AgentExecutor {
164169
}
165170

166171
return .actions(completedActions)
172+
case let .thought(content):
173+
return .actions([
174+
.init(
175+
toolName: "Thought",
176+
toolInput: content,
177+
log: "Thought: \(content)",
178+
observation: nil
179+
),
180+
])
167181
}
168182
}
169183

@@ -212,3 +226,4 @@ extension Double: AgentOutputParsable {
212226

213227
public var botReadableContent: String { String(self) }
214228
}
229+

Tool/Sources/LangChain/Agents/FunctionCallingChatAgent.swift

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import OpenAIService
44

55
public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>: Agent {
66
public typealias ScratchPadContent = [ChatMessage]
7-
7+
88
public struct EndFunction: ChatGPTFunction {
99
public typealias Argument = Output
1010
public typealias Result = String
@@ -79,7 +79,8 @@ public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>:
7979
public init(
8080
configuration: ChatGPTConfiguration = UserPreferenceChatGPTConfiguration(),
8181
tools: [AgentTool] = [],
82-
endFunction: EndFunction
82+
endFunction: EndFunction,
83+
extraSystemPrompt: String = ""
8384
) {
8485
let functions = tools.compactMap { $0 as? (any ChatGPTFunction) }
8586
let otherTools = tools.filter { !($0 is (any ChatGPTFunction)) }
@@ -103,13 +104,13 @@ public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>:
103104
.init(
104105
role: .system,
105106
content: """
106-
Respond to the human as helpfully and accurately as possible. \
107-
Save the final answer when it's ready
108-
109-
Begin!
107+
Gather information using functions, and generate a final answer to my query as helpfully and accurately as possible.
108+
You don't ask me for additional information.
109+
\(extraSystemPrompt)
110+
When you have the final answer, you MUST call `\(endFunction.name)` to save it.
110111
"""
111112
),
112-
.init(role: .user, content: agentInput.input)
113+
.init(role: .user, content: agentInput.input),
113114
] + agentInput.thoughts.content
114115
}
115116
)
@@ -118,49 +119,57 @@ public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>:
118119
public func extraPlan(input: AgentInput<String, ScratchPadContent>) {
119120
// no extra plan
120121
}
121-
122-
public func prepareForEarlyStopWithGenerate() -> String {
123-
functionProvider.shouldFinish = true
124-
return "(call sendFinalAnswer to finish)"
125-
}
126-
127-
public func constructScratchpad(
122+
123+
func constructBaseScratchpad(
128124
intermediateSteps: [AgentAction]
129-
) -> AgentScratchPad<ScratchPadContent> {
130-
let baseScratchpad = intermediateSteps.flatMap {
131-
[
125+
) -> ScratchPadContent {
126+
return intermediateSteps.flatMap {
127+
if let observation = $0.observation {
128+
return [
129+
ChatMessage(
130+
role: .assistant,
131+
content: nil,
132+
functionCall: .init(name: $0.toolName, arguments: $0.toolInput)
133+
),
134+
ChatMessage(role: .function, content: observation, name: $0.toolName),
135+
]
136+
}
137+
return [
138+
ChatMessage(role: .assistant, content: $0.toolInput),
132139
ChatMessage(
133-
role: .assistant,
134-
content: nil,
135-
functionCall: .init(name: $0.toolName, arguments: $0.toolInput)
140+
role: .user,
141+
content: "Please continue, call \(functionProvider.endFunction.name) when you are done."
136142
),
137-
ChatMessage(role: .function, content: $0.observation),
138143
]
139144
}
140-
return .init(content: baseScratchpad)
141145
}
142-
143-
public func constructFinalScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad<ScratchPadContent> {
144-
let baseScratchpad = intermediateSteps.flatMap {
145-
[
146-
ChatMessage(
147-
role: .assistant,
148-
content: nil,
149-
functionCall: .init(name: $0.toolName, arguments: $0.toolInput)
150-
),
151-
ChatMessage(role: .function, content: $0.observation),
152-
]
153-
}
146+
147+
public func constructScratchpad(
148+
intermediateSteps: [AgentAction]
149+
) -> AgentScratchPad<ScratchPadContent> {
150+
functionProvider.shouldFinish = false
151+
let baseScratchpad = constructBaseScratchpad(intermediateSteps: intermediateSteps)
154152
return .init(content: baseScratchpad)
155153
}
156154

155+
public func constructFinalScratchpad(
156+
intermediateSteps: [AgentAction]
157+
) -> AgentScratchPad<ScratchPadContent> {
158+
functionProvider.shouldFinish = true
159+
let baseScratchpad = constructBaseScratchpad(intermediateSteps: intermediateSteps)
160+
return .init(content: baseScratchpad + [
161+
ChatMessage(role: .assistant, content: "Now I need to save the final answer"),
162+
ChatMessage(role: .user, content: "Please continue"),
163+
])
164+
}
165+
157166
public func validateTools(tools: [AgentTool]) throws {
158167
// no validation
159168
}
160169

161170
public func parseOutput(_ message: ChatMessage) async -> AgentNextStep<Output> {
162-
if message.role == .assistant, let functionCall = message.functionCall {
163-
if let function = functionProvider.functionTools.first(where: {
171+
if let functionCall = message.functionCall {
172+
if let function = functionProvider.functions.first(where: {
164173
$0.name == functionCall.name
165174
}) {
166175
if function.name == functionProvider.endFunction.name {
@@ -204,6 +213,8 @@ public class FunctionCallingChatAgent<Output: AgentOutputParsable & Decodable>:
204213
case let .unstructured(x), let .structured(x):
205214
return .finish(.init(returnValue: .unstructured(x), log: finish.log))
206215
}
216+
case let .thought(content):
217+
return .finish(.init(returnValue: .unstructured(content), log: content))
207218
}
208219
}
209220
}

0 commit comments

Comments
 (0)