From 9de477454cf658fe79a86615ca035135a09ab901 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Fri, 1 Mar 2024 21:34:05 +0800 Subject: [PATCH 01/37] Add ResponseStream --- .../OpenAIService/APIs/ResponseStream.swift | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 Tool/Sources/OpenAIService/APIs/ResponseStream.swift diff --git a/Tool/Sources/OpenAIService/APIs/ResponseStream.swift b/Tool/Sources/OpenAIService/APIs/ResponseStream.swift new file mode 100644 index 00000000..ce28b7f9 --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/ResponseStream.swift @@ -0,0 +1,45 @@ +import Foundation + +struct ResponseStream: AsyncSequence { + func makeAsyncIterator() -> Stream.AsyncIterator { + stream.makeAsyncIterator() + } + + typealias Stream = AsyncThrowingStream + typealias AsyncIterator = Stream.AsyncIterator + typealias Element = Chunk + + struct LineContent { + let chunk: Chunk? + let done: Bool + } + + let stream: Stream + + init(result: URLSession.AsyncBytes, lineExtractor: @escaping (String) throws -> LineContent) { + stream = AsyncThrowingStream { continuation in + let task = Task { + do { + for try await line in result.lines { + if Task.isCancelled { break } + let content = try lineExtractor(line) + if let chunk = content.chunk { + continuation.yield(chunk) + } + + if content.done { break } + } + continuation.finish() + } catch { + continuation.finish(throwing: error) + result.task.cancel() + } + } + continuation.onTermination = { _ in + task.cancel() + result.task.cancel() + } + } + } +} + From 99dbf8f3e9707b246b3d3bea7cc78f5242e139de Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Fri, 1 Mar 2024 23:31:27 +0800 Subject: [PATCH 02/37] Move definitions to CompletionsAPIDefinition --- .../APIs/CompletionsAPIDefinition.swift | 209 ++++++++++++++++++ .../APIs/OpenAICompletionAPI.swift | 50 +---- .../APIs/OpenAICompletionStreamAPI.swift | 151 ------------- 3 files changed, 210 insertions(+), 200 deletions(-) create mode 100644 Tool/Sources/OpenAIService/APIs/CompletionsAPIDefinition.swift diff --git a/Tool/Sources/OpenAIService/APIs/CompletionsAPIDefinition.swift b/Tool/Sources/OpenAIService/APIs/CompletionsAPIDefinition.swift new file mode 100644 index 00000000..e8b0cd72 --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/CompletionsAPIDefinition.swift @@ -0,0 +1,209 @@ +import AIModel +import Foundation +import Preferences + +/// https://platform.openai.com/docs/api-reference/chat/create +struct CompletionRequestBody: Codable, Equatable { + struct Message: Codable, Equatable { + /// The role of the message. + var role: ChatMessage.Role + /// The content of the message. + var content: String + /// When we want to reply to a function call with the result, we have to provide the + /// name of the function call, and include the result in `content`. + /// + /// - important: It's required when the role is `function`. + var name: String? + /// When the bot wants to call a function, it will reply with a function call in format: + /// ```json + /// { + /// "name": "weather", + /// "arguments": "{ \"location\": \"earth\" }" + /// } + /// ``` + var function_call: CompletionRequestBody.MessageFunctionCall? + } + + struct MessageFunctionCall: Codable, Equatable { + /// The name of the + var name: String + /// A JSON string. + var arguments: String? + } + + struct Function: Codable { + var name: String + var description: String + /// JSON schema. + var arguments: String + } + + var model: String + var messages: [Message] + var temperature: Double? + var top_p: Double? + var n: Double? + var stream: Bool? + var stop: [String]? + var max_tokens: Int? + var presence_penalty: Double? + var frequency_penalty: Double? + var logit_bias: [String: Double]? + var user: String? + /// Pass nil to let the bot decide. + var function_call: FunctionCallStrategy? + var functions: [ChatGPTFunctionSchema]? + + init( + model: String, + messages: [Message], + temperature: Double? = nil, + top_p: Double? = nil, + n: Double? = nil, + stream: Bool? = nil, + stop: [String]? = nil, + max_tokens: Int? = nil, + presence_penalty: Double? = nil, + frequency_penalty: Double? = nil, + logit_bias: [String: Double]? = nil, + user: String? = nil, + function_call: FunctionCallStrategy? = nil, + functions: [ChatGPTFunctionSchema] = [] + ) { + self.model = model + self.messages = messages + self.temperature = temperature + self.top_p = top_p + self.n = n + self.stream = stream + self.stop = stop + self.max_tokens = max_tokens + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.logit_bias = logit_bias + self.user = user + if UserDefaults.shared.value(for: \.disableFunctionCalling) { + self.function_call = nil + self.functions = nil + } else { + self.function_call = function_call + self.functions = functions.isEmpty ? nil : functions + } + } +} + +public enum FunctionCallStrategy: Codable, Equatable { + /// Forbid the bot to call any function. + case none + /// Let the bot choose what function to call. + case auto + /// Force the bot to call a function with the given name. + case name(String) + + struct CallFunctionNamed: Codable { + var name: String + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .none: + try container.encode("none") + case .auto: + try container.encode("auto") + case let .name(name): + try container.encode(CallFunctionNamed(name: name)) + } + } +} + +// MARK: - Stream API + +typealias CompletionStreamAPIBuilder = ( + String, + ChatModel, + URL, + CompletionRequestBody, + ChatGPTPrompt +) -> any CompletionStreamAPI + +protocol CompletionStreamAPI { + associatedtype CompletionSequence: AsyncSequence + where CompletionSequence.Element == CompletionStreamDataChunk + func callAsFunction() async throws -> CompletionSequence +} + +struct CompletionStreamDataChunk: Codable { + var id: String? + var object: String? + var model: String? + var choices: [Choice]? + + struct Choice: Codable { + var delta: Delta? + var index: Int? + var finish_reason: String? + + struct Delta: Codable { + struct FunctionCall: Codable { + var name: String? + var arguments: String? + } + + var role: ChatMessage.Role? + var content: String? + var function_call: FunctionCall? + } + } +} + +// MARK: - Non Stream API + +typealias CompletionAPIBuilder = (String, ChatModel, URL, CompletionRequestBody, ChatGPTPrompt) + -> any CompletionAPI + +protocol CompletionAPI { + func callAsFunction() async throws -> CompletionResponseBody +} + +/// https://platform.openai.com/docs/api-reference/chat/create +struct CompletionResponseBody: Codable, Equatable { + struct Message: Codable, Equatable { + /// The role of the message. + var role: ChatMessage.Role + /// The content of the message. + var content: String? + /// When we want to reply to a function call with the result, we have to provide the + /// name of the function call, and include the result in `content`. + /// + /// - important: It's required when the role is `function`. + var name: String? + /// When the bot wants to call a function, it will reply with a function call in format: + /// ```json + /// { + /// "name": "weather", + /// "arguments": "{ \"location\": \"earth\" }" + /// } + /// ``` + var function_call: CompletionRequestBody.MessageFunctionCall? + } + + struct Choice: Codable, Equatable { + var message: Message + var index: Int + var finish_reason: String + } + + struct Usage: Codable, Equatable { + var prompt_tokens: Int + var completion_tokens: Int + var total_tokens: Int + } + + var id: String? + var object: String + var model: String + var usage: Usage + var choices: [Choice] +} + diff --git a/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift b/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift index 31e86492..2a058f16 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift @@ -2,54 +2,6 @@ import AIModel import Foundation import Preferences -typealias CompletionAPIBuilder = (String, ChatModel, URL, CompletionRequestBody, ChatGPTPrompt) - -> CompletionAPI - -protocol CompletionAPI { - func callAsFunction() async throws -> CompletionResponseBody -} - -/// https://platform.openai.com/docs/api-reference/chat/create -struct CompletionResponseBody: Codable, Equatable { - struct Message: Codable, Equatable { - /// The role of the message. - var role: ChatMessage.Role - /// The content of the message. - var content: String? - /// When we want to reply to a function call with the result, we have to provide the - /// name of the function call, and include the result in `content`. - /// - /// - important: It's required when the role is `function`. - var name: String? - /// When the bot wants to call a function, it will reply with a function call in format: - /// ```json - /// { - /// "name": "weather", - /// "arguments": "{ \"location\": \"earth\" }" - /// } - /// ``` - var function_call: CompletionRequestBody.MessageFunctionCall? - } - - struct Choice: Codable, Equatable { - var message: Message - var index: Int - var finish_reason: String - } - - struct Usage: Codable, Equatable { - var prompt_tokens: Int - var completion_tokens: Int - var total_tokens: Int - } - - var id: String? - var object: String - var model: String - var usage: Usage - var choices: [Choice] -} - struct CompletionAPIError: Error, Codable, LocalizedError { struct E: Codable { var message: String @@ -95,7 +47,7 @@ struct OpenAICompletionAPI: CompletionAPI { case .azureOpenAI: request.setValue(apiKey, forHTTPHeaderField: "api-key") case .googleAI: - assert(false, "Unsupported") + assertionFailure("Unsupported") } } diff --git a/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift b/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift index 46c6b1ff..afaafe9e 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift @@ -3,157 +3,6 @@ import AsyncAlgorithms import Foundation import Preferences -typealias CompletionStreamAPIBuilder = ( - String, - ChatModel, - URL, - CompletionRequestBody, - ChatGPTPrompt -) -> any CompletionStreamAPI - -protocol CompletionStreamAPI { - func callAsFunction() async throws -> AsyncThrowingStream -} - -public enum FunctionCallStrategy: Codable, Equatable { - /// Forbid the bot to call any function. - case none - /// Let the bot choose what function to call. - case auto - /// Force the bot to call a function with the given name. - case name(String) - - struct CallFunctionNamed: Codable { - var name: String - } - - public func encode(to encoder: Encoder) throws { - var container = encoder.singleValueContainer() - switch self { - case .none: - try container.encode("none") - case .auto: - try container.encode("auto") - case let .name(name): - try container.encode(CallFunctionNamed(name: name)) - } - } -} - -/// https://platform.openai.com/docs/api-reference/chat/create -struct CompletionRequestBody: Codable, Equatable { - struct Message: Codable, Equatable { - /// The role of the message. - var role: ChatMessage.Role - /// The content of the message. - var content: String - /// When we want to reply to a function call with the result, we have to provide the - /// name of the function call, and include the result in `content`. - /// - /// - important: It's required when the role is `function`. - var name: String? - /// When the bot wants to call a function, it will reply with a function call in format: - /// ```json - /// { - /// "name": "weather", - /// "arguments": "{ \"location\": \"earth\" }" - /// } - /// ``` - var function_call: CompletionRequestBody.MessageFunctionCall? - } - - struct MessageFunctionCall: Codable, Equatable { - /// The name of the - var name: String - /// A JSON string. - var arguments: String? - } - - struct Function: Codable { - var name: String - var description: String - /// JSON schema. - var arguments: String - } - - var model: String - var messages: [Message] - var temperature: Double? - var top_p: Double? - var n: Double? - var stream: Bool? - var stop: [String]? - var max_tokens: Int? - var presence_penalty: Double? - var frequency_penalty: Double? - var logit_bias: [String: Double]? - var user: String? - /// Pass nil to let the bot decide. - var function_call: FunctionCallStrategy? - var functions: [ChatGPTFunctionSchema]? - - init( - model: String, - messages: [Message], - temperature: Double? = nil, - top_p: Double? = nil, - n: Double? = nil, - stream: Bool? = nil, - stop: [String]? = nil, - max_tokens: Int? = nil, - presence_penalty: Double? = nil, - frequency_penalty: Double? = nil, - logit_bias: [String: Double]? = nil, - user: String? = nil, - function_call: FunctionCallStrategy? = nil, - functions: [ChatGPTFunctionSchema] = [] - ) { - self.model = model - self.messages = messages - self.temperature = temperature - self.top_p = top_p - self.n = n - self.stream = stream - self.stop = stop - self.max_tokens = max_tokens - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - self.logit_bias = logit_bias - self.user = user - if UserDefaults.shared.value(for: \.disableFunctionCalling) { - self.function_call = nil - self.functions = nil - } else { - self.function_call = function_call - self.functions = functions.isEmpty ? nil : functions - } - } -} - -struct CompletionStreamDataChunk: Codable { - var id: String? - var object: String? - var model: String? - var choices: [Choice]? - - struct Choice: Codable { - var delta: Delta? - var index: Int? - var finish_reason: String? - - struct Delta: Codable { - struct FunctionCall: Codable { - var name: String? - var arguments: String? - } - - var role: ChatMessage.Role? - var content: String? - var function_call: FunctionCall? - } - } -} - struct OpenAICompletionStreamAPI: CompletionStreamAPI { var apiKey: String var endpoint: URL From 1b213afaea143e4353d2b454f0e1de4a0be68e98 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Fri, 1 Mar 2024 23:31:45 +0800 Subject: [PATCH 03/37] Rename file --- ...ionsAPIDefinition.swift => ChatCompletionsAPIDefinition.swift} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename Tool/Sources/OpenAIService/APIs/{CompletionsAPIDefinition.swift => ChatCompletionsAPIDefinition.swift} (100%) diff --git a/Tool/Sources/OpenAIService/APIs/CompletionsAPIDefinition.swift b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift similarity index 100% rename from Tool/Sources/OpenAIService/APIs/CompletionsAPIDefinition.swift rename to Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift From fd7f92d04f75a1983fec9fa9cf2e64048b5c30cc Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Fri, 1 Mar 2024 23:35:51 +0800 Subject: [PATCH 04/37] Rename types --- Pro | 2 +- .../APIs/ChatCompletionsAPIDefinition.swift | 28 +-- .../APIs/GoogleAICompletionAPI.swift | 6 +- .../APIs/GoogleAICompletionStreamAPI.swift | 10 +- .../OpenAIService/APIs/OlamaService.swift | 185 ++++++++++++++++++ .../APIs/OpenAICompletionAPI.swift | 10 +- .../APIs/OpenAICompletionStreamAPI.swift | 12 +- .../OpenAIService/ChatGPTService.swift | 14 +- Tool/Sources/OpenAIService/Debug/Debug.swift | 2 +- .../ChatGPTStreamTests.swift | 24 +-- 10 files changed, 239 insertions(+), 54 deletions(-) create mode 100644 Tool/Sources/OpenAIService/APIs/OlamaService.swift diff --git a/Pro b/Pro index 908dd291..5f1f1dd2 160000 --- a/Pro +++ b/Pro @@ -1 +1 @@ -Subproject commit 908dd2919e3da89cb05e6d57cad8228c2df08846 +Subproject commit 5f1f1dd24c3a6ec27acc3d7f252b8775ae9beea6 diff --git a/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift index e8b0cd72..582f7296 100644 --- a/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift +++ b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift @@ -3,7 +3,7 @@ import Foundation import Preferences /// https://platform.openai.com/docs/api-reference/chat/create -struct CompletionRequestBody: Codable, Equatable { +struct ChatCompletionsRequestBody: Codable, Equatable { struct Message: Codable, Equatable { /// The role of the message. var role: ChatMessage.Role @@ -21,7 +21,7 @@ struct CompletionRequestBody: Codable, Equatable { /// "arguments": "{ \"location\": \"earth\" }" /// } /// ``` - var function_call: CompletionRequestBody.MessageFunctionCall? + var function_call: ChatCompletionsRequestBody.MessageFunctionCall? } struct MessageFunctionCall: Codable, Equatable { @@ -119,21 +119,21 @@ public enum FunctionCallStrategy: Codable, Equatable { // MARK: - Stream API -typealias CompletionStreamAPIBuilder = ( +typealias ChatCompletionsStreamAPIBuilder = ( String, ChatModel, URL, - CompletionRequestBody, + ChatCompletionsRequestBody, ChatGPTPrompt -) -> any CompletionStreamAPI +) -> any ChatCompletionsStreamAPI -protocol CompletionStreamAPI { +protocol ChatCompletionsStreamAPI { associatedtype CompletionSequence: AsyncSequence - where CompletionSequence.Element == CompletionStreamDataChunk + where CompletionSequence.Element == ChatCompletionsStreamDataChunk func callAsFunction() async throws -> CompletionSequence } -struct CompletionStreamDataChunk: Codable { +struct ChatCompletionsStreamDataChunk: Codable { var id: String? var object: String? var model: String? @@ -159,15 +159,15 @@ struct CompletionStreamDataChunk: Codable { // MARK: - Non Stream API -typealias CompletionAPIBuilder = (String, ChatModel, URL, CompletionRequestBody, ChatGPTPrompt) - -> any CompletionAPI +typealias ChatCompletionsAPIBuilder = (String, ChatModel, URL, ChatCompletionsRequestBody, ChatGPTPrompt) + -> any ChatCompletionsAPI -protocol CompletionAPI { - func callAsFunction() async throws -> CompletionResponseBody +protocol ChatCompletionsAPI { + func callAsFunction() async throws -> ChatCompletionResponseBody } /// https://platform.openai.com/docs/api-reference/chat/create -struct CompletionResponseBody: Codable, Equatable { +struct ChatCompletionResponseBody: Codable, Equatable { struct Message: Codable, Equatable { /// The role of the message. var role: ChatMessage.Role @@ -185,7 +185,7 @@ struct CompletionResponseBody: Codable, Equatable { /// "arguments": "{ \"location\": \"earth\" }" /// } /// ``` - var function_call: CompletionRequestBody.MessageFunctionCall? + var function_call: ChatCompletionsRequestBody.MessageFunctionCall? } struct Choice: Codable, Equatable { diff --git a/Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift b/Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift index ded6e372..8cd916db 100644 --- a/Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift +++ b/Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift @@ -3,13 +3,13 @@ import Foundation import GoogleGenerativeAI import Preferences -struct GoogleCompletionAPI: CompletionAPI { +struct GoogleCompletionAPI: ChatCompletionsAPI { let apiKey: String let model: ChatModel - var requestBody: CompletionRequestBody + var requestBody: ChatCompletionsRequestBody let prompt: ChatGPTPrompt - func callAsFunction() async throws -> CompletionResponseBody { + func callAsFunction() async throws -> ChatCompletionResponseBody { let aiModel = GenerativeModel( name: model.info.modelName, apiKey: apiKey, diff --git a/Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift b/Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift index 47492340..ee2b4895 100644 --- a/Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift +++ b/Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift @@ -3,13 +3,13 @@ import Foundation import GoogleGenerativeAI import Preferences -struct GoogleCompletionStreamAPI: CompletionStreamAPI { +struct GoogleCompletionStreamAPI: ChatCompletionsStreamAPI { let apiKey: String let model: ChatModel - var requestBody: CompletionRequestBody + var requestBody: ChatCompletionsRequestBody let prompt: ChatGPTPrompt - func callAsFunction() async throws -> AsyncThrowingStream { + func callAsFunction() async throws -> AsyncThrowingStream { let aiModel = GenerativeModel( name: model.info.modelName, apiKey: apiKey, @@ -31,13 +31,13 @@ struct GoogleCompletionStreamAPI: CompletionStreamAPI { ) } - let stream = AsyncThrowingStream { continuation in + let stream = AsyncThrowingStream { continuation in let stream = aiModel.generateContentStream(history) let task = Task { do { for try await response in stream { if Task.isCancelled { break } - let chunk = CompletionStreamDataChunk( + let chunk = ChatCompletionsStreamDataChunk( object: "", model: model.info.modelName, choices: response.candidates.map { candidate in diff --git a/Tool/Sources/OpenAIService/APIs/OlamaService.swift b/Tool/Sources/OpenAIService/APIs/OlamaService.swift new file mode 100644 index 00000000..0a4ec0c9 --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/OlamaService.swift @@ -0,0 +1,185 @@ +import AIModel +import Foundation +import Preferences + +public actor OllamaService { + var apiKey: String + var endpoint: URL + var requestBody: ChatCompletionsRequestBody + var model: ChatModel + + public enum ResponseFormat: String { + case none = "" + case json + } + + init( + apiKey: String, + model: ChatModel, + endpoint: URL, + requestBody: ChatCompletionsRequestBody + ) { + self.apiKey = apiKey + self.endpoint = endpoint + self.requestBody = requestBody + self.model = model + } +} + +extension OllamaService: ChatCompletionsAPI { + func callAsFunction() async throws -> ChatCompletionResponseBody { + fatalError() + } +} + +extension OllamaService: ChatCompletionsStreamAPI { + typealias CompletionSequence = AsyncMapSequence, ChatCompletionsStreamDataChunk> + + func callAsFunction() async throws -> CompletionSequence { + let requestBody = ChatCompletionRequestBody( + model: model.info.modelName, + messages: requestBody.messages.map { message in + .init(role: { + switch message.role { + case .assistant: + return .assistant + case .user: + return .user + case .system: + return .system + case .function: + return .user + } + }(), content: message.content) + }, + stream: true, + options: .init( + temperature: requestBody.temperature, + stop: requestBody.stop, + num_predict: requestBody.max_tokens + ), + keep_alive: nil, + format: nil + ) + + var request = URLRequest(url: endpoint) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(requestBody) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + let (result, response) = try await URLSession.shared.bytes(for: request) + + guard let response = response as? HTTPURLResponse else { + throw CancellationError() + } + + guard response.statusCode == 200 else { + let text = try await result.lines.reduce(into: "") { partialResult, current in + partialResult += current + } + throw Error.otherError(text) + } + + let stream = ResponseStream(result: result) { + let chunk = try JSONDecoder().decode( + ChatCompletionResponseChunk.self, + from: $0.data(using: .utf8) ?? Data() + ) + return .init(chunk: chunk, done: chunk.done) + } + + let sequence = stream.map { chunk in + ChatCompletionsStreamDataChunk( + id: UUID().uuidString, + object: chunk.model, + model: chunk.model, + choices: [ + .init( + delta: .init( + role: { + switch chunk.message?.role { + case .none: + return nil + case .assistant: + return .assistant + case .user: + return .user + case .system: + return .system + } + }(), + content: chunk.message?.content + ) + ), + ] + ) + } + + return sequence + } +} + +extension OllamaService { + struct Message: Codable, Equatable { + public enum Role: String, Codable { + case user + case assistant + case system + } + + /// The role of the message. + public var role: Role + /// The content of the message. + public var content: String + } + + enum Error: Swift.Error, LocalizedError { + case decodeError(Swift.Error) + case otherError(String) + + public var errorDescription: String? { + switch self { + case let .decodeError(error): + return error.localizedDescription + case let .otherError(message): + return message + } + } + } +} + +// MARK: - Chat Completion API + +/// https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-streaming +extension OllamaService { + struct ChatCompletionRequestBody: Codable { + struct Options: Codable { + var temperature: Double? + var stop: [String]? + var num_predict: Int? + var top_k: Int? + var top_p: Double? + } + + var model: String + var messages: [Message] + var stream: Bool + var options: Options + var keep_alive: String? + var format: String? + } + + struct ChatCompletionResponseChunk: Decodable { + var model: String + var message: Message? + var response: String? + var done: Bool + var total_duration: Int64? + var load_duration: Int64? + var prompt_eval_count: Int? + var prompt_eval_duration: Int64? + var eval_count: Int? + var eval_duration: Int64? + } +} + diff --git a/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift b/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift index 2a058f16..2426a9d0 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift @@ -15,17 +15,17 @@ struct CompletionAPIError: Error, Codable, LocalizedError { var errorDescription: String? { error.message } } -struct OpenAICompletionAPI: CompletionAPI { +struct OpenAICompletionAPI: ChatCompletionsAPI { var apiKey: String var endpoint: URL - var requestBody: CompletionRequestBody + var requestBody: ChatCompletionsRequestBody var model: ChatModel init( apiKey: String, model: ChatModel, endpoint: URL, - requestBody: CompletionRequestBody + requestBody: ChatCompletionsRequestBody ) { self.apiKey = apiKey self.endpoint = endpoint @@ -34,7 +34,7 @@ struct OpenAICompletionAPI: CompletionAPI { self.model = model } - func callAsFunction() async throws -> CompletionResponseBody { + func callAsFunction() async throws -> ChatCompletionResponseBody { var request = URLRequest(url: endpoint) request.httpMethod = "POST" let encoder = JSONEncoder() @@ -63,7 +63,7 @@ struct OpenAICompletionAPI: CompletionAPI { } do { - return try JSONDecoder().decode(CompletionResponseBody.self, from: result) + return try JSONDecoder().decode(ChatCompletionResponseBody.self, from: result) } catch { dump(error) throw error diff --git a/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift b/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift index afaafe9e..c662724b 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift @@ -3,17 +3,17 @@ import AsyncAlgorithms import Foundation import Preferences -struct OpenAICompletionStreamAPI: CompletionStreamAPI { +struct OpenAICompletionStreamAPI: ChatCompletionsStreamAPI { var apiKey: String var endpoint: URL - var requestBody: CompletionRequestBody + var requestBody: ChatCompletionsRequestBody var model: ChatModel init( apiKey: String, model: ChatModel, endpoint: URL, - requestBody: CompletionRequestBody + requestBody: ChatCompletionsRequestBody ) { self.apiKey = apiKey self.endpoint = endpoint @@ -22,7 +22,7 @@ struct OpenAICompletionStreamAPI: CompletionStreamAPI { self.model = model } - func callAsFunction() async throws -> AsyncThrowingStream { + func callAsFunction() async throws -> AsyncThrowingStream { var request = URLRequest(url: endpoint) request.httpMethod = "POST" let encoder = JSONEncoder() @@ -55,7 +55,7 @@ struct OpenAICompletionStreamAPI: CompletionStreamAPI { throw error ?? ChatGPTServiceError.responseInvalid } - let stream = AsyncThrowingStream { continuation in + let stream = AsyncThrowingStream { continuation in let task = Task { do { for try await line in result.lines { @@ -64,7 +64,7 @@ struct OpenAICompletionStreamAPI: CompletionStreamAPI { guard line.hasPrefix(prefix), let content = line.dropFirst(prefix.count).data(using: .utf8), let chunk = try? JSONDecoder() - .decode(CompletionStreamDataChunk.self, from: content) + .decode(ChatCompletionsStreamDataChunk.self, from: content) else { continue } continuation.yield(chunk) } diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index 5d1480ca..430e361e 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -69,7 +69,7 @@ public class ChatGPTService: ChatGPTServiceType { public var functionProvider: ChatGPTFunctionProvider var runningTask: Task? - var buildCompletionStreamAPI: CompletionStreamAPIBuilder = { + var buildCompletionStreamAPI: ChatCompletionsStreamAPIBuilder = { apiKey, model, endpoint, requestBody, prompt in switch model.format { case .googleAI: @@ -89,7 +89,7 @@ public class ChatGPTService: ChatGPTServiceType { } } - var buildCompletionAPI: CompletionAPIBuilder = { + var buildCompletionAPI: ChatCompletionsAPIBuilder = { apiKey, model, endpoint, requestBody, prompt in switch model.format { case .googleAI: @@ -275,7 +275,7 @@ extension ChatGPTService { } let messages = prompt.history.map { - CompletionRequestBody.Message( + ChatCompletionsRequestBody.Message( role: $0.role, content: $0.content ?? "", name: $0.name, @@ -286,7 +286,7 @@ extension ChatGPTService { } let remainingTokens = prompt.remainingTokenCount - let requestBody = CompletionRequestBody( + let requestBody = ChatCompletionsRequestBody( model: model.info.modelName, messages: messages, temperature: configuration.temperature, @@ -403,7 +403,7 @@ extension ChatGPTService { } let messages = prompt.history.map { - CompletionRequestBody.Message( + ChatCompletionsRequestBody.Message( role: $0.role, content: $0.content ?? "", name: $0.name, @@ -414,7 +414,7 @@ extension ChatGPTService { } let remainingTokens = prompt.remainingTokenCount - let requestBody = CompletionRequestBody( + let requestBody = ChatCompletionsRequestBody( model: model.info.modelName, messages: messages, temperature: configuration.temperature, @@ -582,7 +582,7 @@ extension ChatGPTService { } extension ChatGPTService { - func changeBuildCompletionStreamAPI(_ builder: @escaping CompletionStreamAPIBuilder) { + func changeBuildCompletionStreamAPI(_ builder: @escaping ChatCompletionsStreamAPIBuilder) { buildCompletionStreamAPI = builder } } diff --git a/Tool/Sources/OpenAIService/Debug/Debug.swift b/Tool/Sources/OpenAIService/Debug/Debug.swift index b27358e0..31864964 100644 --- a/Tool/Sources/OpenAIService/Debug/Debug.swift +++ b/Tool/Sources/OpenAIService/Debug/Debug.swift @@ -6,7 +6,7 @@ enum Debugger { static var id: UUID? #if DEBUG - static func didSendRequestBody(body: CompletionRequestBody) { + static func didSendRequestBody(body: ChatCompletionsRequestBody) { do { let json = try JSONEncoder().encode(body) let center = NotificationCenter.default diff --git a/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift b/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift index 5349f85e..7e90445b 100644 --- a/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift +++ b/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift @@ -14,7 +14,7 @@ final class ChatGPTStreamTests: XCTestCase { configuration: configuration, functionProvider: functionProvider ) - var requestBody: CompletionRequestBody? + var requestBody: ChatCompletionsRequestBody? service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in requestBody = _requestBody return MockCompletionStreamAPI_Message() @@ -75,7 +75,7 @@ final class ChatGPTStreamTests: XCTestCase { configuration: configuration, functionProvider: functionProvider ) - var requestBody: CompletionRequestBody? + var requestBody: ChatCompletionsRequestBody? service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in requestBody = _requestBody if _requestBody.messages.count <= 2 { @@ -158,7 +158,7 @@ final class ChatGPTStreamTests: XCTestCase { configuration: configuration, functionProvider: functionProvider ) - var requestBody: CompletionRequestBody? + var requestBody: ChatCompletionsRequestBody? service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in requestBody = _requestBody @@ -265,7 +265,7 @@ final class ChatGPTStreamTests: XCTestCase { configuration: configuration, functionProvider: functionProvider ) - var requestBody: CompletionRequestBody? + var requestBody: ChatCompletionsRequestBody? service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in requestBody = _requestBody if _requestBody.messages.count <= 2 { @@ -335,14 +335,14 @@ final class ChatGPTStreamTests: XCTestCase { } extension ChatGPTStreamTests { - struct MockCompletionStreamAPI_Message: CompletionStreamAPI { + struct MockCompletionStreamAPI_Message: ChatCompletionsStreamAPI { @Dependency(\.uuid) var uuid func callAsFunction() async throws - -> AsyncThrowingStream + -> AsyncThrowingStream { let id = uuid().uuidString - return AsyncThrowingStream { continuation in - let chunks: [CompletionStreamDataChunk] = [ + return AsyncThrowingStream { continuation in + let chunks: [ChatCompletionsStreamDataChunk] = [ .init(id: id, object: "", model: "", choices: [ .init(delta: .init(role: .assistant), index: 0, finish_reason: ""), ]), @@ -364,14 +364,14 @@ extension ChatGPTStreamTests { } } - struct MockCompletionStreamAPI_Function: CompletionStreamAPI { + struct MockCompletionStreamAPI_Function: ChatCompletionsStreamAPI { @Dependency(\.uuid) var uuid func callAsFunction() async throws - -> AsyncThrowingStream + -> AsyncThrowingStream { let id = uuid().uuidString - return AsyncThrowingStream { continuation in - let chunks: [CompletionStreamDataChunk] = [ + return AsyncThrowingStream { continuation in + let chunks: [ChatCompletionsStreamDataChunk] = [ .init(id: id, object: "", model: "", choices: [ .init( delta: .init( From d3c02a12df565642398a71b5daf11ba6e51cc539 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Fri, 1 Mar 2024 23:38:54 +0800 Subject: [PATCH 05/37] Convert API types to OpenAIService --- .../APIs/OpenAICompletionAPI.swift | 73 ------------------- ...ionStreamAPI.swift => OpenAIService.swift} | 58 ++++++++++++++- 2 files changed, 55 insertions(+), 76 deletions(-) delete mode 100644 Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift rename Tool/Sources/OpenAIService/APIs/{OpenAICompletionStreamAPI.swift => OpenAIService.swift} (59%) diff --git a/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift b/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift deleted file mode 100644 index 2426a9d0..00000000 --- a/Tool/Sources/OpenAIService/APIs/OpenAICompletionAPI.swift +++ /dev/null @@ -1,73 +0,0 @@ -import AIModel -import Foundation -import Preferences - -struct CompletionAPIError: Error, Codable, LocalizedError { - struct E: Codable { - var message: String - var type: String - var param: String - var code: String - } - - var error: E - - var errorDescription: String? { error.message } -} - -struct OpenAICompletionAPI: ChatCompletionsAPI { - var apiKey: String - var endpoint: URL - var requestBody: ChatCompletionsRequestBody - var model: ChatModel - - init( - apiKey: String, - model: ChatModel, - endpoint: URL, - requestBody: ChatCompletionsRequestBody - ) { - self.apiKey = apiKey - self.endpoint = endpoint - self.requestBody = requestBody - self.requestBody.stream = false - self.model = model - } - - func callAsFunction() async throws -> ChatCompletionResponseBody { - var request = URLRequest(url: endpoint) - request.httpMethod = "POST" - let encoder = JSONEncoder() - request.httpBody = try encoder.encode(requestBody) - request.setValue("application/json", forHTTPHeaderField: "Content-Type") - if !apiKey.isEmpty { - switch model.format { - case .openAI, .openAICompatible: - request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") - case .azureOpenAI: - request.setValue(apiKey, forHTTPHeaderField: "api-key") - case .googleAI: - assertionFailure("Unsupported") - } - } - - let (result, response) = try await URLSession.shared.data(for: request) - guard let response = response as? HTTPURLResponse else { - throw ChatGPTServiceError.responseInvalid - } - - guard response.statusCode == 200 else { - let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result) - throw error ?? ChatGPTServiceError - .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") - } - - do { - return try JSONDecoder().decode(ChatCompletionResponseBody.self, from: result) - } catch { - dump(error) - throw error - } - } -} - diff --git a/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift b/Tool/Sources/OpenAIService/APIs/OpenAIService.swift similarity index 59% rename from Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift rename to Tool/Sources/OpenAIService/APIs/OpenAIService.swift index c662724b..afd5785e 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAICompletionStreamAPI.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAIService.swift @@ -3,7 +3,20 @@ import AsyncAlgorithms import Foundation import Preferences -struct OpenAICompletionStreamAPI: ChatCompletionsStreamAPI { +actor OpenAIService: ChatCompletionsStreamAPI, ChatCompletionsAPI { + struct CompletionAPIError: Error, Codable, LocalizedError { + struct E: Codable { + var message: String + var type: String + var param: String + var code: String + } + + var error: E + + var errorDescription: String? { error.message } + } + var apiKey: String var endpoint: URL var requestBody: ChatCompletionsRequestBody @@ -18,11 +31,13 @@ struct OpenAICompletionStreamAPI: ChatCompletionsStreamAPI { self.apiKey = apiKey self.endpoint = endpoint self.requestBody = requestBody - self.requestBody.stream = true self.model = model } - func callAsFunction() async throws -> AsyncThrowingStream { + func callAsFunction() async throws + -> AsyncThrowingStream + { + requestBody.stream = true var request = URLRequest(url: endpoint) request.httpMethod = "POST" let encoder = JSONEncoder() @@ -81,5 +96,42 @@ struct OpenAICompletionStreamAPI: ChatCompletionsStreamAPI { return stream } + + func callAsFunction() async throws -> ChatCompletionResponseBody { + requestBody.stream = false + var request = URLRequest(url: endpoint) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(requestBody) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + if !apiKey.isEmpty { + switch model.format { + case .openAI, .openAICompatible: + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .azureOpenAI: + request.setValue(apiKey, forHTTPHeaderField: "api-key") + case .googleAI: + assertionFailure("Unsupported") + } + } + + let (result, response) = try await URLSession.shared.data(for: request) + guard let response = response as? HTTPURLResponse else { + throw ChatGPTServiceError.responseInvalid + } + + guard response.statusCode == 200 else { + let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result) + throw error ?? ChatGPTServiceError + .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") + } + + do { + return try JSONDecoder().decode(ChatCompletionResponseBody.self, from: result) + } catch { + dump(error) + throw error + } + } } From 6d6fc492bbfa60025ef697fec5dcacc505e00049 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Fri, 1 Mar 2024 23:41:08 +0800 Subject: [PATCH 06/37] Convert API types to GoogleAIService --- .../APIs/GoogleAICompletionStreamAPI.swift | 84 ------------------ ...pletionAPI.swift => GoogleAIService.swift} | 88 ++++++++++++++++++- 2 files changed, 87 insertions(+), 85 deletions(-) delete mode 100644 Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift rename Tool/Sources/OpenAIService/APIs/{GoogleAICompletionAPI.swift => GoogleAIService.swift} (66%) diff --git a/Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift b/Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift deleted file mode 100644 index ee2b4895..00000000 --- a/Tool/Sources/OpenAIService/APIs/GoogleAICompletionStreamAPI.swift +++ /dev/null @@ -1,84 +0,0 @@ -import AIModel -import Foundation -import GoogleGenerativeAI -import Preferences - -struct GoogleCompletionStreamAPI: ChatCompletionsStreamAPI { - let apiKey: String - let model: ChatModel - var requestBody: ChatCompletionsRequestBody - let prompt: ChatGPTPrompt - - func callAsFunction() async throws -> AsyncThrowingStream { - let aiModel = GenerativeModel( - name: model.info.modelName, - apiKey: apiKey, - generationConfig: .init(GenerationConfig( - temperature: requestBody.temperature.map(Float.init), - topP: requestBody.top_p.map(Float.init) - )) - ) - let history = prompt.googleAICompatible.history.map { message in - ModelContent( - ChatMessage( - role: message.role, - content: message.content, - name: message.name, - functionCall: message.functionCall.map { - .init(name: $0.name, arguments: $0.arguments) - } - ) - ) - } - - let stream = AsyncThrowingStream { continuation in - let stream = aiModel.generateContentStream(history) - let task = Task { - do { - for try await response in stream { - if Task.isCancelled { break } - let chunk = ChatCompletionsStreamDataChunk( - object: "", - model: model.info.modelName, - choices: response.candidates.map { candidate in - .init(delta: .init( - role: .assistant, - content: candidate.content.parts - .first(where: { $0.text != nil })?.text ?? "" - )) - } - ) - continuation.yield(chunk) - } - continuation.finish() - } catch let error as GenerateContentError { - struct ErrorWrapper: Error, LocalizedError { - let error: Error - var errorDescription: String? { - var s = "" - dump(error, to: &s) - return "Internal Error: \(s)" - } - } - - switch error { - case let .internalError(underlying): - continuation.finish(throwing: ErrorWrapper(error: underlying)) - case .promptBlocked: - continuation.finish(throwing: error) - case .responseStoppedEarly: - continuation.finish(throwing: error) - } - } catch { - continuation.finish(throwing: error) - } - } - continuation.onTermination = { _ in - task.cancel() - } - } - - return stream - } -} - diff --git a/Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift b/Tool/Sources/OpenAIService/APIs/GoogleAIService.swift similarity index 66% rename from Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift rename to Tool/Sources/OpenAIService/APIs/GoogleAIService.swift index 8cd916db..7eb349cb 100644 --- a/Tool/Sources/OpenAIService/APIs/GoogleAICompletionAPI.swift +++ b/Tool/Sources/OpenAIService/APIs/GoogleAIService.swift @@ -3,12 +3,24 @@ import Foundation import GoogleGenerativeAI import Preferences -struct GoogleCompletionAPI: ChatCompletionsAPI { +actor GoogleAIService: ChatCompletionsAPI, ChatCompletionsStreamAPI { let apiKey: String let model: ChatModel var requestBody: ChatCompletionsRequestBody let prompt: ChatGPTPrompt + init( + apiKey: String, + model: ChatModel, + requestBody: ChatCompletionsRequestBody, + prompt: ChatGPTPrompt + ) { + self.apiKey = apiKey + self.model = model + self.requestBody = requestBody + self.prompt = prompt + } + func callAsFunction() async throws -> ChatCompletionResponseBody { let aiModel = GenerativeModel( name: model.info.modelName, @@ -78,6 +90,80 @@ struct GoogleCompletionAPI: ChatCompletionsAPI { throw error } } + + func callAsFunction() async throws + -> AsyncThrowingStream + { + let aiModel = GenerativeModel( + name: model.info.modelName, + apiKey: apiKey, + generationConfig: .init(GenerationConfig( + temperature: requestBody.temperature.map(Float.init), + topP: requestBody.top_p.map(Float.init) + )) + ) + let history = prompt.googleAICompatible.history.map { message in + ModelContent( + ChatMessage( + role: message.role, + content: message.content, + name: message.name, + functionCall: message.functionCall.map { + .init(name: $0.name, arguments: $0.arguments) + } + ) + ) + } + + let stream = AsyncThrowingStream { continuation in + let stream = aiModel.generateContentStream(history) + let task = Task { + do { + for try await response in stream { + if Task.isCancelled { break } + let chunk = ChatCompletionsStreamDataChunk( + object: "", + model: model.info.modelName, + choices: response.candidates.map { candidate in + .init(delta: .init( + role: .assistant, + content: candidate.content.parts + .first(where: { $0.text != nil })?.text ?? "" + )) + } + ) + continuation.yield(chunk) + } + continuation.finish() + } catch let error as GenerateContentError { + struct ErrorWrapper: Error, LocalizedError { + let error: Error + var errorDescription: String? { + var s = "" + dump(error, to: &s) + return "Internal Error: \(s)" + } + } + + switch error { + case let .internalError(underlying): + continuation.finish(throwing: ErrorWrapper(error: underlying)) + case .promptBlocked: + continuation.finish(throwing: error) + case .responseStoppedEarly: + continuation.finish(throwing: error) + } + } catch { + continuation.finish(throwing: error) + } + } + continuation.onTermination = { _ in + task.cancel() + } + } + + return stream + } } extension ChatGPTPrompt { From d741333167ecffdde55c3d6347e638722ba1cf49 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Fri, 1 Mar 2024 23:54:58 +0800 Subject: [PATCH 07/37] Reset ChatCompletionsStreamAPI to return AsyncThrowingStream --- Pro | 2 +- .../APIs/ChatCompletionsAPIDefinition.swift | 33 ++++++++++++++++--- .../OpenAIService/APIs/OlamaService.swift | 8 ++--- .../OpenAIService/ChatGPTService.swift | 8 ++--- .../OpenAIService/EmbeddingService.swift | 10 ++++-- 5 files changed, 46 insertions(+), 15 deletions(-) diff --git a/Pro b/Pro index 5f1f1dd2..322e9455 160000 --- a/Pro +++ b/Pro @@ -1 +1 @@ -Subproject commit 5f1f1dd24c3a6ec27acc3d7f252b8775ae9beea6 +Subproject commit 322e945557e02cf6643131eed6f8c1576296ab01 diff --git a/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift index 582f7296..c106329c 100644 --- a/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift +++ b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift @@ -128,9 +128,28 @@ typealias ChatCompletionsStreamAPIBuilder = ( ) -> any ChatCompletionsStreamAPI protocol ChatCompletionsStreamAPI { - associatedtype CompletionSequence: AsyncSequence - where CompletionSequence.Element == ChatCompletionsStreamDataChunk - func callAsFunction() async throws -> CompletionSequence + func callAsFunction() async throws -> AsyncThrowingStream +} + +extension AsyncSequence { + func toStream() -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let task = Task { + do { + for try await element in self { + continuation.yield(element) + } + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } + + continuation.onTermination = { _ in + task.cancel() + } + } + } } struct ChatCompletionsStreamDataChunk: Codable { @@ -159,7 +178,13 @@ struct ChatCompletionsStreamDataChunk: Codable { // MARK: - Non Stream API -typealias ChatCompletionsAPIBuilder = (String, ChatModel, URL, ChatCompletionsRequestBody, ChatGPTPrompt) +typealias ChatCompletionsAPIBuilder = ( + String, + ChatModel, + URL, + ChatCompletionsRequestBody, + ChatGPTPrompt +) -> any ChatCompletionsAPI protocol ChatCompletionsAPI { diff --git a/Tool/Sources/OpenAIService/APIs/OlamaService.swift b/Tool/Sources/OpenAIService/APIs/OlamaService.swift index 0a4ec0c9..f6950df7 100644 --- a/Tool/Sources/OpenAIService/APIs/OlamaService.swift +++ b/Tool/Sources/OpenAIService/APIs/OlamaService.swift @@ -33,9 +33,9 @@ extension OllamaService: ChatCompletionsAPI { } extension OllamaService: ChatCompletionsStreamAPI { - typealias CompletionSequence = AsyncMapSequence, ChatCompletionsStreamDataChunk> - - func callAsFunction() async throws -> CompletionSequence { + func callAsFunction() async throws + -> AsyncThrowingStream + { let requestBody = ChatCompletionRequestBody( model: model.info.modelName, messages: requestBody.messages.map { message in @@ -115,7 +115,7 @@ extension OllamaService: ChatCompletionsStreamAPI { ) } - return sequence + return sequence.toStream() } } diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index 430e361e..14e37b29 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -73,14 +73,14 @@ public class ChatGPTService: ChatGPTServiceType { apiKey, model, endpoint, requestBody, prompt in switch model.format { case .googleAI: - return GoogleCompletionStreamAPI( + return GoogleAIService( apiKey: apiKey, model: model, requestBody: requestBody, prompt: prompt ) case .openAI, .openAICompatible, .azureOpenAI: - return OpenAICompletionStreamAPI( + return OpenAIService( apiKey: apiKey, model: model, endpoint: endpoint, @@ -93,14 +93,14 @@ public class ChatGPTService: ChatGPTServiceType { apiKey, model, endpoint, requestBody, prompt in switch model.format { case .googleAI: - return GoogleCompletionAPI( + return GoogleAIService( apiKey: apiKey, model: model, requestBody: requestBody, prompt: prompt ) case .openAI, .openAICompatible, .azureOpenAI: - return OpenAICompletionAPI( + return OpenAIService( apiKey: apiKey, model: model, endpoint: endpoint, diff --git a/Tool/Sources/OpenAIService/EmbeddingService.swift b/Tool/Sources/OpenAIService/EmbeddingService.swift index d3bd1c8d..ed8c2c24 100644 --- a/Tool/Sources/OpenAIService/EmbeddingService.swift +++ b/Tool/Sources/OpenAIService/EmbeddingService.swift @@ -73,7 +73,10 @@ public struct EmbeddingService { } guard response.statusCode == 200 else { - let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result) + let error = try? JSONDecoder().decode( + OpenAIService.CompletionAPIError.self, + from: result + ) throw error ?? ChatGPTServiceError .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") } @@ -124,7 +127,10 @@ public struct EmbeddingService { } guard response.statusCode == 200 else { - let error = try? JSONDecoder().decode(CompletionAPIError.self, from: result) + let error = try? JSONDecoder().decode( + OpenAIService.CompletionAPIError.self, + from: result + ) throw error ?? ChatGPTServiceError .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") } From 16a770e57bccf810a1fff227622cf23cdc547bc8 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sat, 2 Mar 2024 00:20:26 +0800 Subject: [PATCH 08/37] Add UI for ollama models --- .../ChatModelManagement/ChatModelEdit.swift | 11 +++--- .../ChatModelEditView.swift | 36 +++++++++++++++++++ .../ChatModelManagement.swift | 1 + .../EmbeddingModelEdit.swift | 5 ++- .../EmbeddingModelEditView.swift | 36 +++++++++++++++++++ .../EmbeddingModelManagement.swift | 1 + Tool/Sources/AIModel/ChatModel.swift | 11 +++++- Tool/Sources/AIModel/EmbeddingModel.swift | 11 +++++- .../OpenAIService/APIs/OpenAIService.swift | 4 +++ .../OpenAIService/ChatGPTService.swift | 14 ++++++++ .../OpenAIService/EmbeddingService.swift | 6 ++++ 11 files changed, 129 insertions(+), 7 deletions(-) diff --git a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift index 342ef862..6eaa6b5f 100644 --- a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift +++ b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift @@ -14,6 +14,7 @@ struct ChatModelEdit: ReducerProtocol { @BindingState var maxTokens: Int = 4000 @BindingState var supportsFunctionCalling: Bool = true @BindingState var modelName: String = "" + @BindingState var ollamaKeepAlive: String = "" var apiKeyName: String { apiKeySelection.apiKeyName } var baseURL: String { baseURLSelection.baseURL } var isFullURL: Bool { baseURLSelection.isFullURL } @@ -48,7 +49,7 @@ struct ChatModelEdit: ReducerProtocol { Scope(state: \.apiKeySelection, action: /Action.apiKeySelection) { APIKeySelection() } - + Scope(state: \.baseURLSelection, action: /Action.baseURLSelection) { BaseURLSelection() } @@ -135,10 +136,10 @@ struct ChatModelEdit: ReducerProtocol { state.suggestedMaxTokens = nil return .none } - + case .apiKeySelection: return .none - + case .baseURLSelection: return .none @@ -169,6 +170,7 @@ extension ChatModelEdit.State { maxTokens: model.info.maxTokens, supportsFunctionCalling: model.info.supportsFunctionCalling, modelName: model.info.modelName, + ollamaKeepAlive: model.info.ollamaKeepAlive, apiKeySelection: .init( apiKeyName: model.info.apiKeyName, apiKeyManagement: .init(availableAPIKeyNames: [model.info.apiKeyName]) @@ -195,7 +197,8 @@ extension ChatModel { } return state.supportsFunctionCalling }(), - modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines) + modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines), + ollamaKeepAlive: state.ollamaKeepAlive ) ) } diff --git a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift index b46a0baf..81ef2a93 100644 --- a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift +++ b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift @@ -24,6 +24,8 @@ struct ChatModelEditView: View { openAICompatible case .googleAI: googleAI + case .ollama: + ollama } } } @@ -92,6 +94,8 @@ struct ChatModelEditView: View { Text("OpenAI Compatible").tag(format) case .googleAI: Text("Google Generative AI").tag(format) + case .ollama: + Text("Ollama").tag(format) } } }, @@ -344,6 +348,38 @@ struct ChatModelEditView: View { maxTokensTextField } + + @ViewBuilder + var ollama: some View { + baseURLTextField(prompt: Text("http://127.0.0.1:11434")) { + Text("/api/chat") + } + + WithViewStore( + store, + removeDuplicates: { $0.modelName == $1.modelName } + ) { viewStore in + TextField("Model Name", text: viewStore.$modelName) + } + + maxTokensTextField + + WithViewStore( + store, + removeDuplicates: { $0.ollamaKeepAlive == $1.ollamaKeepAlive } + ) { viewStore in + TextField(text: viewStore.$ollamaKeepAlive, prompt: Text("Default Value")) { + Text("Keep Alive") + } + } + + VStack(alignment: .leading, spacing: 8) { + Text(Image(systemName: "exclamationmark.triangle.fill")) + Text( + " For more details, please visit [https://ollama.com](https://ollama.com)." + ) + } + .padding(.vertical) + } } #Preview("OpenAI") { diff --git a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelManagement.swift b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelManagement.swift index 6f34bdc6..1bbb109b 100644 --- a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelManagement.swift +++ b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelManagement.swift @@ -11,6 +11,7 @@ extension ChatModel: ManageableAIModel { case .azureOpenAI: return "Azure OpenAI" case .openAICompatible: return "OpenAI Compatible" case .googleAI: return "Google Generative AI" + case .ollama: return "Ollama" } } diff --git a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift index df8dbf22..d8ec83b2 100644 --- a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift +++ b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift @@ -13,6 +13,7 @@ struct EmbeddingModelEdit: ReducerProtocol { @BindingState var format: EmbeddingModel.Format @BindingState var maxTokens: Int = 8191 @BindingState var modelName: String = "" + @BindingState var ollamaKeepAlive: String = "" var apiKeyName: String { apiKeySelection.apiKeyName } var baseURL: String { baseURLSelection.baseURL } var isFullURL: Bool { baseURLSelection.isFullURL } @@ -155,6 +156,7 @@ extension EmbeddingModelEdit.State { format: model.format, maxTokens: model.info.maxTokens, modelName: model.info.modelName, + ollamaKeepAlive: model.info.ollamaKeepAlive, apiKeySelection: .init( apiKeyName: model.info.apiKeyName, apiKeyManagement: .init(availableAPIKeyNames: [model.info.apiKeyName]) @@ -175,7 +177,8 @@ extension EmbeddingModel { baseURL: state.baseURL.trimmingCharacters(in: .whitespacesAndNewlines), isFullURL: state.isFullURL, maxTokens: state.maxTokens, - modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines) + modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines), + ollamaKeepAlive: state.ollamaKeepAlive ) ) } diff --git a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEditView.swift b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEditView.swift index c9c7b452..2bad443f 100644 --- a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEditView.swift +++ b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEditView.swift @@ -22,6 +22,8 @@ struct EmbeddingModelEditView: View { azureOpenAI case .openAICompatible: openAICompatible + case .ollama: + ollama } } } @@ -88,6 +90,8 @@ struct EmbeddingModelEditView: View { Text("Azure OpenAI").tag(format) case .openAICompatible: Text("OpenAI Compatible").tag(format) + case .ollama: + Text("Ollama").tag(format) } } }, @@ -289,6 +293,38 @@ struct EmbeddingModelEditView: View { maxTokensTextField } + + @ViewBuilder + var ollama: some View { + baseURLTextField(prompt: Text("http://127.0.0.1:11434")) { + Text("/api/embeddings") + } + + WithViewStore( + store, + removeDuplicates: { $0.modelName == $1.modelName } + ) { viewStore in + TextField("Model Name", text: viewStore.$modelName) + } + + maxTokensTextField + + WithViewStore( + store, + removeDuplicates: { $0.ollamaKeepAlive == $1.ollamaKeepAlive } + ) { viewStore in + TextField(text: viewStore.$ollamaKeepAlive, prompt: Text("Default Value")) { + Text("Keep Alive") + } + } + + VStack(alignment: .leading, spacing: 8) { + Text(Image(systemName: "exclamationmark.triangle.fill")) + Text( + " For more details, please visit [https://ollama.com](https://ollama.com)." + ) + } + .padding(.vertical) + } } class EmbeddingModelManagementView_Editing_Previews: PreviewProvider { diff --git a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelManagement.swift b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelManagement.swift index eda907d3..71b0d4a5 100644 --- a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelManagement.swift +++ b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelManagement.swift @@ -10,6 +10,7 @@ extension EmbeddingModel: ManageableAIModel { case .openAI: return "OpenAI" case .azureOpenAI: return "Azure OpenAI" case .openAICompatible: return "OpenAI Compatible" + case .ollama: return "Ollama" } } diff --git a/Tool/Sources/AIModel/ChatModel.swift b/Tool/Sources/AIModel/ChatModel.swift index 344c996b..88af31dd 100644 --- a/Tool/Sources/AIModel/ChatModel.swift +++ b/Tool/Sources/AIModel/ChatModel.swift @@ -21,6 +21,7 @@ public struct ChatModel: Codable, Equatable, Identifiable { case azureOpenAI case openAICompatible case googleAI + case ollama } public struct Info: Codable, Equatable { @@ -42,6 +43,8 @@ public struct ChatModel: Codable, Equatable, Identifiable { get { modelName } set { modelName = newValue } } + @FallbackDecoding + public var ollamaKeepAlive: String public init( apiKeyName: String = "", @@ -50,7 +53,8 @@ public struct ChatModel: Codable, Equatable, Identifiable { maxTokens: Int = 4000, supportsFunctionCalling: Bool = true, supportsOpenAIAPI2023_11: Bool = false, - modelName: String = "" + modelName: String = "", + ollamaKeepAlive: String = "" ) { self.apiKeyName = apiKeyName self.baseURL = baseURL @@ -59,6 +63,7 @@ public struct ChatModel: Codable, Equatable, Identifiable { self.supportsFunctionCalling = supportsFunctionCalling self.supportsOpenAIAPI2023_11 = supportsOpenAIAPI2023_11 self.modelName = modelName + self.ollamaKeepAlive = ollamaKeepAlive } } @@ -83,6 +88,10 @@ public struct ChatModel: Codable, Equatable, Identifiable { let baseURL = info.baseURL if baseURL.isEmpty { return "https://generativelanguage.googleapis.com/v1" } return "\(baseURL)/v1/chat/completions" + case .ollama: + let baseURL = info.baseURL + if baseURL.isEmpty { return "http://localhost:11434/api/chat" } + return "\(baseURL)/api/chat" } } } diff --git a/Tool/Sources/AIModel/EmbeddingModel.swift b/Tool/Sources/AIModel/EmbeddingModel.swift index c942be9a..cd88cd3d 100644 --- a/Tool/Sources/AIModel/EmbeddingModel.swift +++ b/Tool/Sources/AIModel/EmbeddingModel.swift @@ -20,6 +20,7 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { case openAI case azureOpenAI case openAICompatible + case ollama } public struct Info: Codable, Equatable { @@ -39,6 +40,8 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { get { modelName } set { modelName = newValue } } + @FallbackDecoding + public var ollamaKeepAlive: String public init( apiKeyName: String = "", @@ -46,7 +49,8 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { isFullURL: Bool = false, maxTokens: Int = 8192, dimensions: Int = 1536, - modelName: String = "" + modelName: String = "", + ollamaKeepAlive: String = "" ) { self.apiKeyName = apiKeyName self.baseURL = baseURL @@ -54,6 +58,7 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { self.maxTokens = maxTokens self.dimensions = dimensions self.modelName = modelName + self.ollamaKeepAlive = ollamaKeepAlive } } @@ -74,6 +79,10 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { let version = "2024-02-15-preview" if baseURL.isEmpty { return "" } return "\(baseURL)/openai/deployments/\(deployment)/embeddings?api-version=\(version)" + case .ollama: + let baseURL = info.baseURL + if baseURL.isEmpty { return "http://localhost:11434/api/embeddings" } + return "\(baseURL)/api/embeddings" } } } diff --git a/Tool/Sources/OpenAIService/APIs/OpenAIService.swift b/Tool/Sources/OpenAIService/APIs/OpenAIService.swift index afd5785e..4dd3f059 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAIService.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAIService.swift @@ -51,6 +51,8 @@ actor OpenAIService: ChatCompletionsStreamAPI, ChatCompletionsAPI { request.setValue(apiKey, forHTTPHeaderField: "api-key") case .googleAI: assertionFailure("Unsupported") + case .ollama: + assertionFailure("Unsupported") } } @@ -112,6 +114,8 @@ actor OpenAIService: ChatCompletionsStreamAPI, ChatCompletionsAPI { request.setValue(apiKey, forHTTPHeaderField: "api-key") case .googleAI: assertionFailure("Unsupported") + case .ollama: + assertionFailure("Unsupported") } } diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index 14e37b29..14835513 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -86,6 +86,13 @@ public class ChatGPTService: ChatGPTServiceType { endpoint: endpoint, requestBody: requestBody ) + case .ollama: + return OllamaService( + apiKey: apiKey, + model: model, + endpoint: endpoint, + requestBody: requestBody + ) } } @@ -106,6 +113,13 @@ public class ChatGPTService: ChatGPTServiceType { endpoint: endpoint, requestBody: requestBody ) + case .ollama: + return OllamaService( + apiKey: apiKey, + model: model, + endpoint: endpoint, + requestBody: requestBody + ) } } diff --git a/Tool/Sources/OpenAIService/EmbeddingService.swift b/Tool/Sources/OpenAIService/EmbeddingService.swift index ed8c2c24..44b28b39 100644 --- a/Tool/Sources/OpenAIService/EmbeddingService.swift +++ b/Tool/Sources/OpenAIService/EmbeddingService.swift @@ -64,6 +64,9 @@ public struct EmbeddingService { ) case .azureOpenAI: request.setValue(configuration.apiKey, forHTTPHeaderField: "api-key") + case .ollama: + #warning("MUSTDO:") + fatalError() } } @@ -118,6 +121,9 @@ public struct EmbeddingService { ) case .azureOpenAI: request.setValue(configuration.apiKey, forHTTPHeaderField: "api-key") + case .ollama: + #warning("MUSTDO:") + fatalError() } } From ef6dd45cec40870abac5aec085f7fe28f9280a2d Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sat, 2 Mar 2024 00:36:50 +0800 Subject: [PATCH 09/37] Implement Ollama non stream chat API --- .../OpenAIService/APIs/OlamaService.swift | 75 ++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/Tool/Sources/OpenAIService/APIs/OlamaService.swift b/Tool/Sources/OpenAIService/APIs/OlamaService.swift index f6950df7..ced8f662 100644 --- a/Tool/Sources/OpenAIService/APIs/OlamaService.swift +++ b/Tool/Sources/OpenAIService/APIs/OlamaService.swift @@ -28,7 +28,80 @@ public actor OllamaService { extension OllamaService: ChatCompletionsAPI { func callAsFunction() async throws -> ChatCompletionResponseBody { - fatalError() + let requestBody = ChatCompletionRequestBody( + model: model.info.modelName, + messages: requestBody.messages.map { message in + .init(role: { + switch message.role { + case .assistant: + return .assistant + case .user: + return .user + case .system: + return .system + case .function: + return .user + } + }(), content: message.content) + }, + stream: false, + options: .init( + temperature: requestBody.temperature, + stop: requestBody.stop, + num_predict: requestBody.max_tokens + ), + keep_alive: nil, + format: nil + ) + + var request = URLRequest(url: endpoint) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(requestBody) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + let (result, response) = try await URLSession.shared.data(for: request) + + guard let response = response as? HTTPURLResponse else { + throw CancellationError() + } + + guard response.statusCode == 200 else { + let text = String(data: result, encoding: .utf8) + throw Error.otherError(text ?? "Unknown error") + } + + let body = try JSONDecoder().decode( + ChatCompletionResponseChunk.self, + from: result + ) + + return .init( + object: body.model, + model: body.model, + usage: .init( + prompt_tokens: body.prompt_eval_count ?? 0, + completion_tokens: body.eval_count ?? 0, + total_tokens: (body.eval_count ?? 0) + (body.prompt_eval_count ?? 0) + ), + choices: [ + .init( + message: body.message.map { message in + .init(role: { + switch message.role { + case .assistant: + return .assistant + case .user: + return .user + case .system: + return .system + } + }(), content: message.content) + } ?? .init(role: .assistant), + index: 0, + finish_reason: "" + ), + ] + ) } } From 58483b0780cae284e63506a1be72f814d8766486 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sat, 2 Mar 2024 00:38:07 +0800 Subject: [PATCH 10/37] Update --- .../APIs/ChatCompletionsAPIDefinition.swift | 17 ----------------- Tool/Sources/OpenAIService/ChatGPTService.swift | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift index c106329c..ecdc9f93 100644 --- a/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift +++ b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift @@ -119,14 +119,6 @@ public enum FunctionCallStrategy: Codable, Equatable { // MARK: - Stream API -typealias ChatCompletionsStreamAPIBuilder = ( - String, - ChatModel, - URL, - ChatCompletionsRequestBody, - ChatGPTPrompt -) -> any ChatCompletionsStreamAPI - protocol ChatCompletionsStreamAPI { func callAsFunction() async throws -> AsyncThrowingStream } @@ -178,15 +170,6 @@ struct ChatCompletionsStreamDataChunk: Codable { // MARK: - Non Stream API -typealias ChatCompletionsAPIBuilder = ( - String, - ChatModel, - URL, - ChatCompletionsRequestBody, - ChatGPTPrompt -) - -> any ChatCompletionsAPI - protocol ChatCompletionsAPI { func callAsFunction() async throws -> ChatCompletionResponseBody } diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index 14835513..985f7b8f 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -1,3 +1,4 @@ +import AIModel import AsyncAlgorithms import Dependencies import Foundation @@ -63,6 +64,22 @@ public struct ChatGPTError: Error, Codable, LocalizedError { } } +typealias ChatCompletionsStreamAPIBuilder = ( + String, + ChatModel, + URL, + ChatCompletionsRequestBody, + ChatGPTPrompt +) -> any ChatCompletionsStreamAPI + +typealias ChatCompletionsAPIBuilder = ( + String, + ChatModel, + URL, + ChatCompletionsRequestBody, + ChatGPTPrompt +) -> any ChatCompletionsAPI + public class ChatGPTService: ChatGPTServiceType { public var memory: ChatGPTMemory public var configuration: ChatGPTConfiguration From ed811e865761d4c9c0a769fcf9c70be5a3720d39 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sat, 2 Mar 2024 00:55:51 +0800 Subject: [PATCH 11/37] Add OpenAIEmbeddingService --- Pro | 2 +- .../APIs/EmbeddingAPIDefinitions.swift | 30 ++++ ...t => GoogleAIChatCompletionsService.swift} | 2 +- ...wift => OlamaChatCompletionsService.swift} | 10 +- ...ift => OpenAIChatCompletionsService.swift} | 2 +- .../APIs/OpenAIEmbeddingService.swift | 126 ++++++++++++++ .../OpenAIService/ChatGPTService.swift | 12 +- .../OpenAIService/EmbeddingService.swift | 155 ++++++------------ 8 files changed, 218 insertions(+), 121 deletions(-) create mode 100644 Tool/Sources/OpenAIService/APIs/EmbeddingAPIDefinitions.swift rename Tool/Sources/OpenAIService/APIs/{GoogleAIService.swift => GoogleAIChatCompletionsService.swift} (99%) rename Tool/Sources/OpenAIService/APIs/{OlamaService.swift => OlamaChatCompletionsService.swift} (96%) rename Tool/Sources/OpenAIService/APIs/{OpenAIService.swift => OpenAIChatCompletionsService.swift} (98%) create mode 100644 Tool/Sources/OpenAIService/APIs/OpenAIEmbeddingService.swift diff --git a/Pro b/Pro index 322e9455..de3b6ef6 160000 --- a/Pro +++ b/Pro @@ -1 +1 @@ -Subproject commit 322e945557e02cf6643131eed6f8c1576296ab01 +Subproject commit de3b6ef6303f29264d347a52dc5c38dfb4dfde6a diff --git a/Tool/Sources/OpenAIService/APIs/EmbeddingAPIDefinitions.swift b/Tool/Sources/OpenAIService/APIs/EmbeddingAPIDefinitions.swift new file mode 100644 index 00000000..22003d36 --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/EmbeddingAPIDefinitions.swift @@ -0,0 +1,30 @@ +import AIModel +import Foundation +import Preferences + +protocol EmbeddingAPI { + func embed(text: String) async throws -> EmbeddingResponse + func embed(texts: [String]) async throws -> EmbeddingResponse + func embed(tokens: [[Int]]) async throws -> EmbeddingResponse +} + +public struct EmbeddingResponse: Decodable { + public struct Object: Decodable { + public var embedding: [Float] + public var index: Int + public var object: String + } + + public var data: [Object] + public var model: String + + public struct Usage: Decodable { + public var prompt_tokens: Int + public var total_tokens: Int + } + + public var usage: Usage +} + + + diff --git a/Tool/Sources/OpenAIService/APIs/GoogleAIService.swift b/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift similarity index 99% rename from Tool/Sources/OpenAIService/APIs/GoogleAIService.swift rename to Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift index 7eb349cb..7608fe10 100644 --- a/Tool/Sources/OpenAIService/APIs/GoogleAIService.swift +++ b/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift @@ -3,7 +3,7 @@ import Foundation import GoogleGenerativeAI import Preferences -actor GoogleAIService: ChatCompletionsAPI, ChatCompletionsStreamAPI { +actor GoogleAIChatCompletionsService: ChatCompletionsAPI, ChatCompletionsStreamAPI { let apiKey: String let model: ChatModel var requestBody: ChatCompletionsRequestBody diff --git a/Tool/Sources/OpenAIService/APIs/OlamaService.swift b/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift similarity index 96% rename from Tool/Sources/OpenAIService/APIs/OlamaService.swift rename to Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift index ced8f662..b5046868 100644 --- a/Tool/Sources/OpenAIService/APIs/OlamaService.swift +++ b/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift @@ -2,7 +2,7 @@ import AIModel import Foundation import Preferences -public actor OllamaService { +public actor OllamaChatCompletionsService { var apiKey: String var endpoint: URL var requestBody: ChatCompletionsRequestBody @@ -26,7 +26,7 @@ public actor OllamaService { } } -extension OllamaService: ChatCompletionsAPI { +extension OllamaChatCompletionsService: ChatCompletionsAPI { func callAsFunction() async throws -> ChatCompletionResponseBody { let requestBody = ChatCompletionRequestBody( model: model.info.modelName, @@ -105,7 +105,7 @@ extension OllamaService: ChatCompletionsAPI { } } -extension OllamaService: ChatCompletionsStreamAPI { +extension OllamaChatCompletionsService: ChatCompletionsStreamAPI { func callAsFunction() async throws -> AsyncThrowingStream { @@ -192,7 +192,7 @@ extension OllamaService: ChatCompletionsStreamAPI { } } -extension OllamaService { +extension OllamaChatCompletionsService { struct Message: Codable, Equatable { public enum Role: String, Codable { case user @@ -224,7 +224,7 @@ extension OllamaService { // MARK: - Chat Completion API /// https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-streaming -extension OllamaService { +extension OllamaChatCompletionsService { struct ChatCompletionRequestBody: Codable { struct Options: Codable { var temperature: Double? diff --git a/Tool/Sources/OpenAIService/APIs/OpenAIService.swift b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift similarity index 98% rename from Tool/Sources/OpenAIService/APIs/OpenAIService.swift rename to Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift index 4dd3f059..f66ba80d 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAIService.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift @@ -3,7 +3,7 @@ import AsyncAlgorithms import Foundation import Preferences -actor OpenAIService: ChatCompletionsStreamAPI, ChatCompletionsAPI { +actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI { struct CompletionAPIError: Error, Codable, LocalizedError { struct E: Codable { var message: String diff --git a/Tool/Sources/OpenAIService/APIs/OpenAIEmbeddingService.swift b/Tool/Sources/OpenAIService/APIs/OpenAIEmbeddingService.swift new file mode 100644 index 00000000..989dcd52 --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/OpenAIEmbeddingService.swift @@ -0,0 +1,126 @@ +import AIModel +import Foundation +import Logger + +struct OpenAIEmbeddingService: EmbeddingAPI { + struct EmbeddingRequestBody: Encodable { + var input: [String] + var model: String + } + + struct EmbeddingFromTokensRequestBody: Encodable { + var input: [[Int]] + var model: String + } + + let apiKey: String + let model: EmbeddingModel + let endpoint: String + + public func embed(text: String) async throws -> EmbeddingResponse { + return try await embed(texts: [text]) + } + + public func embed(texts text: [String]) async throws -> EmbeddingResponse { + guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect } + var request = URLRequest(url: url) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(EmbeddingRequestBody( + input: text, + model: model.info.modelName + )) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + if !apiKey.isEmpty { + switch model.format { + case .openAI, .openAICompatible: + request.setValue( + "Bearer \(apiKey)", + forHTTPHeaderField: "Authorization" + ) + case .azureOpenAI: + request.setValue(apiKey, forHTTPHeaderField: "api-key") + case .ollama: + assertionFailure("Unsupported") + } + } + + let (result, response) = try await URLSession.shared.data(for: request) + guard let response = response as? HTTPURLResponse else { + throw ChatGPTServiceError.responseInvalid + } + + guard response.statusCode == 200 else { + let error = try? JSONDecoder().decode( + OpenAIChatCompletionsService.CompletionAPIError.self, + from: result + ) + throw error ?? ChatGPTServiceError + .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") + } + + let embeddingResponse = try JSONDecoder().decode(EmbeddingResponse.self, from: result) + #if DEBUG + Logger.service.info(""" + Embedding usage + - number of strings: \(text.count) + - prompt tokens: \(embeddingResponse.usage.prompt_tokens) + - total tokens: \(embeddingResponse.usage.total_tokens) + + """) + #endif + return embeddingResponse + } + + public func embed(tokens: [[Int]]) async throws -> EmbeddingResponse { + guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect } + var request = URLRequest(url: url) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(EmbeddingFromTokensRequestBody( + input: tokens, + model: model.info.modelName + )) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + if !apiKey.isEmpty { + switch model.format { + case .openAI, .openAICompatible: + request.setValue( + "Bearer \(apiKey)", + forHTTPHeaderField: "Authorization" + ) + case .azureOpenAI: + request.setValue(apiKey, forHTTPHeaderField: "api-key") + case .ollama: + assertionFailure("Unsupported") + } + } + + let (result, response) = try await URLSession.shared.data(for: request) + guard let response = response as? HTTPURLResponse else { + throw ChatGPTServiceError.responseInvalid + } + + guard response.statusCode == 200 else { + let error = try? JSONDecoder().decode( + OpenAIChatCompletionsService.CompletionAPIError.self, + from: result + ) + throw error ?? ChatGPTServiceError + .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") + } + + let embeddingResponse = try JSONDecoder().decode(EmbeddingResponse.self, from: result) + #if DEBUG + Logger.service.info(""" + Embedding usage + - number of strings: \(tokens.count) + - prompt tokens: \(embeddingResponse.usage.prompt_tokens) + - total tokens: \(embeddingResponse.usage.total_tokens) + + """) + #endif + return embeddingResponse + } +} + diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index 985f7b8f..9ccfe244 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -90,21 +90,21 @@ public class ChatGPTService: ChatGPTServiceType { apiKey, model, endpoint, requestBody, prompt in switch model.format { case .googleAI: - return GoogleAIService( + return GoogleAIChatCompletionsService( apiKey: apiKey, model: model, requestBody: requestBody, prompt: prompt ) case .openAI, .openAICompatible, .azureOpenAI: - return OpenAIService( + return OpenAIChatCompletionsService( apiKey: apiKey, model: model, endpoint: endpoint, requestBody: requestBody ) case .ollama: - return OllamaService( + return OllamaChatCompletionsService( apiKey: apiKey, model: model, endpoint: endpoint, @@ -117,21 +117,21 @@ public class ChatGPTService: ChatGPTServiceType { apiKey, model, endpoint, requestBody, prompt in switch model.format { case .googleAI: - return GoogleAIService( + return GoogleAIChatCompletionsService( apiKey: apiKey, model: model, requestBody: requestBody, prompt: prompt ) case .openAI, .openAICompatible, .azureOpenAI: - return OpenAIService( + return OpenAIChatCompletionsService( apiKey: apiKey, model: model, endpoint: endpoint, requestBody: requestBody ) case .ollama: - return OllamaService( + return OllamaChatCompletionsService( apiKey: apiKey, model: model, endpoint: endpoint, diff --git a/Tool/Sources/OpenAIService/EmbeddingService.swift b/Tool/Sources/OpenAIService/EmbeddingService.swift index 44b28b39..34d88ca6 100644 --- a/Tool/Sources/OpenAIService/EmbeddingService.swift +++ b/Tool/Sources/OpenAIService/EmbeddingService.swift @@ -1,34 +1,6 @@ import Foundation import Logger -public struct EmbeddingResponse: Decodable { - public struct Object: Decodable { - public var embedding: [Float] - public var index: Int - public var object: String - } - - public var data: [Object] - public var model: String - - public struct Usage: Decodable { - public var prompt_tokens: Int - public var total_tokens: Int - } - - public var usage: Usage -} - -struct EmbeddingRequestBody: Encodable { - var input: [String] - var model: String -} - -struct EmbeddingFromTokensRequestBody: Encodable { - var input: [[Int]] - var model: String -} - public struct EmbeddingService { public let configuration: EmbeddingConfiguration @@ -37,54 +9,51 @@ public struct EmbeddingService { } public func embed(text: String) async throws -> EmbeddingResponse { - return try await embed(text: [text]) - } - - public func embed(text: [String]) async throws -> EmbeddingResponse { guard let model = configuration.model else { throw ChatGPTServiceError.embeddingModelNotAvailable } - guard let url = URL(string: configuration.endpoint) else { - throw ChatGPTServiceError.endpointIncorrect - } - var request = URLRequest(url: url) - request.httpMethod = "POST" - let encoder = JSONEncoder() - request.httpBody = try encoder.encode(EmbeddingRequestBody( - input: text, - model: model.info.modelName - )) - request.setValue("application/json", forHTTPHeaderField: "Content-Type") - if !configuration.apiKey.isEmpty { - switch model.format { - case .openAI, .openAICompatible: - request.setValue( - "Bearer \(configuration.apiKey)", - forHTTPHeaderField: "Authorization" - ) - case .azureOpenAI: - request.setValue(configuration.apiKey, forHTTPHeaderField: "api-key") - case .ollama: - #warning("MUSTDO:") - fatalError() - } + let embeddingResponse: EmbeddingResponse + switch model.format { + case .openAI, .openAICompatible, .azureOpenAI: + embeddingResponse = try await OpenAIEmbeddingService( + apiKey: configuration.apiKey, + model: model, + endpoint: configuration.endpoint + ).embed(text: text) + case .ollama: + #warning("MUSTDO:") + fatalError() } - let (result, response) = try await URLSession.shared.data(for: request) - guard let response = response as? HTTPURLResponse else { - throw ChatGPTServiceError.responseInvalid - } + #if DEBUG + Logger.service.info(""" + Embedding usage + - number of strings: \(text.count) + - prompt tokens: \(embeddingResponse.usage.prompt_tokens) + - total tokens: \(embeddingResponse.usage.total_tokens) + + """) + #endif + return embeddingResponse + } - guard response.statusCode == 200 else { - let error = try? JSONDecoder().decode( - OpenAIService.CompletionAPIError.self, - from: result - ) - throw error ?? ChatGPTServiceError - .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") + public func embed(text: [String]) async throws -> EmbeddingResponse { + guard let model = configuration.model else { + throw ChatGPTServiceError.embeddingModelNotAvailable + } + let embeddingResponse: EmbeddingResponse + switch model.format { + case .openAI, .openAICompatible, .azureOpenAI: + embeddingResponse = try await OpenAIEmbeddingService( + apiKey: configuration.apiKey, + model: model, + endpoint: configuration.endpoint + ).embed(texts: text) + case .ollama: + #warning("MUSTDO:") + fatalError() } - let embeddingResponse = try JSONDecoder().decode(EmbeddingResponse.self, from: result) #if DEBUG Logger.service.info(""" Embedding usage @@ -101,47 +70,19 @@ public struct EmbeddingService { guard let model = configuration.model else { throw ChatGPTServiceError.embeddingModelNotAvailable } - guard let url = URL(string: configuration.endpoint) else { - throw ChatGPTServiceError.endpointIncorrect - } - var request = URLRequest(url: url) - request.httpMethod = "POST" - let encoder = JSONEncoder() - request.httpBody = try encoder.encode(EmbeddingFromTokensRequestBody( - input: tokens, - model: model.info.modelName - )) - request.setValue("application/json", forHTTPHeaderField: "Content-Type") - if !configuration.apiKey.isEmpty { - switch model.format { - case .openAI, .openAICompatible: - request.setValue( - "Bearer \(configuration.apiKey)", - forHTTPHeaderField: "Authorization" - ) - case .azureOpenAI: - request.setValue(configuration.apiKey, forHTTPHeaderField: "api-key") - case .ollama: - #warning("MUSTDO:") - fatalError() - } - } - - let (result, response) = try await URLSession.shared.data(for: request) - guard let response = response as? HTTPURLResponse else { - throw ChatGPTServiceError.responseInvalid - } - - guard response.statusCode == 200 else { - let error = try? JSONDecoder().decode( - OpenAIService.CompletionAPIError.self, - from: result - ) - throw error ?? ChatGPTServiceError - .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") + let embeddingResponse: EmbeddingResponse + switch model.format { + case .openAI, .openAICompatible, .azureOpenAI: + embeddingResponse = try await OpenAIEmbeddingService( + apiKey: configuration.apiKey, + model: model, + endpoint: configuration.endpoint + ).embed(tokens: tokens) + case .ollama: + #warning("MUSTDO:") + fatalError() } - let embeddingResponse = try JSONDecoder().decode(EmbeddingResponse.self, from: result) #if DEBUG Logger.service.info(""" Embedding usage From 9f137eec048c65564f09edbdec53b3aeeef6b493 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sat, 2 Mar 2024 01:07:01 +0800 Subject: [PATCH 12/37] Add OllamaEmbeddingService --- Pro | 2 +- .../APIs/EmbeddingAPIDefinitions.swift | 2 - .../APIs/OllamaEmbeddingService.swift | 92 +++++++++++++++++++ .../OpenAIService/EmbeddingService.swift | 18 ++-- 4 files changed, 105 insertions(+), 9 deletions(-) create mode 100644 Tool/Sources/OpenAIService/APIs/OllamaEmbeddingService.swift diff --git a/Pro b/Pro index de3b6ef6..13a9fde5 160000 --- a/Pro +++ b/Pro @@ -1 +1 @@ -Subproject commit de3b6ef6303f29264d347a52dc5c38dfb4dfde6a +Subproject commit 13a9fde5ea17fda4bd39927428bc7267add18244 diff --git a/Tool/Sources/OpenAIService/APIs/EmbeddingAPIDefinitions.swift b/Tool/Sources/OpenAIService/APIs/EmbeddingAPIDefinitions.swift index 22003d36..0715e0f1 100644 --- a/Tool/Sources/OpenAIService/APIs/EmbeddingAPIDefinitions.swift +++ b/Tool/Sources/OpenAIService/APIs/EmbeddingAPIDefinitions.swift @@ -26,5 +26,3 @@ public struct EmbeddingResponse: Decodable { public var usage: Usage } - - diff --git a/Tool/Sources/OpenAIService/APIs/OllamaEmbeddingService.swift b/Tool/Sources/OpenAIService/APIs/OllamaEmbeddingService.swift new file mode 100644 index 00000000..fda6ef42 --- /dev/null +++ b/Tool/Sources/OpenAIService/APIs/OllamaEmbeddingService.swift @@ -0,0 +1,92 @@ +import AIModel +import Foundation +import Logger + +struct OllamaEmbeddingService: EmbeddingAPI { + struct EmbeddingRequestBody: Encodable { + var prompt: String + var model: String + } + + struct ResponseBody: Decodable { + var embedding: [Float] + } + + let model: EmbeddingModel + let endpoint: String + + public func embed(text: String) async throws -> EmbeddingResponse { + guard let url = URL(string: endpoint) else { throw ChatGPTServiceError.endpointIncorrect } + var request = URLRequest(url: url) + request.httpMethod = "POST" + let encoder = JSONEncoder() + request.httpBody = try encoder.encode(EmbeddingRequestBody( + prompt: text, + model: model.info.modelName + )) + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + + let (result, response) = try await URLSession.shared.data(for: request) + guard let response = response as? HTTPURLResponse else { + throw ChatGPTServiceError.responseInvalid + } + + guard response.statusCode == 200 else { + let error = try? JSONDecoder().decode( + OpenAIChatCompletionsService.CompletionAPIError.self, + from: result + ) + throw error ?? ChatGPTServiceError + .otherError(String(data: result, encoding: .utf8) ?? "Unknown Error") + } + + let embeddingResponse = try JSONDecoder().decode(ResponseBody.self, from: result) + #if DEBUG + Logger.service.info(""" + Embedding usage + - number of strings: \(text.count) + - prompt tokens: N/A + - total tokens: \(embeddingResponse.embedding.count) + + """) + #endif + return .init( + data: [.init( + embedding: embeddingResponse.embedding, + index: 0, + object: model.info.modelName + )], + model: model.info.modelName, + usage: .init(prompt_tokens: 0, total_tokens: embeddingResponse.embedding.count) + ) + } + + public func embed(texts: [String]) async throws -> EmbeddingResponse { + try await withThrowingTaskGroup(of: EmbeddingResponse.self) { group in + for text in texts { + _ = group.addTaskUnlessCancelled { + try await self.embed(text: text) + } + } + + var result = EmbeddingResponse( + data: [], + model: model.info.modelName, + usage: .init(prompt_tokens: 0, total_tokens: 0) + ) + + for try await response in group { + result.data.append(contentsOf: response.data) + result.usage.prompt_tokens += response.usage.prompt_tokens + result.usage.total_tokens += response.usage.total_tokens + } + + return result + } + } + + public func embed(tokens: [[Int]]) async throws -> EmbeddingResponse { + throw CancellationError() + } +} + diff --git a/Tool/Sources/OpenAIService/EmbeddingService.swift b/Tool/Sources/OpenAIService/EmbeddingService.swift index 34d88ca6..d5bf2f41 100644 --- a/Tool/Sources/OpenAIService/EmbeddingService.swift +++ b/Tool/Sources/OpenAIService/EmbeddingService.swift @@ -21,8 +21,10 @@ public struct EmbeddingService { endpoint: configuration.endpoint ).embed(text: text) case .ollama: - #warning("MUSTDO:") - fatalError() + embeddingResponse = try await OllamaEmbeddingService( + model: model, + endpoint: configuration.endpoint + ).embed(text: text) } #if DEBUG @@ -50,8 +52,10 @@ public struct EmbeddingService { endpoint: configuration.endpoint ).embed(texts: text) case .ollama: - #warning("MUSTDO:") - fatalError() + embeddingResponse = try await OllamaEmbeddingService( + model: model, + endpoint: configuration.endpoint + ).embed(texts: text) } #if DEBUG @@ -79,8 +83,10 @@ public struct EmbeddingService { endpoint: configuration.endpoint ).embed(tokens: tokens) case .ollama: - #warning("MUSTDO:") - fatalError() + embeddingResponse = try await OllamaEmbeddingService( + model: model, + endpoint: configuration.endpoint + ).embed(tokens: tokens) } #if DEBUG From 6282aa13c59f4d9e7bdaa567560777addaa64958 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sat, 2 Mar 2024 01:10:17 +0800 Subject: [PATCH 13/37] Update test success message --- .../EmbeddingModelEdit.swift | 17 ++++++++--------- .../APIs/OllamaEmbeddingService.swift | 6 +++--- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift index d8ec83b2..5154a917 100644 --- a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift +++ b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift @@ -84,14 +84,13 @@ struct EmbeddingModelEdit: ReducerProtocol { ) return .run { send in do { - let tokenUsage = - try await EmbeddingService( - configuration: UserPreferenceEmbeddingConfiguration() - .overriding { - $0.model = model - } - ).embed(text: "Hello").usage.total_tokens - await send(.testSucceeded("Used \(tokenUsage) tokens.")) + _ = try await EmbeddingService( + configuration: UserPreferenceEmbeddingConfiguration() + .overriding { + $0.model = model + } + ).embed(text: "Hello") + await send(.testSucceeded("Succeeded!")) } catch { await send(.testFailed(error.localizedDescription)) } @@ -156,7 +155,7 @@ extension EmbeddingModelEdit.State { format: model.format, maxTokens: model.info.maxTokens, modelName: model.info.modelName, - ollamaKeepAlive: model.info.ollamaKeepAlive, + ollamaKeepAlive: model.info.ollamaKeepAlive, apiKeySelection: .init( apiKeyName: model.info.apiKeyName, apiKeyManagement: .init(availableAPIKeyNames: [model.info.apiKeyName]) diff --git a/Tool/Sources/OpenAIService/APIs/OllamaEmbeddingService.swift b/Tool/Sources/OpenAIService/APIs/OllamaEmbeddingService.swift index fda6ef42..dfd170cc 100644 --- a/Tool/Sources/OpenAIService/APIs/OllamaEmbeddingService.swift +++ b/Tool/Sources/OpenAIService/APIs/OllamaEmbeddingService.swift @@ -44,9 +44,9 @@ struct OllamaEmbeddingService: EmbeddingAPI { #if DEBUG Logger.service.info(""" Embedding usage - - number of strings: \(text.count) + - number of strings: 1 - prompt tokens: N/A - - total tokens: \(embeddingResponse.embedding.count) + - total tokens: N/A """) #endif @@ -57,7 +57,7 @@ struct OllamaEmbeddingService: EmbeddingAPI { object: model.info.modelName )], model: model.info.modelName, - usage: .init(prompt_tokens: 0, total_tokens: embeddingResponse.embedding.count) + usage: .init(prompt_tokens: 0, total_tokens: 0) ) } From 96bf018fb0120cf9843f0e5e17e1f85016af203f Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sat, 2 Mar 2024 15:24:23 +0800 Subject: [PATCH 14/37] Change Max Tokens to Context Window for better clearity --- .../AccountSettings/ChatModelManagement/ChatModelEditView.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift index 81ef2a93..fd6b1e21 100644 --- a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift +++ b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEditView.swift @@ -175,7 +175,7 @@ struct ChatModelEditView: View { ) TextField(text: textFieldBinding) { - Text("Max Tokens (Including Reply)") + Text("Context Window") .multilineTextAlignment(.trailing) } .overlay(alignment: .trailing) { From 37111d2b1dd9d63c33780404033e59608aa619e7 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sat, 2 Mar 2024 15:26:00 +0800 Subject: [PATCH 15/37] Update --- Pro | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Pro b/Pro index 13a9fde5..fbb89b80 160000 --- a/Pro +++ b/Pro @@ -1 +1 @@ -Subproject commit 13a9fde5ea17fda4bd39927428bc7267add18244 +Subproject commit fbb89b803e55e4e27d3346fd5a6b49a3b108280d From ecf740e66d4cf08ca7544bf5079b2870e8f27923 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sat, 2 Mar 2024 15:40:17 +0800 Subject: [PATCH 16/37] Migrate service specific info to their own structs --- Tool/Sources/AIModel/ChatModel.swift | 45 +++++++++++++++---- Tool/Sources/AIModel/EmbeddingModel.swift | 28 +++++++----- .../APIs/OpenAIChatCompletionsService.swift | 20 ++++++++- .../APIs/OpenAIEmbeddingService.swift | 30 ++++++++----- 4 files changed, 90 insertions(+), 33 deletions(-) diff --git a/Tool/Sources/AIModel/ChatModel.swift b/Tool/Sources/AIModel/ChatModel.swift index 88af31dd..f3aea30b 100644 --- a/Tool/Sources/AIModel/ChatModel.swift +++ b/Tool/Sources/AIModel/ChatModel.swift @@ -25,6 +25,24 @@ public struct ChatModel: Codable, Equatable, Identifiable { } public struct Info: Codable, Equatable { + public struct OllamaInfo: Codable, Equatable { + @FallbackDecoding + public var keepAlive: String + + public init(keepAlive: String = "") { + self.keepAlive = keepAlive + } + } + + public struct OpenAIInfo: Codable, Equatable { + @FallbackDecoding + public var organizationID: String + + public init(organizationID: String = "") { + self.organizationID = organizationID + } + } + @FallbackDecoding public var apiKeyName: String @FallbackDecoding @@ -39,12 +57,11 @@ public struct ChatModel: Codable, Equatable, Identifiable { public var supportsOpenAIAPI2023_11: Bool @FallbackDecoding public var modelName: String - public var azureOpenAIDeploymentName: String { - get { modelName } - set { modelName = newValue } - } - @FallbackDecoding - public var ollamaKeepAlive: String + + @FallbackDecoding + public var openAIInfo: OpenAIInfo + @FallbackDecoding + public var ollamaInfo: OllamaInfo public init( apiKeyName: String = "", @@ -54,7 +71,8 @@ public struct ChatModel: Codable, Equatable, Identifiable { supportsFunctionCalling: Bool = true, supportsOpenAIAPI2023_11: Bool = false, modelName: String = "", - ollamaKeepAlive: String = "" + openAIInfo: OpenAIInfo = OpenAIInfo(), + ollamaInfo: OllamaInfo = OllamaInfo() ) { self.apiKeyName = apiKeyName self.baseURL = baseURL @@ -63,7 +81,8 @@ public struct ChatModel: Codable, Equatable, Identifiable { self.supportsFunctionCalling = supportsFunctionCalling self.supportsOpenAIAPI2023_11 = supportsOpenAIAPI2023_11 self.modelName = modelName - self.ollamaKeepAlive = ollamaKeepAlive + self.openAIInfo = openAIInfo + self.ollamaInfo = ollamaInfo } } @@ -80,7 +99,7 @@ public struct ChatModel: Codable, Equatable, Identifiable { return "\(baseURL)/v1/chat/completions" case .azureOpenAI: let baseURL = info.baseURL - let deployment = info.azureOpenAIDeploymentName + let deployment = info.modelName let version = "2024-02-15-preview" if baseURL.isEmpty { return "" } return "\(baseURL)/openai/deployments/\(deployment)/chat/completions?api-version=\(version)" @@ -104,3 +123,11 @@ public struct EmptyChatModelFormat: FallbackValueProvider { public static var defaultValue: ChatModel.Format { .openAI } } +public struct EmptyChatModelOllamaInfo: FallbackValueProvider { + public static var defaultValue: ChatModel.Info.OllamaInfo { .init() } +} + +public struct EmptyChatModelOpenAIInfo: FallbackValueProvider { + public static var defaultValue: ChatModel.Info.OpenAIInfo { .init() } +} + diff --git a/Tool/Sources/AIModel/EmbeddingModel.swift b/Tool/Sources/AIModel/EmbeddingModel.swift index cd88cd3d..d86650fa 100644 --- a/Tool/Sources/AIModel/EmbeddingModel.swift +++ b/Tool/Sources/AIModel/EmbeddingModel.swift @@ -1,5 +1,5 @@ -import Foundation import CodableWrappers +import Foundation public struct EmbeddingModel: Codable, Equatable, Identifiable { public var id: String @@ -24,6 +24,9 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { } public struct Info: Codable, Equatable { + public typealias OllamaInfo = ChatModel.Info.OllamaInfo + public typealias OpenAIInfo = ChatModel.Info.OpenAIInfo + @FallbackDecoding public var apiKeyName: String @FallbackDecoding @@ -36,12 +39,11 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { public var dimensions: Int @FallbackDecoding public var modelName: String - public var azureOpenAIDeploymentName: String { - get { modelName } - set { modelName = newValue } - } - @FallbackDecoding - public var ollamaKeepAlive: String + + @FallbackDecoding + public var openAIInfo: OpenAIInfo + @FallbackDecoding + public var ollamaInfo: OllamaInfo public init( apiKeyName: String = "", @@ -50,7 +52,8 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { maxTokens: Int = 8192, dimensions: Int = 1536, modelName: String = "", - ollamaKeepAlive: String = "" + openAIInfo: OpenAIInfo = OpenAIInfo(), + ollamaInfo: OllamaInfo = OllamaInfo() ) { self.apiKeyName = apiKeyName self.baseURL = baseURL @@ -58,10 +61,11 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { self.maxTokens = maxTokens self.dimensions = dimensions self.modelName = modelName - self.ollamaKeepAlive = ollamaKeepAlive + self.openAIInfo = openAIInfo + self.ollamaInfo = ollamaInfo } } - + public var endpoint: String { switch format { case .openAI: @@ -75,7 +79,7 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { return "\(baseURL)/v1/embeddings" case .azureOpenAI: let baseURL = info.baseURL - let deployment = info.azureOpenAIDeploymentName + let deployment = info.modelName let version = "2024-02-15-preview" if baseURL.isEmpty { return "" } return "\(baseURL)/openai/deployments/\(deployment)/embeddings?api-version=\(version)" @@ -87,7 +91,6 @@ public struct EmbeddingModel: Codable, Equatable, Identifiable { } } - public struct EmptyEmbeddingModelInfo: FallbackValueProvider { public static var defaultValue: EmbeddingModel.Info { .init() } } @@ -95,3 +98,4 @@ public struct EmptyEmbeddingModelInfo: FallbackValueProvider { public struct EmptyEmbeddingModelFormat: FallbackValueProvider { public static var defaultValue: EmbeddingModel.Format { .openAI } } + diff --git a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift index f66ba80d..4cbdf2bd 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift @@ -45,7 +45,15 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI request.setValue("application/json", forHTTPHeaderField: "Content-Type") if !apiKey.isEmpty { switch model.format { - case .openAI, .openAICompatible: + case .openAI: + if !model.info.openAIInfo.organizationID.isEmpty { + request.setValue( + "OpenAI-Organization", + forHTTPHeaderField: model.info.openAIInfo.organizationID + ) + } + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .openAICompatible: request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") case .azureOpenAI: request.setValue(apiKey, forHTTPHeaderField: "api-key") @@ -108,7 +116,15 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI request.setValue("application/json", forHTTPHeaderField: "Content-Type") if !apiKey.isEmpty { switch model.format { - case .openAI, .openAICompatible: + case .openAI: + if !model.info.openAIInfo.organizationID.isEmpty { + request.setValue( + "OpenAI-Organization", + forHTTPHeaderField: model.info.openAIInfo.organizationID + ) + } + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .openAICompatible: request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") case .azureOpenAI: request.setValue(apiKey, forHTTPHeaderField: "api-key") diff --git a/Tool/Sources/OpenAIService/APIs/OpenAIEmbeddingService.swift b/Tool/Sources/OpenAIService/APIs/OpenAIEmbeddingService.swift index 989dcd52..140e9d09 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAIEmbeddingService.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAIEmbeddingService.swift @@ -33,11 +33,16 @@ struct OpenAIEmbeddingService: EmbeddingAPI { request.setValue("application/json", forHTTPHeaderField: "Content-Type") if !apiKey.isEmpty { switch model.format { - case .openAI, .openAICompatible: - request.setValue( - "Bearer \(apiKey)", - forHTTPHeaderField: "Authorization" - ) + case .openAI: + if model.info.openAIInfo.organizationID.isEmpty { + request.setValue( + "OpenAI-Organization", + forHTTPHeaderField: model.info.openAIInfo.organizationID + ) + } + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .openAICompatible: + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") case .azureOpenAI: request.setValue(apiKey, forHTTPHeaderField: "api-key") case .ollama: @@ -84,11 +89,16 @@ struct OpenAIEmbeddingService: EmbeddingAPI { request.setValue("application/json", forHTTPHeaderField: "Content-Type") if !apiKey.isEmpty { switch model.format { - case .openAI, .openAICompatible: - request.setValue( - "Bearer \(apiKey)", - forHTTPHeaderField: "Authorization" - ) + case .openAI: + if model.info.openAIInfo.organizationID.isEmpty { + request.setValue( + "OpenAI-Organization", + forHTTPHeaderField: model.info.openAIInfo.organizationID + ) + } + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + case .openAICompatible: + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") case .azureOpenAI: request.setValue(apiKey, forHTTPHeaderField: "api-key") case .ollama: From 42a8efb13517ef647301eabf21db20295acfb972 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sun, 3 Mar 2024 00:42:53 +0800 Subject: [PATCH 17/37] Support tool call --- .../WebChatContextCollector.swift | 2 +- Core/Sources/ChatGPTChatTab/Chat.swift | 1 + Core/Sources/ChatService/ChatService.swift | 4 +- .../ChatModelManagement/ChatModelEdit.swift | 4 +- .../EmbeddingModelEdit.swift | 4 +- Pro | 2 +- Tool/Sources/LangChain/Chains/LLMChain.swift | 7 +- .../Chains/RefineDocumentChain.swift | 4 +- .../RelevantInformationExtractionChain.swift | 10 +- .../StructuredOutputChatModelChain.swift | 4 +- .../APIs/ChatCompletionsAPIDefinition.swift | 170 +++++------ .../APIs/GoogleAIChatCompletionsService.swift | 136 ++++----- .../APIs/OlamaChatCompletionsService.swift | 80 +++--- .../APIs/OpenAIChatCompletionsService.swift | 266 +++++++++++++++++- .../OpenAIService/ChatGPTService.swift | 264 +++++++++-------- ...toManagedChatGPTMemoryOpenAIStrategy.swift | 10 +- .../OpenAIService/Memory/ChatGPTMemory.swift | 37 ++- Tool/Sources/OpenAIService/Models.swift | 27 +- 18 files changed, 666 insertions(+), 366 deletions(-) diff --git a/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift b/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift index 29327cea..851fdcf7 100644 --- a/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift +++ b/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift @@ -32,7 +32,7 @@ extension WebChatContextCollector { static func detectLinks(from messages: [ChatMessage]) -> [String] { return messages.lazy .compactMap { - $0.content ?? $0.functionCall?.arguments + $0.content ?? $0.toolCalls?.map(\.function.arguments).joined(separator: " ") ?? "" } .map(detectLinks(from:)) .flatMap { $0 } diff --git a/Core/Sources/ChatGPTChatTab/Chat.swift b/Core/Sources/ChatGPTChatTab/Chat.swift index 868266bc..6b61adb5 100644 --- a/Core/Sources/ChatGPTChatTab/Chat.swift +++ b/Core/Sources/ChatGPTChatTab/Chat.swift @@ -313,6 +313,7 @@ struct Chat: ReducerProtocol { } return .ignored case .function: return .function + case .tool: return .function } }(), text: message.summary ?? message.content ?? "", diff --git a/Core/Sources/ChatService/ChatService.swift b/Core/Sources/ChatService/ChatService.swift index e8b7dae0..4bb74639 100644 --- a/Core/Sources/ChatService/ChatService.swift +++ b/Core/Sources/ChatService/ChatService.swift @@ -124,9 +124,9 @@ public final class ChatService: ObservableObject { await chatGPTService.stopReceivingMessage() isReceivingMessage = false - // if it's stopped before the function finishes, remove the function call. + // if it's stopped before the tool calls finish, remove the message. await memory.mutateHistory { history in - if history.last?.role == .assistant, history.last?.functionCall != nil { + if history.last?.role == .assistant, history.last?.toolCalls != nil { history.removeLast() } } diff --git a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift index 6eaa6b5f..8ef06b86 100644 --- a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift +++ b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift @@ -170,7 +170,7 @@ extension ChatModelEdit.State { maxTokens: model.info.maxTokens, supportsFunctionCalling: model.info.supportsFunctionCalling, modelName: model.info.modelName, - ollamaKeepAlive: model.info.ollamaKeepAlive, + ollamaKeepAlive: model.info.ollamaInfo.keepAlive, apiKeySelection: .init( apiKeyName: model.info.apiKeyName, apiKeyManagement: .init(availableAPIKeyNames: [model.info.apiKeyName]) @@ -198,7 +198,7 @@ extension ChatModel { return state.supportsFunctionCalling }(), modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines), - ollamaKeepAlive: state.ollamaKeepAlive + ollamaInfo: .init(keepAlive: state.ollamaKeepAlive) ) ) } diff --git a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift index 5154a917..c5d1378e 100644 --- a/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift +++ b/Core/Sources/HostApp/AccountSettings/EmbeddingModelManagement/EmbeddingModelEdit.swift @@ -155,7 +155,7 @@ extension EmbeddingModelEdit.State { format: model.format, maxTokens: model.info.maxTokens, modelName: model.info.modelName, - ollamaKeepAlive: model.info.ollamaKeepAlive, + ollamaKeepAlive: model.info.ollamaInfo.keepAlive, apiKeySelection: .init( apiKeyName: model.info.apiKeyName, apiKeyManagement: .init(availableAPIKeyNames: [model.info.apiKeyName]) @@ -177,7 +177,7 @@ extension EmbeddingModel { isFullURL: state.isFullURL, maxTokens: state.maxTokens, modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines), - ollamaKeepAlive: state.ollamaKeepAlive + ollamaInfo: .init(keepAlive: state.ollamaKeepAlive) ) ) } diff --git a/Pro b/Pro index fbb89b80..c6cace85 160000 --- a/Pro +++ b/Pro @@ -1 +1 @@ -Subproject commit fbb89b803e55e4e27d3346fd5a6b49a3b108280d +Subproject commit c6cace85afbe7cc277462e852de7f86e708160d2 diff --git a/Tool/Sources/LangChain/Chains/LLMChain.swift b/Tool/Sources/LangChain/Chains/LLMChain.swift index 1201bd45..2ba4aef4 100644 --- a/Tool/Sources/LangChain/Chains/LLMChain.swift +++ b/Tool/Sources/LangChain/Chains/LLMChain.swift @@ -33,10 +33,11 @@ public class ChatModelChain: Chain { public func parseOutput(_ output: Output) -> String { if let content = output.content { return content - } else if let functionCall = output.functionCall { - return "\(functionCall.name): \(functionCall.arguments)" + } else if let toolCalls = output.toolCalls { + return toolCalls.map { "[\($0.id)] \($0.function.name): \($0.function.arguments)" } + .joined(separator: "\n") } - + return "" } } diff --git a/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift b/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift index 3c38eb3a..3b24e6ad 100644 --- a/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift +++ b/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift @@ -42,7 +42,7 @@ public final class RefineDocumentChain: Chain { } class FunctionProvider: ChatGPTFunctionProvider { - var functionCallStrategy: FunctionCallStrategy? = .name("respond") + var functionCallStrategy: FunctionCallStrategy? = .function(name: "respond") var functions: [any ChatGPTFunction] = [RespondFunction()] } @@ -153,7 +153,7 @@ public final class RefineDocumentChain: Chain { } func extractAnswer(_ chatMessage: ChatMessage) -> IntermediateAnswer { - if let functionCall = chatMessage.functionCall { + for functionCall in chatMessage.toolCalls?.map(\.function) ?? [] { do { let intermediateAnswer = try JSONDecoder().decode( IntermediateAnswer.self, diff --git a/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift b/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift index d55cafd4..4c9f696a 100644 --- a/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift +++ b/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift @@ -15,7 +15,7 @@ public final class RelevantInformationExtractionChain: Chain { public typealias Output = String class FunctionProvider: ChatGPTFunctionProvider { - var functionCallStrategy: FunctionCallStrategy? = .name("saveFinalAnswer") + var functionCallStrategy: FunctionCallStrategy? = .function(name: "saveFinalAnswer") var functions: [any ChatGPTFunction] = [FinalAnswer()] } @@ -103,8 +103,10 @@ public final class RelevantInformationExtractionChain: Chain { taskInput, callbackManagers: callbackManagers ) - - if let functionCall = output.functionCall { + + if let functionCall = output.toolCalls? + .first(where: { $0.function.name == FinalAnswer().name })?.function + { do { let arguments = try JSONDecoder().decode( FinalAnswer.Arguments.self, @@ -118,7 +120,7 @@ public final class RelevantInformationExtractionChain: Chain { return output.content ?? "" } } - + return output.content ?? "" } diff --git a/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift b/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift index 9c938cce..6ea1dbb5 100644 --- a/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift +++ b/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift @@ -53,7 +53,7 @@ public class StructuredOutputChatModelChain: Chain { } var functionCallStrategy: FunctionCallStrategy? { - .name(endFunction.name) + .function(name: endFunction.name) } } @@ -108,7 +108,7 @@ public class StructuredOutputChatModelChain: Chain { } public func parseOutput(_ message: ChatMessage) async -> Output? { - if let functionCall = message.functionCall { + if let functionCall = message.toolCalls?.first?.function { do { let result = try JSONDecoder().decode( EndFunction.Arguments.self, diff --git a/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift index ecdc9f93..324cc554 100644 --- a/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift +++ b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift @@ -1,8 +1,8 @@ import AIModel +import CodableWrappers import Foundation import Preferences -/// https://platform.openai.com/docs/api-reference/chat/create struct ChatCompletionsRequestBody: Codable, Equatable { struct Message: Codable, Equatable { /// The role of the message. @@ -14,14 +14,12 @@ struct ChatCompletionsRequestBody: Codable, Equatable { /// /// - important: It's required when the role is `function`. var name: String? - /// When the bot wants to call a function, it will reply with a function call in format: - /// ```json - /// { - /// "name": "weather", - /// "arguments": "{ \"location\": \"earth\" }" - /// } - /// ``` - var function_call: ChatCompletionsRequestBody.MessageFunctionCall? + /// Tool calls in an assistant message. + var toolCalls: [MessageToolCall]? + /// When we want to call a tool, we have to provide the id of the call. + /// + /// - important: It's required when the role is `tool`. + var toolCallId: String? } struct MessageFunctionCall: Codable, Equatable { @@ -31,77 +29,76 @@ struct ChatCompletionsRequestBody: Codable, Equatable { var arguments: String? } - struct Function: Codable { - var name: String - var description: String - /// JSON schema. - var arguments: String + struct MessageToolCall: Codable, Equatable { + /// The id of the tool call. + var id: String + /// The type of the tool. + var type: String + /// The function call. + var function: MessageFunctionCall + } + + struct Tool: Codable, Equatable { + var type: String = "function" + var function: ChatGPTFunctionSchema } var model: String var messages: [Message] var temperature: Double? - var top_p: Double? - var n: Double? var stream: Bool? var stop: [String]? - var max_tokens: Int? - var presence_penalty: Double? - var frequency_penalty: Double? - var logit_bias: [String: Double]? - var user: String? + var maxTokens: Int? /// Pass nil to let the bot decide. - var function_call: FunctionCallStrategy? - var functions: [ChatGPTFunctionSchema]? + var toolChoice: FunctionCallStrategy? + var tools: [Tool]? init( model: String, messages: [Message], temperature: Double? = nil, - top_p: Double? = nil, - n: Double? = nil, stream: Bool? = nil, stop: [String]? = nil, - max_tokens: Int? = nil, - presence_penalty: Double? = nil, - frequency_penalty: Double? = nil, - logit_bias: [String: Double]? = nil, - user: String? = nil, - function_call: FunctionCallStrategy? = nil, - functions: [ChatGPTFunctionSchema] = [] + maxTokens: Int? = nil, + toolChoice: FunctionCallStrategy? = nil, + tools: [Tool] = [] ) { self.model = model self.messages = messages self.temperature = temperature - self.top_p = top_p - self.n = n self.stream = stream self.stop = stop - self.max_tokens = max_tokens - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - self.logit_bias = logit_bias - self.user = user + self.maxTokens = maxTokens if UserDefaults.shared.value(for: \.disableFunctionCalling) { - self.function_call = nil - self.functions = nil + self.toolChoice = nil + self.tools = nil } else { - self.function_call = function_call - self.functions = functions.isEmpty ? nil : functions + self.toolChoice = toolChoice + self.tools = tools.isEmpty ? nil : tools } } } +struct EmptyMessageFunctionCall: FallbackValueProvider { + static var defaultValue: ChatCompletionsRequestBody.MessageFunctionCall { + .init(name: "") + } +} + public enum FunctionCallStrategy: Codable, Equatable { /// Forbid the bot to call any function. case none /// Let the bot choose what function to call. case auto /// Force the bot to call a function with the given name. - case name(String) + case function(name: String) struct CallFunctionNamed: Codable { - var name: String + var type = "function" + let function: Function + struct Function: Codable { + var name: String + } } public func encode(to encoder: Encoder) throws { @@ -111,8 +108,8 @@ public enum FunctionCallStrategy: Codable, Equatable { try container.encode("none") case .auto: try container.encode("auto") - case let .name(name): - try container.encode(CallFunctionNamed(name: name)) + case let .function(name): + try container.encode(CallFunctionNamed(function: .init(name: name))) } } } @@ -144,28 +141,29 @@ extension AsyncSequence { } } -struct ChatCompletionsStreamDataChunk: Codable { - var id: String? - var object: String? - var model: String? - var choices: [Choice]? - - struct Choice: Codable { - var delta: Delta? - var index: Int? - var finish_reason: String? - - struct Delta: Codable { - struct FunctionCall: Codable { - var name: String? - var arguments: String? - } +struct ChatCompletionsStreamDataChunk { + struct Delta { + struct FunctionCall { + var name: String? + var arguments: String? + } - var role: ChatMessage.Role? - var content: String? - var function_call: FunctionCall? + struct ToolCall { + var id: String? + var type: String? + var function: FunctionCall? } + + var role: ChatMessage.Role? + var content: String? + var toolCalls: [ToolCall]? } + + var id: String? + var object: String? + var model: String? + var message: Delta? + var finishReason: String? } // MARK: - Non Stream API @@ -174,44 +172,14 @@ protocol ChatCompletionsAPI { func callAsFunction() async throws -> ChatCompletionResponseBody } -/// https://platform.openai.com/docs/api-reference/chat/create struct ChatCompletionResponseBody: Codable, Equatable { - struct Message: Codable, Equatable { - /// The role of the message. - var role: ChatMessage.Role - /// The content of the message. - var content: String? - /// When we want to reply to a function call with the result, we have to provide the - /// name of the function call, and include the result in `content`. - /// - /// - important: It's required when the role is `function`. - var name: String? - /// When the bot wants to call a function, it will reply with a function call in format: - /// ```json - /// { - /// "name": "weather", - /// "arguments": "{ \"location\": \"earth\" }" - /// } - /// ``` - var function_call: ChatCompletionsRequestBody.MessageFunctionCall? - } - - struct Choice: Codable, Equatable { - var message: Message - var index: Int - var finish_reason: String - } - - struct Usage: Codable, Equatable { - var prompt_tokens: Int - var completion_tokens: Int - var total_tokens: Int - } - + typealias Message = ChatCompletionsRequestBody.Message + var id: String? var object: String var model: String - var usage: Usage - var choices: [Choice] + var message: Message + var otherChoices: [Message] + var finishReason: String } diff --git a/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift index 7608fe10..4b779168 100644 --- a/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift +++ b/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift @@ -26,48 +26,16 @@ actor GoogleAIChatCompletionsService: ChatCompletionsAPI, ChatCompletionsStreamA name: model.info.modelName, apiKey: apiKey, generationConfig: .init(GenerationConfig( - temperature: requestBody.temperature.map(Float.init), - topP: requestBody.top_p.map(Float.init) + temperature: requestBody.temperature.map(Float.init) )) ) let history = prompt.googleAICompatible.history.map { message in - ModelContent( - ChatMessage( - role: message.role, - content: message.content, - name: message.name, - functionCall: message.functionCall.map { - .init(name: $0.name, arguments: $0.arguments) - } - ) - ) + ModelContent(message) } do { let response = try await aiModel.generateContent(history) - - return .init( - object: "chat.completion", - model: model.info.modelName, - usage: .init(prompt_tokens: 0, completion_tokens: 0, total_tokens: 0), - choices: response.candidates.enumerated().map { - let (index, candidate) = $0 - return .init( - message: .init( - role: .assistant, - content: candidate.content.parts.first(where: { part in - if let text = part.text { - return !text.isEmpty - } else { - return false - } - })?.text ?? "" - ), - index: index, - finish_reason: candidate.finishReason?.rawValue ?? "" - ) - } - ) + return response.formalized() } catch let error as GenerateContentError { struct ErrorWrapper: Error, LocalizedError { let error: Error @@ -98,21 +66,11 @@ actor GoogleAIChatCompletionsService: ChatCompletionsAPI, ChatCompletionsStreamA name: model.info.modelName, apiKey: apiKey, generationConfig: .init(GenerationConfig( - temperature: requestBody.temperature.map(Float.init), - topP: requestBody.top_p.map(Float.init) + temperature: requestBody.temperature.map(Float.init) )) ) let history = prompt.googleAICompatible.history.map { message in - ModelContent( - ChatMessage( - role: message.role, - content: message.content, - name: message.name, - functionCall: message.functionCall.map { - .init(name: $0.name, arguments: $0.arguments) - } - ) - ) + ModelContent(message) } let stream = AsyncThrowingStream { continuation in @@ -121,17 +79,7 @@ actor GoogleAIChatCompletionsService: ChatCompletionsAPI, ChatCompletionsStreamA do { for try await response in stream { if Task.isCancelled { break } - let chunk = ChatCompletionsStreamDataChunk( - object: "", - model: model.info.modelName, - choices: response.candidates.map { candidate in - .init(delta: .init( - role: .assistant, - content: candidate.content.parts - .first(where: { $0.text != nil })?.text ?? "" - )) - } - ) + let chunk = response.formalizedAsChunk() continuation.yield(chunk) } continuation.finish() @@ -246,7 +194,7 @@ extension ChatGPTPrompt { extension ModelContent { static func convertRole(_ role: ChatMessage.Role) -> String { switch role { - case .user, .system, .function: + case .user, .system, .function, .tool: return "user" case .assistant: return "model" @@ -263,12 +211,18 @@ extension ModelContent { return """ Result of \(message.name ?? "function"): \(message.content ?? "N/A") """ + case .tool: + return """ + Result of \(message.toolCallId ?? "tool"): \(message.content ?? "N/A") + """ case .assistant: - if let functionCall = message.functionCall { - return """ - Call function: \(functionCall.name) - Arguments: \(functionCall.arguments) - """ + if let toolCalls = message.toolCalls { + return toolCalls.map { + """ + Call function: \($0.function.name) - \($0.id) + Arguments: \($0.function.arguments) + """ + }.joined(separator: "\n") } else { return message.content ?? " " } @@ -282,3 +236,57 @@ extension ModelContent { } } +extension GenerateContentResponse { + func formalized() -> ChatCompletionResponseBody { + let message: ChatCompletionResponseBody.Message + let otherMessages: [ChatCompletionResponseBody.Message] + + func convertMessage(_ candidate: CandidateResponse) -> ChatCompletionResponseBody.Message { + .init( + role: .assistant, + content: candidate.content.parts.first(where: { part in + if let text = part.text { + return !text.isEmpty + } else { + return false + } + })?.text ?? "" + ) + } + + if let first = candidates.first { + message = convertMessage(first) + otherMessages = candidates.dropFirst().map { convertMessage($0) } + } else { + message = .init(role: .assistant, content: "") + otherMessages = [] + } + + return .init( + object: "chat.completion", + model: "", + message: message, + otherChoices: otherMessages, + finishReason: candidates.first?.finishReason?.rawValue ?? "" + ) + } + + func formalizedAsChunk() -> ChatCompletionsStreamDataChunk { + func convertMessage( + _ candidate: CandidateResponse + ) -> ChatCompletionsStreamDataChunk.Delta { + .init( + role: .assistant, + content: candidate.content.parts + .first(where: { $0.text != nil })?.text ?? "" + ) + } + + return .init( + object: "", + model: "", + message: candidates.first.map(convertMessage) + ) + } +} + diff --git a/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift index b5046868..d65d3cca 100644 --- a/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift +++ b/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift @@ -41,6 +41,8 @@ extension OllamaChatCompletionsService: ChatCompletionsAPI { return .system case .function: return .user + case .tool: + return .user } }(), content: message.content) }, @@ -48,7 +50,7 @@ extension OllamaChatCompletionsService: ChatCompletionsAPI { options: .init( temperature: requestBody.temperature, stop: requestBody.stop, - num_predict: requestBody.max_tokens + num_predict: requestBody.maxTokens ), keep_alive: nil, format: nil @@ -78,29 +80,23 @@ extension OllamaChatCompletionsService: ChatCompletionsAPI { return .init( object: body.model, model: body.model, - usage: .init( - prompt_tokens: body.prompt_eval_count ?? 0, - completion_tokens: body.eval_count ?? 0, - total_tokens: (body.eval_count ?? 0) + (body.prompt_eval_count ?? 0) - ), - choices: [ + message: body.message.map { message in .init( - message: body.message.map { message in - .init(role: { - switch message.role { - case .assistant: - return .assistant - case .user: - return .user - case .system: - return .system - } - }(), content: message.content) - } ?? .init(role: .assistant), - index: 0, - finish_reason: "" - ), - ] + role: { + switch message.role { + case .assistant: + return .assistant + case .user: + return .user + case .system: + return .system + } + }(), + content: message.content + ) + } ?? .init(role: .assistant, content: ""), + otherChoices: [], + finishReason: "" ) } } @@ -122,6 +118,8 @@ extension OllamaChatCompletionsService: ChatCompletionsStreamAPI { return .system case .function: return .user + case .tool: + return .user } }(), content: message.content) }, @@ -129,7 +127,7 @@ extension OllamaChatCompletionsService: ChatCompletionsStreamAPI { options: .init( temperature: requestBody.temperature, stop: requestBody.stop, - num_predict: requestBody.max_tokens + num_predict: requestBody.maxTokens ), keep_alive: nil, format: nil @@ -166,25 +164,21 @@ extension OllamaChatCompletionsService: ChatCompletionsStreamAPI { id: UUID().uuidString, object: chunk.model, model: chunk.model, - choices: [ - .init( - delta: .init( - role: { - switch chunk.message?.role { - case .none: - return nil - case .assistant: - return .assistant - case .user: - return .user - case .system: - return .system - } - }(), - content: chunk.message?.content - ) - ), - ] + message: .init( + role: { + switch chunk.message?.role { + case .none: + return nil + case .assistant: + return .assistant + case .user: + return .user + case .system: + return .system + } + }(), + content: chunk.message?.content + ) ) } diff --git a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift index 4cbdf2bd..2c868404 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift @@ -3,6 +3,7 @@ import AsyncAlgorithms import Foundation import Preferences +/// https://platform.openai.com/docs/api-reference/chat/create actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI { struct CompletionAPIError: Error, Codable, LocalizedError { struct E: Codable { @@ -17,9 +18,121 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI var errorDescription: String? { error.message } } + struct StreamDataChunk: Codable { + var id: String? + var object: String? + var model: String? + var choices: [Choice]? + + struct Choice: Codable { + var delta: Delta? + var index: Int? + var finish_reason: String? + + struct Delta: Codable { + var role: ChatMessage.Role? + var content: String? + var function_call: RequestBody.MessageFunctionCall? + var tool_calls: [RequestBody.MessageToolCall]? + } + } + } + + struct ResponseBody: Codable, Equatable { + struct Message: Codable, Equatable { + /// The role of the message. + var role: ChatMessage.Role + /// The content of the message. + var content: String? + /// When we want to reply to a function call with the result, we have to provide the + /// name of the function call, and include the result in `content`. + /// + /// - important: It's required when the role is `function`. + var name: String? + /// When the bot wants to call a function, it will reply with a function call in format: + /// ```json + /// { + /// "name": "weather", + /// "arguments": "{ \"location\": \"earth\" }" + /// } + /// ``` + var function_call: RequestBody.MessageFunctionCall? + /// Tool calls in an assistant message. + var tool_calls: [RequestBody.MessageToolCall]? + } + + struct Choice: Codable, Equatable { + var message: Message + var index: Int + var finish_reason: String + } + + struct Usage: Codable, Equatable { + var prompt_tokens: Int + var completion_tokens: Int + var total_tokens: Int + } + + var id: String? + var object: String + var model: String + var usage: Usage + var choices: [Choice] + } + + struct RequestBody: Codable, Equatable { + struct Message: Codable, Equatable { + /// The role of the message. + var role: ChatMessage.Role + /// The content of the message. + var content: String + /// When we want to reply to a function call with the result, we have to provide the + /// name of the function call, and include the result in `content`. + /// + /// - important: It's required when the role is `function`. + var name: String? + /// Tool calls in an assistant message. + var tool_calls: [MessageToolCall]? + /// When we want to call a tool, we have to provide the id of the call. + /// + /// - important: It's required when the role is `tool`. + var tool_call_id: String? + } + + struct MessageFunctionCall: Codable, Equatable { + /// The name of the + var name: String + /// A JSON string. + var arguments: String? + } + + struct MessageToolCall: Codable, Equatable { + /// The id of the tool call. + var id: String + /// The type of the tool. + var type: String + /// The function call. + var function: MessageFunctionCall + } + + struct Tool: Codable, Equatable { + var type: String = "function" + var function: ChatGPTFunctionSchema + } + + var model: String + var messages: [Message] + var temperature: Double? + var stream: Bool? + var stop: [String]? + var max_tokens: Int? + var tool_choice: FunctionCallStrategy? + var tools: [Tool]? + } + var apiKey: String var endpoint: URL - var requestBody: ChatCompletionsRequestBody + var requestBody: RequestBody var model: ChatModel init( @@ -30,7 +143,7 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI ) { self.apiKey = apiKey self.endpoint = endpoint - self.requestBody = requestBody + self.requestBody = .init(requestBody) self.model = model } @@ -89,9 +202,9 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI guard line.hasPrefix(prefix), let content = line.dropFirst(prefix.count).data(using: .utf8), let chunk = try? JSONDecoder() - .decode(ChatCompletionsStreamDataChunk.self, from: content) + .decode(StreamDataChunk.self, from: content) else { continue } - continuation.yield(chunk) + continuation.yield(chunk.formalized()) } continuation.finish() } catch { @@ -147,7 +260,8 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI } do { - return try JSONDecoder().decode(ChatCompletionResponseBody.self, from: result) + let body = try JSONDecoder().decode(ResponseBody.self, from: result) + return body.formalized() } catch { dump(error) throw error @@ -155,3 +269,145 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI } } +extension OpenAIChatCompletionsService.ResponseBody { + func formalized() -> ChatCompletionResponseBody { + let message: ChatCompletionResponseBody.Message + let otherMessages: [ChatCompletionResponseBody.Message] + + func convertMessage(_ message: Message) -> ChatCompletionResponseBody.Message { + .init( + role: message.role, + content: message.content ?? "", + toolCalls: { + if let toolCalls = message.tool_calls { + return toolCalls.map { toolCall in + .init( + id: toolCall.id, + type: toolCall.type, + function: .init( + name: toolCall.function.name, + arguments: toolCall.function.arguments + ) + ) + } + } else if let functionCall = message.function_call { + return [ + .init( + id: functionCall.name, + type: "function", + function: .init( + name: functionCall.name, + arguments: functionCall.arguments + ) + ), + ] + } else { + return [] + } + }() + ) + } + + if let first = choices.first?.message { + message = convertMessage(first) + otherMessages = choices.dropFirst().map { convertMessage($0.message) } + } else { + message = .init(role: .assistant, content: "") + otherMessages = [] + } + + return .init( + id: id, + object: object, + model: model, + message: message, + otherChoices: otherMessages, + finishReason: choices.first?.finish_reason ?? "" + ) + } +} + +extension OpenAIChatCompletionsService.StreamDataChunk { + func formalized() -> ChatCompletionsStreamDataChunk { + .init( + id: id, + object: object, + model: model, + message: { + if let choice = self.choices?.first { + return .init( + role: choice.delta?.role, + content: choice.delta?.content, + toolCalls: { + if let toolCalls = choice.delta?.tool_calls { + return toolCalls.map { + .init( + id: $0.id, + type: $0.type, + function: .init( + name: $0.function.name, + arguments: $0.function.arguments + ) + ) + } + } + + if let functionCall = choice.delta?.function_call { + return [ + .init( + id: functionCall.name, + type: "function", + function: .init( + name: functionCall.name, + arguments: functionCall.arguments + ) + ), + ] + } + + return nil + }() + ) + } + return nil + }(), + finishReason: choices?.first?.finish_reason + ) + } +} + +extension OpenAIChatCompletionsService.RequestBody { + init(_ body: ChatCompletionsRequestBody) { + model = body.model + messages = body.messages.map { + .init( + role: $0.role, + content: $0.content, + name: $0.name, + tool_calls: $0.toolCalls?.map { tool in + MessageToolCall( + id: tool.id, + type: tool.type, + function: MessageFunctionCall( + name: tool.function.name, + arguments: tool.function.arguments + ) + ) + }, + tool_call_id: $0.toolCallId + ) + } + temperature = body.temperature + stream = body.stream + stop = body.stop + max_tokens = body.maxTokens + tool_choice = body.toolChoice + tools = body.tools?.map { + Tool( + type: $0.type, + function: $0.function + ) + } + } +} + diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index 9ccfe244..9e3a9fdb 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -2,6 +2,7 @@ import AIModel import AsyncAlgorithms import Dependencies import Foundation +import IdentifiedCollections import Preferences public protocol ChatGPTServiceType { @@ -168,7 +169,7 @@ public class ChatGPTService: ChatGPTServiceType { role: .user, content: content, name: nil, - functionCall: nil, + toolCalls: nil, summary: summary, references: [] ) @@ -179,18 +180,20 @@ public class ChatGPTService: ChatGPTServiceType { AsyncThrowingStream { continuation in let task = Task(priority: .userInitiated) { do { - var functionCall: ChatMessage.FunctionCall? - var functionCallMessageID = "" + var pendingToolCalls = IdentifiedArrayOf() + var functionCallMessageIDs = [String: String]() var isInitialCall = true - loop: while functionCall != nil || isInitialCall { + loop: while !pendingToolCalls.isEmpty || isInitialCall { try Task.checkCancellation() isInitialCall = false - if let call = functionCall { + for toolCall in pendingToolCalls { if !configuration.runFunctionsAutomatically { break loop } - functionCall = nil - await runFunctionCall(call, messageId: functionCallMessageID) + await runFunctionCall( + toolCall, + messageId: functionCallMessageIDs[toolCall.id] + ) } let stream = try await sendMemory() @@ -206,18 +209,15 @@ public class ChatGPTService: ChatGPTServiceType { #if DEBUG reply.append(text) #endif - case let .functionCall(call): - if functionCall == nil { - functionCallMessageID = uuid().uuidString - functionCall = call - } else { - functionCall?.name.append(call.name) - functionCall?.arguments.append(call.arguments) - } - await prepareFunctionCall( - call, - messageId: functionCallMessageID + + case let .toolCall(toolCall): + let id = storeToolCallsChunks( + chunk: toolCall, + into: &pendingToolCalls, + messageIds: &functionCallMessageIDs ) + + await prepareFunctionCall(toolCall, messageId: id) } } #if DEBUG @@ -257,17 +257,19 @@ public class ChatGPTService: ChatGPTServiceType { return try await Debugger.$id.withValue(.init()) { let message = try await sendMemoryAndWait() var finalResult = message?.content - var functionCall = message?.functionCall - while let call = functionCall { + var toolCalls = message?.toolCalls + while let calls = toolCalls, !calls.isEmpty { try Task.checkCancellation() if !configuration.runFunctionsAutomatically { break } - functionCall = nil - await runFunctionCall(call) + toolCalls = nil + for call in calls { + await runFunctionCall(call) + } guard let nextMessage = try await sendMemoryAndWait() else { break } finalResult = nextMessage.content - functionCall = nextMessage.functionCall + toolCalls = nextMessage.toolCalls } #if DEBUG @@ -291,7 +293,7 @@ public class ChatGPTService: ChatGPTServiceType { extension ChatGPTService { enum StreamContent { case text(String) - case functionCall(ChatMessage.FunctionCall) + case toolCall(ChatMessage.ToolCall) } /// Send the memory as prompt to ChatGPT, with stream enabled. @@ -305,42 +307,7 @@ extension ChatGPTService { throw ChatGPTServiceError.endpointIncorrect } - let messages = prompt.history.map { - ChatCompletionsRequestBody.Message( - role: $0.role, - content: $0.content ?? "", - name: $0.name, - function_call: $0.functionCall.map { - .init(name: $0.name, arguments: $0.arguments) - } - ) - } - let remainingTokens = prompt.remainingTokenCount - - let requestBody = ChatCompletionsRequestBody( - model: model.info.modelName, - messages: messages, - temperature: configuration.temperature, - stream: true, - stop: configuration.stop.isEmpty ? nil : configuration.stop, - max_tokens: maxTokenForReply( - maxToken: model.info.maxTokens, - remainingTokens: remainingTokens - ), - function_call: model.info.supportsFunctionCalling - ? functionProvider.functionCallStrategy - : nil, - functions: - model.info.supportsFunctionCalling - ? functionProvider.functions.map { - ChatGPTFunctionSchema( - name: $0.name, - description: $0.description, - parameters: $0.argumentSchema - ) - } - : [] - ) + let requestBody = createRequestBody(prompt: prompt, model: model, stream: true) let api = buildCompletionStreamAPI( configuration.apiKey, @@ -368,16 +335,20 @@ extension ChatGPTService { if Task.isCancelled { throw CancellationError() } - guard let delta = chunk.choices?.first?.delta else { continue } + guard let delta = chunk.message else { continue } // The api will always return a function call with JSON object. // The first round will contain the function name and an empty argument. // e.g. {"name":"weather","arguments":""} // The other rounds will contain part of the arguments. - let functionCall = delta.function_call.map { - ChatMessage.FunctionCall( - name: $0.name ?? "", - arguments: $0.arguments ?? "" + let toolCalls = delta.toolCalls?.map { + ChatMessage.ToolCall( + id: $0.id ?? "", + type: $0.type ?? "", + function: .init( + name: $0.function?.name ?? "", + arguments: $0.function?.arguments ?? "" + ) ) } @@ -385,11 +356,13 @@ extension ChatGPTService { id: proposedId, role: delta.role, content: delta.content, - functionCall: functionCall + toolCalls: toolCalls ) - if let functionCall { - continuation.yield(.functionCall(functionCall)) + if let toolCalls { + for toolCall in toolCalls { + continuation.yield(.toolCall(toolCall)) + } } if let content = delta.content { @@ -433,42 +406,7 @@ extension ChatGPTService { throw ChatGPTServiceError.endpointIncorrect } - let messages = prompt.history.map { - ChatCompletionsRequestBody.Message( - role: $0.role, - content: $0.content ?? "", - name: $0.name, - function_call: $0.functionCall.map { - .init(name: $0.name, arguments: $0.arguments) - } - ) - } - let remainingTokens = prompt.remainingTokenCount - - let requestBody = ChatCompletionsRequestBody( - model: model.info.modelName, - messages: messages, - temperature: configuration.temperature, - stream: true, - stop: configuration.stop.isEmpty ? nil : configuration.stop, - max_tokens: maxTokenForReply( - maxToken: model.info.maxTokens, - remainingTokens: remainingTokens - ), - function_call: model.info.supportsFunctionCalling - ? functionProvider.functionCallStrategy - : nil, - functions: - model.info.supportsFunctionCalling - ? functionProvider.functions.map { - ChatGPTFunctionSchema( - name: $0.name, - description: $0.description, - parameters: $0.argumentSchema - ) - } - : [] - ) + let requestBody = createRequestBody(prompt: prompt, model: model, stream: false) let api = buildCompletionAPI( configuration.apiKey, @@ -484,14 +422,17 @@ extension ChatGPTService { let response = try await api() - guard let choice = response.choices.first else { return nil } + let choice = response.message let message = ChatMessage( id: proposedId, - role: choice.message.role, - content: choice.message.content, - name: choice.message.name, - functionCall: choice.message.function_call.map { - ChatMessage.FunctionCall(name: $0.name, arguments: $0.arguments ?? "") + role: choice.role, + content: choice.content, + name: choice.name, + toolCalls: choice.toolCalls?.map { + ChatMessage.ToolCall(id: $0.id, type: $0.type, function: .init( + name: $0.function.name, + arguments: $0.function.arguments ?? "" + )) }, references: prompt.references ) @@ -499,11 +440,40 @@ extension ChatGPTService { return message } + func storeToolCallsChunks( + chunk toolCall: ChatMessage.ToolCall, + into toolCalls: inout IdentifiedArrayOf, + messageIds: inout [String: String] + ) -> String { + if let index = toolCalls.firstIndex(where: { $0.id == toolCall.id }) { + if !toolCall.id.isEmpty { + toolCalls[index].id = toolCall.id + } + if !toolCall.type.isEmpty { + toolCalls[index].type = toolCall.type + } + toolCalls[index].function.name.append(toolCall.function.name) + toolCalls[index].function.arguments.append(toolCall.function.arguments) + + } else { + toolCalls.append(toolCall) + } + + let id = messageIds[toolCall.id] ?? UUID().uuidString + messageIds[toolCall.id] = id + return id + } + /// When a function call is detected, but arguments are not yet ready, we can call this /// to insert a message placeholder in memory. - func prepareFunctionCall(_ call: ChatMessage.FunctionCall, messageId: String) async { - guard let function = functionProvider.function(named: call.name) else { return } - await memory.streamMessage(id: messageId, role: .function, name: call.name) + func prepareFunctionCall(_ call: ChatMessage.ToolCall, messageId: String) async { + guard let function = functionProvider.function(named: call.function.name) else { return } + await memory.streamMessage( + id: messageId, + role: .tool, + name: call.function.name, + toolCallId: call.id + ) await function.prepare { [weak self] summary in await self?.memory.updateMessage(id: messageId) { message in message.summary = summary @@ -514,24 +484,29 @@ extension ChatGPTService { /// Run a function call from the bot, and insert the result in memory. @discardableResult func runFunctionCall( - _ call: ChatMessage.FunctionCall, + _ call: ChatMessage.ToolCall, messageId: String? = nil ) async -> String { #if DEBUG - Debugger.didReceiveFunction(name: call.name, arguments: call.arguments) + Debugger.didReceiveFunction(name: call.function.name, arguments: call.function.arguments) #endif let messageId = messageId ?? uuid().uuidString - guard let function = functionProvider.function(named: call.name) else { - return await fallbackFunctionCall(call, messageId: messageId) + guard let function = functionProvider.function(named: call.function.name) else { + return await fallbackFunctionCall(call.function, messageId: messageId) } - await memory.streamMessage(id: messageId, role: .function, name: call.name) + await memory.streamMessage( + id: messageId, + role: .function, + name: call.function.name, + toolCallId: call.id + ) do { // Run the function - let result = try await function.call(argumentsJsonString: call.arguments) { + let result = try await function.call(argumentsJsonString: call.function.arguments) { [weak self] summary in await self?.memory.updateMessage(id: messageId) { message in message.summary = summary @@ -610,6 +585,57 @@ extension ChatGPTService { ) return content } + + func createRequestBody( + prompt: ChatGPTPrompt, + model: ChatModel, + stream: Bool + ) -> ChatCompletionsRequestBody { + let messages = prompt.history.map { + ChatCompletionsRequestBody.Message( + role: $0.role, + content: $0.content ?? "", + name: $0.name, + toolCalls: $0.toolCalls?.map { + .init( + id: $0.id, + type: $0.type, + function: .init( + name: $0.function.name, + arguments: $0.function.arguments + ) + ) + } + ) + } + let remainingTokens = prompt.remainingTokenCount + + let requestBody = ChatCompletionsRequestBody( + model: model.info.modelName, + messages: messages, + temperature: configuration.temperature, + stream: stream, + stop: configuration.stop.isEmpty ? nil : configuration.stop, + maxTokens: maxTokenForReply( + maxToken: model.info.maxTokens, + remainingTokens: remainingTokens + ), + toolChoice: model.info.supportsFunctionCalling + ? functionProvider.functionCallStrategy + : nil, + tools: model.info.supportsFunctionCalling + ? functionProvider.functions.map { + .init(function: ChatGPTFunctionSchema( + name: $0.name, + description: $0.description, + parameters: $0.argumentSchema + )) + } + : [] + ) + + return requestBody + } } extension ChatGPTService { diff --git a/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift b/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift index 07d72acb..0ed9873d 100644 --- a/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift +++ b/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift @@ -37,9 +37,13 @@ extension TokenEncoder { encodingContent.append(name) total += 1 } - if let functionCall = message.functionCall { - encodingContent.append(functionCall.name) - encodingContent.append(functionCall.arguments) + if let toolCalls = message.toolCalls { + for toolCall in toolCalls { + encodingContent.append(toolCall.id) + encodingContent.append(toolCall.type) + encodingContent.append(toolCall.function.name) + encodingContent.append(toolCall.function.arguments) + } } total += await withTaskGroup(of: Int.self, body: { group in for content in encodingContent { diff --git a/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift b/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift index d27569d6..62b8f369 100644 --- a/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift +++ b/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift @@ -23,7 +23,7 @@ public protocol ChatGPTMemory { func mutateHistory(_ update: (inout [ChatMessage]) -> Void) async /// Generate prompt that would be send through the API. /// - /// A memory should make sure that the history in the prompt + /// A memory should make sure that the history in the prompt /// doesn't exceed the maximum token count. /// /// The history can be different from the actual history. @@ -64,7 +64,8 @@ public extension ChatGPTMemory { role: ChatMessage.Role? = nil, content: String? = nil, name: String? = nil, - functionCall: ChatMessage.FunctionCall? = nil, + toolCallId: String? = nil, + toolCalls: [ChatMessage.ToolCall]? = nil, summary: String? = nil, references: [ChatMessage.Reference]? = nil ) async { @@ -80,12 +81,28 @@ public extension ChatGPTMemory { if let role { history[index].role = role } - if let functionCall { - if history[index].functionCall == nil { - history[index].functionCall = functionCall + if let toolCalls { + if history[index].toolCalls == nil { + history[index].toolCalls = toolCalls } else { - history[index].functionCall?.name.append(functionCall.name) - history[index].functionCall?.arguments.append(functionCall.arguments) + for toolCall in toolCalls { + if let index = history[index].toolCalls? + .firstIndex(where: { $0.id == toolCall.id }) + { + if !toolCall.id.isEmpty { + history[index].toolCalls?[index].id = toolCall.id + } + if !toolCall.type.isEmpty { + history[index].toolCalls?[index].type = toolCall.type + } + history[index].toolCalls?[index].function.name + .append(toolCall.function.name) + history[index].toolCalls?[index].function.arguments + .append(toolCall.function.arguments) + } else { + history[index].toolCalls?.append(toolCall) + } + } } } if let summary { @@ -97,13 +114,17 @@ public extension ChatGPTMemory { if let name { history[index].name = name } + if let toolCallId { + history[index].toolCallId = toolCallId + } } else { history.append(.init( id: id, role: role ?? .system, content: content, name: name, - functionCall: functionCall, + toolCallId: toolCallId, + toolCalls: toolCalls, summary: summary, references: references ?? [] )) diff --git a/Tool/Sources/OpenAIService/Models.swift b/Tool/Sources/OpenAIService/Models.swift index 8ff25b96..f4f0fff8 100644 --- a/Tool/Sources/OpenAIService/Models.swift +++ b/Tool/Sources/OpenAIService/Models.swift @@ -16,6 +16,7 @@ public struct ChatMessage: Equatable, Codable { case user case assistant case function + case tool } public struct FunctionCall: Codable, Equatable { @@ -26,6 +27,17 @@ public struct ChatMessage: Equatable, Codable { self.arguments = arguments } } + + public struct ToolCall: Codable, Equatable, Identifiable { + public var id: String + public var type: String + public var function: FunctionCall + public init(id: String, type: String, function: FunctionCall) { + self.id = id + self.type = type + self.function = function + } + } public struct Reference: Codable, Equatable { public enum Kind: String, Codable { @@ -82,7 +94,7 @@ public struct ChatMessage: Equatable, Codable { } /// A function call from the bot. - public var functionCall: FunctionCall? { + public var toolCalls: [ToolCall]? { didSet { tokensCount = nil } } @@ -90,6 +102,11 @@ public struct ChatMessage: Equatable, Codable { public var name: String? { didSet { tokensCount = nil } } + + /// The tool id of a reply to a tool call. + public var toolCallId: String? { + didSet { tokensCount = nil } + } /// The summary of a message that is used for display. public var summary: String? @@ -107,7 +124,7 @@ public struct ChatMessage: Equatable, Codable { /// Is the message considered empty. var isEmpty: Bool { if let content, !content.isEmpty { return false } - if let functionCall, !functionCall.name.isEmpty { return false } + if let toolCalls, !toolCalls.isEmpty { return false } if let name, !name.isEmpty { return false } return true } @@ -117,7 +134,8 @@ public struct ChatMessage: Equatable, Codable { role: Role, content: String?, name: String? = nil, - functionCall: FunctionCall? = nil, + toolCallId: String? = nil, + toolCalls: [ToolCall]? = nil, summary: String? = nil, tokenCount: Int? = nil, references: [Reference] = [] @@ -125,7 +143,8 @@ public struct ChatMessage: Equatable, Codable { self.role = role self.content = content self.name = name - self.functionCall = functionCall + self.toolCallId = toolCallId + self.toolCalls = toolCalls self.summary = summary self.id = id tokensCount = tokenCount From c962212c0be6df19351c9b2b691d6f11fbf5a2a5 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sun, 3 Mar 2024 00:46:25 +0800 Subject: [PATCH 18/37] Remove function role --- Core/Sources/ChatGPTChatTab/Chat.swift | 5 ++--- Core/Sources/ChatGPTChatTab/ChatPanel.swift | 4 ++-- .../ChatService/ContextAwareAutoManagedChatGPTMemory.swift | 2 +- Pro | 2 +- .../OpenAIService/APIs/GoogleAIChatCompletionsService.swift | 6 +----- .../OpenAIService/APIs/OlamaChatCompletionsService.swift | 4 ---- Tool/Sources/OpenAIService/ChatGPTService.swift | 4 ++-- Tool/Sources/OpenAIService/Models.swift | 1 - 8 files changed, 9 insertions(+), 19 deletions(-) diff --git a/Core/Sources/ChatGPTChatTab/Chat.swift b/Core/Sources/ChatGPTChatTab/Chat.swift index 6b61adb5..5a0df30c 100644 --- a/Core/Sources/ChatGPTChatTab/Chat.swift +++ b/Core/Sources/ChatGPTChatTab/Chat.swift @@ -9,7 +9,7 @@ public struct DisplayedChatMessage: Equatable { public enum Role: Equatable { case user case assistant - case function + case tool case ignored } @@ -312,8 +312,7 @@ struct Chat: ReducerProtocol { return .assistant } return .ignored - case .function: return .function - case .tool: return .function + case .tool: return .tool } }(), text: message.summary ?? message.content ?? "", diff --git a/Core/Sources/ChatGPTChatTab/ChatPanel.swift b/Core/Sources/ChatGPTChatTab/ChatPanel.swift index 7a729bfd..2729b5e4 100644 --- a/Core/Sources/ChatGPTChatTab/ChatPanel.swift +++ b/Core/Sources/ChatGPTChatTab/ChatPanel.swift @@ -258,7 +258,7 @@ struct ChatHistory: View { trailing: -8 )) .padding(.vertical, 4) - case .function: + case .tool: FunctionMessage(id: message.id, text: text) case .ignored: EmptyView() @@ -453,7 +453,7 @@ struct ChatPanel_Preview: PreviewProvider { ), .init( id: "6", - role: .function, + role: .tool, text: """ Searching for something... - abc diff --git a/Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift b/Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift index 9f4a53e1..66804271 100644 --- a/Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift +++ b/Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift @@ -39,7 +39,7 @@ public final class ContextAwareAutoManagedChatGPTMemory: ChatGPTMemory { public func generatePrompt() async -> ChatGPTPrompt { let content = (await memory.history) - .last(where: { $0.role == .user || $0.role == .function })?.content + .last(where: { $0.role == .user || $0.role == .tool })?.content try? await contextController.collectContextInformation( systemPrompt: """ \(chatService?.systemPrompt ?? "") diff --git a/Pro b/Pro index c6cace85..49bbd4a4 160000 --- a/Pro +++ b/Pro @@ -1 +1 @@ -Subproject commit c6cace85afbe7cc277462e852de7f86e708160d2 +Subproject commit 49bbd4a40e033ca059a256782039cef07ff63ff9 diff --git a/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift index 4b779168..02f26d21 100644 --- a/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift +++ b/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift @@ -194,7 +194,7 @@ extension ChatGPTPrompt { extension ModelContent { static func convertRole(_ role: ChatMessage.Role) -> String { switch role { - case .user, .system, .function, .tool: + case .user, .system, .tool: return "user" case .assistant: return "model" @@ -207,10 +207,6 @@ extension ModelContent { return "System Prompt:\n\(message.content ?? " ")" case .user: return message.content ?? " " - case .function: - return """ - Result of \(message.name ?? "function"): \(message.content ?? "N/A") - """ case .tool: return """ Result of \(message.toolCallId ?? "tool"): \(message.content ?? "N/A") diff --git a/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift index d65d3cca..e2ef4d5a 100644 --- a/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift +++ b/Tool/Sources/OpenAIService/APIs/OlamaChatCompletionsService.swift @@ -39,8 +39,6 @@ extension OllamaChatCompletionsService: ChatCompletionsAPI { return .user case .system: return .system - case .function: - return .user case .tool: return .user } @@ -116,8 +114,6 @@ extension OllamaChatCompletionsService: ChatCompletionsStreamAPI { return .user case .system: return .system - case .function: - return .user case .tool: return .user } diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index 9e3a9fdb..3b1c4f57 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -499,7 +499,7 @@ extension ChatGPTService { await memory.streamMessage( id: messageId, - role: .function, + role: .tool, name: call.function.name, toolCallId: call.id ) @@ -578,7 +578,7 @@ extension ChatGPTService { }() await memory.streamMessage( id: messageId, - role: .function, + role: .tool, content: content, name: call.name, summary: "Finished running function." diff --git a/Tool/Sources/OpenAIService/Models.swift b/Tool/Sources/OpenAIService/Models.swift index f4f0fff8..02901672 100644 --- a/Tool/Sources/OpenAIService/Models.swift +++ b/Tool/Sources/OpenAIService/Models.swift @@ -15,7 +15,6 @@ public struct ChatMessage: Equatable, Codable { case system case user case assistant - case function case tool } From a412c23f09cf94512ec2ac58488d7c226f890679 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sun, 3 Mar 2024 21:34:52 +0800 Subject: [PATCH 19/37] Move tool call responses into its source assistant message --- .../WebChatContextCollector.swift | 3 +- Core/Sources/ChatGPTChatTab/Chat.swift | 27 ++- Core/Sources/ChatService/ChatService.swift | 2 +- ...ContextAwareAutoManagedChatGPTMemory.swift | 2 +- Pro | 2 +- Tool/Sources/LangChain/Chains/LLMChain.swift | 2 +- .../Chains/RefineDocumentChain.swift | 2 +- .../RelevantInformationExtractionChain.swift | 2 +- .../StructuredOutputChatModelChain.swift | 2 +- .../APIs/ChatCompletionsAPIDefinition.swift | 27 ++- .../APIs/GoogleAIChatCompletionsService.swift | 18 +- .../APIs/OpenAIChatCompletionsService.swift | 98 +++++--- .../OpenAIService/ChatGPTService.swift | 210 +++++++++--------- .../FucntionCall/ChatGPTFunction.swift | 11 +- ...toManagedChatGPTMemoryOpenAIStrategy.swift | 10 +- .../OpenAIService/Memory/ChatGPTMemory.swift | 94 +++++--- Tool/Sources/OpenAIService/Models.swift | 36 ++- 17 files changed, 345 insertions(+), 203 deletions(-) diff --git a/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift b/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift index 851fdcf7..c35a03ae 100644 --- a/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift +++ b/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift @@ -32,7 +32,8 @@ extension WebChatContextCollector { static func detectLinks(from messages: [ChatMessage]) -> [String] { return messages.lazy .compactMap { - $0.content ?? $0.toolCalls?.map(\.function.arguments).joined(separator: " ") ?? "" + $0.content ?? $0.toolCallContext?.toolCalls.map(\.function.arguments) + .joined(separator: " ") ?? "" } .map(detectLinks(from:)) .flatMap { $0 } diff --git a/Core/Sources/ChatGPTChatTab/Chat.swift b/Core/Sources/ChatGPTChatTab/Chat.swift index 5a0df30c..05b50642 100644 --- a/Core/Sources/ChatGPTChatTab/Chat.swift +++ b/Core/Sources/ChatGPTChatTab/Chat.swift @@ -15,7 +15,7 @@ public struct DisplayedChatMessage: Equatable { public struct Reference: Equatable { public typealias Kind = ChatMessage.Reference.Kind - + public var title: String public var subtitle: String public var uri: String @@ -135,7 +135,7 @@ struct Chat: ReducerProtocol { await send(.focusOnTextField) await send(.refresh) } - + case .refresh: return .run { send in await send(.chatMenu(.refresh)) @@ -298,8 +298,9 @@ struct Chat: ReducerProtocol { }.cancellable(id: CancelID.observeDefaultScopesChange(id), cancelInFlight: true) case .historyChanged: - state.history = service.chatHistory.map { message in - .init( + state.history = service.chatHistory.flatMap { message in + var all = [DisplayedChatMessage]() + all.append(.init( id: message.id, role: { switch message.role { @@ -312,7 +313,6 @@ struct Chat: ReducerProtocol { return .assistant } return .ignored - case .tool: return .tool } }(), text: message.summary ?? message.content ?? "", @@ -325,7 +325,20 @@ struct Chat: ReducerProtocol { kind: $0.kind ) } - ) + )) + + if let responses = message.toolCallContext?.responses { + for response in responses { + all.append(.init( + id: message.id + response.id, + role: .tool, + text: response.summary ?? response.content, + references: [] + )) + } + } + + return all } state.title = { @@ -401,7 +414,7 @@ struct ChatMenu: ReducerProtocol { return .run { await $0(.refresh) } - + case .refresh: state.temperatureOverride = service.configuration.overriding.temperature state.chatModelIdOverride = service.configuration.overriding.modelId diff --git a/Core/Sources/ChatService/ChatService.swift b/Core/Sources/ChatService/ChatService.swift index 4bb74639..145473da 100644 --- a/Core/Sources/ChatService/ChatService.swift +++ b/Core/Sources/ChatService/ChatService.swift @@ -126,7 +126,7 @@ public final class ChatService: ObservableObject { // if it's stopped before the tool calls finish, remove the message. await memory.mutateHistory { history in - if history.last?.role == .assistant, history.last?.toolCalls != nil { + if history.last?.role == .assistant, history.last?.toolCallContext?.toolCalls != nil { history.removeLast() } } diff --git a/Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift b/Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift index 66804271..ac44d87c 100644 --- a/Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift +++ b/Core/Sources/ChatService/ContextAwareAutoManagedChatGPTMemory.swift @@ -39,7 +39,7 @@ public final class ContextAwareAutoManagedChatGPTMemory: ChatGPTMemory { public func generatePrompt() async -> ChatGPTPrompt { let content = (await memory.history) - .last(where: { $0.role == .user || $0.role == .tool })?.content + .last(where: { $0.role == .user })?.content try? await contextController.collectContextInformation( systemPrompt: """ \(chatService?.systemPrompt ?? "") diff --git a/Pro b/Pro index 49bbd4a4..a2e8aa56 160000 --- a/Pro +++ b/Pro @@ -1 +1 @@ -Subproject commit 49bbd4a40e033ca059a256782039cef07ff63ff9 +Subproject commit a2e8aa56ff9b30bb3f3ae50a01b183cfbabb14cb diff --git a/Tool/Sources/LangChain/Chains/LLMChain.swift b/Tool/Sources/LangChain/Chains/LLMChain.swift index 2ba4aef4..fd8ef05d 100644 --- a/Tool/Sources/LangChain/Chains/LLMChain.swift +++ b/Tool/Sources/LangChain/Chains/LLMChain.swift @@ -33,7 +33,7 @@ public class ChatModelChain: Chain { public func parseOutput(_ output: Output) -> String { if let content = output.content { return content - } else if let toolCalls = output.toolCalls { + } else if let toolCalls = output.toolCallContext?.toolCalls { return toolCalls.map { "[\($0.id)] \($0.function.name): \($0.function.arguments)" } .joined(separator: "\n") } diff --git a/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift b/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift index 3b24e6ad..bbf0f764 100644 --- a/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift +++ b/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift @@ -153,7 +153,7 @@ public final class RefineDocumentChain: Chain { } func extractAnswer(_ chatMessage: ChatMessage) -> IntermediateAnswer { - for functionCall in chatMessage.toolCalls?.map(\.function) ?? [] { + for functionCall in chatMessage.toolCallContext?.toolCalls.map(\.function) ?? [] { do { let intermediateAnswer = try JSONDecoder().decode( IntermediateAnswer.self, diff --git a/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift b/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift index 4c9f696a..445c75ee 100644 --- a/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift +++ b/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift @@ -104,7 +104,7 @@ public final class RelevantInformationExtractionChain: Chain { callbackManagers: callbackManagers ) - if let functionCall = output.toolCalls? + if let functionCall = output.toolCallContext?.toolCalls .first(where: { $0.function.name == FinalAnswer().name })?.function { do { diff --git a/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift b/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift index 6ea1dbb5..103f3244 100644 --- a/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift +++ b/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift @@ -108,7 +108,7 @@ public class StructuredOutputChatModelChain: Chain { } public func parseOutput(_ message: ChatMessage) async -> Output? { - if let functionCall = message.toolCalls?.first?.function { + if let functionCall = message.toolCallContext?.toolCalls.first?.function { do { let result = try JSONDecoder().decode( EndFunction.Arguments.self, diff --git a/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift index 324cc554..a86aba7b 100644 --- a/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift +++ b/Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift @@ -5,8 +5,28 @@ import Preferences struct ChatCompletionsRequestBody: Codable, Equatable { struct Message: Codable, Equatable { + enum Role: String, Codable, Equatable { + case system + case user + case assistant + case tool + + var asChatMessageRole: ChatMessage.Role { + switch self { + case .system: + return .system + case .user: + return .user + case .assistant: + return .assistant + case .tool: + return .user + } + } + } + /// The role of the message. - var role: ChatMessage.Role + var role: Role /// The content of the message. var content: String /// When we want to reply to a function call with the result, we have to provide the @@ -149,12 +169,13 @@ struct ChatCompletionsStreamDataChunk { } struct ToolCall { + var index: Int? var id: String? var type: String? var function: FunctionCall? } - var role: ChatMessage.Role? + var role: ChatCompletionsRequestBody.Message.Role? var content: String? var toolCalls: [ToolCall]? } @@ -174,7 +195,7 @@ protocol ChatCompletionsAPI { struct ChatCompletionResponseBody: Codable, Equatable { typealias Message = ChatCompletionsRequestBody.Message - + var id: String? var object: String var model: String diff --git a/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift index 02f26d21..1152f69a 100644 --- a/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift +++ b/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift @@ -194,7 +194,7 @@ extension ChatGPTPrompt { extension ModelContent { static func convertRole(_ role: ChatMessage.Role) -> String { switch role { - case .user, .system, .tool: + case .user, .system: return "user" case .assistant: return "model" @@ -207,16 +207,14 @@ extension ModelContent { return "System Prompt:\n\(message.content ?? " ")" case .user: return message.content ?? " " - case .tool: - return """ - Result of \(message.toolCallId ?? "tool"): \(message.content ?? "N/A") - """ case .assistant: - if let toolCalls = message.toolCalls { - return toolCalls.map { - """ - Call function: \($0.function.name) - \($0.id) - Arguments: \($0.function.arguments) + if let toolCallContext = message.toolCallContext { + return toolCallContext.toolCalls.map { call in + let response = toolCallContext.responses.first(where: { $0.id == call.id }) + return """ + Call function: \(call.function.name) + Arguments: \(call.function.arguments) + Result: \(response?.content ?? "N/A") """ }.joined(separator: "\n") } else { diff --git a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift index 2c868404..48d77a17 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift @@ -1,6 +1,7 @@ import AIModel import AsyncAlgorithms import Foundation +import Logger import Preferences /// https://platform.openai.com/docs/api-reference/chat/create @@ -18,6 +19,24 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI var errorDescription: String? { error.message } } + enum MessageRole: String, Codable { + case system + case user + case assistant + case function + case tool + + var formalized: ChatCompletionsRequestBody.Message.Role { + switch self { + case .system: return .system + case .user: return .user + case .assistant: return .assistant + case .function: return .tool + case .tool: return .tool + } + } + } + struct StreamDataChunk: Codable { var id: String? var object: String? @@ -30,7 +49,7 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI var finish_reason: String? struct Delta: Codable { - var role: ChatMessage.Role? + var role: MessageRole? var content: String? var function_call: RequestBody.MessageFunctionCall? var tool_calls: [RequestBody.MessageToolCall]? @@ -41,7 +60,7 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI struct ResponseBody: Codable, Equatable { struct Message: Codable, Equatable { /// The role of the message. - var role: ChatMessage.Role + var role: MessageRole /// The content of the message. var content: String? /// When we want to reply to a function call with the result, we have to provide the @@ -83,7 +102,7 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI struct RequestBody: Codable, Equatable { struct Message: Codable, Equatable { /// The role of the message. - var role: ChatMessage.Role + var role: MessageRole /// The content of the message. var content: String /// When we want to reply to a function call with the result, we have to provide the @@ -97,22 +116,28 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI /// /// - important: It's required when the role is `tool`. var tool_call_id: String? + /// When the bot wants to call a function, it will reply with a function call. + /// + /// Deprecated. + var function_call: MessageFunctionCall? } struct MessageFunctionCall: Codable, Equatable { /// The name of the - var name: String + var name: String? /// A JSON string. var arguments: String? } struct MessageToolCall: Codable, Equatable { + /// When it's returned as a data chunk, use the index to identify the tool call. + var index: Int? /// The id of the tool call. - var id: String + var id: String? /// The type of the tool. - var type: String + var type: String? /// The function call. - var function: MessageFunctionCall + var function: MessageFunctionCall? } struct Tool: Codable, Equatable { @@ -200,11 +225,17 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI if Task.isCancelled { break } let prefix = "data: " guard line.hasPrefix(prefix), - let content = line.dropFirst(prefix.count).data(using: .utf8), - let chunk = try? JSONDecoder() - .decode(StreamDataChunk.self, from: content) + let content = line.dropFirst(prefix.count).data(using: .utf8) else { continue } - continuation.yield(chunk.formalized()) + do { + let chunk = try JSONDecoder().decode( + StreamDataChunk.self, + from: content + ) + continuation.yield(chunk.formalized()) + } catch { + Logger.service.error("Error decoding stream data: \(error)") + } } continuation.finish() } catch { @@ -276,27 +307,27 @@ extension OpenAIChatCompletionsService.ResponseBody { func convertMessage(_ message: Message) -> ChatCompletionResponseBody.Message { .init( - role: message.role, + role: message.role.formalized, content: message.content ?? "", toolCalls: { if let toolCalls = message.tool_calls { return toolCalls.map { toolCall in .init( - id: toolCall.id, - type: toolCall.type, + id: toolCall.id ?? "", + type: toolCall.type ?? "function", function: .init( - name: toolCall.function.name, - arguments: toolCall.function.arguments + name: toolCall.function?.name ?? "", + arguments: toolCall.function?.arguments ) ) } } else if let functionCall = message.function_call { return [ .init( - id: functionCall.name, + id: functionCall.name ?? "", type: "function", function: .init( - name: functionCall.name, + name: functionCall.name ?? "", arguments: functionCall.arguments ) ), @@ -336,17 +367,18 @@ extension OpenAIChatCompletionsService.StreamDataChunk { message: { if let choice = self.choices?.first { return .init( - role: choice.delta?.role, + role: choice.delta?.role?.formalized, content: choice.delta?.content, toolCalls: { if let toolCalls = choice.delta?.tool_calls { return toolCalls.map { .init( + index: $0.index, id: $0.id, type: $0.type, function: .init( - name: $0.function.name, - arguments: $0.function.arguments + name: $0.function?.name, + arguments: $0.function?.arguments ) ) } @@ -355,6 +387,7 @@ extension OpenAIChatCompletionsService.StreamDataChunk { if let functionCall = choice.delta?.function_call { return [ .init( + index: 0, id: functionCall.name, type: "function", function: .init( @@ -379,12 +412,23 @@ extension OpenAIChatCompletionsService.StreamDataChunk { extension OpenAIChatCompletionsService.RequestBody { init(_ body: ChatCompletionsRequestBody) { model = body.model - messages = body.messages.map { + messages = body.messages.map { message in .init( - role: $0.role, - content: $0.content, - name: $0.name, - tool_calls: $0.toolCalls?.map { tool in + role: { + switch message.role { + case .user: + return .user + case .assistant: + return .assistant + case .system: + return .system + case .tool: + return .tool + } + }(), + content: message.content, + name: message.name, + tool_calls: message.toolCalls?.map { tool in MessageToolCall( id: tool.id, type: tool.type, @@ -394,7 +438,7 @@ extension OpenAIChatCompletionsService.RequestBody { ) ) }, - tool_call_id: $0.toolCallId + tool_call_id: message.toolCallId ) } temperature = body.temperature diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index 3b1c4f57..c3fb10a9 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -169,7 +169,7 @@ public class ChatGPTService: ChatGPTServiceType { role: .user, content: content, name: nil, - toolCalls: nil, + toolCallContext: nil, summary: summary, references: [] ) @@ -180,8 +180,8 @@ public class ChatGPTService: ChatGPTServiceType { AsyncThrowingStream { continuation in let task = Task(priority: .userInitiated) { do { - var pendingToolCalls = IdentifiedArrayOf() - var functionCallMessageIDs = [String: String]() + var pendingToolCalls = [ChatMessage.ToolCall]() + var sourceMessageId = "" var isInitialCall = true loop: while !pendingToolCalls.isEmpty || isInitialCall { try Task.checkCancellation() @@ -192,10 +192,12 @@ public class ChatGPTService: ChatGPTServiceType { } await runFunctionCall( toolCall, - messageId: functionCallMessageIDs[toolCall.id] + sourceMessageId: sourceMessageId ) } - let stream = try await sendMemory() + sourceMessageId = uuid() + .uuidString + String(date().timeIntervalSince1970) + let stream = try await sendMemory(proposedId: sourceMessageId) #if DEBUG var reply = "" @@ -211,15 +213,17 @@ public class ChatGPTService: ChatGPTServiceType { #endif case let .toolCall(toolCall): - let id = storeToolCallsChunks( - chunk: toolCall, - into: &pendingToolCalls, - messageIds: &functionCallMessageIDs + await prepareFunctionCall( + toolCall, + sourceMessageId: sourceMessageId ) - - await prepareFunctionCall(toolCall, messageId: id) } } + + pendingToolCalls = await memory.history + .last { $0.id == sourceMessageId }? + .toolCallContext?.toolCalls ?? [] + #if DEBUG Debugger.didReceiveResponse(content: reply) #endif @@ -257,19 +261,19 @@ public class ChatGPTService: ChatGPTServiceType { return try await Debugger.$id.withValue(.init()) { let message = try await sendMemoryAndWait() var finalResult = message?.content - var toolCalls = message?.toolCalls - while let calls = toolCalls, !calls.isEmpty { + var toolCalls = message?.toolCallContext?.toolCalls + while let sourceMessageId = message?.id, let calls = toolCalls, !calls.isEmpty { try Task.checkCancellation() if !configuration.runFunctionsAutomatically { break } toolCalls = nil for call in calls { - await runFunctionCall(call) + await runFunctionCall(call, sourceMessageId: sourceMessageId) } guard let nextMessage = try await sendMemoryAndWait() else { break } finalResult = nextMessage.content - toolCalls = nextMessage.toolCalls + toolCalls = nextMessage.toolCallContext?.toolCalls } #if DEBUG @@ -297,7 +301,7 @@ extension ChatGPTService { } /// Send the memory as prompt to ChatGPT, with stream enabled. - func sendMemory() async throws -> AsyncThrowingStream { + func sendMemory(proposedId: String) async throws -> AsyncThrowingStream { let prompt = await memory.generatePrompt() guard let model = configuration.model else { @@ -321,8 +325,6 @@ extension ChatGPTService { Debugger.didSendRequestBody(body: requestBody) #endif - let proposedId = uuid().uuidString + String(date().timeIntervalSince1970) - return AsyncThrowingStream { continuation in let task = Task { do { @@ -341,26 +343,27 @@ extension ChatGPTService { // The first round will contain the function name and an empty argument. // e.g. {"name":"weather","arguments":""} // The other rounds will contain part of the arguments. - let toolCalls = delta.toolCalls?.map { - ChatMessage.ToolCall( - id: $0.id ?? "", - type: $0.type ?? "", - function: .init( - name: $0.function?.name ?? "", - arguments: $0.function?.arguments ?? "" + let toolCalls = delta.toolCalls? + .reduce(into: [Int: ChatMessage.ToolCall]()) { + $0[$1.index ?? 0] = ChatMessage.ToolCall( + id: $1.id ?? "", + type: $1.type ?? "", + function: .init( + name: $1.function?.name ?? "", + arguments: $1.function?.arguments ?? "" + ) ) - ) - } + } await memory.streamMessage( id: proposedId, - role: delta.role, + role: delta.role?.asChatMessageRole, content: delta.content, toolCalls: toolCalls ) if let toolCalls { - for toolCall in toolCalls { + for toolCall in toolCalls.values { continuation.yield(.toolCall(toolCall)) } } @@ -425,14 +428,23 @@ extension ChatGPTService { let choice = response.message let message = ChatMessage( id: proposedId, - role: choice.role, + role: { + switch choice.role { + case .system: .system + case .user: .user + case .assistant: .assistant + case .tool: .user + } + }(), content: choice.content, name: choice.name, - toolCalls: choice.toolCalls?.map { - ChatMessage.ToolCall(id: $0.id, type: $0.type, function: .init( - name: $0.function.name, - arguments: $0.function.arguments ?? "" - )) + toolCallContext: choice.toolCalls.map { + .init(toolCalls: $0.map { + ChatMessage.ToolCall(id: $0.id, type: $0.type, function: .init( + name: $0.function.name, + arguments: $0.function.arguments ?? "" + )) + }, responses: []) }, references: prompt.references ) @@ -440,44 +452,17 @@ extension ChatGPTService { return message } - func storeToolCallsChunks( - chunk toolCall: ChatMessage.ToolCall, - into toolCalls: inout IdentifiedArrayOf, - messageIds: inout [String: String] - ) -> String { - if let index = toolCalls.firstIndex(where: { $0.id == toolCall.id }) { - if !toolCall.id.isEmpty { - toolCalls[index].id = toolCall.id - } - if !toolCall.type.isEmpty { - toolCalls[index].type = toolCall.type - } - toolCalls[index].function.name.append(toolCall.function.name) - toolCalls[index].function.arguments.append(toolCall.function.arguments) - - } else { - toolCalls.append(toolCall) - } - - let id = messageIds[toolCall.id] ?? UUID().uuidString - messageIds[toolCall.id] = id - return id - } - /// When a function call is detected, but arguments are not yet ready, we can call this /// to insert a message placeholder in memory. - func prepareFunctionCall(_ call: ChatMessage.ToolCall, messageId: String) async { + func prepareFunctionCall(_ call: ChatMessage.ToolCall, sourceMessageId: String) async { guard let function = functionProvider.function(named: call.function.name) else { return } - await memory.streamMessage( - id: messageId, - role: .tool, - name: call.function.name, - toolCallId: call.id - ) + await memory.streamToolCallResponse(id: sourceMessageId, toolCallId: call.id) await function.prepare { [weak self] summary in - await self?.memory.updateMessage(id: messageId) { message in - message.summary = summary - } + await self?.memory.streamToolCallResponse( + id: sourceMessageId, + toolCallId: call.id, + summary: summary + ) } } @@ -485,22 +470,18 @@ extension ChatGPTService { @discardableResult func runFunctionCall( _ call: ChatMessage.ToolCall, - messageId: String? = nil + sourceMessageId: String ) async -> String { #if DEBUG Debugger.didReceiveFunction(name: call.function.name, arguments: call.function.arguments) #endif - let messageId = messageId ?? uuid().uuidString - guard let function = functionProvider.function(named: call.function.name) else { - return await fallbackFunctionCall(call.function, messageId: messageId) + return await fallbackFunctionCall(call, sourceMessageId: sourceMessageId) } - await memory.streamMessage( - id: messageId, - role: .tool, - name: call.function.name, + await memory.streamToolCallResponse( + id: sourceMessageId, toolCallId: call.id ) @@ -508,18 +489,22 @@ extension ChatGPTService { // Run the function let result = try await function.call(argumentsJsonString: call.function.arguments) { [weak self] summary in - await self?.memory.updateMessage(id: messageId) { message in - message.summary = summary - } + await self?.memory.streamToolCallResponse( + id: sourceMessageId, + toolCallId: call.id, + summary: summary + ) } #if DEBUG Debugger.didReceiveFunctionResult(result: result.botReadableContent) #endif - await memory.updateMessage(id: messageId) { message in - message.content = result.botReadableContent - } + await memory.streamToolCallResponse( + id: sourceMessageId, + toolCallId: call.id, + content: result.botReadableContent + ) return result.botReadableContent } catch { @@ -530,20 +515,22 @@ extension ChatGPTService { Debugger.didReceiveFunctionResult(result: content) #endif - await memory.updateMessage(id: messageId) { message in - message.content = content - } + await memory.streamToolCallResponse( + id: sourceMessageId, + toolCallId: call.id, + content: content + ) return content } } /// Mock a function call result when the bot is calling a function that is not implemented. func fallbackFunctionCall( - _ call: ChatMessage.FunctionCall, - messageId: String + _ call: ChatMessage.ToolCall, + sourceMessageId: String ) async -> String { let memory = ConversationChatGPTMemory(systemPrompt: { - if call.name == "python" { + if call.function.name == "python" { return """ Act like a Python interpreter. I will give you Python code and you will execute it. @@ -551,7 +538,7 @@ extension ChatGPTService { """ } else { return """ - You are a function simulator. Your name is \(call.name). + You are a function simulator. Your name is \(call.function.name). Act like a function. I will send you the arguments. Reply with output of the function and tell me it's an answer generated by LLM. @@ -570,17 +557,16 @@ extension ChatGPTService { let content: String = await { do { return try await service.sendAndWait(content: """ - \(call.arguments) + \(call.function.arguments) """) ?? "No result." } catch { return "No result." } }() - await memory.streamMessage( - id: messageId, - role: .tool, + await memory.streamToolCallResponse( + id: sourceMessageId, + toolCallId: call.id, content: content, - name: call.name, summary: "Finished running function." ) return content @@ -591,12 +577,19 @@ extension ChatGPTService { model: ChatModel, stream: Bool ) -> ChatCompletionsRequestBody { - let messages = prompt.history.map { - ChatCompletionsRequestBody.Message( - role: $0.role, - content: $0.content ?? "", - name: $0.name, - toolCalls: $0.toolCalls?.map { + let messages = prompt.history.flatMap { chatMessage in + var all = [ChatCompletionsRequestBody.Message]() + all.append(ChatCompletionsRequestBody.Message( + role: { + switch chatMessage.role { + case .system: .system + case .user: .user + case .assistant: .assistant + } + }(), + content: chatMessage.content ?? "", + name: chatMessage.name, + toolCalls: chatMessage.toolCallContext?.toolCalls.map { .init( id: $0.id, type: $0.type, @@ -606,8 +599,21 @@ extension ChatGPTService { ) ) } - ) + )) + + if let responses = chatMessage.toolCallContext?.responses { + for response in responses { + all.append(ChatCompletionsRequestBody.Message( + role: .tool, + content: response.content, + toolCallId: response.id + )) + } + } + + return all } + let remainingTokens = prompt.remainingTokenCount let requestBody = ChatCompletionsRequestBody( diff --git a/Tool/Sources/OpenAIService/FucntionCall/ChatGPTFunction.swift b/Tool/Sources/OpenAIService/FucntionCall/ChatGPTFunction.swift index 420fc180..d2d8aaad 100644 --- a/Tool/Sources/OpenAIService/FucntionCall/ChatGPTFunction.swift +++ b/Tool/Sources/OpenAIService/FucntionCall/ChatGPTFunction.swift @@ -43,9 +43,14 @@ public extension ChatGPTFunction { argumentsJsonString: String, reportProgress: @escaping ReportProgress ) async throws -> Result { - let arguments = try JSONDecoder() - .decode(Arguments.self, from: argumentsJsonString.data(using: .utf8) ?? Data()) - return try await call(arguments: arguments, reportProgress: reportProgress) + do { + let arguments = try JSONDecoder() + .decode(Arguments.self, from: argumentsJsonString.data(using: .utf8) ?? Data()) + return try await call(arguments: arguments, reportProgress: reportProgress) + } catch { + await reportProgress("Error: Failed to decode arguments. \(error.localizedDescription)") + throw error + } } } diff --git a/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift b/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift index 0ed9873d..7fb4c323 100644 --- a/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift +++ b/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift @@ -37,13 +37,19 @@ extension TokenEncoder { encodingContent.append(name) total += 1 } - if let toolCalls = message.toolCalls { - for toolCall in toolCalls { + if let toolCallContext = message.toolCallContext { + for toolCall in toolCallContext.toolCalls { encodingContent.append(toolCall.id) encodingContent.append(toolCall.type) encodingContent.append(toolCall.function.name) encodingContent.append(toolCall.function.arguments) } + + for response in toolCallContext.responses { + total += 4 + encodingContent.append(response.content) + encodingContent.append(response.id) + } } total += await withTaskGroup(of: Int.self, body: { group in for content in encodingContent { diff --git a/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift b/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift index 62b8f369..409b0ff8 100644 --- a/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift +++ b/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift @@ -58,73 +58,109 @@ public extension ChatGPTMemory { } } + func streamToolCallResponse( + id: String, + toolCallId: String, + content: String? = nil, + summary: String? = nil + ) async { + await updateMessage(id: id) { message in + if let index = message.toolCallContext?.responses.firstIndex(where: { + $0.id == toolCallId + }) { + if let content { + message.toolCallContext?.responses[index].content = content + } + if let summary { + message.toolCallContext?.responses[index].summary = summary + } + } else { + message.toolCallContext?.responses.append(.init( + id: toolCallId, + content: content ?? "", + summary: summary ?? "" + )) + } + } + } + /// Stream a message to the history. func streamMessage( id: String, role: ChatMessage.Role? = nil, content: String? = nil, name: String? = nil, - toolCallId: String? = nil, - toolCalls: [ChatMessage.ToolCall]? = nil, + toolCalls: [Int: ChatMessage.ToolCall]? = nil, summary: String? = nil, references: [ChatMessage.Reference]? = nil ) async { - await mutateHistory { history in - if let index = history.firstIndex(where: { $0.id == id }) { + if await history.contains(where: { $0.id == id }) { + await updateMessage(id: id) { message in if let content { - if history[index].content == nil { - history[index].content = content + if message.content == nil { + message.content = content } else { - history[index].content?.append(content) + message.content?.append(content) } } if let role { - history[index].role = role + message.role = role } if let toolCalls { - if history[index].toolCalls == nil { - history[index].toolCalls = toolCalls - } else { - for toolCall in toolCalls { - if let index = history[index].toolCalls? - .firstIndex(where: { $0.id == toolCall.id }) - { + if var existedToolCalls = message.toolCallContext?.toolCalls { + for pair in toolCalls.sorted(by: { $0.key <= $1.key }) { + let (proposedIndex, toolCall) = pair + let index = { + if toolCall.id.isEmpty { return proposedIndex } + return existedToolCalls.lastIndex(where: { $0.id == toolCall.id }) + ?? proposedIndex + }() + if index < existedToolCalls.endIndex { if !toolCall.id.isEmpty { - history[index].toolCalls?[index].id = toolCall.id + existedToolCalls[index].id = toolCall.id } if !toolCall.type.isEmpty { - history[index].toolCalls?[index].type = toolCall.type + existedToolCalls[index].type = toolCall.type } - history[index].toolCalls?[index].function.name + existedToolCalls[index].function.name .append(toolCall.function.name) - history[index].toolCalls?[index].function.arguments + existedToolCalls[index].function.arguments .append(toolCall.function.arguments) } else { - history[index].toolCalls?.append(toolCall) + existedToolCalls.append(toolCall) } } + message.toolCallContext?.toolCalls = existedToolCalls + } else { + message.toolCallContext = .init( + toolCalls: toolCalls.sorted(by: { $0.key <= $1.key }).map(\.value), + responses: [] + ) } } if let summary { - history[index].summary = summary + message.summary = summary } if let references { - history[index].references.append(contentsOf: references) + message.references.append(contentsOf: references) } if let name { - history[index].name = name - } - if let toolCallId { - history[index].toolCallId = toolCallId + message.name = name } - } else { + } + } else { + await mutateHistory { history in history.append(.init( id: id, role: role ?? .system, content: content, name: name, - toolCallId: toolCallId, - toolCalls: toolCalls, + toolCallContext: toolCalls.map { calls in + .init( + toolCalls: calls.sorted(by: { $0.key <= $1.key }).map(\.value), + responses: [] + ) + }, summary: summary, references: references ?? [] )) diff --git a/Tool/Sources/OpenAIService/Models.swift b/Tool/Sources/OpenAIService/Models.swift index 02901672..1412850e 100644 --- a/Tool/Sources/OpenAIService/Models.swift +++ b/Tool/Sources/OpenAIService/Models.swift @@ -15,7 +15,6 @@ public struct ChatMessage: Equatable, Codable { case system case user case assistant - case tool } public struct FunctionCall: Codable, Equatable { @@ -37,6 +36,22 @@ public struct ChatMessage: Equatable, Codable { self.function = function } } + + public struct ToolCallResponse: Codable, Equatable { + public var id: String + public var content: String + public var summary: String? + public init(id: String, content: String, summary: String?) { + self.id = id + self.content = content + self.summary = summary + } + } + + public struct ToolCallContext: Codable, Equatable { + public var toolCalls: [ToolCall] + public var responses: [ToolCallResponse] + } public struct Reference: Codable, Equatable { public enum Kind: String, Codable { @@ -85,6 +100,7 @@ public struct ChatMessage: Equatable, Codable { } /// The role of a message. + @FallbackDecoding public var role: Role /// The content of the message, either the chat message, or a result of a function call. @@ -93,7 +109,7 @@ public struct ChatMessage: Equatable, Codable { } /// A function call from the bot. - public var toolCalls: [ToolCall]? { + public var toolCallContext: ToolCallContext? { didSet { tokensCount = nil } } @@ -101,11 +117,6 @@ public struct ChatMessage: Equatable, Codable { public var name: String? { didSet { tokensCount = nil } } - - /// The tool id of a reply to a tool call. - public var toolCallId: String? { - didSet { tokensCount = nil } - } /// The summary of a message that is used for display. public var summary: String? @@ -123,7 +134,7 @@ public struct ChatMessage: Equatable, Codable { /// Is the message considered empty. var isEmpty: Bool { if let content, !content.isEmpty { return false } - if let toolCalls, !toolCalls.isEmpty { return false } + if let toolCallContext, !toolCallContext.toolCalls.isEmpty { return false } if let name, !name.isEmpty { return false } return true } @@ -133,8 +144,7 @@ public struct ChatMessage: Equatable, Codable { role: Role, content: String?, name: String? = nil, - toolCallId: String? = nil, - toolCalls: [ToolCall]? = nil, + toolCallContext: ToolCallContext? = nil, summary: String? = nil, tokenCount: Int? = nil, references: [Reference] = [] @@ -142,8 +152,7 @@ public struct ChatMessage: Equatable, Codable { self.role = role self.content = content self.name = name - self.toolCallId = toolCallId - self.toolCalls = toolCalls + self.toolCallContext = toolCallContext self.summary = summary self.id = id tokensCount = tokenCount @@ -155,3 +164,6 @@ public struct ReferenceKindFallback: FallbackValueProvider { public static var defaultValue: ChatMessage.Reference.Kind { .other } } +public struct ChatMessageRoleFallback: FallbackValueProvider { + public static var defaultValue: ChatMessage.Role { .user } +} From ca9e9816175be772b89c6500245a634fc4e78d4c Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sun, 3 Mar 2024 21:41:08 +0800 Subject: [PATCH 20/37] Use ResponseStream to handle stream responses --- .../APIs/OpenAIChatCompletionsService.swift | 45 ++++++++----------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift index 48d77a17..77d7cfac 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift @@ -218,37 +218,28 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI throw error ?? ChatGPTServiceError.responseInvalid } - let stream = AsyncThrowingStream { continuation in - let task = Task { - do { - for try await line in result.lines { - if Task.isCancelled { break } - let prefix = "data: " - guard line.hasPrefix(prefix), - let content = line.dropFirst(prefix.count).data(using: .utf8) - else { continue } - do { - let chunk = try JSONDecoder().decode( - StreamDataChunk.self, - from: content - ) - continuation.yield(chunk.formalized()) - } catch { - Logger.service.error("Error decoding stream data: \(error)") - } - } - continuation.finish() - } catch { - continuation.finish(throwing: error) - } + let stream = ResponseStream(result: result) { + var line = $0 + let prefix = "data: " + if line.hasPrefix(prefix) { + line.removeFirst(prefix.count) } - continuation.onTermination = { _ in - task.cancel() - result.task.cancel() + + if line == "[DONE]" { return .init(chunk: nil, done: true) } + + do { + let chunk = try JSONDecoder().decode( + StreamDataChunk.self, + from: line.data(using: .utf8) ?? Data() + ) + return .init(chunk: chunk, done: false) + } catch { + Logger.service.error("Error decoding stream data: \(error)") + throw error } } - return stream + return stream.map { $0.formalized() }.toStream() } func callAsFunction() async throws -> ChatCompletionResponseBody { From 21d7ae4ce481e01f23954d1cadbe806c64d00548 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sun, 3 Mar 2024 22:00:32 +0800 Subject: [PATCH 21/37] Put tool call responses into the tool call struct --- .../WebChatContextCollector.swift | 3 +- Core/Sources/ChatGPTChatTab/Chat.swift | 16 ++++----- Core/Sources/ChatService/ChatService.swift | 2 +- Pro | 2 +- Tool/Sources/LangChain/Chains/LLMChain.swift | 2 +- .../Chains/RefineDocumentChain.swift | 2 +- .../RelevantInformationExtractionChain.swift | 2 +- .../StructuredOutputChatModelChain.swift | 2 +- .../APIs/GoogleAIChatCompletionsService.swift | 8 ++--- .../OpenAIService/ChatGPTService.swift | 36 +++++++++---------- ...toManagedChatGPTMemoryOpenAIStrategy.swift | 11 +++--- .../OpenAIService/Memory/ChatGPTMemory.swift | 28 ++++----------- Tool/Sources/OpenAIService/Models.swift | 15 ++++---- 13 files changed, 51 insertions(+), 78 deletions(-) diff --git a/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift b/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift index c35a03ae..851fdcf7 100644 --- a/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift +++ b/Core/Sources/ChatContextCollectors/WebChatContextCollector/WebChatContextCollector.swift @@ -32,8 +32,7 @@ extension WebChatContextCollector { static func detectLinks(from messages: [ChatMessage]) -> [String] { return messages.lazy .compactMap { - $0.content ?? $0.toolCallContext?.toolCalls.map(\.function.arguments) - .joined(separator: " ") ?? "" + $0.content ?? $0.toolCalls?.map(\.function.arguments).joined(separator: " ") ?? "" } .map(detectLinks(from:)) .flatMap { $0 } diff --git a/Core/Sources/ChatGPTChatTab/Chat.swift b/Core/Sources/ChatGPTChatTab/Chat.swift index 05b50642..7a1ccc78 100644 --- a/Core/Sources/ChatGPTChatTab/Chat.swift +++ b/Core/Sources/ChatGPTChatTab/Chat.swift @@ -327,15 +327,13 @@ struct Chat: ReducerProtocol { } )) - if let responses = message.toolCallContext?.responses { - for response in responses { - all.append(.init( - id: message.id + response.id, - role: .tool, - text: response.summary ?? response.content, - references: [] - )) - } + for call in message.toolCalls ?? [] { + all.append(.init( + id: message.id + call.response.id, + role: .tool, + text: call.response.summary ?? call.response.content, + references: [] + )) } return all diff --git a/Core/Sources/ChatService/ChatService.swift b/Core/Sources/ChatService/ChatService.swift index 145473da..4bb74639 100644 --- a/Core/Sources/ChatService/ChatService.swift +++ b/Core/Sources/ChatService/ChatService.swift @@ -126,7 +126,7 @@ public final class ChatService: ObservableObject { // if it's stopped before the tool calls finish, remove the message. await memory.mutateHistory { history in - if history.last?.role == .assistant, history.last?.toolCallContext?.toolCalls != nil { + if history.last?.role == .assistant, history.last?.toolCalls != nil { history.removeLast() } } diff --git a/Pro b/Pro index a2e8aa56..ede561a8 160000 --- a/Pro +++ b/Pro @@ -1 +1 @@ -Subproject commit a2e8aa56ff9b30bb3f3ae50a01b183cfbabb14cb +Subproject commit ede561a8f50276d915ee672e8fad59226d349b08 diff --git a/Tool/Sources/LangChain/Chains/LLMChain.swift b/Tool/Sources/LangChain/Chains/LLMChain.swift index fd8ef05d..2ba4aef4 100644 --- a/Tool/Sources/LangChain/Chains/LLMChain.swift +++ b/Tool/Sources/LangChain/Chains/LLMChain.swift @@ -33,7 +33,7 @@ public class ChatModelChain: Chain { public func parseOutput(_ output: Output) -> String { if let content = output.content { return content - } else if let toolCalls = output.toolCallContext?.toolCalls { + } else if let toolCalls = output.toolCalls { return toolCalls.map { "[\($0.id)] \($0.function.name): \($0.function.arguments)" } .joined(separator: "\n") } diff --git a/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift b/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift index bbf0f764..3b24e6ad 100644 --- a/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift +++ b/Tool/Sources/LangChain/Chains/RefineDocumentChain.swift @@ -153,7 +153,7 @@ public final class RefineDocumentChain: Chain { } func extractAnswer(_ chatMessage: ChatMessage) -> IntermediateAnswer { - for functionCall in chatMessage.toolCallContext?.toolCalls.map(\.function) ?? [] { + for functionCall in chatMessage.toolCalls?.map(\.function) ?? [] { do { let intermediateAnswer = try JSONDecoder().decode( IntermediateAnswer.self, diff --git a/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift b/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift index 445c75ee..4c9f696a 100644 --- a/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift +++ b/Tool/Sources/LangChain/Chains/RelevantInformationExtractionChain.swift @@ -104,7 +104,7 @@ public final class RelevantInformationExtractionChain: Chain { callbackManagers: callbackManagers ) - if let functionCall = output.toolCallContext?.toolCalls + if let functionCall = output.toolCalls? .first(where: { $0.function.name == FinalAnswer().name })?.function { do { diff --git a/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift b/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift index 103f3244..6ea1dbb5 100644 --- a/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift +++ b/Tool/Sources/LangChain/Chains/StructuredOutputChatModelChain.swift @@ -108,7 +108,7 @@ public class StructuredOutputChatModelChain: Chain { } public func parseOutput(_ message: ChatMessage) async -> Output? { - if let functionCall = message.toolCallContext?.toolCalls.first?.function { + if let functionCall = message.toolCalls?.first?.function { do { let result = try JSONDecoder().decode( EndFunction.Arguments.self, diff --git a/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift index 1152f69a..80a5c9b8 100644 --- a/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift +++ b/Tool/Sources/OpenAIService/APIs/GoogleAIChatCompletionsService.swift @@ -208,13 +208,13 @@ extension ModelContent { case .user: return message.content ?? " " case .assistant: - if let toolCallContext = message.toolCallContext { - return toolCallContext.toolCalls.map { call in - let response = toolCallContext.responses.first(where: { $0.id == call.id }) + if let toolCalls = message.toolCalls { + return toolCalls.map { call in + let response = call.response return """ Call function: \(call.function.name) Arguments: \(call.function.arguments) - Result: \(response?.content ?? "N/A") + Result: \(response.content) """ }.joined(separator: "\n") } else { diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index c3fb10a9..83e2cc13 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -169,7 +169,7 @@ public class ChatGPTService: ChatGPTServiceType { role: .user, content: content, name: nil, - toolCallContext: nil, + toolCalls: nil, summary: summary, references: [] ) @@ -222,7 +222,7 @@ public class ChatGPTService: ChatGPTServiceType { pendingToolCalls = await memory.history .last { $0.id == sourceMessageId }? - .toolCallContext?.toolCalls ?? [] + .toolCalls ?? [] #if DEBUG Debugger.didReceiveResponse(content: reply) @@ -261,7 +261,7 @@ public class ChatGPTService: ChatGPTServiceType { return try await Debugger.$id.withValue(.init()) { let message = try await sendMemoryAndWait() var finalResult = message?.content - var toolCalls = message?.toolCallContext?.toolCalls + var toolCalls = message?.toolCalls while let sourceMessageId = message?.id, let calls = toolCalls, !calls.isEmpty { try Task.checkCancellation() if !configuration.runFunctionsAutomatically { @@ -273,7 +273,7 @@ public class ChatGPTService: ChatGPTServiceType { } guard let nextMessage = try await sendMemoryAndWait() else { break } finalResult = nextMessage.content - toolCalls = nextMessage.toolCallContext?.toolCalls + toolCalls = nextMessage.toolCalls } #if DEBUG @@ -438,13 +438,11 @@ extension ChatGPTService { }(), content: choice.content, name: choice.name, - toolCallContext: choice.toolCalls.map { - .init(toolCalls: $0.map { - ChatMessage.ToolCall(id: $0.id, type: $0.type, function: .init( - name: $0.function.name, - arguments: $0.function.arguments ?? "" - )) - }, responses: []) + toolCalls: choice.toolCalls?.map { + ChatMessage.ToolCall(id: $0.id, type: $0.type, function: .init( + name: $0.function.name, + arguments: $0.function.arguments ?? "" + )) }, references: prompt.references ) @@ -589,7 +587,7 @@ extension ChatGPTService { }(), content: chatMessage.content ?? "", name: chatMessage.name, - toolCalls: chatMessage.toolCallContext?.toolCalls.map { + toolCalls: chatMessage.toolCalls?.map { .init( id: $0.id, type: $0.type, @@ -601,14 +599,12 @@ extension ChatGPTService { } )) - if let responses = chatMessage.toolCallContext?.responses { - for response in responses { - all.append(ChatCompletionsRequestBody.Message( - role: .tool, - content: response.content, - toolCallId: response.id - )) - } + for call in chatMessage.toolCalls ?? [] { + all.append(ChatCompletionsRequestBody.Message( + role: .tool, + content: call.response.content, + toolCallId: call.response.id + )) } return all diff --git a/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift b/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift index 7fb4c323..be68b9b4 100644 --- a/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift +++ b/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift @@ -37,18 +37,15 @@ extension TokenEncoder { encodingContent.append(name) total += 1 } - if let toolCallContext = message.toolCallContext { - for toolCall in toolCallContext.toolCalls { + if let toolCalls = message.toolCalls { + for toolCall in toolCalls { encodingContent.append(toolCall.id) encodingContent.append(toolCall.type) encodingContent.append(toolCall.function.name) encodingContent.append(toolCall.function.arguments) - } - - for response in toolCallContext.responses { total += 4 - encodingContent.append(response.content) - encodingContent.append(response.id) + encodingContent.append(toolCall.response.content) + encodingContent.append(toolCall.response.id) } } total += await withTaskGroup(of: Int.self, body: { group in diff --git a/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift b/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift index 409b0ff8..33300ee8 100644 --- a/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift +++ b/Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift @@ -65,21 +65,15 @@ public extension ChatGPTMemory { summary: String? = nil ) async { await updateMessage(id: id) { message in - if let index = message.toolCallContext?.responses.firstIndex(where: { + if let index = message.toolCalls?.firstIndex(where: { $0.id == toolCallId }) { if let content { - message.toolCallContext?.responses[index].content = content + message.toolCalls?[index].response.content = content } if let summary { - message.toolCallContext?.responses[index].summary = summary + message.toolCalls?[index].response.summary = summary } - } else { - message.toolCallContext?.responses.append(.init( - id: toolCallId, - content: content ?? "", - summary: summary ?? "" - )) } } } @@ -107,7 +101,7 @@ public extension ChatGPTMemory { message.role = role } if let toolCalls { - if var existedToolCalls = message.toolCallContext?.toolCalls { + if var existedToolCalls = message.toolCalls { for pair in toolCalls.sorted(by: { $0.key <= $1.key }) { let (proposedIndex, toolCall) = pair let index = { @@ -130,12 +124,9 @@ public extension ChatGPTMemory { existedToolCalls.append(toolCall) } } - message.toolCallContext?.toolCalls = existedToolCalls + message.toolCalls = existedToolCalls } else { - message.toolCallContext = .init( - toolCalls: toolCalls.sorted(by: { $0.key <= $1.key }).map(\.value), - responses: [] - ) + message.toolCalls = toolCalls.sorted(by: { $0.key <= $1.key }).map(\.value) } } if let summary { @@ -155,12 +146,7 @@ public extension ChatGPTMemory { role: role ?? .system, content: content, name: name, - toolCallContext: toolCalls.map { calls in - .init( - toolCalls: calls.sorted(by: { $0.key <= $1.key }).map(\.value), - responses: [] - ) - }, + toolCalls: toolCalls?.sorted(by: { $0.key <= $1.key }).map(\.value), summary: summary, references: references ?? [] )) diff --git a/Tool/Sources/OpenAIService/Models.swift b/Tool/Sources/OpenAIService/Models.swift index 1412850e..3e606e50 100644 --- a/Tool/Sources/OpenAIService/Models.swift +++ b/Tool/Sources/OpenAIService/Models.swift @@ -30,10 +30,12 @@ public struct ChatMessage: Equatable, Codable { public var id: String public var type: String public var function: FunctionCall + public var response: ToolCallResponse public init(id: String, type: String, function: FunctionCall) { self.id = id self.type = type self.function = function + response = .init(id: id, content: "", summary: nil) } } @@ -47,11 +49,6 @@ public struct ChatMessage: Equatable, Codable { self.summary = summary } } - - public struct ToolCallContext: Codable, Equatable { - public var toolCalls: [ToolCall] - public var responses: [ToolCallResponse] - } public struct Reference: Codable, Equatable { public enum Kind: String, Codable { @@ -109,7 +106,7 @@ public struct ChatMessage: Equatable, Codable { } /// A function call from the bot. - public var toolCallContext: ToolCallContext? { + public var toolCalls: [ToolCall]? { didSet { tokensCount = nil } } @@ -134,7 +131,7 @@ public struct ChatMessage: Equatable, Codable { /// Is the message considered empty. var isEmpty: Bool { if let content, !content.isEmpty { return false } - if let toolCallContext, !toolCallContext.toolCalls.isEmpty { return false } + if let toolCalls, !toolCalls.isEmpty { return false } if let name, !name.isEmpty { return false } return true } @@ -144,7 +141,7 @@ public struct ChatMessage: Equatable, Codable { role: Role, content: String?, name: String? = nil, - toolCallContext: ToolCallContext? = nil, + toolCalls: [ToolCall]? = nil, summary: String? = nil, tokenCount: Int? = nil, references: [Reference] = [] @@ -152,7 +149,7 @@ public struct ChatMessage: Equatable, Codable { self.role = role self.content = content self.name = name - self.toolCallContext = toolCallContext + self.toolCalls = toolCalls self.summary = summary self.id = id tokensCount = tokenCount From 143a92f1047ddff2466fd9ac81558aa45ab08567 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sun, 3 Mar 2024 22:20:23 +0800 Subject: [PATCH 22/37] Parse Mistral.AI errors --- .../APIs/OpenAIChatCompletionsService.swift | 58 +++++++++++++++++-- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift index 77d7cfac..49a04727 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift @@ -6,17 +6,65 @@ import Preferences /// https://platform.openai.com/docs/api-reference/chat/create actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI { - struct CompletionAPIError: Error, Codable, LocalizedError { - struct E: Codable { + struct CompletionAPIError: Error, Decodable, LocalizedError { + struct ErrorDetail: Decodable { var message: String var type: String var param: String var code: String } - var error: E + struct MistralAIErrorMessage: Decodable { + struct Detail: Decodable { + var msg: String? + } + + var message: String? + var msg: String? + var detail: [Detail]? + } + + enum Message { + case raw(String) + case mistralAI(MistralAIErrorMessage) + } + + var error: ErrorDetail? + var message: Message + + var errorDescription: String? { + if let message = error?.message { return message } + switch message { + case let .raw(string): + return string + case let .mistralAI(mistralAIErrorMessage): + return mistralAIErrorMessage.message + ?? mistralAIErrorMessage.msg + ?? mistralAIErrorMessage.detail?.first?.msg + ?? "Unknown Error" + } + } + + enum CodingKeys: String, CodingKey { + case error + case message + } + + init(from decoder: Decoder) throws { + let container: KeyedDecodingContainer = try decoder + .container(keyedBy: CodingKeys.self) - var errorDescription: String? { error.message } + error = try? container.decode(ErrorDetail.self, forKey: .error) + message = { + if let e = try? container.decode(MistralAIErrorMessage.self, forKey: .message) { + return CompletionAPIError.Message.mistralAI(e) + } + if let e = try? container.decode(String.self, forKey: .message) { + return .raw(e) + } + return .raw("Unknown Error") + }() + } } enum MessageRole: String, Codable { @@ -214,7 +262,7 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI guard let data = text.data(using: .utf8) else { throw ChatGPTServiceError.responseInvalid } let decoder = JSONDecoder() - let error = try? decoder.decode(ChatGPTError.self, from: data) + let error = try? decoder.decode(CompletionAPIError.self, from: data) throw error ?? ChatGPTServiceError.responseInvalid } From ef396be97d18f0914966c21b8900e62f3748e94f Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sun, 3 Mar 2024 23:22:31 +0800 Subject: [PATCH 23/37] Remove useless fields --- Tool/Sources/AIModel/ChatModel.swift | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Tool/Sources/AIModel/ChatModel.swift b/Tool/Sources/AIModel/ChatModel.swift index f3aea30b..325592a7 100644 --- a/Tool/Sources/AIModel/ChatModel.swift +++ b/Tool/Sources/AIModel/ChatModel.swift @@ -53,8 +53,6 @@ public struct ChatModel: Codable, Equatable, Identifiable { public var maxTokens: Int @FallbackDecoding public var supportsFunctionCalling: Bool - @FallbackDecoding - public var supportsOpenAIAPI2023_11: Bool @FallbackDecoding public var modelName: String @@ -69,7 +67,6 @@ public struct ChatModel: Codable, Equatable, Identifiable { isFullURL: Bool = false, maxTokens: Int = 4000, supportsFunctionCalling: Bool = true, - supportsOpenAIAPI2023_11: Bool = false, modelName: String = "", openAIInfo: OpenAIInfo = OpenAIInfo(), ollamaInfo: OllamaInfo = OllamaInfo() @@ -79,7 +76,6 @@ public struct ChatModel: Codable, Equatable, Identifiable { self.isFullURL = isFullURL self.maxTokens = maxTokens self.supportsFunctionCalling = supportsFunctionCalling - self.supportsOpenAIAPI2023_11 = supportsOpenAIAPI2023_11 self.modelName = modelName self.openAIInfo = openAIInfo self.ollamaInfo = ollamaInfo From 2ad4e6d2a77f4f39bf036451b5ff9cc5fe134cf8 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sun, 3 Mar 2024 23:23:00 +0800 Subject: [PATCH 24/37] Remove tool calling content if function calling is not supported --- .../OpenAIService/ChatGPTService.swift | 43 ++++++++++++------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index 83e2cc13..ae932a91 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -587,24 +587,37 @@ extension ChatGPTService { }(), content: chatMessage.content ?? "", name: chatMessage.name, - toolCalls: chatMessage.toolCalls?.map { - .init( - id: $0.id, - type: $0.type, - function: .init( - name: $0.function.name, - arguments: $0.function.arguments - ) - ) - } + toolCalls: { + if model.info.supportsFunctionCalling { + chatMessage.toolCalls?.map { + .init( + id: $0.id, + type: $0.type, + function: .init( + name: $0.function.name, + arguments: $0.function.arguments + ) + ) + } + } else { + nil + } + }() )) for call in chatMessage.toolCalls ?? [] { - all.append(ChatCompletionsRequestBody.Message( - role: .tool, - content: call.response.content, - toolCallId: call.response.id - )) + if model.info.supportsFunctionCalling { + all.append(ChatCompletionsRequestBody.Message( + role: .tool, + content: call.response.content, + toolCallId: call.response.id + )) + } else { + all.append(ChatCompletionsRequestBody.Message( + role: .user, + content: call.response.content + )) + } } return all From 36f51bb192c83cdb84f1b27ab6629ea8e8103dea Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sun, 3 Mar 2024 23:23:17 +0800 Subject: [PATCH 25/37] Fix unit test --- Tool/Sources/OpenAIService/Models.swift | 16 +- .../ChatGPTStreamTests.swift | 294 ++++++++++++------ ...matPromptToBeGoogleAICompatibleTests.swift | 14 +- 3 files changed, 216 insertions(+), 108 deletions(-) diff --git a/Tool/Sources/OpenAIService/Models.swift b/Tool/Sources/OpenAIService/Models.swift index 3e606e50..a27fde87 100644 --- a/Tool/Sources/OpenAIService/Models.swift +++ b/Tool/Sources/OpenAIService/Models.swift @@ -25,20 +25,25 @@ public struct ChatMessage: Equatable, Codable { self.arguments = arguments } } - + public struct ToolCall: Codable, Equatable, Identifiable { public var id: String public var type: String public var function: FunctionCall public var response: ToolCallResponse - public init(id: String, type: String, function: FunctionCall) { + public init( + id: String, + type: String, + function: FunctionCall, + response: ToolCallResponse? = nil + ) { self.id = id self.type = type self.function = function - response = .init(id: id, content: "", summary: nil) + self.response = response ?? .init(id: id, content: "", summary: nil) } } - + public struct ToolCallResponse: Codable, Equatable { public var id: String public var content: String @@ -67,7 +72,7 @@ public struct ChatMessage: Equatable, Codable { case webpage case other } - + public var title: String public var subTitle: String public var uri: String @@ -164,3 +169,4 @@ public struct ReferenceKindFallback: FallbackValueProvider { public struct ChatMessageRoleFallback: FallbackValueProvider { public static var defaultValue: ChatMessage.Role { .user } } + diff --git a/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift b/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift index 7e90445b..b6018ef9 100644 --- a/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift +++ b/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift @@ -60,7 +60,7 @@ final class ChatGPTStreamTests: XCTestCase { ), ], "History is not updated") - XCTAssertEqual(requestBody?.functions, nil, "Function schema is not submitted") + XCTAssertEqual(requestBody?.tools, nil, "Function schema is not submitted") } } @@ -93,7 +93,7 @@ final class ChatGPTStreamTests: XCTestCase { for try await text in stream { all.append(text) let history = await memory.history - XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000040.0") + XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000030.0") XCTAssertTrue( history.last?.content?.hasPrefix(all.joined()) ?? false, "History is not updated" @@ -105,9 +105,14 @@ final class ChatGPTStreamTests: XCTestCase { .init(role: .user, content: "Hello"), .init( role: .assistant, content: "", - function_call: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + toolCalls: [ + .init( + id: "id", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + )] ), - .init(role: .function, content: "Function is called.", name: "function"), + .init(role: .tool, content: "Function is called.", toolCallId: "id"), ], "System prompt is not included") XCTAssertEqual(all, ["hello", "my", "friends"], "Text stream is not correct") @@ -123,26 +128,33 @@ final class ChatGPTStreamTests: XCTestCase { id: "00000000-0000-0000-0000-0000000000010.0", role: .assistant, content: nil, - functionCall: .init(name: "function", arguments: "{\n\"foo\": 1\n}") - ), - .init( - id: "00000000-0000-0000-0000-000000000003", - role: .function, - content: "Function is called.", - name: "function", - summary: nil + toolCalls: [ + .init( + id: "id", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), + response: .init(id: "id", content: "Function is called.", summary: nil) + ), + ] ), .init( - id: "00000000-0000-0000-0000-0000000000040.0", + id: "00000000-0000-0000-0000-0000000000030.0", role: .assistant, content: "hellomyfriends" ), ], "History is not updated") - XCTAssertEqual(requestBody?.functions, [ + XCTAssertEqual(requestBody?.tools, [ EmptyFunction(), ].map { - .init(name: $0.name, description: $0.description, parameters: $0.argumentSchema) + .init( + type: "function", + function: .init( + name: $0.name, + description: $0.description, + parameters: $0.argumentSchema + ) + ) }, "Function schema is not submitted") } } @@ -163,7 +175,7 @@ final class ChatGPTStreamTests: XCTestCase { service.changeBuildCompletionStreamAPI { _, _, _, _requestBody, _ in requestBody = _requestBody if _requestBody.messages.count <= 4 { - return MockCompletionStreamAPI_Function() + return MockCompletionStreamAPI_Function(count: 3) } return MockCompletionStreamAPI_Message() } @@ -177,7 +189,7 @@ final class ChatGPTStreamTests: XCTestCase { for try await text in stream { all.append(text) let history = await memory.history - XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000070.0") + XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000030.0") XCTAssertTrue( history.last?.content?.hasPrefix(all.joined()) ?? false, "History is not updated" @@ -189,14 +201,39 @@ final class ChatGPTStreamTests: XCTestCase { .init(role: .user, content: "Hello"), .init( role: .assistant, content: "", - function_call: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + toolCalls: [ + .init( + id: "id", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + ), + .init( + id: "id2", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + ), + .init( + id: "id3", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + ), + ] ), - .init(role: .function, content: "Function is called.", name: "function"), .init( - role: .assistant, content: "", - function_call: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + role: .tool, + content: "Function is called.", + toolCallId: "id" + ), + .init( + role: .tool, + content: "Function is called.", + toolCallId: "id2" + ), + .init( + role: .tool, + content: "Function is called.", + toolCallId: "id3" ), - .init(role: .function, content: "Function is called.", name: "function"), ], "System prompt is not included") XCTAssertEqual(all, ["hello", "my", "friends"], "Text stream is not correct") @@ -212,39 +249,45 @@ final class ChatGPTStreamTests: XCTestCase { id: "00000000-0000-0000-0000-0000000000010.0", role: .assistant, content: nil, - functionCall: .init(name: "function", arguments: "{\n\"foo\": 1\n}") - ), - .init( - id: "00000000-0000-0000-0000-000000000003", - role: .function, - content: "Function is called.", - name: "function", - summary: nil - ), - .init( - id: "00000000-0000-0000-0000-0000000000040.0", - role: .assistant, - content: nil, - functionCall: .init(name: "function", arguments: "{\n\"foo\": 1\n}") - ), - .init( - id: "00000000-0000-0000-0000-000000000006", - role: .function, - content: "Function is called.", - name: "function", - summary: nil + toolCalls: [ + .init( + id: "id", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), + response: .init(id: "id", content: "Function is called.", summary: nil) + ), + .init( + id: "id2", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), + response: .init(id: "id2", content: "Function is called.", summary: nil) + ), + .init( + id: "id3", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), + response: .init(id: "id3", content: "Function is called.", summary: nil) + ), + ] ), .init( - id: "00000000-0000-0000-0000-0000000000070.0", + id: "00000000-0000-0000-0000-0000000000030.0", role: .assistant, content: "hellomyfriends" ), ], "History is not updated") - XCTAssertEqual(requestBody?.functions, [ + XCTAssertEqual(requestBody?.tools, [ EmptyFunction(), ].map { - .init(name: $0.name, description: $0.description, parameters: $0.argumentSchema) + .init( + type: "function", + function: .init( + name: $0.name, + description: $0.description, + parameters: $0.argumentSchema + ) + ) }, "Function schema is not submitted") } } @@ -283,7 +326,7 @@ final class ChatGPTStreamTests: XCTestCase { for try await text in stream { all.append(text) let history = await memory.history - XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000040.0") + XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000030.0") XCTAssertTrue( history.last?.content?.hasPrefix(all.joined()) ?? false, "History is not updated" @@ -294,10 +337,9 @@ final class ChatGPTStreamTests: XCTestCase { .init(role: .system, content: "system"), .init(role: .user, content: "Hello"), .init( - role: .assistant, content: "", - function_call: .init(name: "function", arguments: "{\n\"foo\": 1\n}") + role: .assistant, content: "" ), - .init(role: .function, content: "Function is called.", name: "function"), + .init(role: .user, content: "Function is called."), ], "System prompt is not included") XCTAssertEqual(all, ["hello", "my", "friends"], "Text stream is not correct") @@ -313,23 +355,23 @@ final class ChatGPTStreamTests: XCTestCase { id: "00000000-0000-0000-0000-0000000000010.0", role: .assistant, content: nil, - functionCall: .init(name: "function", arguments: "{\n\"foo\": 1\n}") - ), - .init( - id: "00000000-0000-0000-0000-000000000003", - role: .function, - content: "Function is called.", - name: "function", - summary: nil + toolCalls: [ + .init( + id: "id", + type: "function", + function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), + response: .init(id: "id", content: "Function is called.", summary: nil) + ), + ] ), .init( - id: "00000000-0000-0000-0000-0000000000040.0", + id: "00000000-0000-0000-0000-0000000000030.0", role: .assistant, content: "hellomyfriends" ), ], "History is not updated") - XCTAssertEqual(requestBody?.functions, nil, "Functions should be nil") + XCTAssertEqual(requestBody?.tools, nil, "Functions should be nil") } } } @@ -343,18 +385,34 @@ extension ChatGPTStreamTests { let id = uuid().uuidString return AsyncThrowingStream { continuation in let chunks: [ChatCompletionsStreamDataChunk] = [ - .init(id: id, object: "", model: "", choices: [ - .init(delta: .init(role: .assistant), index: 0, finish_reason: ""), - ]), - .init(id: id, object: "", model: "", choices: [ - .init(delta: .init(content: "hello"), index: 0, finish_reason: ""), - ]), - .init(id: id, object: "", model: "", choices: [ - .init(delta: .init(content: "my"), index: 0, finish_reason: ""), - ]), - .init(id: id, object: "", model: "", choices: [ - .init(delta: .init(content: "friends"), index: 0, finish_reason: ""), - ]), + .init( + id: id, + object: "", + model: "", + message: .init(role: .assistant), + finishReason: "" + ), + .init( + id: id, + object: "", + model: "", + message: .init(content: "hello"), + finishReason: "" + ), + .init( + id: id, + object: "", + model: "", + message: .init(content: "my"), + finishReason: "" + ), + .init( + id: id, + object: "", + model: "", + message: .init(content: "friends"), + finishReason: "" + ), ] for chunk in chunks { continuation.yield(chunk) @@ -366,51 +424,87 @@ extension ChatGPTStreamTests { struct MockCompletionStreamAPI_Function: ChatCompletionsStreamAPI { @Dependency(\.uuid) var uuid + var count: Int = 1 func callAsFunction() async throws -> AsyncThrowingStream { let id = uuid().uuidString return AsyncThrowingStream { continuation in - let chunks: [ChatCompletionsStreamDataChunk] = [ - .init(id: id, object: "", model: "", choices: [ + for i in 0.. Date: Sun, 3 Mar 2024 23:26:39 +0800 Subject: [PATCH 26/37] Remove `id` from ToolCallResponse --- Core/Sources/ChatGPTChatTab/Chat.swift | 2 +- Tool/Sources/OpenAIService/ChatGPTService.swift | 2 +- .../AutoManagedChatGPTMemoryOpenAIStrategy.swift | 2 +- Tool/Sources/OpenAIService/Models.swift | 6 ++---- Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift | 10 +++++----- .../ReformatPromptToBeGoogleAICompatibleTests.swift | 2 +- 6 files changed, 11 insertions(+), 13 deletions(-) diff --git a/Core/Sources/ChatGPTChatTab/Chat.swift b/Core/Sources/ChatGPTChatTab/Chat.swift index 7a1ccc78..4d722d68 100644 --- a/Core/Sources/ChatGPTChatTab/Chat.swift +++ b/Core/Sources/ChatGPTChatTab/Chat.swift @@ -329,7 +329,7 @@ struct Chat: ReducerProtocol { for call in message.toolCalls ?? [] { all.append(.init( - id: message.id + call.response.id, + id: message.id + call.id, role: .tool, text: call.response.summary ?? call.response.content, references: [] diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index ae932a91..bc7df108 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -610,7 +610,7 @@ extension ChatGPTService { all.append(ChatCompletionsRequestBody.Message( role: .tool, content: call.response.content, - toolCallId: call.response.id + toolCallId: call.id )) } else { all.append(ChatCompletionsRequestBody.Message( diff --git a/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift b/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift index be68b9b4..3619a7e9 100644 --- a/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift +++ b/Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemoryStrategy/AutoManagedChatGPTMemoryOpenAIStrategy.swift @@ -45,7 +45,7 @@ extension TokenEncoder { encodingContent.append(toolCall.function.arguments) total += 4 encodingContent.append(toolCall.response.content) - encodingContent.append(toolCall.response.id) + encodingContent.append(toolCall.id) } } total += await withTaskGroup(of: Int.self, body: { group in diff --git a/Tool/Sources/OpenAIService/Models.swift b/Tool/Sources/OpenAIService/Models.swift index a27fde87..af95a6e5 100644 --- a/Tool/Sources/OpenAIService/Models.swift +++ b/Tool/Sources/OpenAIService/Models.swift @@ -40,16 +40,14 @@ public struct ChatMessage: Equatable, Codable { self.id = id self.type = type self.function = function - self.response = response ?? .init(id: id, content: "", summary: nil) + self.response = response ?? .init(content: "", summary: nil) } } public struct ToolCallResponse: Codable, Equatable { - public var id: String public var content: String public var summary: String? - public init(id: String, content: String, summary: String?) { - self.id = id + public init(content: String, summary: String?) { self.content = content self.summary = summary } diff --git a/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift b/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift index b6018ef9..38bebbbe 100644 --- a/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift +++ b/Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift @@ -133,7 +133,7 @@ final class ChatGPTStreamTests: XCTestCase { id: "id", type: "function", function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), - response: .init(id: "id", content: "Function is called.", summary: nil) + response: .init(content: "Function is called.", summary: nil) ), ] ), @@ -254,19 +254,19 @@ final class ChatGPTStreamTests: XCTestCase { id: "id", type: "function", function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), - response: .init(id: "id", content: "Function is called.", summary: nil) + response: .init(content: "Function is called.", summary: nil) ), .init( id: "id2", type: "function", function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), - response: .init(id: "id2", content: "Function is called.", summary: nil) + response: .init(content: "Function is called.", summary: nil) ), .init( id: "id3", type: "function", function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), - response: .init(id: "id3", content: "Function is called.", summary: nil) + response: .init(content: "Function is called.", summary: nil) ), ] ), @@ -360,7 +360,7 @@ final class ChatGPTStreamTests: XCTestCase { id: "id", type: "function", function: .init(name: "function", arguments: "{\n\"foo\": 1\n}"), - response: .init(id: "id", content: "Function is called.", summary: nil) + response: .init(content: "Function is called.", summary: nil) ), ] ), diff --git a/Tool/Tests/OpenAIServiceTests/ReformatPromptToBeGoogleAICompatibleTests.swift b/Tool/Tests/OpenAIServiceTests/ReformatPromptToBeGoogleAICompatibleTests.swift index 26e19ee1..e92d5079 100644 --- a/Tool/Tests/OpenAIServiceTests/ReformatPromptToBeGoogleAICompatibleTests.swift +++ b/Tool/Tests/OpenAIServiceTests/ReformatPromptToBeGoogleAICompatibleTests.swift @@ -110,7 +110,7 @@ class ReformatPromptToBeGoogleAICompatibleTests: XCTestCase { id: "id", type: "function", function: .init(name: "ping", arguments: "{ \"ip\": \"127.0.0.1\" }"), - response: .init(id: "id", content: "42ms", summary: nil) + response: .init(content: "42ms", summary: nil) ), ] ), From 21d5d685e52a964f5e8d96dccbbcfaf9567cae18 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Sun, 3 Mar 2024 23:50:46 +0800 Subject: [PATCH 27/37] Fix error decoding --- .../APIs/OpenAIChatCompletionsService.swift | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift index 49a04727..96e1a97a 100644 --- a/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift +++ b/Tool/Sources/OpenAIService/APIs/OpenAIChatCompletionsService.swift @@ -9,9 +9,9 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI struct CompletionAPIError: Error, Decodable, LocalizedError { struct ErrorDetail: Decodable { var message: String - var type: String - var param: String - var code: String + var type: String? + var param: String? + var code: String? } struct MistralAIErrorMessage: Decodable { @@ -51,10 +51,14 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI } init(from decoder: Decoder) throws { - let container: KeyedDecodingContainer = try decoder - .container(keyedBy: CodingKeys.self) + let container = try decoder.container(keyedBy: CodingKeys.self) - error = try? container.decode(ErrorDetail.self, forKey: .error) + do { + error = try container.decode(ErrorDetail.self, forKey: .error) + } catch { + print(error) + self.error = nil + } message = { if let e = try? container.decode(MistralAIErrorMessage.self, forKey: .message) { return CompletionAPIError.Message.mistralAI(e) From e9b88b708e0b057e37d9bb19ec3cdbcce22e8f72 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Mon, 4 Mar 2024 00:00:16 +0800 Subject: [PATCH 28/37] Remove warnings --- Pro | 2 +- Tool/Sources/GitHubCopilotService/GitHubCopilotService.swift | 4 ++-- Tool/Sources/OpenAIService/ChatGPTService.swift | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Pro b/Pro index ede561a8..9f891956 160000 --- a/Pro +++ b/Pro @@ -1 +1 @@ -Subproject commit ede561a8f50276d915ee672e8fad59226d349b08 +Subproject commit 9f891956191b9ba542b40fd090e7d4c11dbcc4c2 diff --git a/Tool/Sources/GitHubCopilotService/GitHubCopilotService.swift b/Tool/Sources/GitHubCopilotService/GitHubCopilotService.swift index 01225c22..e18eb58a 100644 --- a/Tool/Sources/GitHubCopilotService/GitHubCopilotService.swift +++ b/Tool/Sources/GitHubCopilotService/GitHubCopilotService.swift @@ -169,11 +169,11 @@ public class GitHubCopilotBaseService { Task { [weak self] in _ = try? await server.sendRequest(GitHubCopilotRequest.SetEditorInfo()) - for await notification in NotificationCenter.default + for await _ in NotificationCenter.default .notifications(named: .gitHubCopilotShouldRefreshEditorInformation) { print("Yes!") - guard let self else { return } + guard self != nil else { return } _ = try? await server.sendRequest(GitHubCopilotRequest.SetEditorInfo()) } } diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index bc7df108..a16bc932 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -285,7 +285,7 @@ public class ChatGPTService: ChatGPTServiceType { } } - #warning("TODO: remove this and let the concurrency system handle it") + #warning("TODO: Move the cancellation up to the caller.") public func stopReceivingMessage() { runningTask?.cancel() runningTask = nil From f2ed55d7211c36fcdd2345c5be99f6eb8e14058c Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Mon, 4 Mar 2024 00:17:47 +0800 Subject: [PATCH 29/37] Fix function calling settings for ollama --- .../ChatModelManagement/ChatModelEdit.swift | 3 +++ Tool/Sources/OpenAIService/ChatGPTService.swift | 15 +++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift index 8ef06b86..0e9b2062 100644 --- a/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift +++ b/Core/Sources/HostApp/AccountSettings/ChatModelManagement/ChatModelEdit.swift @@ -195,6 +195,9 @@ extension ChatModel { if case .googleAI = state.format { return false } + if case .ollama = state.format { + return false + } return state.supportsFunctionCalling }(), modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines), diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index a16bc932..c7d7dca8 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -575,6 +575,13 @@ extension ChatGPTService { model: ChatModel, stream: Bool ) -> ChatCompletionsRequestBody { + let serviceSupportsFunctionCalling = switch model.format { + case .openAI, .openAICompatible, .azureOpenAI: + model.info.supportsFunctionCalling + case .ollama, .googleAI: + false + } + let messages = prompt.history.flatMap { chatMessage in var all = [ChatCompletionsRequestBody.Message]() all.append(ChatCompletionsRequestBody.Message( @@ -588,7 +595,7 @@ extension ChatGPTService { content: chatMessage.content ?? "", name: chatMessage.name, toolCalls: { - if model.info.supportsFunctionCalling { + if serviceSupportsFunctionCalling { chatMessage.toolCalls?.map { .init( id: $0.id, @@ -606,7 +613,7 @@ extension ChatGPTService { )) for call in chatMessage.toolCalls ?? [] { - if model.info.supportsFunctionCalling { + if serviceSupportsFunctionCalling { all.append(ChatCompletionsRequestBody.Message( role: .tool, content: call.response.content, @@ -635,10 +642,10 @@ extension ChatGPTService { maxToken: model.info.maxTokens, remainingTokens: remainingTokens ), - toolChoice: model.info.supportsFunctionCalling + toolChoice: serviceSupportsFunctionCalling ? functionProvider.functionCallStrategy : nil, - tools: model.info.supportsFunctionCalling + tools: serviceSupportsFunctionCalling ? functionProvider.functions.map { .init(function: ChatGPTFunctionSchema( name: $0.name, From 84058eced01c13efb72c2ce8baa46a0c4bf5cb73 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Mon, 4 Mar 2024 00:19:50 +0800 Subject: [PATCH 30/37] Update --- Pro | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Pro b/Pro index 9f891956..a3792303 160000 --- a/Pro +++ b/Pro @@ -1 +1 @@ -Subproject commit 9f891956191b9ba542b40fd090e7d4c11dbcc4c2 +Subproject commit a37923038f72208000a3829683d7bdfb8d939e28 From 6063d41fa9963ceabee7427e6017a035636db08b Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Mon, 4 Mar 2024 23:19:25 +0800 Subject: [PATCH 31/37] Give response role the default value `assistant` --- Tool/Sources/OpenAIService/ChatGPTService.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/Tool/Sources/OpenAIService/ChatGPTService.swift b/Tool/Sources/OpenAIService/ChatGPTService.swift index c7d7dca8..61970a5e 100644 --- a/Tool/Sources/OpenAIService/ChatGPTService.swift +++ b/Tool/Sources/OpenAIService/ChatGPTService.swift @@ -330,6 +330,7 @@ extension ChatGPTService { do { await memory.streamMessage( id: proposedId, + role: .assistant, references: prompt.references ) let chunks = try await api() From 6fb6360b96cf9a297278f937575c61e19380985e Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Wed, 6 Mar 2024 14:38:20 +0800 Subject: [PATCH 32/37] Hide the circle when Xcode is not active --- .../SuggestionWidget/WidgetWindowsController.swift | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/Core/Sources/SuggestionWidget/WidgetWindowsController.swift b/Core/Sources/SuggestionWidget/WidgetWindowsController.swift index 41a40820..454c79ee 100644 --- a/Core/Sources/SuggestionWidget/WidgetWindowsController.swift +++ b/Core/Sources/SuggestionWidget/WidgetWindowsController.swift @@ -93,6 +93,7 @@ actor WidgetWindowsController: NSObject { let xcodeInspector = self.xcodeInspector let activeApp = await xcodeInspector.safe.activeApplication let latestActiveXcode = await xcodeInspector.safe.latestActiveXcode + let previousActiveApplication = xcodeInspector.previousActiveApplication await MainActor.run { let state = store.withState { $0 } let isChatPanelDetached = state.chatPanelState.chatPanelInASeparateWindow @@ -123,9 +124,17 @@ actor WidgetWindowsController: NSObject { return true }() + let previousAppIsXcode = previousActiveApplication?.isXcode ?? false + windows.sharedPanelWindow.alphaValue = noFocus ? 0 : 1 windows.suggestionPanelWindow.alphaValue = noFocus ? 0 : 1 - windows.widgetWindow.alphaValue = noFocus ? 0 : 1 + windows.widgetWindow.alphaValue = if noFocus { + 0 + } else if previousAppIsXcode { + 1 + } else { + 0 + } windows.toastWindow.alphaValue = noFocus ? 0 : 1 if isChatPanelDetached { windows.chatPanelWindow.isWindowHidden = !hasChat From d6afaab41908c1ed8978a7e7666e519244c3726a Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Wed, 6 Mar 2024 14:38:29 +0800 Subject: [PATCH 33/37] Update --- ExtensionService/AppDelegate+Menu.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/ExtensionService/AppDelegate+Menu.swift b/ExtensionService/AppDelegate+Menu.swift index e655522d..e5567c27 100644 --- a/ExtensionService/AppDelegate+Menu.swift +++ b/ExtensionService/AppDelegate+Menu.swift @@ -20,6 +20,7 @@ extension AppDelegate { .init("sourceEditorDebugMenu") } + @MainActor @objc func buildStatusBarMenu() { let statusBar = NSStatusBar.system statusBarItem = statusBar.statusItem( From a9b25c87fc3ac40e6e0fe3e5edb5541dc5a2be0a Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Thu, 7 Mar 2024 17:14:05 +0800 Subject: [PATCH 34/37] Bump Copilot.vim to 1.25.0 --- .../GitHubCopilotService/GitHubCopilotInstallationManager.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tool/Sources/GitHubCopilotService/GitHubCopilotInstallationManager.swift b/Tool/Sources/GitHubCopilotService/GitHubCopilotInstallationManager.swift index 92877938..f7013f08 100644 --- a/Tool/Sources/GitHubCopilotService/GitHubCopilotInstallationManager.swift +++ b/Tool/Sources/GitHubCopilotService/GitHubCopilotInstallationManager.swift @@ -10,7 +10,7 @@ public struct GitHubCopilotInstallationManager { return URL(string: link)! } - static let latestSupportedVersion = "1.19.2" + static let latestSupportedVersion = "1.25.0" public init() {} From 1285933dd491df12481de5b849fbc8a5d045410d Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Thu, 7 Mar 2024 17:19:49 +0800 Subject: [PATCH 35/37] Bump Codeium language server to 1.8.5 --- Tool/Sources/CodeiumService/CodeiumInstallationManager.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tool/Sources/CodeiumService/CodeiumInstallationManager.swift b/Tool/Sources/CodeiumService/CodeiumInstallationManager.swift index 7574c31b..9ea25108 100644 --- a/Tool/Sources/CodeiumService/CodeiumInstallationManager.swift +++ b/Tool/Sources/CodeiumService/CodeiumInstallationManager.swift @@ -3,7 +3,7 @@ import Terminal public struct CodeiumInstallationManager { private static var isInstalling = false - static let latestSupportedVersion = "1.6.9" + static let latestSupportedVersion = "1.8.5" public init() {} From 6f8091b92c14719e7f62bab480d9e41616514e19 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Mon, 4 Mar 2024 14:42:42 +0800 Subject: [PATCH 36/37] Bump version to 0.31.0 --- Version.xcconfig | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Version.xcconfig b/Version.xcconfig index a4dd52c7..deb81403 100644 --- a/Version.xcconfig +++ b/Version.xcconfig @@ -1,3 +1,3 @@ -APP_VERSION = 0.30.5 -APP_BUILD = 328 +APP_VERSION = 0.31.0 +APP_BUILD = 333 From a1604dd810ff59f0d21619c1196e6aea94a13544 Mon Sep 17 00:00:00 2001 From: Shx Guo Date: Fri, 8 Mar 2024 14:51:52 +0800 Subject: [PATCH 37/37] Update appcast.xml --- appcast.xml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/appcast.xml b/appcast.xml index ffc6db08..5cb4561e 100644 --- a/appcast.xml +++ b/appcast.xml @@ -2,6 +2,17 @@ Copilot for Xcode + + 0.31.0 + Fri, 08 Mar 2024 14:48:32 +0800 + 333 + 0.31.0 + 12.0 + + https://github.com/intitni/CopilotForXcode/releases/tag/0.31.0 + + + 0.31.0