|
| 1 | +import ChatBasic |
| 2 | +import Foundation |
| 3 | + |
| 4 | +/// A singleton that stores all the possible capabilities of an ``RAGChatAgent``. |
| 5 | +public enum RAGChatAgentCapabilityContainer { |
| 6 | + static var capabilities: [String: any RAGChatAgentCapability] = [:] |
| 7 | + static func add(_ capability: any RAGChatAgentCapability) { |
| 8 | + capabilities[capability.id] = capability |
| 9 | + } |
| 10 | + |
| 11 | + static func add(_ capabilities: [any RAGChatAgentCapability]) { |
| 12 | + capabilities.forEach { add($0) } |
| 13 | + } |
| 14 | +} |
| 15 | + |
| 16 | +/// A protocol that defines the capability of an ``RAGChatAgent``. |
| 17 | +protocol RAGChatAgentCapability: Identifiable { |
| 18 | + typealias Request = ChatAgentRequest |
| 19 | + typealias Reference = ChatAgentContext.Reference |
| 20 | + |
| 21 | + /// The name to be displayed to the user. |
| 22 | + var name: String { get } |
| 23 | + /// The identifier of the capability. |
| 24 | + var id: String { get } |
| 25 | + /// Fetch the context for a given request. It can return a portion of the context at a time. |
| 26 | + func fetchContext(for request: ChatAgentRequest) async -> AsyncStream<ChatAgentContext> |
| 27 | +} |
| 28 | + |
| 29 | +public struct ChatAgentContext { |
| 30 | + public typealias Reference = ChatMessage.Reference |
| 31 | + |
| 32 | + /// Extra system prompt to be included in the chat request. |
| 33 | + public var extraSystemPrompt: String? |
| 34 | + /// References to be included in the chat request. |
| 35 | + public var references: [Reference] |
| 36 | + /// Functions to be included in the chat request. |
| 37 | + public var functions: [any ChatGPTFunction] |
| 38 | + |
| 39 | + public init( |
| 40 | + extraSystemPrompt: String? = nil, |
| 41 | + references: [ChatMessage.Reference] = [], |
| 42 | + functions: [any ChatGPTFunction] = [] |
| 43 | + ) { |
| 44 | + self.extraSystemPrompt = extraSystemPrompt |
| 45 | + self.references = references |
| 46 | + self.functions = functions |
| 47 | + } |
| 48 | +} |
| 49 | + |
| 50 | +// MARK: - Default Implementation |
| 51 | + |
| 52 | +extension RAGChatAgentCapability { |
| 53 | + func fetchContext(for request: ChatAgentRequest) async -> AsyncStream<ChatAgentContext> { |
| 54 | + return AsyncStream { continuation in |
| 55 | + continuation.finish() |
| 56 | + } |
| 57 | + } |
| 58 | +} |
| 59 | + |
0 commit comments