@@ -5,12 +5,44 @@ import OpenAIService
55/// This is an agent used to get a structured output.
66public class StructuredOutputChatModelChain < Output: Decodable > : Chain {
77 public struct EndFunction : ChatGPTArgumentsCollectingFunction {
8- public typealias Arguments = Output
9- public var name : String { " saveFinalAnswer " }
8+ public struct Arguments : Decodable {
9+ var finalAnswer : Output
10+ }
11+
12+ public var name : String { " FinalAnswer " }
1013 public var description : String { " Save the final answer when it's ready " }
11- public let argumentSchema : JSONSchemaValue
12- public init ( argumentSchema: JSONSchemaValue ) {
13- self . argumentSchema = argumentSchema
14+ public var argumentSchema : JSONSchemaValue {
15+ return [
16+ . type: " object " ,
17+ . properties: [
18+ " finalAnswer " : . hash( finalAnswerSchema) ,
19+ ] ,
20+ . required: [ " finalAnswer " ] ,
21+ ]
22+ }
23+
24+ public let finalAnswerSchema : [ String : JSONSchemaValue ]
25+
26+ public init ( argumentSchema: [ String : JSONSchemaValue ] ) {
27+ finalAnswerSchema = argumentSchema
28+ }
29+
30+ public init ( ) where Output == String {
31+ finalAnswerSchema = [
32+ JSONSchemaKey . type. key: " string " ,
33+ ]
34+ }
35+
36+ public init ( ) where Output == Int {
37+ finalAnswerSchema = [
38+ JSONSchemaKey . type. key: " number " ,
39+ ]
40+ }
41+
42+ public init ( ) where Output == Double {
43+ finalAnswerSchema = [
44+ JSONSchemaKey . type. key: " number " ,
45+ ]
1446 }
1547 }
1648
@@ -79,20 +111,14 @@ public class StructuredOutputChatModelChain<Output: Decodable>: Chain {
79111
80112 public func parseOutput( _ message: ChatMessage ) async -> Output ? {
81113 if let functionCall = message. functionCall {
82- if let function = functionProvider. functions. first ( where: {
83- $0. name == functionCall. name
84- } ) {
85- if function. name == functionProvider. endFunction. name {
86- do {
87- let result = try JSONDecoder ( ) . decode (
88- Output . self,
89- from: functionCall. arguments. data ( using: . utf8) ?? Data ( )
90- )
91- return result
92- } catch {
93- return nil
94- }
95- }
114+ do {
115+ let result = try JSONDecoder ( ) . decode (
116+ EndFunction . Arguments. self,
117+ from: functionCall. arguments. data ( using: . utf8) ?? Data ( )
118+ )
119+ return result. finalAnswer
120+ } catch {
121+ return nil
96122 }
97123 }
98124
0 commit comments