Skip to content

Commit 7a37f6c

Browse files
committed
Reimplement CallbackManager to be more easily extensible
1 parent 42016e3 commit 7a37f6c

File tree

8 files changed

+113
-87
lines changed

8 files changed

+113
-87
lines changed

Core/Sources/ChatPlugins/SearchChatPlugin/SearchQuery.swift

Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import BingSearchService
22
import Foundation
33
import LangChain
4+
import OpenAIService
45

56
enum SearchEvent {
67
case startAction(String)
@@ -42,7 +43,10 @@ func search(_ query: String) async throws
4243
),
4344
]
4445

45-
let chatModel = OpenAIChat(temperature: 0, stream: true)
46+
let chatModel = OpenAIChat(
47+
configuration: UserPreferenceChatGPTConfiguration().overriding { $0.temperature = 0 },
48+
stream: true
49+
)
4650

4751
let agentExecutor = AgentExecutor(
4852
agent: ChatAgent(
@@ -55,63 +59,37 @@ func search(_ query: String) async throws
5559
earlyStopHandleType: .generate
5660
)
5761

58-
class ResultCallbackManager: ChainCallbackManager {
62+
return (AsyncThrowingStream<SearchEvent, Error> { continuation in
5963
var accumulation: String = ""
6064
var isGeneratingFinalAnswer = false
61-
var onFinalAnswerToken: (String) -> Void
62-
var onAgentActionStart: (String) -> Void
63-
var onAgentActionEnd: (String) -> Void
64-
65-
init(
66-
onFinalAnswerToken: @escaping (String) -> Void,
67-
onAgentActionStart: @escaping (String) -> Void,
68-
onAgentActionEnd: @escaping (String) -> Void
69-
) {
70-
self.onFinalAnswerToken = onFinalAnswerToken
71-
self.onAgentActionStart = onAgentActionStart
72-
self.onAgentActionEnd = onAgentActionEnd
73-
}
74-
75-
func onChainStart<T>(type: T.Type, input: T.Input) where T: LangChain.Chain {}
7665

77-
func onAgentFinish(output: LangChain.AgentFinish) {}
78-
79-
func onAgentActionStart(action: LangChain.AgentAction) {
80-
onAgentActionStart("\(action.toolName): \(action.toolInput)")
81-
}
82-
83-
func onAgentActionEnd(action: LangChain.AgentAction) {
84-
onAgentActionEnd("\(action.toolName): \(action.toolInput)")
85-
}
86-
87-
func onLLMNewToken(token: String) {
88-
if isGeneratingFinalAnswer {
89-
onFinalAnswerToken(token)
90-
return
66+
let callbackManager = CallbackManager { manager in
67+
manager.on(CallbackEvents.AgentActionDidStart.self) {
68+
continuation.yield(.startAction("\($0.toolName): \($0.toolInput)"))
9169
}
92-
accumulation.append(token)
93-
if accumulation.hasSuffix("Final Answer: ") {
94-
isGeneratingFinalAnswer = true
95-
accumulation = ""
70+
71+
manager.on(CallbackEvents.AgentActionDidEnd.self) {
72+
continuation.yield(.endAction("\($0.toolName): \($0.toolInput)"))
9673
}
97-
}
98-
}
9974

100-
return (AsyncThrowingStream<SearchEvent, Error> { continuation in
101-
let callback = ResultCallbackManager(
102-
onFinalAnswerToken: {
103-
continuation.yield(.answerToken($0))
104-
},
105-
onAgentActionStart: {
106-
continuation.yield(.startAction($0))
107-
},
108-
onAgentActionEnd: {
109-
continuation.yield(.endAction($0))
75+
manager.on(CallbackEvents.LLMDidProduceNewToken.self) {
76+
if isGeneratingFinalAnswer {
77+
continuation.yield(.answerToken($0))
78+
return
79+
}
80+
accumulation.append($0)
81+
if accumulation.hasSuffix("Final Answer: ") {
82+
isGeneratingFinalAnswer = true
83+
accumulation = ""
84+
}
11085
}
111-
)
86+
}
11287
Task {
11388
do {
114-
let finalAnswer = try await agentExecutor.run(query, callbackManagers: [callback])
89+
let finalAnswer = try await agentExecutor.run(
90+
query,
91+
callbackManagers: [callbackManager]
92+
)
11593
continuation.yield(.finishAnswer(finalAnswer, linkStorage.links))
11694
continuation.finish()
11795
} catch {

Tool/Sources/LangChain/Agent.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public extension Agent {
8686
func plan(
8787
input: Input,
8888
intermediateSteps: [AgentAction],
89-
callbackManagers: [ChainCallbackManager]
89+
callbackManagers: [CallbackManager]
9090
) async throws -> AgentNextStep {
9191
let input = getFullInputs(input: input, intermediateSteps: intermediateSteps)
9292
let output = try await chatModelChain.call(input, callbackManagers: callbackManagers)
@@ -97,7 +97,7 @@ public extension Agent {
9797
input: Input,
9898
earlyStoppedHandleType: AgentEarlyStopHandleType,
9999
intermediateSteps: [AgentAction],
100-
callbackManagers: [ChainCallbackManager]
100+
callbackManagers: [CallbackManager]
101101
) async throws -> AgentFinish {
102102
switch earlyStoppedHandleType {
103103
case .force:

Tool/Sources/LangChain/AgentExecutor.swift

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
11
import Foundation
22

3-
public protocol ChainCallbackManager {
4-
func onChainStart<T: Chain>(type: T.Type, input: T.Input)
5-
func onAgentFinish(output: AgentFinish)
6-
func onAgentActionStart(action: AgentAction)
7-
func onAgentActionEnd(action: AgentAction)
8-
func onLLMNewToken(token: String)
9-
}
10-
113
public actor AgentExecutor<InnerAgent: Agent>: Chain where InnerAgent.Input == String {
124
public typealias Input = String
135
public struct Output {
@@ -39,7 +31,7 @@ public actor AgentExecutor<InnerAgent: Agent>: Chain where InnerAgent.Input == S
3931

4032
public func callLogic(
4133
_ input: Input,
42-
callbackManagers: [ChainCallbackManager]
34+
callbackManagers: [CallbackManager]
4335
) async throws -> Output {
4436
try agent.validateTools(tools: Array(tools.values))
4537

@@ -89,7 +81,7 @@ public actor AgentExecutor<InnerAgent: Agent>: Chain where InnerAgent.Input == S
8981
}
9082
iterations += 1
9183
}
92-
84+
9385
let output = try await agent.returnStoppedResponse(
9486
input: input,
9587
earlyStoppedHandleType: earlyStopHandleType,
@@ -106,7 +98,7 @@ public actor AgentExecutor<InnerAgent: Agent>: Chain where InnerAgent.Input == S
10698
public nonisolated func parseOutput(_ output: Output) -> String {
10799
output.finalOutput
108100
}
109-
101+
110102
public func cancel() {
111103
isCancelled = true
112104
earlyStopHandleType = .force
@@ -119,10 +111,10 @@ extension AgentExecutor {
119111
func end(
120112
output: AgentFinish,
121113
intermediateSteps: [AgentAction],
122-
callbackManagers: [ChainCallbackManager]
114+
callbackManagers: [CallbackManager]
123115
) -> Output {
124116
for callbackManager in callbackManagers {
125-
callbackManager.onAgentFinish(output: output)
117+
callbackManager.send(CallbackEvents.AgentDidFinish(info: output))
126118
}
127119
let finalOutput = output.returnValue
128120
return .init(finalOutput: finalOutput, intermediateSteps: intermediateSteps)
@@ -131,7 +123,7 @@ extension AgentExecutor {
131123
func takeNextStep(
132124
input: Input,
133125
intermediateSteps: [AgentAction],
134-
callbackManagers: [ChainCallbackManager]
126+
callbackManagers: [CallbackManager]
135127
) async throws -> AgentNextStep {
136128
let output = try await agent.plan(
137129
input: input,
@@ -144,7 +136,8 @@ extension AgentExecutor {
144136
let completedActions = try await withThrowingTaskGroup(of: AgentAction.self) {
145137
taskGroup in
146138
for action in actions {
147-
callbackManagers.forEach { $0.onAgentActionStart(action: action) }
139+
callbackManagers
140+
.forEach { $0.send(CallbackEvents.AgentActionDidStart(info: action)) }
148141
guard let tool = tools[action.toolName] else { throw InvalidToolError() }
149142
taskGroup.addTask {
150143
let observation = try await tool.run(input: action.toolInput)
@@ -154,7 +147,8 @@ extension AgentExecutor {
154147
var completedActions = [AgentAction]()
155148
for try await action in taskGroup {
156149
completedActions.append(action)
157-
callbackManagers.forEach { $0.onAgentActionEnd(action: action) }
150+
callbackManagers
151+
.forEach { $0.send(CallbackEvents.AgentActionDidEnd(info: action)) }
158152
}
159153
return completedActions
160154
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import Foundation
2+
3+
public protocol CallbackEvent {
4+
associatedtype Info
5+
var info: Info { get }
6+
}
7+
8+
public enum CallbackEvents {}
9+
10+
public struct CallbackManager {
11+
private var observers = [Any]()
12+
13+
public init() {}
14+
15+
public init(observers: (inout CallbackManager) -> Void) {
16+
var manager = CallbackManager()
17+
observers(&manager)
18+
self = manager
19+
}
20+
21+
public mutating func on<Event: CallbackEvent>(
22+
_: Event.Type = Event.self,
23+
_ handler: @escaping (Event.Info) -> Void
24+
) {
25+
observers.append(handler)
26+
}
27+
28+
public func send<Event: CallbackEvent>(_ event: Event) {
29+
for case let observer as ((Event.Info) -> Void) in observers {
30+
observer(event.info)
31+
}
32+
}
33+
}
34+
35+
public extension CallbackEvents {
36+
struct AgentDidFinish: CallbackEvent {
37+
public let info: AgentFinish
38+
}
39+
40+
struct AgentActionDidStart: CallbackEvent {
41+
public let info: AgentAction
42+
}
43+
44+
struct AgentActionDidEnd: CallbackEvent {
45+
public let info: AgentAction
46+
}
47+
48+
struct LLMDidProduceNewToken: CallbackEvent {
49+
public let info: String
50+
}
51+
52+
struct ChainDidStart<T: Chain>: CallbackEvent {
53+
public let info: (type: T.Type, input: T.Input)
54+
}
55+
}
56+

Tool/Sources/LangChain/Chain.swift

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,20 @@ import Foundation
33
public protocol Chain {
44
associatedtype Input
55
associatedtype Output
6-
func callLogic(_ input: Input, callbackManagers: [ChainCallbackManager]) async throws -> Output
6+
func callLogic(_ input: Input, callbackManagers: [CallbackManager]) async throws -> Output
77
func parseOutput(_ output: Output) -> String
88
}
99

1010
public extension Chain {
11-
func run(_ input: Input, callbackManagers: [ChainCallbackManager] = []) async throws -> String {
11+
func run(_ input: Input, callbackManagers: [CallbackManager] = []) async throws -> String {
1212
let output = try await call(input, callbackManagers: callbackManagers)
1313
return parseOutput(output)
1414
}
1515

16-
func call(_ input: Input, callbackManagers: [ChainCallbackManager] = []) async throws -> Output {
16+
func call(_ input: Input, callbackManagers: [CallbackManager] = []) async throws -> Output {
1717
for callbackManager in callbackManagers {
18-
callbackManager.onChainStart(type: Self.self, input: input)
18+
callbackManager
19+
.send(CallbackEvents.ChainDidStart(info: (type: Self.self, input: input)))
1920
}
2021
return try await callLogic(input, callbackManagers: callbackManagers)
2122
}
@@ -35,7 +36,7 @@ public struct SimpleChain<Input, Output>: Chain {
3536

3637
public func callLogic(
3738
_ input: Input,
38-
callbackManagers: [ChainCallbackManager]
39+
callbackManagers: [CallbackManager]
3940
) async throws -> Output {
4041
return try await block(input)
4142
}
@@ -54,7 +55,7 @@ public struct ConnectedChain<A: Chain, B: Chain>: Chain where B.Input == A.Outpu
5455

5556
public func callLogic(
5657
_ input: Input,
57-
callbackManagers: [ChainCallbackManager] = []
58+
callbackManagers: [CallbackManager] = []
5859
) async throws -> Output {
5960
let a = try await chainA.call(input, callbackManagers: callbackManagers)
6061
let b = try await chainB.call(a, callbackManagers: callbackManagers)
@@ -75,7 +76,7 @@ public struct PairedChain<A: Chain, B: Chain>: Chain {
7576

7677
public func callLogic(
7778
_ input: Input,
78-
callbackManagers: [ChainCallbackManager] = []
79+
callbackManagers: [CallbackManager] = []
7980
) async throws -> Output {
8081
async let a = chainA.call(input.0, callbackManagers: callbackManagers)
8182
async let b = chainB.call(input.1, callbackManagers: callbackManagers)
@@ -96,7 +97,7 @@ public struct MappedChain<A: Chain, NewOutput>: Chain {
9697

9798
public func callLogic(
9899
_ input: Input,
99-
callbackManagers: [ChainCallbackManager]
100+
callbackManagers: [CallbackManager]
100101
) async throws -> Output {
101102
let output = try await chain.call(input, callbackManagers: callbackManagers)
102103
return map(output)

Tool/Sources/LangChain/Chains/LLMChain.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public class ChatModelChain<Input>: Chain {
1919

2020
public func callLogic(
2121
_ input: Input,
22-
callbackManagers: [ChainCallbackManager]
22+
callbackManagers: [CallbackManager]
2323
) async throws -> Output {
2424
let prompt = promptTemplate(input)
2525
let output = try await chatModel.generate(

Tool/Sources/LangChain/ChatModel/ChatModel.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ public protocol ChatModel {
44
func generate(
55
prompt: [ChatMessage],
66
stops: [String],
7-
callbackManagers: [ChainCallbackManager]
7+
callbackManagers: [CallbackManager]
88
) async throws -> String
99
}
1010

Tool/Sources/LangChain/ChatModel/OpenAIChat.swift

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,22 @@ import Foundation
22
import OpenAIService
33

44
public struct OpenAIChat: ChatModel {
5-
public var temperature: Double
5+
public var configuration: ChatGPTConfiguration
66
public var stream: Bool
77

88
public init(
9-
temperature: Double = 0.7,
10-
stream: Bool = false
9+
configuration: ChatGPTConfiguration,
10+
stream: Bool
1111
) {
12-
self.temperature = temperature
12+
self.configuration = configuration
1313
self.stream = stream
1414
}
1515

1616
public func generate(
1717
prompt: [ChatMessage],
1818
stops: [String],
19-
callbackManagers: [ChainCallbackManager]
19+
callbackManagers: [CallbackManager]
2020
) async throws -> String {
21-
let configuration = UserPreferenceChatGPTConfiguration().overriding(.init(
22-
temperature: temperature,
23-
stop: stops
24-
))
2521
let memory = AutoManagedChatGPTMemory(
2622
systemPrompt: "",
2723
configuration: configuration,
@@ -47,7 +43,8 @@ public struct OpenAIChat: ChatModel {
4743
var message = ""
4844
for try await trunk in stream {
4945
message.append(trunk)
50-
callbackManagers.forEach { $0.onLLMNewToken(token: trunk) }
46+
callbackManagers
47+
.forEach { $0.send(CallbackEvents.LLMDidProduceNewToken(info: trunk)) }
5148
}
5249
return message
5350
} else {

0 commit comments

Comments
 (0)