Skip to content

Commit 87f876a

Browse files
committed
Add StructuredOutputChatModelChain
1 parent 144e4e6 commit 87f876a

4 files changed

Lines changed: 125 additions & 221 deletions

File tree

Tool/Sources/LangChain/Agent.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,6 @@ public protocol Agent {
9696
associatedtype Output: AgentOutputParsable
9797
associatedtype ScratchPadContent: Equatable
9898
var chatModelChain: ChatModelChain<AgentInput<Input, ScratchPadContent>> { get }
99-
var observationPrefix: String { get }
100-
var llmPrefix: String { get }
10199

102100
func validateTools(tools: [AgentTool]) throws
103101
func constructScratchpad(intermediateSteps: [AgentAction]) -> AgentScratchPad<ScratchPadContent>

Tool/Sources/LangChain/Agents/FunctionCallingChatAgent.swift

Lines changed: 0 additions & 219 deletions
This file was deleted.

Tool/Sources/LangChain/Callback.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ public protocol CallbackEvent {
66
}
77

88
public struct CallbackEvents {
9+
public struct UnTypedEvent: CallbackEvent {
10+
public var info: String
11+
init(info: String) {
12+
self.info = info
13+
}
14+
}
15+
16+
public var untyped: UnTypedEvent.Type { UnTypedEvent.self }
17+
918
private init() {}
1019
}
1120

@@ -52,6 +61,12 @@ public struct CallbackManager {
5261
observer.handler(info)
5362
}
5463
}
64+
65+
public func send(_ string: String) {
66+
for case let observer as Observer<CallbackEvents.UnTypedEvent> in observers {
67+
observer.handler(string)
68+
}
69+
}
5570
}
5671

5772
public extension [CallbackManager] {
@@ -65,5 +80,9 @@ public extension [CallbackManager] {
6580
) {
6681
for cb in self { cb.send(keyPath, info) }
6782
}
83+
84+
func send(_ event: String) {
85+
for cb in self { cb.send(event) }
86+
}
6887
}
6988

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import Foundation
2+
import Logger
3+
import OpenAIService
4+
5+
/// This is an agent used to get a structured output.
6+
public class StructuredOutputChatModelChain<Output: Decodable>: Chain {
7+
public struct EndFunction: ChatGPTFunction {
8+
public typealias Argument = Output
9+
public typealias Result = String
10+
public var name: String { "saveFinalAnswer" }
11+
public var description: String { "Save the final answer when it's ready" }
12+
public let argumentSchema: JSONSchemaValue
13+
public var reportProgress: (String) async -> Void = { _ in }
14+
public func prepare() async {}
15+
public func call(arguments: Argument) async throws -> Result { "" }
16+
public init(argumentSchema: JSONSchemaValue) {
17+
self.argumentSchema = argumentSchema
18+
}
19+
}
20+
21+
struct FunctionProvider: ChatGPTFunctionProvider {
22+
var endFunction: EndFunction
23+
var functions: [any ChatGPTFunction] {
24+
[endFunction]
25+
}
26+
27+
var functionCallStrategy: FunctionCallStrategy? {
28+
.name(endFunction.name)
29+
}
30+
}
31+
32+
public typealias Input = String
33+
public let chatModelChain: ChatModelChain<String>
34+
var functionProvider: FunctionProvider
35+
36+
public init(
37+
configuration: ChatGPTConfiguration = UserPreferenceChatGPTConfiguration(),
38+
tools: [AgentTool] = [],
39+
endFunction: EndFunction,
40+
extraSystemPrompt: String = ""
41+
) {
42+
functionProvider = .init(
43+
endFunction: endFunction
44+
)
45+
chatModelChain = .init(
46+
chatModel: OpenAIChat(
47+
configuration: configuration.overriding {
48+
$0.runFunctionsAutomatically = false
49+
},
50+
memory: nil,
51+
functionProvider: functionProvider,
52+
stream: false
53+
),
54+
stops: ["Observation:"],
55+
promptTemplate: { input in
56+
[
57+
.init(
58+
role: .system,
59+
content: """
60+
You are a helpful assistant
61+
Generate a final answer to my query as concisely, helpfully and accurately as possible.
62+
You don't ask me for additional information.
63+
\(extraSystemPrompt)
64+
"""
65+
),
66+
.init(role: .user, content: input),
67+
]
68+
}
69+
)
70+
}
71+
72+
public func callLogic(
73+
_ input: String,
74+
callbackManagers: [CallbackManager]
75+
) async throws -> Output? {
76+
let output = try await chatModelChain.call(input, callbackManagers: callbackManagers)
77+
return await parseOutput(output)
78+
}
79+
80+
public func parseOutput(_ output: Output?) -> String {
81+
return String(describing: output)
82+
}
83+
84+
public func parseOutput(_ message: ChatMessage) async -> Output? {
85+
if let functionCall = message.functionCall {
86+
if let function = functionProvider.functions.first(where: {
87+
$0.name == functionCall.name
88+
}) {
89+
if function.name == functionProvider.endFunction.name {
90+
do {
91+
let result = try JSONDecoder().decode(
92+
Output.self,
93+
from: functionCall.arguments.data(using: .utf8) ?? Data()
94+
)
95+
return result
96+
} catch {
97+
return nil
98+
}
99+
}
100+
}
101+
}
102+
103+
return nil
104+
}
105+
}
106+

0 commit comments

Comments
 (0)