Skip to content

Commit e187a35

Browse files
committed
Add function calling agent
1 parent 8c2f2b8 commit e187a35

4 files changed

Lines changed: 196 additions & 7 deletions

File tree

Tool/Sources/LangChain/Agent.swift

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ public protocol Agent {
100100

101101
func validateTools(tools: [AgentTool]) throws
102102
func constructScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad
103-
func parseOutput(_ output: String) -> AgentNextStep
103+
func extraPlan(input: AgentInput<Input>)
104+
func prepareForEarlyStopWithGenerate()
105+
func parseOutput(_ output: ChatModelChain<AgentInput<Input>>.Output) async -> AgentNextStep
104106
}
105107

106108
public extension Agent {
@@ -115,8 +117,9 @@ public extension Agent {
115117
callbackManagers: [CallbackManager]
116118
) async throws -> AgentNextStep {
117119
let input = getFullInputs(input: input, intermediateSteps: intermediateSteps)
120+
extraPlan(input: input)
118121
let output = try await chatModelChain.call(input, callbackManagers: callbackManagers)
119-
return parseOutput(output.content ?? "")
122+
return await parseOutput(output)
120123
}
121124

122125
func returnStoppedResponse(
@@ -139,14 +142,14 @@ public extension Agent {
139142
(Please continue with `Final Answer:`)
140143
"""
141144
let input = AgentInput(input: input, thoughts: .text(thoughts))
145+
prepareForEarlyStopWithGenerate()
142146
let output = try await chatModelChain.call(input, callbackManagers: callbackManagers)
143-
let reply = output.content ?? ""
144-
let nextAction = parseOutput(reply)
147+
let nextAction = await parseOutput(output)
145148
switch nextAction {
146149
case let .finish(finish):
147150
return finish
148151
case .actions:
149-
return AgentFinish(returnValue: reply, log: reply)
152+
return AgentFinish(returnValue: output.content ?? "", log: output.content ?? "")
150153
}
151154
}
152155
}

Tool/Sources/LangChain/AgentTool.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import Foundation
2+
import OpenAIService
23

34
public protocol AgentTool {
45
var name: String { get }
@@ -30,3 +31,18 @@ public struct SimpleAgentTool: AgentTool {
3031
}
3132
}
3233

34+
public struct FunctionCallingAgentTool: AgentTool {
35+
public let function: any ChatGPTFunction
36+
public var name: String { function.name }
37+
public var description: String { function.description }
38+
public let returnDirectly: Bool
39+
40+
public init(function: any ChatGPTFunction, returnDirectly: Bool = false) {
41+
self.function = function
42+
self.returnDirectly = returnDirectly
43+
}
44+
45+
public func run(input: String) async throws -> String {
46+
try await function.call(argumentsJsonString: input).botReadableContent
47+
}
48+
}

Tool/Sources/LangChain/Agents/ChatAgent.swift

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,22 @@ public class ChatAgent: Agent {
102102
(Please continue with `Thought:` or `Final Answer:`)
103103
""")
104104
}
105-
105+
106106
public func validateTools(tools: [AgentTool]) throws {
107107
// no validation
108108
}
109109

110-
public func parseOutput(_ text: String) -> AgentNextStep {
110+
public func extraPlan(input: AgentInput<String>) {
111+
// do nothing
112+
}
113+
114+
public func prepareForEarlyStopWithGenerate() {
115+
// do nothing
116+
}
117+
118+
public func parseOutput(_ output: ChatMessage) async -> AgentNextStep {
119+
let text = output.content ?? ""
120+
111121
func parseFinalAnswerIfPossible() -> AgentNextStep? {
112122
let throughAnswerParser = PrefixThrough("Final Answer:")
113123
var parsableContent = text[...]
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import Foundation
2+
import Logger
3+
import OpenAIService
4+
5+
public class FunctionCallingChatAgent: Agent {
6+
struct EndFunction: ChatGPTFunction {
7+
struct Argument: Codable {
8+
let finalAnswer: String
9+
}
10+
11+
typealias Result = String
12+
13+
var name: String { "sendFinalAnswer" }
14+
var description: String { "Send the final answer to user" }
15+
var argumentSchema: JSONSchemaValue {
16+
[
17+
.type: "object",
18+
.properties: [
19+
"finalAnswer": [
20+
.type: "string",
21+
.description: "the final answer to send to user",
22+
],
23+
],
24+
.required: ["finalAnswer"],
25+
]
26+
}
27+
28+
var reportProgress: (String) async -> Void = { _ in }
29+
func prepare() async {}
30+
func call(arguments: Argument) async throws -> Result {
31+
return arguments.finalAnswer
32+
}
33+
}
34+
35+
struct FunctionProvider: ChatGPTFunctionProvider {
36+
var tools: [AgentTool] = []
37+
var functionTools: [any ChatGPTFunction] = []
38+
var functions: [any ChatGPTFunction] {
39+
functionTools + [EndFunction()]
40+
}
41+
42+
var functionCallStrategy: FunctionCallStrategy? = nil
43+
}
44+
45+
public typealias Input = String
46+
public var observationPrefix: String { "Observation: " }
47+
public var llmPrefix: String { "Thought: " }
48+
public let chatModelChain: ChatModelChain<AgentInput<String>>
49+
var functionProvider: FunctionProvider
50+
51+
public init(
52+
configuration: ChatGPTConfiguration = UserPreferenceChatGPTConfiguration(),
53+
memory: ChatGPTMemory = ConversationChatGPTMemory(systemPrompt: ""),
54+
functions: [any ChatGPTFunction] = [],
55+
tools: [AgentTool] = []
56+
) {
57+
functionProvider = .init(tools: tools, functionTools: functions)
58+
chatModelChain = .init(
59+
chatModel: OpenAIChat(
60+
configuration: configuration.overriding {
61+
$0.runFunctionsAutomatically = false
62+
},
63+
memory: memory,
64+
functionProvider: functionProvider,
65+
stream: false
66+
),
67+
stops: ["Observation:"],
68+
promptTemplate: { agentInput in
69+
[
70+
.init(
71+
role: .system,
72+
content: """
73+
Respond to the human as helpfully and accurately as possible. \
74+
Format final answer to be more readable, in a ordered list if possible. \
75+
76+
Begin!
77+
"""
78+
),
79+
agentInput.thoughts.isEmpty
80+
? .init(role: .user, content: agentInput.input)
81+
: .init(
82+
role: .user,
83+
content: """
84+
\(agentInput.input)
85+
86+
\({
87+
switch agentInput.thoughts {
88+
case let .text(text):
89+
return text
90+
case let .messages(messages):
91+
return messages.map { message in
92+
"""
93+
\(message)
94+
"""
95+
}.joined(separator: "\n")
96+
}
97+
}())
98+
"""
99+
),
100+
]
101+
}
102+
)
103+
}
104+
105+
public func extraPlan(input: AgentInput<String>) {
106+
// no extra plan
107+
}
108+
109+
public func prepareForEarlyStopWithGenerate() {
110+
functionProvider.functionTools = []
111+
functionProvider.tools = []
112+
functionProvider.functionCallStrategy = .name("finalAnswer")
113+
}
114+
115+
public func constructScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad {
116+
let baseScratchpad = constructBaseScratchpad(intermediateSteps: intermediateSteps)
117+
if baseScratchpad.isEmpty { return .text("") }
118+
return .text("""
119+
This was your previous work (but I haven't seen any of it! I only see what you return as `Final Answer`):
120+
\(baseScratchpad)
121+
(Please continue with `Thought:` or call a function)
122+
""")
123+
}
124+
125+
public func validateTools(tools: [AgentTool]) throws {
126+
// no validation
127+
}
128+
129+
public func parseOutput(_ message: ChatMessage) async -> AgentNextStep {
130+
if message.role == .function, let functionCall = message.functionCall {
131+
if let function = functionProvider.functionTools.first(where: {
132+
$0.name == functionCall.name
133+
}) {
134+
do {
135+
let result = try await function
136+
.call(argumentsJsonString: functionCall.arguments)
137+
return .actions([.init(
138+
toolName: functionCall.name,
139+
toolInput: result.botReadableContent,
140+
log: result.botReadableContent
141+
)])
142+
} catch {
143+
return .actions([.init(
144+
toolName: functionCall.name,
145+
toolInput: error.localizedDescription,
146+
log: error.localizedDescription
147+
)])
148+
}
149+
}
150+
}
151+
152+
return await ChatAgent(
153+
chatModel: chatModelChain.chatModel,
154+
tools: functionProvider.tools,
155+
preferredLanguage: ""
156+
)
157+
.parseOutput(message)
158+
}
159+
}
160+

0 commit comments

Comments
 (0)