Skip to content

Commit f4397db

Browse files
committed
Add cancellation support to agent executor
1 parent e22ec05 commit f4397db

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

Core/Sources/ChatPlugins/SearchChatPlugin/SearchChatPlugin.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,17 @@ public actor SearchChatPlugin: ChatPlugin {
2828
}
2929

3030
do {
31-
let eventStream = try await search(content)
31+
let (eventStream, cancelAgent) = try await search(content)
3232

3333
var actions = [String]()
3434
var finishedActions = Set<String>()
3535
var message = ""
3636

3737
for try await event in eventStream {
38-
guard !isCancelled else { return }
38+
guard !isCancelled else {
39+
await cancelAgent()
40+
break
41+
}
3942
switch event {
4043
case let .startAction(content):
4144
actions.append(content)

Core/Sources/ChatPlugins/SearchChatPlugin/SearchQuery.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ enum SearchEvent {
99
case finishAnswer(String, [(title: String, link: String)])
1010
}
1111

12-
func search(_ query: String) async throws -> AsyncThrowingStream<SearchEvent, Error> {
12+
func search(_ query: String) async throws
13+
-> (stream: AsyncThrowingStream<SearchEvent, Error>, cancel: () async -> Void)
14+
{
1315
let bingSearch = BingSearchService(
1416
subscriptionKey: UserDefaults.shared.value(for: \.bingSearchSubscriptionKey),
1517
searchURL: UserDefaults.shared.value(for: \.bingSearchEndpoint)
@@ -91,7 +93,7 @@ func search(_ query: String) async throws -> AsyncThrowingStream<SearchEvent, Er
9193
}
9294
}
9395

94-
return AsyncThrowingStream<SearchEvent, Error> { continuation in
96+
return (AsyncThrowingStream<SearchEvent, Error> { continuation in
9597
let callback = ResultCallbackManager(
9698
onFinalAnswerToken: {
9799
continuation.yield(.answerToken($0))
@@ -112,6 +114,8 @@ func search(_ query: String) async throws -> AsyncThrowingStream<SearchEvent, Er
112114
continuation.finish(throwing: error)
113115
}
114116
}
115-
}
117+
}, {
118+
await agentExecutor.cancel()
119+
})
116120
}
117121

Tool/Sources/LangChain/AgentExecutor.swift

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ public actor AgentExecutor<InnerAgent: Agent>: Chain where InnerAgent.Input == S
1919
let tools: [String: AgentTool]
2020
let maxIteration: Int?
2121
let maxExecutionTime: Double?
22-
let earlyStopHandleType: AgentEarlyStopHandleType
22+
var earlyStopHandleType: AgentEarlyStopHandleType
2323
var now: () -> Date = { Date() }
24+
var isCancelled = false
2425

2526
public init(
2627
agent: InnerAgent,
@@ -47,6 +48,7 @@ public actor AgentExecutor<InnerAgent: Agent>: Chain where InnerAgent.Input == S
4748
var intermediateSteps: [AgentAction] = []
4849

4950
func shouldContinue() -> Bool {
51+
if isCancelled { return false }
5052
if let maxIteration = maxIteration, iterations >= maxIteration {
5153
return false
5254
}
@@ -87,7 +89,7 @@ public actor AgentExecutor<InnerAgent: Agent>: Chain where InnerAgent.Input == S
8789
}
8890
iterations += 1
8991
}
90-
92+
9193
let output = try await agent.returnStoppedResponse(
9294
input: input,
9395
earlyStoppedHandleType: earlyStopHandleType,
@@ -104,6 +106,11 @@ public actor AgentExecutor<InnerAgent: Agent>: Chain where InnerAgent.Input == S
104106
public nonisolated func parseOutput(_ output: Output) -> String {
105107
output.finalOutput
106108
}
109+
110+
public func cancel() {
111+
isCancelled = true
112+
earlyStopHandleType = .force
113+
}
107114
}
108115

109116
struct InvalidToolError: Error {}

0 commit comments

Comments
 (0)