|
| 1 | +import Foundation |
| 2 | +import Logger |
| 3 | +import OpenAIService |
| 4 | + |
| 5 | +public class CombineAnswersChain: Chain { |
| 6 | + public struct Input: Decodable { |
| 7 | + public var question: String |
| 8 | + public var answers: [String] |
| 9 | + public init(question: String, answers: [String]) { |
| 10 | + self.question = question |
| 11 | + self.answers = answers |
| 12 | + } |
| 13 | + } |
| 14 | + |
| 15 | + public typealias Output = String |
| 16 | + public let chatModelChain: ChatModelChain<Input> |
| 17 | + |
| 18 | + public init( |
| 19 | + configuration: ChatGPTConfiguration = UserPreferenceChatGPTConfiguration(), |
| 20 | + extraInstructions: String = "" |
| 21 | + ) { |
| 22 | + chatModelChain = .init( |
| 23 | + chatModel: OpenAIChat( |
| 24 | + configuration: configuration.overriding { |
| 25 | + $0.runFunctionsAutomatically = false |
| 26 | + }, |
| 27 | + memory: nil, |
| 28 | + stream: false |
| 29 | + ), |
| 30 | + stops: ["Observation:"], |
| 31 | + promptTemplate: { input in |
| 32 | + [ |
| 33 | + .init( |
| 34 | + role: .system, |
| 35 | + content: """ |
| 36 | + You are a helpful assistant. |
| 37 | + Your job is to combine multiple answers from different sources to one question. |
| 38 | + \(extraInstructions) |
| 39 | + """ |
| 40 | + ), |
| 41 | + .init(role: .user, content: """ |
| 42 | + Question: \(input.question) |
| 43 | +
|
| 44 | + Answers: |
| 45 | + \(input.answers.joined(separator: "\n\(String(repeating: "-", count: 32))\n")) |
| 46 | +
|
| 47 | + What is the combined answer? |
| 48 | + """), |
| 49 | + ] |
| 50 | + } |
| 51 | + ) |
| 52 | + } |
| 53 | + |
| 54 | + public func callLogic( |
| 55 | + _ input: Input, |
| 56 | + callbackManagers: [CallbackManager] |
| 57 | + ) async throws -> String { |
| 58 | + let output = try await chatModelChain.call(input, callbackManagers: callbackManagers) |
| 59 | + return await parseOutput(output) |
| 60 | + } |
| 61 | + |
| 62 | + public func parseOutput(_ message: ChatMessage) async -> String { |
| 63 | + return message.content ?? "No answer." |
| 64 | + } |
| 65 | + |
| 66 | + public func parseOutput(_ output: String) -> String { |
| 67 | + output |
| 68 | + } |
| 69 | +} |
| 70 | + |
0 commit comments