Skip to content

Commit ba3b360

Browse files
committed
Simplify usage of StructuredOutputChatModelChain
1 parent 4860fc8 commit ba3b360

1 file changed

Lines changed: 45 additions & 19 deletions

File tree

Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,44 @@ import OpenAIService
55
/// This is an agent used to get a structured output.
66
public class StructuredOutputChatModelChain<Output: Decodable>: Chain {
77
public struct EndFunction: ChatGPTArgumentsCollectingFunction {
8-
public typealias Arguments = Output
9-
public var name: String { "saveFinalAnswer" }
8+
public struct Arguments: Decodable {
9+
var finalAnswer: Output
10+
}
11+
12+
public var name: String { "FinalAnswer" }
1013
public var description: String { "Save the final answer when it's ready" }
11-
public let argumentSchema: JSONSchemaValue
12-
public init(argumentSchema: JSONSchemaValue) {
13-
self.argumentSchema = argumentSchema
14+
public var argumentSchema: JSONSchemaValue {
15+
return [
16+
.type: "object",
17+
.properties: [
18+
"finalAnswer": .hash(finalAnswerSchema),
19+
],
20+
.required: ["finalAnswer"],
21+
]
22+
}
23+
24+
public let finalAnswerSchema: [String: JSONSchemaValue]
25+
26+
public init(argumentSchema: [String: JSONSchemaValue]) {
27+
finalAnswerSchema = argumentSchema
28+
}
29+
30+
public init() where Output == String {
31+
finalAnswerSchema = [
32+
JSONSchemaKey.type.key: "string",
33+
]
34+
}
35+
36+
public init() where Output == Int {
37+
finalAnswerSchema = [
38+
JSONSchemaKey.type.key: "number",
39+
]
40+
}
41+
42+
public init() where Output == Double {
43+
finalAnswerSchema = [
44+
JSONSchemaKey.type.key: "number",
45+
]
1446
}
1547
}
1648

@@ -79,20 +111,14 @@ public class StructuredOutputChatModelChain<Output: Decodable>: Chain {
79111

80112
public func parseOutput(_ message: ChatMessage) async -> Output? {
81113
if let functionCall = message.functionCall {
82-
if let function = functionProvider.functions.first(where: {
83-
$0.name == functionCall.name
84-
}) {
85-
if function.name == functionProvider.endFunction.name {
86-
do {
87-
let result = try JSONDecoder().decode(
88-
Output.self,
89-
from: functionCall.arguments.data(using: .utf8) ?? Data()
90-
)
91-
return result
92-
} catch {
93-
return nil
94-
}
95-
}
114+
do {
115+
let result = try JSONDecoder().decode(
116+
EndFunction.Arguments.self,
117+
from: functionCall.arguments.data(using: .utf8) ?? Data()
118+
)
119+
return result.finalAnswer
120+
} catch {
121+
return nil
96122
}
97123
}
98124

0 commit comments

Comments
 (0)