diff --git a/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift b/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift index 29327cea..851fdcf7 100644 --- a/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift +++ b/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift @@ -32,7 +32,7 @@ extension WebChatContextCollector { static func detectLinks(from messages: [ChatMessage]) -> [String] { return messages.lazy .compactMap { - $0.content ?? $0.functionCall?.arguments + $0.content ?? $0.toolCalls?.map(\.function.arguments).joined(separator: " ") ?? "" } .map(detectLinks(from:)) .flatMap { $0 } diff --git a/Core/Sources/ChatGPTChatTab/Chat.swift b/Core/Sources/ChatGPTChatTab/Chat.swift index 868266bc..4d722d68 100644 --- a/Core/Sources/ChatGPTChatTab/Chat.swift +++ b/Core/Sources/ChatGPTChatTab/Chat.swift @@ -9,13 +9,13 @@ public struct DisplayedChatMessage: Equatable { public enum Role: Equatable { case user case assistant - case function + case tool case ignored } public struct Reference: Equatable { public typealias Kind = ChatMessage.Reference.Kind - + public var title: String public var subtitle: String public var uri: String @@ -135,7 +135,7 @@ struct Chat: ReducerProtocol { await send(.focusOnTextField) await send(.refresh) } - + case .refresh: return .run { send in await send(.chatMenu(.refresh)) @@ -298,8 +298,9 @@ struct Chat: ReducerProtocol { }.cancellable(id: CancelID.observeDefaultScopesChange(id), cancelInFlight: true) case .historyChanged: - state.history = service.chatHistory.map { message in - .init( + state.history = service.chatHistory.flatMap { message in + var all = [DisplayedChatMessage]() + all.append(.init( id: message.id, role: { switch message.role { @@ -312,7 +313,6 @@ struct Chat: ReducerProtocol { return .assistant } return .ignored - case .function: return .function } }(), text: message.summary ?? message.content ?? "", @@ -325,7 +325,18 @@ struct Chat: ReducerProtocol { kind: $0.kind ) } - ) + )) + + for call in message.toolCalls ?? [] { + all.append(.init( + id: message.id + call.id, + role: .tool, + text: call.response.summary ?? call.response.content, + references: [] + )) + } + + return all } state.title = { @@ -401,7 +412,7 @@ struct ChatMenu: ReducerProtocol { return .run { await $0(.refresh) } - + case .refresh: state.temperatureOverride = service.configuration.overriding.temperature state.chatModelIdOverride = service.configuration.overriding.modelId diff --git a/Core/Sources/ChatGPTChatTab/ChatPanel.swift b/Core/Sources/ChatGPTChatTab/ChatPanel.swift index 7a729bfd..2729b5e4 100644 --- a/Core/Sources/ChatGPTChatTab/ChatPanel.swift +++ b/Core/Sources/ChatGPTChatTab/ChatPanel.swift @@ -258,7 +258,7 @@ struct ChatHistory: View { trailing: -8 )) .padding(.vertical, 4) - case .function: + case .tool: FunctionMessage(id: message.id, text: text) case .ignored: EmptyView() @@ -453,7 +453,7 @@ struct ChatPanel_Preview: PreviewProvider { ), .init( id: "6", - role: .function, + role: .tool, text: """ Searching for something... - abc diff --git a/Core/Sources/ChatService/ChatService.swift b/Core/Sources/ChatService/ChatService.swift index e8b7dae0..4bb74639 100644 --- a/Core/Sources/ChatService/ChatService.swift +++ b/Core/Sources/ChatService/ChatService.swift @@ -124,9 +124,9 @@ public final class ChatService: ObservableObject { await chatGPTService.stopReceivingMessage() isReceivingMessage = false - // if it's stopped before the function finishes, remove the function call. + // if it's stopped before the tool calls finish, remove the message. await memory.mutateHistory { history in - if history.last?.role == .assistant, history.last?.functionCall != nil { + if history.last?.role == .assistant, history.last?.toolCalls != nil { history.removeLast() } } diff --git a/Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift b/Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift index 9f4a53e1..ac44d87c 100644 --- a/Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift +++ b/Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift @@ -39,7 +39,7 @@ public final class ContextAwareAutoManagedChatGPTMemory: ChatGPTMemory { public func generatePrompt() async -> ChatGPTPrompt { let content = (await memory.history) - .last(where: { $0.role == .user || $0.role == .function })?.content + .last(where: { $0.role == .user })?.content try? await contextController.collectContextInformation( systemPrompt: """ \(chatService?.systemPrompt ?? "") diff --git a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift index 342ef862..0e9b2062 100644 --- a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift +++ b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift @@ -14,6 +14,7 @@ struct ChatModelEdit: ReducerProtocol { @BindingState var maxTokens: Int = 4000 @BindingState var supportsFunctionCalling: Bool = true @BindingState var modelName: String = "" + @BindingState var ollamaKeepAlive: String = "" var apiKeyName: String { apiKeySelection.apiKeyName } var baseURL: String { baseURLSelection.baseURL } var isFullURL: Bool { baseURLSelection.isFullURL } @@ -48,7 +49,7 @@ struct ChatModelEdit: ReducerProtocol { Scope(state: \.apiKeySelection, action: /Action.apiKeySelection) { APIKeySelection() } - + Scope(state: \.baseURLSelection, action: /Action.baseURLSelection) { BaseURLSelection() } @@ -135,10 +136,10 @@ struct ChatModelEdit: ReducerProtocol { state.suggestedMaxTokens = nil return .none } - + case .apiKeySelection: return .none - + case .baseURLSelection: return .none @@ -169,6 +170,7 @@ extension ChatModelEdit.State { maxTokens: model.info.maxTokens, supportsFunctionCalling: model.info.supportsFunctionCalling, modelName: model.info.modelName, + ollamaKeepAlive: model.info.ollamaInfo.keepAlive, apiKeySelection: .init( apiKeyName: model.info.apiKeyName, apiKeyManagement: .init(availableAPIKeyNames: [model.info.apiKeyName]) @@ -193,9 +195,13 @@ extension ChatModel { if case .googleAI = state.format { return false } + if case .ollama = state.format { + return false + } return state.supportsFunctionCalling }(), - modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines) + modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines), + ollamaInfo: .init(keepAlive: state.ollamaKeepAlive) ) ) } diff --git a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift index b46a0baf..fd6b1e21 100644 --- a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift +++ b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift @@ -24,6 +24,8 @@ struct ChatModelEditView: View { openAICompatible case .googleAI: googleAI + case .ollama: + ollama } } } @@ -92,6 +94,8 @@ struct ChatModelEditView: View { Text("OpenAI Compatible").tag(format) case .googleAI: Text("Google Generative AI").tag(format) + case .ollama: + Text("Ollama").tag(format) } } }, @@ -171,7 +175,7 @@ struct ChatModelEditView: View { ) TextField(text: textFieldBinding) { - Text("Max Tokens (Including Reply)") + Text("Context Window") .multilineTextAlignment(.trailing) } .overlay(alignment: .trailing) { @@ -344,6 +348,38 @@ struct ChatModelEditView: View { maxTokensTextField } + + @ViewBuilder + var ollama: some View { + baseURLTextField(prompt: Text("http://127.0.0.1:11434")) { + Text("/api/chat") + } + + WithViewStore( + store, + removeDuplicates: { $0.modelName == $1.modelName } + ) { viewStore in + TextField("Model Name", text: viewStore.$modelName) + } + + maxTokensTextField + + WithViewStore( + store, + removeDuplicates: { $0.ollamaKeepAlive == $1.ollamaKeepAlive } + ) { viewStore in + TextField(text: viewStore.$ollamaKeepAlive, prompt: Text("Default Value")) { + Text("Keep Alive") + } + } + + VStack(alignment: .leading, spacing: 8) { + Text(Image(systemName: "exclamationmark.triangle.fill")) + Text( + " For more details, please visit [https://ollama.com](https://ollama.com)." + ) + } + .padding(.vertical) + } } #Preview("OpenAI") { diff --git a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelManagement.swift b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelManagement.swift index 6f34bdc6..1bbb109b 100644 --- a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelManagement.swift +++ b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelManagement.swift @@ -11,6 +11,7 @@ extension ChatModel: ManageableAIModel { case .azureOpenAI: return "Azure OpenAI" case .openAICompatible: return "OpenAI Compatible" case .googleAI: return "Google Generative AI" + case .ollama: return "Ollama" } } diff --git a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift index df8dbf22..c5d1378e 100644 --- a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift +++ b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift @@ -13,6 +13,7 @@ struct EmbeddingModelEdit: ReducerProtocol { @BindingState var format: EmbeddingModel.Format @BindingState var maxTokens: Int = 8191 @BindingState var modelName: String = "" + @BindingState var ollamaKeepAlive: String = "" var apiKeyName: String { apiKeySelection.apiKeyName } var baseURL: String { baseURLSelection.baseURL } var isFullURL: Bool { baseURLSelection.isFullURL } @@ -83,14 +84,13 @@ struct EmbeddingModelEdit: ReducerProtocol { ) return .run { send in do { - let tokenUsage = - try await EmbeddingService( - configuration: UserPreferenceEmbeddingConfiguration() - .overriding { - $0.model = model - } - ).embed(text: "Hello").usage.total_tokens - await send(.testSucceeded("Used \(tokenUsage) tokens.")) + _ = try await EmbeddingService( + configuration: UserPreferenceEmbeddingConfiguration() + .overriding { + $0.model = model + } + ).embed(text: "Hello") + await send(.testSucceeded("Succeeded!")) } catch { await send(.testFailed(error.localizedDescription)) } @@ -155,6 +155,7 @@ extension EmbeddingModelEdit.State { format: model.format, maxTokens: model.info.maxTokens, modelName: model.info.modelName, + ollamaKeepAlive: model.info.ollamaInfo.keepAlive, apiKeySelection: .init( apiKeyName: model.info.apiKeyName, apiKeyManagement: .init(availableAPIKeyNames: [model.info.apiKeyName]) @@ -175,7 +176,8 @@ extension EmbeddingModel { baseURL: state.baseURL.trimmingCharacters(in: .whitespacesAndNewlines), isFullURL: state.isFullURL, maxTokens: state.maxTokens, - modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines) + modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines), + ollamaInfo: .init(keepAlive: state.ollamaKeepAlive) ) ) } diff --git a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEditView.swift b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEditView.swift index c9c7b452..2bad443f 100644 --- a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEditView.swift +++ b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEditView.swift @@ -22,6 +22,8 @@ struct EmbeddingModelEditView: View { azureOpenAI case .openAICompatible: openAICompatible + case .ollama: + ollama } } } @@ -88,6 +90,8 @@ struct EmbeddingModelEditView: View { Text("Azure OpenAI").tag(format) case .openAICompatible: Text("OpenAI Compatible").tag(format) + case .ollama: + Text("Ollama").tag(format) } } }, @@ -289,6 +293,38 @@ struct EmbeddingModelEditView: View { maxTokensTextField } + + @ViewBuilder + var ollama: some View { + baseURLTextField(prompt: Text("http://127.0.0.1:11434")) { + Text("/api/embeddings") + } + + WithViewStore( + store, + removeDuplicates: { $0.modelName == $1.modelName } + ) { viewStore in + TextField("Model Name", text: viewStore.$modelName) + } + + maxTokensTextField + + WithViewStore( + store, + removeDuplicates: { $0.ollamaKeepAlive == $1.ollamaKeepAlive } + ) { viewStore in + TextField(text: viewStore.$ollamaKeepAlive, prompt: Text("Default Value")) { + Text("Keep Alive") + } + } + + VStack(alignment: .leading, spacing: 8) { + Text(Image(systemName: "exclamationmark.triangle.fill")) + Text( + " For more details, please visit [https://ollama.com](https://ollama.com)." + ) + } + .padding(.vertical) + } } class EmbeddingModelManagementView_Editing_Previews: PreviewProvider { diff --git a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelManagement.swift b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelManagement.swift index eda907d3..71b0d4a5 100644 --- a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelManagement.swift +++ b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelManagement.swift @@ -10,6 +10,7 @@ extension EmbeddingModel: ManageableAIModel { case .openAI: return "OpenAI" case .azureOpenAI: return "Azure OpenAI" case .openAICompatible: return "OpenAI Compatible" + case .ollama: return "Ollama" } } diff --git a/Core/Sources/SuggestionWidget/WidgetWindowsController.swift b/Core/Sources/SuggestionWidget/WidgetWindowsController.swift index 41a40820..454c79ee 100644 --- a/Core/Sources/SuggestionWidget/WidgetWindowsController.swift +++ b/Core/Sources/SuggestionWidget/WidgetWindowsController.swift @@ -93,6 +93,7 @@ actor WidgetWindowsController: NSObject { let xcodeInspector = self.xcodeInspector let activeApp = await xcodeInspector.safe.activeApplication let latestActiveXcode = await xcodeInspector.safe.latestActiveXcode + let previousActiveApplication = xcodeInspector.previousActiveApplication await MainActor.run { let state = store.withState { $0 } let isChatPanelDetached = state.chatPanelState.chatPanelInASeparateWindow @@ -123,9 +124,17 @@ actor WidgetWindowsController: NSObject { return true }() + let previousAppIsXcode = previousActiveApplication?.isXcode ?? false + windows.sharedPanelWindow.alphaValue = noFocus ? 0 : 1 windows.suggestionPanelWindow.alphaValue = noFocus ? 0 : 1 - windows.widgetWindow.alphaValue = noFocus ? 0 : 1 + windows.widgetWindow.alphaValue = if noFocus { + 0 + } else if previousAppIsXcode { + 1 + } else { + 0 + } windows.toastWindow.alphaValue = noFocus ? 0 : 1 if isChatPanelDetached { windows.chatPanelWindow.isWindowHidden = !hasChat diff --git a/ExtensionService/AppDelegate+Menu.swift b/ExtensionService/AppDelegate+Menu.swift index e655522d..e5567c27 100644 --- a/ExtensionService/AppDelegate+Menu.swift +++ b/ExtensionService/AppDelegate+Menu.swift @@ -20,6 +20,7 @@ extension AppDelegate { .init("sourceEditorDebugMenu") } + @MainActor @objc func buildStatusBarMenu() { let statusBar = NSStatusBar.system statusBarItem = statusBar.statusItem( diff --git a/Tool/Sources/AIModel/ChatModel.swift b/Tool/Sources/AIModel/ChatModel.swift index 344c996b..325592a7 100644 --- a/Tool/Sources/AIModel/ChatModel.swift +++ b/Tool/Sources/AIModel/ChatModel.swift @@ -21,9 +21,28 @@ public struct ChatModel: Codable, Equatable, Identifiable { case azureOpenAI case openAICompatible case googleAI + case ollama } public struct Info: Codable, Equatable { + public struct OllamaInfo: Codable, Equatable { + @FallbackDecoding + public var keepAlive: String + + public init(keepAlive: String = "") { + self.keepAlive = keepAlive + } + } + + public struct OpenAIInfo: Codable, Equatable { + @FallbackDecoding + public var organizationID: String + + public init(organizationID: String = "") { + self.organizationID = organizationID + } + } + @FallbackDecoding public var apiKeyName: String @FallbackDecoding @@ -34,14 +53,13 @@ public struct ChatModel: Codable, Equatable, Identifiable { public var maxTokens: Int @FallbackDecoding public var supportsFunctionCalling: Bool - @FallbackDecoding - public var supportsOpenAIAPI2023_11: Bool @FallbackDecoding public var modelName: String - public var azureOpenAIDeploymentName: String { - get { modelName } - set { modelName = newValue } - } + + @FallbackDecoding + public var openAIInfo: OpenAIInfo + @FallbackDecoding + public var ollamaInfo: OllamaInfo public init( apiKeyName: String = "", @@ -49,16 +67,18 @@ public struct ChatModel: Codable, Equatable, Identifiable { isFullURL: Bool = false, maxTokens: Int = 4000, supportsFunctionCalling: Bool = true, - supportsOpenAIAPI2023_11: Bool = false, - modelName: String = "" + modelName: String = "", + openAIInfo: OpenAIInfo = OpenAIInfo(), + ollamaInfo: OllamaInfo = OllamaInfo() ) { self.apiKeyName = apiKeyName self.baseURL = baseURL self.isFullURL = isFullURL self.maxTokens = maxTokens self.supportsFunctionCalling = supportsFunctionCalling - self.supportsOpenAIAPI2023_11 = supportsOpenAIAPI2023_11 self.modelName = modelName + self.openAIInfo = openAIInfo + self.ollamaInfo = ollamaInfo } } @@ -75,7 +95,7 @@ public struct ChatModel: Codable, Equatable, Identifiable { return "\(baseURL)/v1/chat/completions" case .azureOpenAI: let baseURL = info.baseURL - let deployment = info.azureOpenAIDeploymentName + let deployment = info.modelName let version = "2024-02-15-preview" if baseURL.isEmpty { return "" } return "\(baseURL)/openai/deployments/\(deployment)/chat/completions?api-version=\(version)" @@ -83,6 +103,10 @@ public struct ChatModel: Codable, Equatable, Identifiable { let baseURL = info.baseURL if baseURL.isEmpty { return "https://generativelanguage.googleapis.com/v1" } return "\(baseURL)/v1/chat/completions" + case .ollama: + let baseURL = info.baseURL + if baseURL.isEmpty { return "http://localhost:11434/api/chat" } + return "\(baseURL)/api/chat" } } } @@ -95,3 +119,11 @@ public struct EmptyChatModelFormat: FallbackValueProvider { public static var defaultValue: ChatModel.Format { .openAI } } +public struct EmptyChatModelOllamaInfo: FallbackValueProvider { + public static var defaultValue: ChatModel.Info.OllamaInfo { .init() } +} + +public struct EmptyChatModelOpenAIInfo: FallbackValueProvider { + public static var defaultValue: ChatModel.Info.OpenAIInfo { .init() } +} + diff --git a/Tool/Sources/AIModel/EmbeddingModel.swift b/Tool/Sources/AIModel/EmbeddingModel.swift index c942be9a..d86650fa 100644 --- a/Tool/Sources/AIModel/EmbeddingModel.swift +++ b/Tool/Sources/AIModel/EmbeddingModel.swift @@ -1,5 +1,5 @@ -import Foundation import CodableWrappers +import Foundation public struct EmbeddingModel: Codable, Equatable, Identifiable { public var id: String @@ -20,9 +20,13 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { case openAI case azureOpenAI case openAICompatible + case ollama } public struct Info: Codable, Equatable { + public typealias OllamaInfo = ChatModel.Info.OllamaInfo + public typealias OpenAIInfo = ChatModel.Info.OpenAIInfo + @FallbackDecoding public var apiKeyName: String @FallbackDecoding @@ -35,10 +39,11 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { public var dimensions: Int @FallbackDecoding public var modelName: String - public var azureOpenAIDeploymentName: String { - get { modelName } - set { modelName = newValue } - } + + @FallbackDecoding + public var openAIInfo: OpenAIInfo + @FallbackDecoding + public var ollamaInfo: OllamaInfo public init( apiKeyName: String = "", @@ -46,7 +51,9 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { isFullURL: Bool = false, maxTokens: Int = 8192, dimensions: Int = 1536, - modelName: String = "" + modelName: String = "", + openAIInfo: OpenAIInfo = OpenAIInfo(), + ollamaInfo: OllamaInfo = OllamaInfo() ) { self.apiKeyName = apiKeyName self.baseURL = baseURL @@ -54,9 +61,11 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { self.maxTokens = maxTokens self.dimensions = dimensions self.modelName = modelName + self.openAIInfo = openAIInfo + self.ollamaInfo = ollamaInfo } } - + public var endpoint: String { switch format { case .openAI: @@ -70,15 +79,18 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { return "\(baseURL)/v1/embeddings" case .azureOpenAI: let baseURL = info.baseURL - let deployment = info.azureOpenAIDeploymentName + let deployment = info.modelName let version = "2024-02-15-preview" if baseURL.isEmpty { return "" } return "\(baseURL)/openai/deployments/\(deployment)/embeddings?api-version=\(version)" + case .ollama: + let baseURL = info.baseURL + if baseURL.isEmpty { return "http://localhost:11434/api/embeddings" } + return "\(baseURL)/api/embeddings" } } } - public struct EmptyEmbeddingModelInfo: FallbackValueProvider { public static var defaultValue: EmbeddingModel.Info { .init() } } @@ -86,3 +98,4 @@ public struct EmptyEmbeddingModelInfo: FallbackValueProvider { public struct EmptyEmbeddingModelFormat: FallbackValueProvider { public static var defaultValue: EmbeddingModel.Format { .openAI } } + diff --git a/Tool/Sources/CodeiumService/CodeiumInstallationManager.swift b/Tool/Sources/CodeiumService/CodeiumInstallationManager.swift index 7574c31b..9ea25108 100644 --- a/Tool/Sources/CodeiumService/CodeiumInstallationManager.swift +++ b/Tool/Sources/CodeiumService/CodeiumInstallationManager.swift @@ -3,7 +3,7 @@ import Terminal public struct CodeiumInstallationManager { private static var isInstalling = false - static let latestSupportedVersion = "1.6.9" + static let latestSupportedVersion = "1.8.5" public init() {} diff --git a/Tool/Sources/GitHubCopilotService/GitHubCopilotInstallationManager.swift b/Tool/Sources/GitHubCopilotService/GitHubCopilotInstallationManager.swift index 92877938..f7013f08 100644 --- a/Tool/Sources/GitHubCopilotService/GitHubCopilotInstallationManager.swift +++ b/Tool/Sources/GitHubCopilotService/GitHubCopilotInstallationManager.swift @@ -10,7 +10,7 @@ public struct GitHubCopilotInstallationManager { return URL(string: link)! } - static let latestSupportedVersion = "1.19.2" + static let latestSupportedVersion = "1.25.0" public init() {} diff --git a/Tool/Sources/GitHubCopilotService/GitHubCopilotService.swift b/Tool/Sources/GitHubCopilotService/GitHubCopilotService.swift index 01225c22..e18eb58a 100644 --- a/Tool/Sources/GitHubCopilotService/GitHubCopilotService.swift +++ b/Tool/Sources/GitHubCopilotService/GitHubCopilotService.swift @@ -169,11 +169,11 @@ public class GitHubCopilotBaseService { Task { [weak self] in _ = try? await server.sendRequest(GitHubCopilotRequest.SetEditorInfo()) - for await notification in NotificationCenter.default + for await _ in NotificationCenter.default .notifications(named: .gitHubCopilotShouldRefreshEditorInformation) { print("Yes!") - guard let self else { return } + guard self != nil else { return } _ = try? await server.sendRequest(GitHubCopilotRequest.SetEditorInfo()) } } diff --git a/Tool/Sources/LangChain/Chains/LLMChain.swift b/Tool/Sources/LangChain/Chains/LLMChain.swift index 1201bd45..2ba4aef4 100644 --- a/Tool/Sources/LangChain/Chains/LLMChain.swift +++ b/Tool/Sources/LangChain/Chains/LLMChain.swift @@ -33,10 +33,11 @@ public class ChatModelChain: Chain { public func parseOutput(_ output: Output) -> String { if let content = output.content { return content - } else if let functionCall = output.functionCall { - return "\(functionCall.name): \(functionCall.arguments)" + } else if let toolCalls = output.toolCalls { + return toolCalls.map { "[\($0.id)] \($0.function.name): \($0.function.arguments)" } + .joined(separator: "\n") } - + return "" } } diff --git a/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift b/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift index 3c38eb3a..3b24e6ad 100644 --- a/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift +++ b/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift @@ -42,7 +42,7 @@ public final class RefineDocumentChain: Chain { } class FunctionProvider: ChatGPTFunctionProvider { - var functionCallStrategy: FunctionCallStrategy? = .name("respond") + var functionCallStrategy: FunctionCallStrategy? = .function(name: "respond") var functions: [any ChatGPTFunction] = [RespondFunction()] } @@ -153,7 +153,7 @@ public final class RefineDocumentChain: Chain { } func extractAnswer(_ chatMessage: ChatMessage) -> IntermediateAnswer { - if let functionCall = chatMessage.functionCall { + for functionCall in chatMessage.toolCalls?.map(\.function) ?? [] { do { let intermediateAnswer = try JSONDecoder().decode( IntermediateAnswer.self, diff --git a/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift b/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift index d55cafd4..4c9f696a 100644 --- a/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift +++ b/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift @@ -15,7 +15,7 @@ public final class RelevantInformationExtractionChain: Chain { public typealias Output = String class FunctionProvider: ChatGPTFunctionProvider { - var functionCallStrategy: FunctionCallStrategy? = .name("saveFinalAnswer") + var functionCallStrategy: FunctionCallStrategy? = .function(name: "saveFinalAnswer") var functions: [any ChatGPTFunction] = [FinalAnswer()] } @@ -103,8 +103,10 @@ public final class RelevantInformationExtractionChain: Chain { taskInput, callbackManagers: callbackManagers ) - - if let functionCall = output.functionCall { + + if let functionCall = output.toolCalls? + .first(where: { $0.function.name == FinalAnswer().name })?.function + { do { let arguments = try JSONDecoder().decode( FinalAnswer.Arguments.self, @@ -118,7 +120,7 @@ public final class RelevantInformationExtractionChain: Chain { return output.content ?? "" } } - + return output.content ?? "" } diff --git a/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift b/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift index 9c938cce..6ea1dbb5 100644 --- a/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift +++ b/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift @@ -53,7 +53,7 @@ public class StructuredOutputChatModelChain: Chain { } var functionCallStrategy: FunctionCallStrategy? { - .name(endFunction.name) + .function(name: endFunction.name) } } @@ -108,7 +108,7 @@ public class StructuredOutputChatModelChain: Chain { } public func parseOutput(_ message: ChatMessage) async -> Output? { - if let functionCall = message.functionCall { + if let functionCall = message.toolCalls?.first?.function { do { let result = try JSONDecoder().decode( EndFunction.Arguments.self, diff --git a/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift new file mode 100644 index 00000000..a86aba7b --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift @@ -0,0 +1,206 @@ +import AIModel +import CodableWrappers +import Foundation +import Preferences + +struct ChatCompletionsRequestBody: Codable, Equatable { + struct Message: Codable, Equatable { + enum Role: String, Codable, Equatable { + case system + case user + case assistant + case tool + + var asChatMessageRole: ChatMessage.Role { + switch self { + case .system: + return .system + case .user: + return .user + case .assistant: + return .assistant + case .tool: + return .user + } + } + } + + /// The role of the message. + var role: Role + /// The content of the message. + var content: String + /// When we want to reply to a function call with the result, we have to provide the + /// name of the function call, and include the result in `content`. + /// + /// - important: It's required when the role is `function`. + var name: String? + /// Tool calls in an assistant message. + var toolCalls: [MessageToolCall]? + /// When we want to call a tool, we have to provide the id of the call. + /// + /// - important: It's required when the role is `tool`. + var toolCallId: String? + } + + struct MessageFunctionCall: Codable, Equatable { + /// The name of the + var name: String + /// A JSON string. + var arguments: String? + } + + struct MessageToolCall: Codable, Equatable { + /// The id of the tool call. + var id: String + /// The type of the tool. + var type: String + /// The function call. + var function: MessageFunctionCall + } + + struct Tool: Codable, Equatable { + var type: String = "function" + var function: ChatGPTFunctionSchema + } + + var model: String + var messages: [Message] + var temperature: Double? + var stream: Bool? + var stop: [String]? + var maxTokens: Int? + /// Pass nil to let the bot decide. + var toolChoice: FunctionCallStrategy? + var tools: [Tool]? + + init( + model: String, + messages: [Message], + temperature: Double? = nil, + stream: Bool? = nil, + stop: [String]? = nil, + maxTokens: Int? = nil, + toolChoice: FunctionCallStrategy? = nil, + tools: [Tool] = [] + ) { + self.model = model + self.messages = messages + self.temperature = temperature + self.stream = stream + self.stop = stop + self.maxTokens = maxTokens + if UserDefaults.shared.value(for: \.disableFunctionCalling) { + self.toolChoice = nil + self.tools = nil + } else { + self.toolChoice = toolChoice + self.tools = tools.isEmpty ? nil : tools + } + } +} + +struct EmptyMessageFunctionCall: FallbackValueProvider { + static var defaultValue: ChatCompletionsRequestBody.MessageFunctionCall { + .init(name: "") + } +} + +public enum FunctionCallStrategy: Codable, Equatable { + /// Forbid the bot to call any function. + case none + /// Let the bot choose what function to call. + case auto + /// Force the bot to call a function with the given name. + case function(name: String) + + struct CallFunctionNamed: Codable { + var type = "function" + let function: Function + struct Function: Codable { + var name: String + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .none: + try container.encode("none") + case .auto: + try container.encode("auto") + case let .function(name): + try container.encode(CallFunctionNamed(function: .init(name: name))) + } + } +} + +// MARK: - Stream API + +protocol ChatCompletionsStreamAPI { + func callAsFunction() async throws -> AsyncThrowingStream +} + +extension AsyncSequence { + func toStream() -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let task = Task { + do { + for try await element in self { + continuation.yield(element) + } + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } + + continuation.onTermination = { _ in + task.cancel() + } + } + } +} + +struct ChatCompletionsStreamDataChunk { + struct Delta { + struct FunctionCall { + var name: String? + var arguments: String? + } + + struct ToolCall { + var index: Int? + var id: String? + var type: String? + var function: FunctionCall? + } + + var role: ChatCompletionsRequestBody.Message.Role? + var content: String? + var toolCalls: [ToolCall]? + } + + var id: String? + var object: String? + var model: String? + var message: Delta? + var finishReason: String? +} + +// MARK: - Non Stream API + +protocol ChatCompletionsAPI { + func callAsFunction() async throws -> ChatCompletionResponseBody +} + +struct ChatCompletionResponseBody: Codable, Equatable { + typealias Message = ChatCompletionsRequestBody.Message + + var id: String? + var object: String + var model: String + var message: Message + var otherChoices: [Message] + var finishReason: String +} + diff --git a/Tool/Sources/OpenAIService/APIs/EmbeddingAPIDefinitions.swift b/Tool/Sources/OpenAIService/APIs/EmbeddingAPIDefinitions.swift new file mode 100644 index 00000000..0715e0f1 --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/EmbeddingAPIDefinitions.swift @@ -0,0 +1,28 @@ +import AIModel +import Foundation +import Preferences + +protocol EmbeddingAPI { + func embed(text: String) async throws -> EmbeddingResponse + func embed(texts: [String]) async throws -> EmbeddingResponse + func embed(tokens: [[Int]]) async throws -> EmbeddingResponse +} + +public struct EmbeddingResponse: Decodable { + public struct Object: Decodable { + public var embedding: [Float] + public var index: Int + public var object: String + } + + public var data: [Object] + public var model: String + + public struct Usage: Decodable { + public var prompt_tokens: Int + public var total_tokens: Int + } + + public var usage: Usage +} + diff --git a/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift new file mode 100644 index 00000000..80a5c9b8 --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift @@ -0,0 +1,286 @@ +import AIModel +import Foundation +import GoogleGenerativeAI +import Preferences + +actor GoogleAIChatCompletionsService: ChatCompletionsAPI, ChatCompletionsStreamAPI { + let apiKey: String + let model: ChatModel + var requestBody: ChatCompletionsRequestBody + let prompt: ChatGPTPrompt + + init( + apiKey: String, + model: ChatModel, + requestBody: ChatCompletionsRequestBody, + prompt: ChatGPTPrompt + ) { + self.apiKey = apiKey + self.model = model + self.requestBody = requestBody + self.prompt = prompt + } + + func callAsFunction() async throws -> ChatCompletionResponseBody { + let aiModel = GenerativeModel( + name: model.info.modelName, + apiKey: apiKey, + generationConfig: .init(GenerationConfig( + temperature: requestBody.temperature.map(Float.init) + )) + ) + let history = prompt.googleAICompatible.history.map { message in + ModelContent(message) + } + + do { + let response = try await aiModel.generateContent(history) + return response.formalized() + } catch let error as GenerateContentError { + struct ErrorWrapper: Error, LocalizedError { + let error: Error + var errorDescription: String? { + var s = "" + dump(error, to: &s) + return "Internal Error: \(s)" + } + } + + switch error { + case let .internalError(underlying): + throw ErrorWrapper(error: underlying) + case .promptBlocked: + throw error + case .responseStoppedEarly: + throw error + } + } catch { + throw error + } + } + + func callAsFunction() async throws + -> AsyncThrowingStream + { + let aiModel = GenerativeModel( + name: model.info.modelName, + apiKey: apiKey, + generationConfig: .init(GenerationConfig( + temperature: requestBody.temperature.map(Float.init) + )) + ) + let history = prompt.googleAICompatible.history.map { message in + ModelContent(message) + } + + let stream = AsyncThrowingStream { continuation in + let stream = aiModel.generateContentStream(history) + let task = Task { + do { + for try await response in stream { + if Task.isCancelled { break } + let chunk = response.formalizedAsChunk() + continuation.yield(chunk) + } + continuation.finish() + } catch let error as GenerateContentError { + struct ErrorWrapper: Error, LocalizedError { + let error: Error + var errorDescription: String? { + var s = "" + dump(error, to: &s) + return "Internal Error: \(s)" + } + } + + switch error { + case let .internalError(underlying): + continuation.finish(throwing: ErrorWrapper(error: underlying)) + case .promptBlocked: + continuation.finish(throwing: error) + case .responseStoppedEarly: + continuation.finish(throwing: error) + } + } catch { + continuation.finish(throwing: error) + } + } + continuation.onTermination = { _ in + task.cancel() + } + } + + return stream + } +} + +extension ChatGPTPrompt { + var googleAICompatible: ChatGPTPrompt { + var history = self.history + var reformattedHistory = [ChatMessage]() + + // We don't want to combine the new user message with others. + let newUserMessage: ChatMessage? = if history.last?.role == .user { + history.removeLast() + } else { + nil + } + + for message in history { + let lastIndex = reformattedHistory.endIndex - 1 + guard lastIndex >= 0 else { // first message + if message.role == .system { + reformattedHistory.append(.init( + id: message.id, + role: .user, + content: ModelContent.convertContent(of: message) + )) + reformattedHistory.append(.init( + role: .assistant, + content: "Got it. Let's start our conversation." + )) + continue + } + + reformattedHistory.append(message) + continue + } + + let lastMessage = reformattedHistory[lastIndex] + + if ModelContent.convertRole(lastMessage.role) == ModelContent + .convertRole(message.role) + { + let newMessage = ChatMessage( + id: message.id, + role: message.role == .assistant ? .assistant : .user, + content: """ + \(ModelContent.convertContent(of: lastMessage)) + + ====== + + \(ModelContent.convertContent(of: message)) + """ + ) + reformattedHistory[lastIndex] = newMessage + } else { + reformattedHistory.append(message) + } + } + + if let newUserMessage { + if let last = reformattedHistory.last, + ModelContent.convertRole(last.role) == ModelContent + .convertRole(newUserMessage.role) + { + // Add dummy message + let dummyMessage = ChatMessage( + role: .assistant, + content: "OK" + ) + reformattedHistory.append(dummyMessage) + } + reformattedHistory.append(newUserMessage) + } + + return .init( + history: reformattedHistory, + references: references, + remainingTokenCount: remainingTokenCount + ) + } +} + +extension ModelContent { + static func convertRole(_ role: ChatMessage.Role) -> String { + switch role { + case .user, .system: + return "user" + case .assistant: + return "model" + } + } + + static func convertContent(of message: ChatMessage) -> String { + switch message.role { + case .system: + return "System Prompt:\n\(message.content ?? " ")" + case .user: + return message.content ?? " " + case .assistant: + if let toolCalls = message.toolCalls { + return toolCalls.map { call in + let response = call.response + return """ + Call function: \(call.function.name) + Arguments: \(call.function.arguments) + Result: \(response.content) + """ + }.joined(separator: "\n") + } else { + return message.content ?? " " + } + } + } + + init(_ message: ChatMessage) { + let role = Self.convertRole(message.role) + let parts = [ModelContent.Part.text(Self.convertContent(of: message))] + self = .init(role: role, parts: parts) + } +} + +extension GenerateContentResponse { + func formalized() -> ChatCompletionResponseBody { + let message: ChatCompletionResponseBody.Message + let otherMessages: [ChatCompletionResponseBody.Message] + + func convertMessage(_ candidate: CandidateResponse) -> ChatCompletionResponseBody.Message { + .init( + role: .assistant, + content: candidate.content.parts.first(where: { part in + if let text = part.text { + return !text.isEmpty + } else { + return false + } + })?.text ?? "" + ) + } + + if let first = candidates.first { + message = convertMessage(first) + otherMessages = candidates.dropFirst().map { convertMessage($0) } + } else { + message = .init(role: .assistant, content: "") + otherMessages = [] + } + + return .init( + object: "chat.completion", + model: "", + message: message, + otherChoices: otherMessages, + finishReason: candidates.first?.finishReason?.rawValue ?? "" + ) + } + + func formalizedAsChunk() -> ChatCompletionsStreamDataChunk { + func convertMessage( + _ candidate: CandidateResponse + ) -> ChatCompletionsStreamDataChunk.Delta { + .init( + role: .assistant, + content: candidate.content.parts + .first(where: { $0.text != nil })?.text ?? "" + ) + } + + return .init( + object: "", + model: "", + message: candidates.first.map(convertMessage) + ) + } +} + diff --git a/Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift b/Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift deleted file mode 100644 index ded6e372..00000000 --- a/Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift +++ /dev/null @@ -1,198 +0,0 @@ -import AIModel -import Foundation -import GoogleGenerativeAI -import Preferences - -struct GoogleCompletionAPI: CompletionAPI { - let apiKey: String - let model: ChatModel - var requestBody: CompletionRequestBody - let prompt: ChatGPTPrompt - - func callAsFunction() async throws -> CompletionResponseBody { - let aiModel = GenerativeModel( - name: model.info.modelName, - apiKey: apiKey, - generationConfig: .init(GenerationConfig( - temperature: requestBody.temperature.map(Float.init), - topP: requestBody.top_p.map(Float.init) - )) - ) - let history = prompt.googleAICompatible.history.map { message in - ModelContent( - ChatMessage( - role: message.role, - content: message.content, - name: message.name, - functionCall: message.functionCall.map { - .init(name: $0.name, arguments: $0.arguments) - } - ) - ) - } - - do { - let response = try await aiModel.generateContent(history) - - return .init( - object: "chat.completion", - model: model.info.modelName, - usage: .init(prompt_tokens: 0, completion_tokens: 0, total_tokens: 0), - choices: response.candidates.enumerated().map { - let (index, candidate) = $0 - return .init( - message: .init( - role: .assistant, - content: candidate.content.parts.first(where: { part in - if let text = part.text { - return !text.isEmpty - } else { - return false - } - })?.text ?? "" - ), - index: index, - finish_reason: candidate.finishReason?.rawValue ?? "" - ) - } - ) - } catch let error as GenerateContentError { - struct ErrorWrapper: Error, LocalizedError { - let error: Error - var errorDescription: String? { - var s = "" - dump(error, to: &s) - return "Internal Error: \(s)" - } - } - - switch error { - case let .internalError(underlying): - throw ErrorWrapper(error: underlying) - case .promptBlocked: - throw error - case .responseStoppedEarly: - throw error - } - } catch { - throw error - } - } -} - -extension ChatGPTPrompt { - var googleAICompatible: ChatGPTPrompt { - var history = self.history - var reformattedHistory = [ChatMessage]() - - // We don't want to combine the new user message with others. - let newUserMessage: ChatMessage? = if history.last?.role == .user { - history.removeLast() - } else { - nil - } - - for message in history { - let lastIndex = reformattedHistory.endIndex - 1 - guard lastIndex >= 0 else { // first message - if message.role == .system { - reformattedHistory.append(.init( - id: message.id, - role: .user, - content: ModelContent.convertContent(of: message) - )) - reformattedHistory.append(.init( - role: .assistant, - content: "Got it. Let's start our conversation." - )) - continue - } - - reformattedHistory.append(message) - continue - } - - let lastMessage = reformattedHistory[lastIndex] - - if ModelContent.convertRole(lastMessage.role) == ModelContent - .convertRole(message.role) - { - let newMessage = ChatMessage( - id: message.id, - role: message.role == .assistant ? .assistant : .user, - content: """ - \(ModelContent.convertContent(of: lastMessage)) - - ====== - - \(ModelContent.convertContent(of: message)) - """ - ) - reformattedHistory[lastIndex] = newMessage - } else { - reformattedHistory.append(message) - } - } - - if let newUserMessage { - if let last = reformattedHistory.last, - ModelContent.convertRole(last.role) == ModelContent - .convertRole(newUserMessage.role) - { - // Add dummy message - let dummyMessage = ChatMessage( - role: .assistant, - content: "OK" - ) - reformattedHistory.append(dummyMessage) - } - reformattedHistory.append(newUserMessage) - } - - return .init( - history: reformattedHistory, - references: references, - remainingTokenCount: remainingTokenCount - ) - } -} - -extension ModelContent { - static func convertRole(_ role: ChatMessage.Role) -> String { - switch role { - case .user, .system, .function: - return "user" - case .assistant: - return "model" - } - } - - static func convertContent(of message: ChatMessage) -> String { - switch message.role { - case .system: - return "System Prompt:\n\(message.content ?? " ")" - case .user: - return message.content ?? " " - case .function: - return """ - Result of \(message.name ?? "function"): \(message.content ?? "N/A") - """ - case .assistant: - if let functionCall = message.functionCall { - return """ - Call function: \(functionCall.name) - Arguments: \(functionCall.arguments) - """ - } else { - return message.content ?? " " - } - } - } - - init(_ message: ChatMessage) { - let role = Self.convertRole(message.role) - let parts = [ModelContent.Part.text(Self.convertContent(of: message))] - self = .init(role: role, parts: parts) - } -} - diff --git a/Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift b/Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift deleted file mode 100644 index 47492340..00000000 --- a/Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift +++ /dev/null @@ -1,84 +0,0 @@ -import AIModel -import Foundation -import GoogleGenerativeAI -import Preferences - -struct GoogleCompletionStreamAPI: CompletionStreamAPI { - let apiKey: String - let model: ChatModel - var requestBody: CompletionRequestBody - let prompt: ChatGPTPrompt - - func callAsFunction() async throws -> AsyncThrowingStream { - let aiModel = GenerativeModel( - name: model.info.modelName, - apiKey: apiKey, - generationConfig: .init(GenerationConfig( - temperature: requestBody.temperature.map(Float.init), - topP: requestBody.top_p.map(Float.init) - )) - ) - let history = prompt.googleAICompatible.history.map { message in - ModelContent( - ChatMessage( - role: message.role, - content: message.content, - name: message.name, - functionCall: message.functionCall.map { - .init(name: $0.name, arguments: $0.arguments) - } - ) - ) - } - - let stream = AsyncThrowingStream { continuation in - let stream = aiModel.generateContentStream(history) - let task = Task { - do { - for try await response in stream { - if Task.isCancelled { break } - let chunk = CompletionStreamDataChunk( - object: "", - model: model.info.modelName, - choices: response.candidates.map { candidate in - .init(delta: .init( - role: .assistant, - content: candidate.content.parts - .first(where: { $0.text != nil })?.text ?? "" - )) - } - ) - continuation.yield(chunk) - } - continuation.finish() - } catch let error as GenerateContentError { - struct ErrorWrapper: Error, LocalizedError { - let error: Error - var errorDescription: String? { - var s = "" - dump(error, to: &s) - return "Internal Error: \(s)" - } - } - - switch error { - case let .internalError(underlying): - continuation.finish(throwing: ErrorWrapper(error: underlying)) - case .promptBlocked: - continuation.finish(throwing: error) - case .responseStoppedEarly: - continuation.finish(throwing: error) - } - } catch { - continuation.finish(throwing: error) - } - } - continuation.onTermination = { _ in - task.cancel() - } - } - - return stream - } -} - diff --git a/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift new file mode 100644 index 00000000..e2ef4d5a --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift @@ -0,0 +1,248 @@ +import AIModel +import Foundation +import Preferences + +public actor OllamaChatCompletionsService { + var apiKey: String + var endpoint: URL + var requestBody: ChatCompletionsRequestBody + var model: ChatModel + + public enum ResponseFormat: String { + case none = "" + case json + } + + init( + apiKey: String, + model: ChatModel, + endpoint: URL, + requestBody: ChatCompletionsRequestBody + ) { + self.apiKey = apiKey + self.endpoint = endpoint + self.requestBody = requestBody + self.model = model + } +} + +extension OllamaChatCompletionsService: ChatCompletionsAPI { + func callAsFunction() async throws -> ChatCompletionResponseBody { + let requestBody = ChatCompletionRequestBody( + model: model.info.modelName, + messages: requestBody.messages.map { message in + .init(role: { + switch message.role { + case .assistant: + return .assistant + case .user: + return .user + case .system: + return .system + case .tool: + return .user + } + }(), content: message.content) + }, + stream: false, + options: .init( + temperature: requestBody.temperature, + stop: requestBody.stop, + num_predict: requestBody.maxTokens + ), + keep_alive: nil, + format: nil + ) + + var request = URLRequest(url: endpoint) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(requestBody) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + let (result, response) = try await URLSession.shared.data(for: request) + + guard let response = response as? HTTPURLResponse else { + throw CancellationError() + } + + guard response.statusCode == 200 else { + let text = String(data: result, encoding: .utf8) + throw Error.otherError(text ?? "Unknown error") + } + + let body = try JSONDecoder().decode( + ChatCompletionResponseChunk.self, + from: result + ) + + return .init( + object: body.model, + model: body.model, + message: body.message.map { message in + .init( + role: { + switch message.role { + case .assistant: + return .assistant + case .user: + return .user + case .system: + return .system + } + }(), + content: message.content + ) + } ?? .init(role: .assistant, content: ""), + otherChoices: [], + finishReason: "" + ) + } +} + +extension OllamaChatCompletionsService: ChatCompletionsStreamAPI { + func callAsFunction() async throws + -> AsyncThrowingStream + { + let requestBody = ChatCompletionRequestBody( + model: model.info.modelName, + messages: requestBody.messages.map { message in + .init(role: { + switch message.role { + case .assistant: + return .assistant + case .user: + return .user + case .system: + return .system + case .tool: + return .user + } + }(), content: message.content) + }, + stream: true, + options: .init( + temperature: requestBody.temperature, + stop: requestBody.stop, + num_predict: requestBody.maxTokens + ), + keep_alive: nil, + format: nil + ) + + var request = URLRequest(url: endpoint) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(requestBody) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + let (result, response) = try await URLSession.shared.bytes(for: request) + + guard let response = response as? HTTPURLResponse else { + throw CancellationError() + } + + guard response.statusCode == 200 else { + let text = try await result.lines.reduce(into: "") { partialResult, current in + partialResult += current + } + throw Error.otherError(text) + } + + let stream = ResponseStream(result: result) { + let chunk = try JSONDecoder().decode( + ChatCompletionResponseChunk.self, + from: $0.data(using: .utf8) ?? Data() + ) + return .init(chunk: chunk, done: chunk.done) + } + + let sequence = stream.map { chunk in + ChatCompletionsStreamDataChunk( + id: UUID().uuidString, + object: chunk.model, + model: chunk.model, + message: .init( + role: { + switch chunk.message?.role { + case .none: + return nil + case .assistant: + return .assistant + case .user: + return .user + case .system: + return .system + } + }(), + content: chunk.message?.content + ) + ) + } + + return sequence.toStream() + } +} + +extension OllamaChatCompletionsService { + struct Message: Codable, Equatable { + public enum Role: String, Codable { + case user + case assistant + case system + } + + /// The role of the message. + public var role: Role + /// The content of the message. + public var content: String + } + + enum Error: Swift.Error, LocalizedError { + case decodeError(Swift.Error) + case otherError(String) + + public var errorDescription: String? { + switch self { + case let .decodeError(error): + return error.localizedDescription + case let .otherError(message): + return message + } + } + } +} + +// MARK: - Chat Completion API + +/// https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-streaming +extension OllamaChatCompletionsService { + struct ChatCompletionRequestBody: Codable { + struct Options: Codable { + var temperature: Double? + var stop: [String]? + var num_predict: Int? + var top_k: Int? + var top_p: Double? + } + + var model: String + var messages: [Message] + var stream: Bool + var options: Options + var keep_alive: String? + var format: String? + } + + struct ChatCompletionResponseChunk: Decodable { + var model: String + var message: Message? + var response: String? + var done: Bool + var total_duration: Int64? + var load_duration: Int64? + var prompt_eval_count: Int? + var prompt_eval_duration: Int64? + var eval_count: Int? + var eval_duration: Int64? + } +} + diff --git a/Tool/Sources/OpenAIService/APIs/OllamaEmbeddingService.swift b/Tool/Sources/OpenAIService/APIs/OllamaEmbeddingService.swift new file mode 100644 index 00000000..dfd170cc --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/OllamaEmbeddingService.swift @@ -0,0 +1,92 @@ +import AIModel +import Foundation +import Logger + +struct OllamaEmbeddingService: EmbeddingAPI { + struct EmbeddingRequestBody: Encodable { + var prompt: String + var model: String + } + + struct ResponseBody: Decodable { + var embedding: [Float] + } + + let model: EmbeddingModel + let endpoint: String + + public func embed(text: String) async throws -> EmbeddingResponse { + guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect } + var request = URLRequest(url: url) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(EmbeddingRequestBody( + prompt: text, + model: model.info.modelName + )) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + + let (result, response) = try await URLSession.shared.data(for: request) + guard let response = response as? HTTPURLResponse else { + throw ChatGPTServiceError.responseInvalid + } + + guard response.statusCode == 200 else { + let error = try? JSONDecoder().decode( + OpenAIChatCompletionsService.CompletionAPIError.self, + from: result + ) + throw error ?? ChatGPTServiceError + .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") + } + + let embeddingResponse = try JSONDecoder().decode(ResponseBody.self, from: result) + #if DEBUG + Logger.service.info(""" + Embedding usage + - number of strings: 1 + - prompt tokens: N/A + - total tokens: N/A + + """) + #endif + return .init( + data: [.init( + embedding: embeddingResponse.embedding, + index: 0, + object: model.info.modelName + )], + model: model.info.modelName, + usage: .init(prompt_tokens: 0, total_tokens: 0) + ) + } + + public func embed(texts: [String]) async throws -> EmbeddingResponse { + try await withThrowingTaskGroup(of: EmbeddingResponse.self) { group in + for text in texts { + _ = group.addTaskUnlessCancelled { + try await self.embed(text: text) + } + } + + var result = EmbeddingResponse( + data: [], + model: model.info.modelName, + usage: .init(prompt_tokens: 0, total_tokens: 0) + ) + + for try await response in group { + result.data.append(contentsOf: response.data) + result.usage.prompt_tokens += response.usage.prompt_tokens + result.usage.total_tokens += response.usage.total_tokens + } + + return result + } + } + + public func embed(tokens: [[Int]]) async throws -> EmbeddingResponse { + throw CancellationError() + } +} + diff --git a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift new file mode 100644 index 00000000..96e1a97a --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift @@ -0,0 +1,500 @@ +import AIModel +import AsyncAlgorithms +import Foundation +import Logger +import Preferences + +/// https://platform.openai.com/docs/api-reference/chat/create +actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI { + struct CompletionAPIError: Error, Decodable, LocalizedError { + struct ErrorDetail: Decodable { + var message: String + var type: String? + var param: String? + var code: String? + } + + struct MistralAIErrorMessage: Decodable { + struct Detail: Decodable { + var msg: String? + } + + var message: String? + var msg: String? + var detail: [Detail]? + } + + enum Message { + case raw(String) + case mistralAI(MistralAIErrorMessage) + } + + var error: ErrorDetail? + var message: Message + + var errorDescription: String? { + if let message = error?.message { return message } + switch message { + case let .raw(string): + return string + case let .mistralAI(mistralAIErrorMessage): + return mistralAIErrorMessage.message + ?? mistralAIErrorMessage.msg + ?? mistralAIErrorMessage.detail?.first?.msg + ?? "Unknown Error" + } + } + + enum CodingKeys: String, CodingKey { + case error + case message + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + + do { + error = try container.decode(ErrorDetail.self, forKey: .error) + } catch { + print(error) + self.error = nil + } + message = { + if let e = try? container.decode(MistralAIErrorMessage.self, forKey: .message) { + return CompletionAPIError.Message.mistralAI(e) + } + if let e = try? container.decode(String.self, forKey: .message) { + return .raw(e) + } + return .raw("Unknown Error") + }() + } + } + + enum MessageRole: String, Codable { + case system + case user + case assistant + case function + case tool + + var formalized: ChatCompletionsRequestBody.Message.Role { + switch self { + case .system: return .system + case .user: return .user + case .assistant: return .assistant + case .function: return .tool + case .tool: return .tool + } + } + } + + struct StreamDataChunk: Codable { + var id: String? + var object: String? + var model: String? + var choices: [Choice]? + + struct Choice: Codable { + var delta: Delta? + var index: Int? + var finish_reason: String? + + struct Delta: Codable { + var role: MessageRole? + var content: String? + var function_call: RequestBody.MessageFunctionCall? + var tool_calls: [RequestBody.MessageToolCall]? + } + } + } + + struct ResponseBody: Codable, Equatable { + struct Message: Codable, Equatable { + /// The role of the message. + var role: MessageRole + /// The content of the message. + var content: String? + /// When we want to reply to a function call with the result, we have to provide the + /// name of the function call, and include the result in `content`. + /// + /// - important: It's required when the role is `function`. + var name: String? + /// When the bot wants to call a function, it will reply with a function call in format: + /// ```json + /// { + /// "name": "weather", + /// "arguments": "{ \"location\": \"earth\" }" + /// } + /// ``` + var function_call: RequestBody.MessageFunctionCall? + /// Tool calls in an assistant message. + var tool_calls: [RequestBody.MessageToolCall]? + } + + struct Choice: Codable, Equatable { + var message: Message + var index: Int + var finish_reason: String + } + + struct Usage: Codable, Equatable { + var prompt_tokens: Int + var completion_tokens: Int + var total_tokens: Int + } + + var id: String? + var object: String + var model: String + var usage: Usage + var choices: [Choice] + } + + struct RequestBody: Codable, Equatable { + struct Message: Codable, Equatable { + /// The role of the message. + var role: MessageRole + /// The content of the message. + var content: String + /// When we want to reply to a function call with the result, we have to provide the + /// name of the function call, and include the result in `content`. + /// + /// - important: It's required when the role is `function`. + var name: String? + /// Tool calls in an assistant message. + var tool_calls: [MessageToolCall]? + /// When we want to call a tool, we have to provide the id of the call. + /// + /// - important: It's required when the role is `tool`. + var tool_call_id: String? + /// When the bot wants to call a function, it will reply with a function call. + /// + /// Deprecated. + var function_call: MessageFunctionCall? + } + + struct MessageFunctionCall: Codable, Equatable { + /// The name of the + var name: String? + /// A JSON string. + var arguments: String? + } + + struct MessageToolCall: Codable, Equatable { + /// When it's returned as a data chunk, use the index to identify the tool call. + var index: Int? + /// The id of the tool call. + var id: String? + /// The type of the tool. + var type: String? + /// The function call. + var function: MessageFunctionCall? + } + + struct Tool: Codable, Equatable { + var type: String = "function" + var function: ChatGPTFunctionSchema + } + + var model: String + var messages: [Message] + var temperature: Double? + var stream: Bool? + var stop: [String]? + var max_tokens: Int? + var tool_choice: FunctionCallStrategy? + var tools: [Tool]? + } + + var apiKey: String + var endpoint: URL + var requestBody: RequestBody + var model: ChatModel + + init( + apiKey: String, + model: ChatModel, + endpoint: URL, + requestBody: ChatCompletionsRequestBody + ) { + self.apiKey = apiKey + self.endpoint = endpoint + self.requestBody = .init(requestBody) + self.model = model + } + + func callAsFunction() async throws + -> AsyncThrowingStream + { + requestBody.stream = true + var request = URLRequest(url: endpoint) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(requestBody) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + if !apiKey.isEmpty { + switch model.format { + case .openAI: + if !model.info.openAIInfo.organizationID.isEmpty { + request.setValue( + "OpenAI-Organization", + forHTTPHeaderField: model.info.openAIInfo.organizationID + ) + } + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .openAICompatible: + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .azureOpenAI: + request.setValue(apiKey, forHTTPHeaderField: "api-key") + case .googleAI: + assertionFailure("Unsupported") + case .ollama: + assertionFailure("Unsupported") + } + } + + let (result, response) = try await URLSession.shared.bytes(for: request) + guard let response = response as? HTTPURLResponse else { + throw ChatGPTServiceError.responseInvalid + } + + guard response.statusCode == 200 else { + let text = try await result.lines.reduce(into: "") { partialResult, current in + partialResult += current + } + guard let data = text.data(using: .utf8) + else { throw ChatGPTServiceError.responseInvalid } + let decoder = JSONDecoder() + let error = try? decoder.decode(CompletionAPIError.self, from: data) + throw error ?? ChatGPTServiceError.responseInvalid + } + + let stream = ResponseStream(result: result) { + var line = $0 + let prefix = "data: " + if line.hasPrefix(prefix) { + line.removeFirst(prefix.count) + } + + if line == "[DONE]" { return .init(chunk: nil, done: true) } + + do { + let chunk = try JSONDecoder().decode( + StreamDataChunk.self, + from: line.data(using: .utf8) ?? Data() + ) + return .init(chunk: chunk, done: false) + } catch { + Logger.service.error("Error decoding stream data: \(error)") + throw error + } + } + + return stream.map { $0.formalized() }.toStream() + } + + func callAsFunction() async throws -> ChatCompletionResponseBody { + requestBody.stream = false + var request = URLRequest(url: endpoint) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(requestBody) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + if !apiKey.isEmpty { + switch model.format { + case .openAI: + if !model.info.openAIInfo.organizationID.isEmpty { + request.setValue( + "OpenAI-Organization", + forHTTPHeaderField: model.info.openAIInfo.organizationID + ) + } + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .openAICompatible: + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .azureOpenAI: + request.setValue(apiKey, forHTTPHeaderField: "api-key") + case .googleAI: + assertionFailure("Unsupported") + case .ollama: + assertionFailure("Unsupported") + } + } + + let (result, response) = try await URLSession.shared.data(for: request) + guard let response = response as? HTTPURLResponse else { + throw ChatGPTServiceError.responseInvalid + } + + guard response.statusCode == 200 else { + let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result) + throw error ?? ChatGPTServiceError + .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") + } + + do { + let body = try JSONDecoder().decode(ResponseBody.self, from: result) + return body.formalized() + } catch { + dump(error) + throw error + } + } +} + +extension OpenAIChatCompletionsService.ResponseBody { + func formalized() -> ChatCompletionResponseBody { + let message: ChatCompletionResponseBody.Message + let otherMessages: [ChatCompletionResponseBody.Message] + + func convertMessage(_ message: Message) -> ChatCompletionResponseBody.Message { + .init( + role: message.role.formalized, + content: message.content ?? "", + toolCalls: { + if let toolCalls = message.tool_calls { + return toolCalls.map { toolCall in + .init( + id: toolCall.id ?? "", + type: toolCall.type ?? "function", + function: .init( + name: toolCall.function?.name ?? "", + arguments: toolCall.function?.arguments + ) + ) + } + } else if let functionCall = message.function_call { + return [ + .init( + id: functionCall.name ?? "", + type: "function", + function: .init( + name: functionCall.name ?? "", + arguments: functionCall.arguments + ) + ), + ] + } else { + return [] + } + }() + ) + } + + if let first = choices.first?.message { + message = convertMessage(first) + otherMessages = choices.dropFirst().map { convertMessage($0.message) } + } else { + message = .init(role: .assistant, content: "") + otherMessages = [] + } + + return .init( + id: id, + object: object, + model: model, + message: message, + otherChoices: otherMessages, + finishReason: choices.first?.finish_reason ?? "" + ) + } +} + +extension OpenAIChatCompletionsService.StreamDataChunk { + func formalized() -> ChatCompletionsStreamDataChunk { + .init( + id: id, + object: object, + model: model, + message: { + if let choice = self.choices?.first { + return .init( + role: choice.delta?.role?.formalized, + content: choice.delta?.content, + toolCalls: { + if let toolCalls = choice.delta?.tool_calls { + return toolCalls.map { + .init( + index: $0.index, + id: $0.id, + type: $0.type, + function: .init( + name: $0.function?.name, + arguments: $0.function?.arguments + ) + ) + } + } + + if let functionCall = choice.delta?.function_call { + return [ + .init( + index: 0, + id: functionCall.name, + type: "function", + function: .init( + name: functionCall.name, + arguments: functionCall.arguments + ) + ), + ] + } + + return nil + }() + ) + } + return nil + }(), + finishReason: choices?.first?.finish_reason + ) + } +} + +extension OpenAIChatCompletionsService.RequestBody { + init(_ body: ChatCompletionsRequestBody) { + model = body.model + messages = body.messages.map { message in + .init( + role: { + switch message.role { + case .user: + return .user + case .assistant: + return .assistant + case .system: + return .system + case .tool: + return .tool + } + }(), + content: message.content, + name: message.name, + tool_calls: message.toolCalls?.map { tool in + MessageToolCall( + id: tool.id, + type: tool.type, + function: MessageFunctionCall( + name: tool.function.name, + arguments: tool.function.arguments + ) + ) + }, + tool_call_id: message.toolCallId + ) + } + temperature = body.temperature + stream = body.stream + stop = body.stop + max_tokens = body.maxTokens + tool_choice = body.toolChoice + tools = body.tools?.map { + Tool( + type: $0.type, + function: $0.function + ) + } + } +} + diff --git a/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift b/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift deleted file mode 100644 index 31e86492..00000000 --- a/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift +++ /dev/null @@ -1,121 +0,0 @@ -import AIModel -import Foundation -import Preferences - -typealias CompletionAPIBuilder = (String, ChatModel, URL, CompletionRequestBody, ChatGPTPrompt) - -> CompletionAPI - -protocol CompletionAPI { - func callAsFunction() async throws -> CompletionResponseBody -} - -/// https://platform.openai.com/docs/api-reference/chat/create -struct CompletionResponseBody: Codable, Equatable { - struct Message: Codable, Equatable { - /// The role of the message. - var role: ChatMessage.Role - /// The content of the message. - var content: String? - /// When we want to reply to a function call with the result, we have to provide the - /// name of the function call, and include the result in `content`. - /// - /// - important: It's required when the role is `function`. - var name: String? - /// When the bot wants to call a function, it will reply with a function call in format: - /// ```json - /// { - /// "name": "weather", - /// "arguments": "{ \"location\": \"earth\" }" - /// } - /// ``` - var function_call: CompletionRequestBody.MessageFunctionCall? - } - - struct Choice: Codable, Equatable { - var message: Message - var index: Int - var finish_reason: String - } - - struct Usage: Codable, Equatable { - var prompt_tokens: Int - var completion_tokens: Int - var total_tokens: Int - } - - var id: String? - var object: String - var model: String - var usage: Usage - var choices: [Choice] -} - -struct CompletionAPIError: Error, Codable, LocalizedError { - struct E: Codable { - var message: String - var type: String - var param: String - var code: String - } - - var error: E - - var errorDescription: String? { error.message } -} - -struct OpenAICompletionAPI: CompletionAPI { - var apiKey: String - var endpoint: URL - var requestBody: CompletionRequestBody - var model: ChatModel - - init( - apiKey: String, - model: ChatModel, - endpoint: URL, - requestBody: CompletionRequestBody - ) { - self.apiKey = apiKey - self.endpoint = endpoint - self.requestBody = requestBody - self.requestBody.stream = false - self.model = model - } - - func callAsFunction() async throws -> CompletionResponseBody { - var request = URLRequest(url: endpoint) - request.httpMethod = "POST" - let encoder = JSONEncoder() - request.httpBody = try encoder.encode(requestBody) - request.setValue("application/json", forHTTPHeaderField: "Content-Type") - if !apiKey.isEmpty { - switch model.format { - case .openAI, .openAICompatible: - request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") - case .azureOpenAI: - request.setValue(apiKey, forHTTPHeaderField: "api-key") - case .googleAI: - assert(false, "Unsupported") - } - } - - let (result, response) = try await URLSession.shared.data(for: request) - guard let response = response as? HTTPURLResponse else { - throw ChatGPTServiceError.responseInvalid - } - - guard response.statusCode == 200 else { - let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result) - throw error ?? ChatGPTServiceError - .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") - } - - do { - return try JSONDecoder().decode(CompletionResponseBody.self, from: result) - } catch { - dump(error) - throw error - } - } -} - diff --git a/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift b/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift deleted file mode 100644 index 46c6b1ff..00000000 --- a/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift +++ /dev/null @@ -1,236 +0,0 @@ -import AIModel -import AsyncAlgorithms -import Foundation -import Preferences - -typealias CompletionStreamAPIBuilder = ( - String, - ChatModel, - URL, - CompletionRequestBody, - ChatGPTPrompt -) -> any CompletionStreamAPI - -protocol CompletionStreamAPI { - func callAsFunction() async throws -> AsyncThrowingStream -} - -public enum FunctionCallStrategy: Codable, Equatable { - /// Forbid the bot to call any function. - case none - /// Let the bot choose what function to call. - case auto - /// Force the bot to call a function with the given name. - case name(String) - - struct CallFunctionNamed: Codable { - var name: String - } - - public func encode(to encoder: Encoder) throws { - var container = encoder.singleValueContainer() - switch self { - case .none: - try container.encode("none") - case .auto: - try container.encode("auto") - case let .name(name): - try container.encode(CallFunctionNamed(name: name)) - } - } -} - -/// https://platform.openai.com/docs/api-reference/chat/create -struct CompletionRequestBody: Codable, Equatable { - struct Message: Codable, Equatable { - /// The role of the message. - var role: ChatMessage.Role - /// The content of the message. - var content: String - /// When we want to reply to a function call with the result, we have to provide the - /// name of the function call, and include the result in `content`. - /// - /// - important: It's required when the role is `function`. - var name: String? - /// When the bot wants to call a function, it will reply with a function call in format: - /// ```json - /// { - /// "name": "weather", - /// "arguments": "{ \"location\": \"earth\" }" - /// } - /// ``` - var function_call: CompletionRequestBody.MessageFunctionCall? - } - - struct MessageFunctionCall: Codable, Equatable { - /// The name of the - var name: String - /// A JSON string. - var arguments: String? - } - - struct Function: Codable { - var name: String - var description: String - /// JSON schema. - var arguments: String - } - - var model: String - var messages: [Message] - var temperature: Double? - var top_p: Double? - var n: Double? - var stream: Bool? - var stop: [String]? - var max_tokens: Int? - var presence_penalty: Double? - var frequency_penalty: Double? - var logit_bias: [String: Double]? - var user: String? - /// Pass nil to let the bot decide. - var function_call: FunctionCallStrategy? - var functions: [ChatGPTFunctionSchema]? - - init( - model: String, - messages: [Message], - temperature: Double? = nil, - top_p: Double? = nil, - n: Double? = nil, - stream: Bool? = nil, - stop: [String]? = nil, - max_tokens: Int? = nil, - presence_penalty: Double? = nil, - frequency_penalty: Double? = nil, - logit_bias: [String: Double]? = nil, - user: String? = nil, - function_call: FunctionCallStrategy? = nil, - functions: [ChatGPTFunctionSchema] = [] - ) { - self.model = model - self.messages = messages - self.temperature = temperature - self.top_p = top_p - self.n = n - self.stream = stream - self.stop = stop - self.max_tokens = max_tokens - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - self.logit_bias = logit_bias - self.user = user - if UserDefaults.shared.value(for: \.disableFunctionCalling) { - self.function_call = nil - self.functions = nil - } else { - self.function_call = function_call - self.functions = functions.isEmpty ? nil : functions - } - } -} - -struct CompletionStreamDataChunk: Codable { - var id: String? - var object: String? - var model: String? - var choices: [Choice]? - - struct Choice: Codable { - var delta: Delta? - var index: Int? - var finish_reason: String? - - struct Delta: Codable { - struct FunctionCall: Codable { - var name: String? - var arguments: String? - } - - var role: ChatMessage.Role? - var content: String? - var function_call: FunctionCall? - } - } -} - -struct OpenAICompletionStreamAPI: CompletionStreamAPI { - var apiKey: String - var endpoint: URL - var requestBody: CompletionRequestBody - var model: ChatModel - - init( - apiKey: String, - model: ChatModel, - endpoint: URL, - requestBody: CompletionRequestBody - ) { - self.apiKey = apiKey - self.endpoint = endpoint - self.requestBody = requestBody - self.requestBody.stream = true - self.model = model - } - - func callAsFunction() async throws -> AsyncThrowingStream { - var request = URLRequest(url: endpoint) - request.httpMethod = "POST" - let encoder = JSONEncoder() - request.httpBody = try encoder.encode(requestBody) - request.setValue("application/json", forHTTPHeaderField: "Content-Type") - if !apiKey.isEmpty { - switch model.format { - case .openAI, .openAICompatible: - request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") - case .azureOpenAI: - request.setValue(apiKey, forHTTPHeaderField: "api-key") - case .googleAI: - assertionFailure("Unsupported") - } - } - - let (result, response) = try await URLSession.shared.bytes(for: request) - guard let response = response as? HTTPURLResponse else { - throw ChatGPTServiceError.responseInvalid - } - - guard response.statusCode == 200 else { - let text = try await result.lines.reduce(into: "") { partialResult, current in - partialResult += current - } - guard let data = text.data(using: .utf8) - else { throw ChatGPTServiceError.responseInvalid } - let decoder = JSONDecoder() - let error = try? decoder.decode(ChatGPTError.self, from: data) - throw error ?? ChatGPTServiceError.responseInvalid - } - - let stream = AsyncThrowingStream { continuation in - let task = Task { - do { - for try await line in result.lines { - if Task.isCancelled { break } - let prefix = "data: " - guard line.hasPrefix(prefix), - let content = line.dropFirst(prefix.count).data(using: .utf8), - let chunk = try? JSONDecoder() - .decode(CompletionStreamDataChunk.self, from: content) - else { continue } - continuation.yield(chunk) - } - continuation.finish() - } catch { - continuation.finish(throwing: error) - } - } - continuation.onTermination = { _ in - task.cancel() - result.task.cancel() - } - } - - return stream - } -} - diff --git a/Tool/Sources/OpenAIService/APIs/OpenAIEmbeddingService.swift b/Tool/Sources/OpenAIService/APIs/OpenAIEmbeddingService.swift new file mode 100644 index 00000000..140e9d09 --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/OpenAIEmbeddingService.swift @@ -0,0 +1,136 @@ +import AIModel +import Foundation +import Logger + +struct OpenAIEmbeddingService: EmbeddingAPI { + struct EmbeddingRequestBody: Encodable { + var input: [String] + var model: String + } + + struct EmbeddingFromTokensRequestBody: Encodable { + var input: [[Int]] + var model: String + } + + let apiKey: String + let model: EmbeddingModel + let endpoint: String + + public func embed(text: String) async throws -> EmbeddingResponse { + return try await embed(texts: [text]) + } + + public func embed(texts text: [String]) async throws -> EmbeddingResponse { + guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect } + var request = URLRequest(url: url) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(EmbeddingRequestBody( + input: text, + model: model.info.modelName + )) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + if !apiKey.isEmpty { + switch model.format { + case .openAI: + if model.info.openAIInfo.organizationID.isEmpty { + request.setValue( + "OpenAI-Organization", + forHTTPHeaderField: model.info.openAIInfo.organizationID + ) + } + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .openAICompatible: + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .azureOpenAI: + request.setValue(apiKey, forHTTPHeaderField: "api-key") + case .ollama: + assertionFailure("Unsupported") + } + } + + let (result, response) = try await URLSession.shared.data(for: request) + guard let response = response as? HTTPURLResponse else { + throw ChatGPTServiceError.responseInvalid + } + + guard response.statusCode == 200 else { + let error = try? JSONDecoder().decode( + OpenAIChatCompletionsService.CompletionAPIError.self, + from: result + ) + throw error ?? ChatGPTServiceError + .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") + } + + let embeddingResponse = try JSONDecoder().decode(EmbeddingResponse.self, from: result) + #if DEBUG + Logger.service.info(""" + Embedding usage + - number of strings: \(text.count) + - prompt tokens: \(embeddingResponse.usage.prompt_tokens) + - total tokens: \(embeddingResponse.usage.total_tokens) + + """) + #endif + return embeddingResponse + } + + public func embed(tokens: [[Int]]) async throws -> EmbeddingResponse { + guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect } + var request = URLRequest(url: url) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(EmbeddingFromTokensRequestBody( + input: tokens, + model: model.info.modelName + )) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + if !apiKey.isEmpty { + switch model.format { + case .openAI: + if model.info.openAIInfo.organizationID.isEmpty { + request.setValue( + "OpenAI-Organization", + forHTTPHeaderField: model.info.openAIInfo.organizationID + ) + } + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .openAICompatible: + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .azureOpenAI: + request.setValue(apiKey, forHTTPHeaderField: "api-key") + case .ollama: + assertionFailure("Unsupported") + } + } + + let (result, response) = try await URLSession.shared.data(for: request) + guard let response = response as? HTTPURLResponse else { + throw ChatGPTServiceError.responseInvalid + } + + guard response.statusCode == 200 else { + let error = try? JSONDecoder().decode( + OpenAIChatCompletionsService.CompletionAPIError.self, + from: result + ) + throw error ?? ChatGPTServiceError + .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") + } + + let embeddingResponse = try JSONDecoder().decode(EmbeddingResponse.self, from: result) + #if DEBUG + Logger.service.info(""" + Embedding usage + - number of strings: \(tokens.count) + - prompt tokens: \(embeddingResponse.usage.prompt_tokens) + - total tokens: \(embeddingResponse.usage.total_tokens) + + """) + #endif + return embeddingResponse + } +} + diff --git a/Tool/Sources/OpenAIService/APIs/ResponseStream.swift b/Tool/Sources/OpenAIService/APIs/ResponseStream.swift new file mode 100644 index 00000000..ce28b7f9 --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/ResponseStream.swift @@ -0,0 +1,45 @@ +import Foundation + +struct ResponseStream: AsyncSequence { + func makeAsyncIterator() -> Stream.AsyncIterator { + stream.makeAsyncIterator() + } + + typealias Stream = AsyncThrowingStream + typealias AsyncIterator = Stream.AsyncIterator + typealias Element = Chunk + + struct LineContent { + let chunk: Chunk? + let done: Bool + } + + let stream: Stream + + init(result: URLSession.AsyncBytes, lineExtractor: @escaping (String) throws -> LineContent) { + stream = AsyncThrowingStream { continuation in + let task = Task { + do { + for try await line in result.lines { + if Task.isCancelled { break } + let content = try lineExtractor(line) + if let chunk = content.chunk { + continuation.yield(chunk) + } + + if content.done { break } + } + continuation.finish() + } catch { + continuation.finish(throwing: error) + result.task.cancel() + } + } + continuation.onTermination = { _ in + task.cancel() + result.task.cancel() + } + } + } +} + diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index 5d1480ca..61970a5e 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -1,6 +1,8 @@ +import AIModel import AsyncAlgorithms import Dependencies import Foundation +import IdentifiedCollections import Preferences public protocol ChatGPTServiceType { @@ -63,24 +65,47 @@ public struct ChatGPTError: Error, Codable, LocalizedError { } } +typealias ChatCompletionsStreamAPIBuilder = ( + String, + ChatModel, + URL, + ChatCompletionsRequestBody, + ChatGPTPrompt +) -> any ChatCompletionsStreamAPI + +typealias ChatCompletionsAPIBuilder = ( + String, + ChatModel, + URL, + ChatCompletionsRequestBody, + ChatGPTPrompt +) -> any ChatCompletionsAPI + public class ChatGPTService: ChatGPTServiceType { public var memory: ChatGPTMemory public var configuration: ChatGPTConfiguration public var functionProvider: ChatGPTFunctionProvider var runningTask: Task? - var buildCompletionStreamAPI: CompletionStreamAPIBuilder = { + var buildCompletionStreamAPI: ChatCompletionsStreamAPIBuilder = { apiKey, model, endpoint, requestBody, prompt in switch model.format { case .googleAI: - return GoogleCompletionStreamAPI( + return GoogleAIChatCompletionsService( apiKey: apiKey, model: model, requestBody: requestBody, prompt: prompt ) case .openAI, .openAICompatible, .azureOpenAI: - return OpenAICompletionStreamAPI( + return OpenAIChatCompletionsService( + apiKey: apiKey, + model: model, + endpoint: endpoint, + requestBody: requestBody + ) + case .ollama: + return OllamaChatCompletionsService( apiKey: apiKey, model: model, endpoint: endpoint, @@ -89,18 +114,25 @@ public class ChatGPTService: ChatGPTServiceType { } } - var buildCompletionAPI: CompletionAPIBuilder = { + var buildCompletionAPI: ChatCompletionsAPIBuilder = { apiKey, model, endpoint, requestBody, prompt in switch model.format { case .googleAI: - return GoogleCompletionAPI( + return GoogleAIChatCompletionsService( apiKey: apiKey, model: model, requestBody: requestBody, prompt: prompt ) case .openAI, .openAICompatible, .azureOpenAI: - return OpenAICompletionAPI( + return OpenAIChatCompletionsService( + apiKey: apiKey, + model: model, + endpoint: endpoint, + requestBody: requestBody + ) + case .ollama: + return OllamaChatCompletionsService( apiKey: apiKey, model: model, endpoint: endpoint, @@ -137,7 +169,7 @@ public class ChatGPTService: ChatGPTServiceType { role: .user, content: content, name: nil, - functionCall: nil, + toolCalls: nil, summary: summary, references: [] ) @@ -148,20 +180,24 @@ public class ChatGPTService: ChatGPTServiceType { AsyncThrowingStream { continuation in let task = Task(priority: .userInitiated) { do { - var functionCall: ChatMessage.FunctionCall? - var functionCallMessageID = "" + var pendingToolCalls = [ChatMessage.ToolCall]() + var sourceMessageId = "" var isInitialCall = true - loop: while functionCall != nil || isInitialCall { + loop: while !pendingToolCalls.isEmpty || isInitialCall { try Task.checkCancellation() isInitialCall = false - if let call = functionCall { + for toolCall in pendingToolCalls { if !configuration.runFunctionsAutomatically { break loop } - functionCall = nil - await runFunctionCall(call, messageId: functionCallMessageID) + await runFunctionCall( + toolCall, + sourceMessageId: sourceMessageId + ) } - let stream = try await sendMemory() + sourceMessageId = uuid() + .uuidString + String(date().timeIntervalSince1970) + let stream = try await sendMemory(proposedId: sourceMessageId) #if DEBUG var reply = "" @@ -175,20 +211,19 @@ public class ChatGPTService: ChatGPTServiceType { #if DEBUG reply.append(text) #endif - case let .functionCall(call): - if functionCall == nil { - functionCallMessageID = uuid().uuidString - functionCall = call - } else { - functionCall?.name.append(call.name) - functionCall?.arguments.append(call.arguments) - } + + case let .toolCall(toolCall): await prepareFunctionCall( - call, - messageId: functionCallMessageID + toolCall, + sourceMessageId: sourceMessageId ) } } + + pendingToolCalls = await memory.history + .last { $0.id == sourceMessageId }? + .toolCalls ?? [] + #if DEBUG Debugger.didReceiveResponse(content: reply) #endif @@ -226,17 +261,19 @@ public class ChatGPTService: ChatGPTServiceType { return try await Debugger.$id.withValue(.init()) { let message = try await sendMemoryAndWait() var finalResult = message?.content - var functionCall = message?.functionCall - while let call = functionCall { + var toolCalls = message?.toolCalls + while let sourceMessageId = message?.id, let calls = toolCalls, !calls.isEmpty { try Task.checkCancellation() if !configuration.runFunctionsAutomatically { break } - functionCall = nil - await runFunctionCall(call) + toolCalls = nil + for call in calls { + await runFunctionCall(call, sourceMessageId: sourceMessageId) + } guard let nextMessage = try await sendMemoryAndWait() else { break } finalResult = nextMessage.content - functionCall = nextMessage.functionCall + toolCalls = nextMessage.toolCalls } #if DEBUG @@ -248,7 +285,7 @@ public class ChatGPTService: ChatGPTServiceType { } } - #warning("TODO: remove this and let the concurrency system handle it") + #warning("TODO: Move the cancellation up to the caller.") public func stopReceivingMessage() { runningTask?.cancel() runningTask = nil @@ -260,11 +297,11 @@ public class ChatGPTService: ChatGPTServiceType { extension ChatGPTService { enum StreamContent { case text(String) - case functionCall(ChatMessage.FunctionCall) + case toolCall(ChatMessage.ToolCall) } /// Send the memory as prompt to ChatGPT, with stream enabled. - func sendMemory() async throws -> AsyncThrowingStream { + func sendMemory(proposedId: String) async throws -> AsyncThrowingStream { let prompt = await memory.generatePrompt() guard let model = configuration.model else { @@ -274,42 +311,7 @@ extension ChatGPTService { throw ChatGPTServiceError.endpointIncorrect } - let messages = prompt.history.map { - CompletionRequestBody.Message( - role: $0.role, - content: $0.content ?? "", - name: $0.name, - function_call: $0.functionCall.map { - .init(name: $0.name, arguments: $0.arguments) - } - ) - } - let remainingTokens = prompt.remainingTokenCount - - let requestBody = CompletionRequestBody( - model: model.info.modelName, - messages: messages, - temperature: configuration.temperature, - stream: true, - stop: configuration.stop.isEmpty ? nil : configuration.stop, - max_tokens: maxTokenForReply( - maxToken: model.info.maxTokens, - remainingTokens: remainingTokens - ), - function_call: model.info.supportsFunctionCalling - ? functionProvider.functionCallStrategy - : nil, - functions: - model.info.supportsFunctionCalling - ? functionProvider.functions.map { - ChatGPTFunctionSchema( - name: $0.name, - description: $0.description, - parameters: $0.argumentSchema - ) - } - : [] - ) + let requestBody = createRequestBody(prompt: prompt, model: model, stream: true) let api = buildCompletionStreamAPI( configuration.apiKey, @@ -323,13 +325,12 @@ extension ChatGPTService { Debugger.didSendRequestBody(body: requestBody) #endif - let proposedId = uuid().uuidString + String(date().timeIntervalSince1970) - return AsyncThrowingStream { continuation in let task = Task { do { await memory.streamMessage( id: proposedId, + role: .assistant, references: prompt.references ) let chunks = try await api() @@ -337,28 +338,35 @@ extension ChatGPTService { if Task.isCancelled { throw CancellationError() } - guard let delta = chunk.choices?.first?.delta else { continue } + guard let delta = chunk.message else { continue } // The api will always return a function call with JSON object. // The first round will contain the function name and an empty argument. // e.g. {"name":"weather","arguments":""} // The other rounds will contain part of the arguments. - let functionCall = delta.function_call.map { - ChatMessage.FunctionCall( - name: $0.name ?? "", - arguments: $0.arguments ?? "" - ) - } + let toolCalls = delta.toolCalls? + .reduce(into: [Int: ChatMessage.ToolCall]()) { + $0[$1.index ?? 0] = ChatMessage.ToolCall( + id: $1.id ?? "", + type: $1.type ?? "", + function: .init( + name: $1.function?.name ?? "", + arguments: $1.function?.arguments ?? "" + ) + ) + } await memory.streamMessage( id: proposedId, - role: delta.role, + role: delta.role?.asChatMessageRole, content: delta.content, - functionCall: functionCall + toolCalls: toolCalls ) - if let functionCall { - continuation.yield(.functionCall(functionCall)) + if let toolCalls { + for toolCall in toolCalls.values { + continuation.yield(.toolCall(toolCall)) + } } if let content = delta.content { @@ -402,42 +410,7 @@ extension ChatGPTService { throw ChatGPTServiceError.endpointIncorrect } - let messages = prompt.history.map { - CompletionRequestBody.Message( - role: $0.role, - content: $0.content ?? "", - name: $0.name, - function_call: $0.functionCall.map { - .init(name: $0.name, arguments: $0.arguments) - } - ) - } - let remainingTokens = prompt.remainingTokenCount - - let requestBody = CompletionRequestBody( - model: model.info.modelName, - messages: messages, - temperature: configuration.temperature, - stream: true, - stop: configuration.stop.isEmpty ? nil : configuration.stop, - max_tokens: maxTokenForReply( - maxToken: model.info.maxTokens, - remainingTokens: remainingTokens - ), - function_call: model.info.supportsFunctionCalling - ? functionProvider.functionCallStrategy - : nil, - functions: - model.info.supportsFunctionCalling - ? functionProvider.functions.map { - ChatGPTFunctionSchema( - name: $0.name, - description: $0.description, - parameters: $0.argumentSchema - ) - } - : [] - ) + let requestBody = createRequestBody(prompt: prompt, model: model, stream: false) let api = buildCompletionAPI( configuration.apiKey, @@ -453,14 +426,24 @@ extension ChatGPTService { let response = try await api() - guard let choice = response.choices.first else { return nil } + let choice = response.message let message = ChatMessage( id: proposedId, - role: choice.message.role, - content: choice.message.content, - name: choice.message.name, - functionCall: choice.message.function_call.map { - ChatMessage.FunctionCall(name: $0.name, arguments: $0.arguments ?? "") + role: { + switch choice.role { + case .system: .system + case .user: .user + case .assistant: .assistant + case .tool: .user + } + }(), + content: choice.content, + name: choice.name, + toolCalls: choice.toolCalls?.map { + ChatMessage.ToolCall(id: $0.id, type: $0.type, function: .init( + name: $0.function.name, + arguments: $0.function.arguments ?? "" + )) }, references: prompt.references ) @@ -470,50 +453,57 @@ extension ChatGPTService { /// When a function call is detected, but arguments are not yet ready, we can call this /// to insert a message placeholder in memory. - func prepareFunctionCall(_ call: ChatMessage.FunctionCall, messageId: String) async { - guard let function = functionProvider.function(named: call.name) else { return } - await memory.streamMessage(id: messageId, role: .function, name: call.name) + func prepareFunctionCall(_ call: ChatMessage.ToolCall, sourceMessageId: String) async { + guard let function = functionProvider.function(named: call.function.name) else { return } + await memory.streamToolCallResponse(id: sourceMessageId, toolCallId: call.id) await function.prepare { [weak self] summary in - await self?.memory.updateMessage(id: messageId) { message in - message.summary = summary - } + await self?.memory.streamToolCallResponse( + id: sourceMessageId, + toolCallId: call.id, + summary: summary + ) } } /// Run a function call from the bot, and insert the result in memory. @discardableResult func runFunctionCall( - _ call: ChatMessage.FunctionCall, - messageId: String? = nil + _ call: ChatMessage.ToolCall, + sourceMessageId: String ) async -> String { #if DEBUG - Debugger.didReceiveFunction(name: call.name, arguments: call.arguments) + Debugger.didReceiveFunction(name: call.function.name, arguments: call.function.arguments) #endif - let messageId = messageId ?? uuid().uuidString - - guard let function = functionProvider.function(named: call.name) else { - return await fallbackFunctionCall(call, messageId: messageId) + guard let function = functionProvider.function(named: call.function.name) else { + return await fallbackFunctionCall(call, sourceMessageId: sourceMessageId) } - await memory.streamMessage(id: messageId, role: .function, name: call.name) + await memory.streamToolCallResponse( + id: sourceMessageId, + toolCallId: call.id + ) do { // Run the function - let result = try await function.call(argumentsJsonString: call.arguments) { + let result = try await function.call(argumentsJsonString: call.function.arguments) { [weak self] summary in - await self?.memory.updateMessage(id: messageId) { message in - message.summary = summary - } + await self?.memory.streamToolCallResponse( + id: sourceMessageId, + toolCallId: call.id, + summary: summary + ) } #if DEBUG Debugger.didReceiveFunctionResult(result: result.botReadableContent) #endif - await memory.updateMessage(id: messageId) { message in - message.content = result.botReadableContent - } + await memory.streamToolCallResponse( + id: sourceMessageId, + toolCallId: call.id, + content: result.botReadableContent + ) return result.botReadableContent } catch { @@ -524,20 +514,22 @@ extension ChatGPTService { Debugger.didReceiveFunctionResult(result: content) #endif - await memory.updateMessage(id: messageId) { message in - message.content = content - } + await memory.streamToolCallResponse( + id: sourceMessageId, + toolCallId: call.id, + content: content + ) return content } } /// Mock a function call result when the bot is calling a function that is not implemented. func fallbackFunctionCall( - _ call: ChatMessage.FunctionCall, - messageId: String + _ call: ChatMessage.ToolCall, + sourceMessageId: String ) async -> String { let memory = ConversationChatGPTMemory(systemPrompt: { - if call.name == "python" { + if call.function.name == "python" { return """ Act like a Python interpreter. I will give you Python code and you will execute it. @@ -545,7 +537,7 @@ extension ChatGPTService { """ } else { return """ - You are a function simulator. Your name is \(call.name). + You are a function simulator. Your name is \(call.function.name). Act like a function. I will send you the arguments. Reply with output of the function and tell me it's an answer generated by LLM. @@ -564,25 +556,113 @@ extension ChatGPTService { let content: String = await { do { return try await service.sendAndWait(content: """ - \(call.arguments) + \(call.function.arguments) """) ?? "No result." } catch { return "No result." } }() - await memory.streamMessage( - id: messageId, - role: .function, + await memory.streamToolCallResponse( + id: sourceMessageId, + toolCallId: call.id, content: content, - name: call.name, summary: "Finished running function." ) return content } + + func createRequestBody( + prompt: ChatGPTPrompt, + model: ChatModel, + stream: Bool + ) -> ChatCompletionsRequestBody { + let serviceSupportsFunctionCalling = switch model.format { + case .openAI, .openAICompatible, .azureOpenAI: + model.info.supportsFunctionCalling + case .ollama, .googleAI: + false + } + + let messages = prompt.history.flatMap { chatMessage in + var all = [ChatCompletionsRequestBody.Message]() + all.append(ChatCompletionsRequestBody.Message( + role: { + switch chatMessage.role { + case .system: .system + case .user: .user + case .assistant: .assistant + } + }(), + content: chatMessage.content ?? "", + name: chatMessage.name, + toolCalls: { + if serviceSupportsFunctionCalling { + chatMessage.toolCalls?.map { + .init( + id: $0.id, + type: $0.type, + function: .init( + name: $0.function.name, + arguments: $0.function.arguments + ) + ) + } + } else { + nil + } + }() + )) + + for call in chatMessage.toolCalls ?? [] { + if serviceSupportsFunctionCalling { + all.append(ChatCompletionsRequestBody.Message( + role: .tool, + content: call.response.content, + toolCallId: call.id + )) + } else { + all.append(ChatCompletionsRequestBody.Message( + role: .user, + content: call.response.content + )) + } + } + + return all + } + + let remainingTokens = prompt.remainingTokenCount + + let requestBody = ChatCompletionsRequestBody( + model: model.info.modelName, + messages: messages, + temperature: configuration.temperature, + stream: stream, + stop: configuration.stop.isEmpty ? nil : configuration.stop, + maxTokens: maxTokenForReply( + maxToken: model.info.maxTokens, + remainingTokens: remainingTokens + ), + toolChoice: serviceSupportsFunctionCalling + ? functionProvider.functionCallStrategy + : nil, + tools: serviceSupportsFunctionCalling + ? functionProvider.functions.map { + .init(function: ChatGPTFunctionSchema( + name: $0.name, + description: $0.description, + parameters: $0.argumentSchema + )) + } + : [] + ) + + return requestBody + } } extension ChatGPTService { - func changeBuildCompletionStreamAPI(_ builder: @escaping CompletionStreamAPIBuilder) { + func changeBuildCompletionStreamAPI(_ builder: @escaping ChatCompletionsStreamAPIBuilder) { buildCompletionStreamAPI = builder } } diff --git a/Tool/Sources/OpenAIService/Debug/Debug.swift b/Tool/Sources/OpenAIService/Debug/Debug.swift index b27358e0..31864964 100644 --- a/Tool/Sources/OpenAIService/Debug/Debug.swift +++ b/Tool/Sources/OpenAIService/Debug/Debug.swift @@ -6,7 +6,7 @@ enum Debugger { static var id: UUID? #if DEBUG - static func didSendRequestBody(body: CompletionRequestBody) { + static func didSendRequestBody(body: ChatCompletionsRequestBody) { do { let json = try JSONEncoder().encode(body) let center = NotificationCenter.default diff --git a/Tool/Sources/OpenAIService/EmbeddingService.swift b/Tool/Sources/OpenAIService/EmbeddingService.swift index d3bd1c8d..d5bf2f41 100644 --- a/Tool/Sources/OpenAIService/EmbeddingService.swift +++ b/Tool/Sources/OpenAIService/EmbeddingService.swift @@ -1,34 +1,6 @@ import Foundation import Logger -public struct EmbeddingResponse: Decodable { - public struct Object: Decodable { - public var embedding: [Float] - public var index: Int - public var object: String - } - - public var data: [Object] - public var model: String - - public struct Usage: Decodable { - public var prompt_tokens: Int - public var total_tokens: Int - } - - public var usage: Usage -} - -struct EmbeddingRequestBody: Encodable { - var input: [String] - var model: String -} - -struct EmbeddingFromTokensRequestBody: Encodable { - var input: [[Int]] - var model: String -} - public struct EmbeddingService { public let configuration: EmbeddingConfiguration @@ -37,48 +9,55 @@ public struct EmbeddingService { } public func embed(text: String) async throws -> EmbeddingResponse { - return try await embed(text: [text]) - } - - public func embed(text: [String]) async throws -> EmbeddingResponse { guard let model = configuration.model else { throw ChatGPTServiceError.embeddingModelNotAvailable } - guard let url = URL(string: configuration.endpoint) else { - throw ChatGPTServiceError.endpointIncorrect - } - var request = URLRequest(url: url) - request.httpMethod = "POST" - let encoder = JSONEncoder() - request.httpBody = try encoder.encode(EmbeddingRequestBody( - input: text, - model: model.info.modelName - )) - request.setValue("application/json", forHTTPHeaderField: "Content-Type") - if !configuration.apiKey.isEmpty { - switch model.format { - case .openAI, .openAICompatible: - request.setValue( - "Bearer \(configuration.apiKey)", - forHTTPHeaderField: "Authorization" - ) - case .azureOpenAI: - request.setValue(configuration.apiKey, forHTTPHeaderField: "api-key") - } + let embeddingResponse: EmbeddingResponse + switch model.format { + case .openAI, .openAICompatible, .azureOpenAI: + embeddingResponse = try await OpenAIEmbeddingService( + apiKey: configuration.apiKey, + model: model, + endpoint: configuration.endpoint + ).embed(text: text) + case .ollama: + embeddingResponse = try await OllamaEmbeddingService( + model: model, + endpoint: configuration.endpoint + ).embed(text: text) } - let (result, response) = try await URLSession.shared.data(for: request) - guard let response = response as? HTTPURLResponse else { - throw ChatGPTServiceError.responseInvalid - } + #if DEBUG + Logger.service.info(""" + Embedding usage + - number of strings: \(text.count) + - prompt tokens: \(embeddingResponse.usage.prompt_tokens) + - total tokens: \(embeddingResponse.usage.total_tokens) + + """) + #endif + return embeddingResponse + } - guard response.statusCode == 200 else { - let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result) - throw error ?? ChatGPTServiceError - .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") + public func embed(text: [String]) async throws -> EmbeddingResponse { + guard let model = configuration.model else { + throw ChatGPTServiceError.embeddingModelNotAvailable + } + let embeddingResponse: EmbeddingResponse + switch model.format { + case .openAI, .openAICompatible, .azureOpenAI: + embeddingResponse = try await OpenAIEmbeddingService( + apiKey: configuration.apiKey, + model: model, + endpoint: configuration.endpoint + ).embed(texts: text) + case .ollama: + embeddingResponse = try await OllamaEmbeddingService( + model: model, + endpoint: configuration.endpoint + ).embed(texts: text) } - let embeddingResponse = try JSONDecoder().decode(EmbeddingResponse.self, from: result) #if DEBUG Logger.service.info(""" Embedding usage @@ -95,41 +74,21 @@ public struct EmbeddingService { guard let model = configuration.model else { throw ChatGPTServiceError.embeddingModelNotAvailable } - guard let url = URL(string: configuration.endpoint) else { - throw ChatGPTServiceError.endpointIncorrect - } - var request = URLRequest(url: url) - request.httpMethod = "POST" - let encoder = JSONEncoder() - request.httpBody = try encoder.encode(EmbeddingFromTokensRequestBody( - input: tokens, - model: model.info.modelName - )) - request.setValue("application/json", forHTTPHeaderField: "Content-Type") - if !configuration.apiKey.isEmpty { - switch model.format { - case .openAI, .openAICompatible: - request.setValue( - "Bearer \(configuration.apiKey)", - forHTTPHeaderField: "Authorization" - ) - case .azureOpenAI: - request.setValue(configuration.apiKey, forHTTPHeaderField: "api-key") - } - } - - let (result, response) = try await URLSession.shared.data(for: request) - guard let response = response as? HTTPURLResponse else { - throw ChatGPTServiceError.responseInvalid - } - - guard response.statusCode == 200 else { - let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result) - throw error ?? ChatGPTServiceError - .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") + let embeddingResponse: EmbeddingResponse + switch model.format { + case .openAI, .openAICompatible, .azureOpenAI: + embeddingResponse = try await OpenAIEmbeddingService( + apiKey: configuration.apiKey, + model: model, + endpoint: configuration.endpoint + ).embed(tokens: tokens) + case .ollama: + embeddingResponse = try await OllamaEmbeddingService( + model: model, + endpoint: configuration.endpoint + ).embed(tokens: tokens) } - let embeddingResponse = try JSONDecoder().decode(EmbeddingResponse.self, from: result) #if DEBUG Logger.service.info(""" Embedding usage diff --git a/Tool/Sources/OpenAIService/FucntionCall/ChatGPTFunction.swift b/Tool/Sources/OpenAIService/FucntionCall/ChatGPTFunction.swift index 420fc180..d2d8aaad 100644 --- a/Tool/Sources/OpenAIService/FucntionCall/ChatGPTFunction.swift +++ b/Tool/Sources/OpenAIService/FucntionCall/ChatGPTFunction.swift @@ -43,9 +43,14 @@ public extension ChatGPTFunction { argumentsJsonString: String, reportProgress: @escaping ReportProgress ) async throws -> Result { - let arguments = try JSONDecoder() - .decode(Arguments.self, from: argumentsJsonString.data(using: .utf8) ?? Data()) - return try await call(arguments: arguments, reportProgress: reportProgress) + do { + let arguments = try JSONDecoder() + .decode(Arguments.self, from: argumentsJsonString.data(using: .utf8) ?? Data()) + return try await call(arguments: arguments, reportProgress: reportProgress) + } catch { + await reportProgress("Error: Failed to decode arguments. \(error.localizedDescription)") + throw error + } } } diff --git a/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift b/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift index 07d72acb..3619a7e9 100644 --- a/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift +++ b/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift @@ -37,9 +37,16 @@ extension TokenEncoder { encodingContent.append(name) total += 1 } - if let functionCall = message.functionCall { - encodingContent.append(functionCall.name) - encodingContent.append(functionCall.arguments) + if let toolCalls = message.toolCalls { + for toolCall in toolCalls { + encodingContent.append(toolCall.id) + encodingContent.append(toolCall.type) + encodingContent.append(toolCall.function.name) + encodingContent.append(toolCall.function.arguments) + total += 4 + encodingContent.append(toolCall.response.content) + encodingContent.append(toolCall.id) + } } total += await withTaskGroup(of: Int.self, body: { group in for content in encodingContent { diff --git a/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift b/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift index d27569d6..33300ee8 100644 --- a/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift +++ b/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift @@ -23,7 +23,7 @@ public protocol ChatGPTMemory { func mutateHistory(_ update: (inout [ChatMessage]) -> Void) async /// Generate prompt that would be send through the API. /// - /// A memory should make sure that the history in the prompt + /// A memory should make sure that the history in the prompt /// doesn't exceed the maximum token count. /// /// The history can be different from the actual history. @@ -58,52 +58,95 @@ public extension ChatGPTMemory { } } + func streamToolCallResponse( + id: String, + toolCallId: String, + content: String? = nil, + summary: String? = nil + ) async { + await updateMessage(id: id) { message in + if let index = message.toolCalls?.firstIndex(where: { + $0.id == toolCallId + }) { + if let content { + message.toolCalls?[index].response.content = content + } + if let summary { + message.toolCalls?[index].response.summary = summary + } + } + } + } + /// Stream a message to the history. func streamMessage( id: String, role: ChatMessage.Role? = nil, content: String? = nil, name: String? = nil, - functionCall: ChatMessage.FunctionCall? = nil, + toolCalls: [Int: ChatMessage.ToolCall]? = nil, summary: String? = nil, references: [ChatMessage.Reference]? = nil ) async { - await mutateHistory { history in - if let index = history.firstIndex(where: { $0.id == id }) { + if await history.contains(where: { $0.id == id }) { + await updateMessage(id: id) { message in if let content { - if history[index].content == nil { - history[index].content = content + if message.content == nil { + message.content = content } else { - history[index].content?.append(content) + message.content?.append(content) } } if let role { - history[index].role = role + message.role = role } - if let functionCall { - if history[index].functionCall == nil { - history[index].functionCall = functionCall + if let toolCalls { + if var existedToolCalls = message.toolCalls { + for pair in toolCalls.sorted(by: { $0.key <= $1.key }) { + let (proposedIndex, toolCall) = pair + let index = { + if toolCall.id.isEmpty { return proposedIndex } + return existedToolCalls.lastIndex(where: { $0.id == toolCall.id }) + ?? proposedIndex + }() + if index < existedToolCalls.endIndex { + if !toolCall.id.isEmpty { + existedToolCalls[index].id = toolCall.id + } + if !toolCall.type.isEmpty { + existedToolCalls[index].type = toolCall.type + } + existedToolCalls[index].function.name + .append(toolCall.function.name) + existedToolCalls[index].function.arguments + .append(toolCall.function.arguments) + } else { + existedToolCalls.append(toolCall) + } + } + message.toolCalls = existedToolCalls } else { - history[index].functionCall?.name.append(functionCall.name) - history[index].functionCall?.arguments.append(functionCall.arguments) + message.toolCalls = toolCalls.sorted(by: { $0.key <= $1.key }).map(\.value) } } if let summary { - history[index].summary = summary + message.summary = summary } if let references { - history[index].references.append(contentsOf: references) + message.references.append(contentsOf: references) } if let name { - history[index].name = name + message.name = name } - } else { + } + } else { + await mutateHistory { history in history.append(.init( id: id, role: role ?? .system, content: content, name: name, - functionCall: functionCall, + toolCalls: toolCalls?.sorted(by: { $0.key <= $1.key }).map(\.value), summary: summary, references: references ?? [] )) diff --git a/Tool/Sources/OpenAIService/Models.swift b/Tool/Sources/OpenAIService/Models.swift index 8ff25b96..af95a6e5 100644 --- a/Tool/Sources/OpenAIService/Models.swift +++ b/Tool/Sources/OpenAIService/Models.swift @@ -15,7 +15,6 @@ public struct ChatMessage: Equatable, Codable { case system case user case assistant - case function } public struct FunctionCall: Codable, Equatable { @@ -27,6 +26,33 @@ public struct ChatMessage: Equatable, Codable { } } + public struct ToolCall: Codable, Equatable, Identifiable { + public var id: String + public var type: String + public var function: FunctionCall + public var response: ToolCallResponse + public init( + id: String, + type: String, + function: FunctionCall, + response: ToolCallResponse? = nil + ) { + self.id = id + self.type = type + self.function = function + self.response = response ?? .init(content: "", summary: nil) + } + } + + public struct ToolCallResponse: Codable, Equatable { + public var content: String + public var summary: String? + public init(content: String, summary: String?) { + self.content = content + self.summary = summary + } + } + public struct Reference: Codable, Equatable { public enum Kind: String, Codable { case `class` @@ -44,7 +70,7 @@ public struct ChatMessage: Equatable, Codable { case webpage case other } - + public var title: String public var subTitle: String public var uri: String @@ -74,6 +100,7 @@ public struct ChatMessage: Equatable, Codable { } /// The role of a message. + @FallbackDecoding public var role: Role /// The content of the message, either the chat message, or a result of a function call. @@ -82,7 +109,7 @@ public struct ChatMessage: Equatable, Codable { } /// A function call from the bot. - public var functionCall: FunctionCall? { + public var toolCalls: [ToolCall]? { didSet { tokensCount = nil } } @@ -107,7 +134,7 @@ public struct ChatMessage: Equatable, Codable { /// Is the message considered empty. var isEmpty: Bool { if let content, !content.isEmpty { return false } - if let functionCall, !functionCall.name.isEmpty { return false } + if let toolCalls, !toolCalls.isEmpty { return false } if let name, !name.isEmpty { return false } return true } @@ -117,7 +144,7 @@ public struct ChatMessage: Equatable, Codable { role: Role, content: String?, name: String? = nil, - functionCall: FunctionCall? = nil, + toolCalls: [ToolCall]? = nil, summary: String? = nil, tokenCount: Int? = nil, references: [Reference] = [] @@ -125,7 +152,7 @@ public struct ChatMessage: Equatable, Codable { self.role = role self.content = content self.name = name - self.functionCall = functionCall + self.toolCalls = toolCalls self.summary = summary self.id = id tokensCount = tokenCount @@ -137,3 +164,7 @@ public struct ReferenceKindFallback: FallbackValueProvider { public static var defaultValue: ChatMessage.Reference.Kind { .other } } +public struct ChatMessageRoleFallback: FallbackValueProvider { + public static var defaultValue: ChatMessage.Role { .user } +} + diff --git a/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift b/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift index 5349f85e..38bebbbe 100644 --- a/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift +++ b/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift @@ -14,7 +14,7 @@ final class ChatGPTStreamTests: XCTestCase { configuration: configuration, functionProvider: functionProvider ) - var requestBody: CompletionRequestBody? + var requestBody: ChatCompletionsRequestBody? service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in requestBody = _requestBody return MockCompletionStreamAPI_Message() @@ -60,7 +60,7 @@ final class ChatGPTStreamTests: XCTestCase { ), ], "History is not updated") - XCTAssertEqual(requestBody?.functions, nil, "Function schema is not submitted") + XCTAssertEqual(requestBody?.tools, nil, "Function schema is not submitted") } } @@ -75,7 +75,7 @@ final class ChatGPTStreamTests: XCTestCase { configuration: configuration, functionProvider: functionProvider ) - var requestBody: CompletionRequestBody? + var requestBody: ChatCompletionsRequestBody? service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in requestBody = _requestBody if _requestBody.messages.count <= 2 { @@ -93,7 +93,7 @@ final class ChatGPTStreamTests: XCTestCase { for try await text in stream { all.append(text) let history = await memory.history - XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000040.0") + XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000030.0") XCTAssertTrue( history.last?.content?.hasPrefix(all.joined()) ?? false, "History is not updated" @@ -105,9 +105,14 @@ final class ChatGPTStreamTests: XCTestCase { .init(role: .user, content: "Hello"), .init( role: .assistant, content: "", - function_call: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + toolCalls: [ + .init( + id: "id", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + )] ), - .init(role: .function, content: "Function is called.", name: "function"), + .init(role: .tool, content: "Function is called.", toolCallId: "id"), ], "System prompt is not included") XCTAssertEqual(all, ["hello", "my", "friends"], "Text stream is not correct") @@ -123,26 +128,33 @@ final class ChatGPTStreamTests: XCTestCase { id: "00000000-0000-0000-0000-0000000000010.0", role: .assistant, content: nil, - functionCall: .init(name: "function", arguments: "{\n\"foo\": 1\n}") - ), - .init( - id: "00000000-0000-0000-0000-000000000003", - role: .function, - content: "Function is called.", - name: "function", - summary: nil + toolCalls: [ + .init( + id: "id", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), + response: .init(content: "Function is called.", summary: nil) + ), + ] ), .init( - id: "00000000-0000-0000-0000-0000000000040.0", + id: "00000000-0000-0000-0000-0000000000030.0", role: .assistant, content: "hellomyfriends" ), ], "History is not updated") - XCTAssertEqual(requestBody?.functions, [ + XCTAssertEqual(requestBody?.tools, [ EmptyFunction(), ].map { - .init(name: $0.name, description: $0.description, parameters: $0.argumentSchema) + .init( + type: "function", + function: .init( + name: $0.name, + description: $0.description, + parameters: $0.argumentSchema + ) + ) }, "Function schema is not submitted") } } @@ -158,12 +170,12 @@ final class ChatGPTStreamTests: XCTestCase { configuration: configuration, functionProvider: functionProvider ) - var requestBody: CompletionRequestBody? + var requestBody: ChatCompletionsRequestBody? service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in requestBody = _requestBody if _requestBody.messages.count <= 4 { - return MockCompletionStreamAPI_Function() + return MockCompletionStreamAPI_Function(count: 3) } return MockCompletionStreamAPI_Message() } @@ -177,7 +189,7 @@ final class ChatGPTStreamTests: XCTestCase { for try await text in stream { all.append(text) let history = await memory.history - XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000070.0") + XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000030.0") XCTAssertTrue( history.last?.content?.hasPrefix(all.joined()) ?? false, "History is not updated" @@ -189,14 +201,39 @@ final class ChatGPTStreamTests: XCTestCase { .init(role: .user, content: "Hello"), .init( role: .assistant, content: "", - function_call: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + toolCalls: [ + .init( + id: "id", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + ), + .init( + id: "id2", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + ), + .init( + id: "id3", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + ), + ] ), - .init(role: .function, content: "Function is called.", name: "function"), .init( - role: .assistant, content: "", - function_call: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + role: .tool, + content: "Function is called.", + toolCallId: "id" + ), + .init( + role: .tool, + content: "Function is called.", + toolCallId: "id2" + ), + .init( + role: .tool, + content: "Function is called.", + toolCallId: "id3" ), - .init(role: .function, content: "Function is called.", name: "function"), ], "System prompt is not included") XCTAssertEqual(all, ["hello", "my", "friends"], "Text stream is not correct") @@ -212,39 +249,45 @@ final class ChatGPTStreamTests: XCTestCase { id: "00000000-0000-0000-0000-0000000000010.0", role: .assistant, content: nil, - functionCall: .init(name: "function", arguments: "{\n\"foo\": 1\n}") - ), - .init( - id: "00000000-0000-0000-0000-000000000003", - role: .function, - content: "Function is called.", - name: "function", - summary: nil - ), - .init( - id: "00000000-0000-0000-0000-0000000000040.0", - role: .assistant, - content: nil, - functionCall: .init(name: "function", arguments: "{\n\"foo\": 1\n}") - ), - .init( - id: "00000000-0000-0000-0000-000000000006", - role: .function, - content: "Function is called.", - name: "function", - summary: nil + toolCalls: [ + .init( + id: "id", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), + response: .init(content: "Function is called.", summary: nil) + ), + .init( + id: "id2", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), + response: .init(content: "Function is called.", summary: nil) + ), + .init( + id: "id3", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), + response: .init(content: "Function is called.", summary: nil) + ), + ] ), .init( - id: "00000000-0000-0000-0000-0000000000070.0", + id: "00000000-0000-0000-0000-0000000000030.0", role: .assistant, content: "hellomyfriends" ), ], "History is not updated") - XCTAssertEqual(requestBody?.functions, [ + XCTAssertEqual(requestBody?.tools, [ EmptyFunction(), ].map { - .init(name: $0.name, description: $0.description, parameters: $0.argumentSchema) + .init( + type: "function", + function: .init( + name: $0.name, + description: $0.description, + parameters: $0.argumentSchema + ) + ) }, "Function schema is not submitted") } } @@ -265,7 +308,7 @@ final class ChatGPTStreamTests: XCTestCase { configuration: configuration, functionProvider: functionProvider ) - var requestBody: CompletionRequestBody? + var requestBody: ChatCompletionsRequestBody? service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in requestBody = _requestBody if _requestBody.messages.count <= 2 { @@ -283,7 +326,7 @@ final class ChatGPTStreamTests: XCTestCase { for try await text in stream { all.append(text) let history = await memory.history - XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000040.0") + XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000030.0") XCTAssertTrue( history.last?.content?.hasPrefix(all.joined()) ?? false, "History is not updated" @@ -294,10 +337,9 @@ final class ChatGPTStreamTests: XCTestCase { .init(role: .system, content: "system"), .init(role: .user, content: "Hello"), .init( - role: .assistant, content: "", - function_call: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + role: .assistant, content: "" ), - .init(role: .function, content: "Function is called.", name: "function"), + .init(role: .user, content: "Function is called."), ], "System prompt is not included") XCTAssertEqual(all, ["hello", "my", "friends"], "Text stream is not correct") @@ -313,48 +355,64 @@ final class ChatGPTStreamTests: XCTestCase { id: "00000000-0000-0000-0000-0000000000010.0", role: .assistant, content: nil, - functionCall: .init(name: "function", arguments: "{\n\"foo\": 1\n}") - ), - .init( - id: "00000000-0000-0000-0000-000000000003", - role: .function, - content: "Function is called.", - name: "function", - summary: nil + toolCalls: [ + .init( + id: "id", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), + response: .init(content: "Function is called.", summary: nil) + ), + ] ), .init( - id: "00000000-0000-0000-0000-0000000000040.0", + id: "00000000-0000-0000-0000-0000000000030.0", role: .assistant, content: "hellomyfriends" ), ], "History is not updated") - XCTAssertEqual(requestBody?.functions, nil, "Functions should be nil") + XCTAssertEqual(requestBody?.tools, nil, "Functions should be nil") } } } extension ChatGPTStreamTests { - struct MockCompletionStreamAPI_Message: CompletionStreamAPI { + struct MockCompletionStreamAPI_Message: ChatCompletionsStreamAPI { @Dependency(\.uuid) var uuid func callAsFunction() async throws - -> AsyncThrowingStream + -> AsyncThrowingStream { let id = uuid().uuidString - return AsyncThrowingStream { continuation in - let chunks: [CompletionStreamDataChunk] = [ - .init(id: id, object: "", model: "", choices: [ - .init(delta: .init(role: .assistant), index: 0, finish_reason: ""), - ]), - .init(id: id, object: "", model: "", choices: [ - .init(delta: .init(content: "hello"), index: 0, finish_reason: ""), - ]), - .init(id: id, object: "", model: "", choices: [ - .init(delta: .init(content: "my"), index: 0, finish_reason: ""), - ]), - .init(id: id, object: "", model: "", choices: [ - .init(delta: .init(content: "friends"), index: 0, finish_reason: ""), - ]), + return AsyncThrowingStream { continuation in + let chunks: [ChatCompletionsStreamDataChunk] = [ + .init( + id: id, + object: "", + model: "", + message: .init(role: .assistant), + finishReason: "" + ), + .init( + id: id, + object: "", + model: "", + message: .init(content: "hello"), + finishReason: "" + ), + .init( + id: id, + object: "", + model: "", + message: .init(content: "my"), + finishReason: "" + ), + .init( + id: id, + object: "", + model: "", + message: .init(content: "friends"), + finishReason: "" + ), ] for chunk in chunks { continuation.yield(chunk) @@ -364,53 +422,89 @@ extension ChatGPTStreamTests { } } - struct MockCompletionStreamAPI_Function: CompletionStreamAPI { + struct MockCompletionStreamAPI_Function: ChatCompletionsStreamAPI { @Dependency(\.uuid) var uuid + var count: Int = 1 func callAsFunction() async throws - -> AsyncThrowingStream + -> AsyncThrowingStream { let id = uuid().uuidString - return AsyncThrowingStream { continuation in - let chunks: [CompletionStreamDataChunk] = [ - .init(id: id, object: "", model: "", choices: [ + return AsyncThrowingStream { continuation in + for i in 0.. Copilot for Xcode + + 0.31.0 + Fri, 08 Mar 2024 14:48:32 +0800 + 333 + 0.31.0 + 12.0 + + https://github.com/intitni/CopilotForXcode/releases/tag/0.31.0 + + + 0.31.0