forked from intitni/CopilotForXcode
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathStructuredOutputChatModelChain.swift
More file actions
126 lines (110 loc) · 3.85 KB
/
StructuredOutputChatModelChain.swift
File metadata and controls
126 lines (110 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import Foundation
import Logger
import OpenAIService
/// This is an agent used to get a structured output.
public class StructuredOutputChatModelChain<Output: Decodable>: Chain {
public struct EndFunction: ChatGPTArgumentsCollectingFunction {
public struct Arguments: Decodable {
var finalAnswer: Output
}
public var name: String { "FinalAnswer" }
public var description: String { "Save the final answer when it's ready" }
public var argumentSchema: JSONSchemaValue {
return [
.type: "object",
.properties: [
"finalAnswer": .hash(finalAnswerSchema),
],
.required: ["finalAnswer"],
]
}
public let finalAnswerSchema: [String: JSONSchemaValue]
public init(argumentSchema: [String: JSONSchemaValue]) {
finalAnswerSchema = argumentSchema
}
public init() where Output == String {
finalAnswerSchema = [
JSONSchemaKey.type.key: "string",
]
}
public init() where Output == Int {
finalAnswerSchema = [
JSONSchemaKey.type.key: "number",
]
}
public init() where Output == Double {
finalAnswerSchema = [
JSONSchemaKey.type.key: "number",
]
}
}
struct FunctionProvider: ChatGPTFunctionProvider {
var endFunction: EndFunction
var functions: [any ChatGPTFunction] {
[endFunction]
}
var functionCallStrategy: FunctionCallStrategy? {
.function(name: endFunction.name)
}
}
public typealias Input = String
public let chatModelChain: ChatModelChain<String>
var functionProvider: FunctionProvider
public init(
configuration: ChatGPTConfiguration = UserPreferenceChatGPTConfiguration(),
endFunction: EndFunction,
promptTemplate: ((String) -> [ChatMessage])? = nil
) {
functionProvider = .init(
endFunction: endFunction
)
chatModelChain = .init(
chatModel: OpenAIChat(
configuration: configuration.overriding {
$0.runFunctionsAutomatically = false
},
memory: nil,
functionProvider: functionProvider,
stream: false
),
stops: ["Observation:"],
promptTemplate: promptTemplate ?? { input in
[
.init(
role: .system,
content: """
You are a helpful assistant
Generate a final answer to my query as concisely, helpfully and accurately as possible.
You don't ask me for additional information.
"""
),
.init(role: .user, content: input),
]
}
)
}
public func callLogic(
_ input: String,
callbackManagers: [CallbackManager]
) async throws -> Output? {
let output = try await chatModelChain.call(input, callbackManagers: callbackManagers)
return await parseOutput(output)
}
public func parseOutput(_ output: Output?) -> String {
return String(describing: output)
}
public func parseOutput(_ message: ChatMessage) async -> Output? {
if let functionCall = message.toolCalls?.first?.function {
do {
let result = try JSONDecoder().decode(
EndFunction.Arguments.self,
from: functionCall.arguments.data(using: .utf8) ?? Data()
)
return result.finalAnswer
} catch {
return nil
}
}
return nil
}
}