Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add Claude support
  • Loading branch information
intitni committed Mar 21, 2024
commit 4ebacf8a833e132750dcfb4f4d34ad8999285c98
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,35 @@ struct ChatModelEdit: ReducerProtocol {
)
return .run { send in
do {
let reply =
try await ChatGPTService(
configuration: UserPreferenceChatGPTConfiguration()
.overriding {
$0.model = model
}
).sendAndWait(content: "Hello")
let service = ChatGPTService(
configuration: UserPreferenceChatGPTConfiguration()
.overriding {
$0.model = model
}
)
let reply = try await service
.sendAndWait(content: "Respond with \"Test succeeded\"")
await send(.testSucceeded(reply ?? "No Message"))
let stream = try await service
.send(content: "Respond with \"Stream response is working\"")
var streamReply = ""
for try await chunk in stream {
streamReply += chunk
}
await send(.testSucceeded(streamReply))
} catch {
await send(.testFailed(error.localizedDescription))
}
}

case let .testSucceeded(message):
state.isTesting = false
toast(message, .info)
toast(message.trimmingCharacters(in: .whitespacesAndNewlines), .info)
return .none

case let .testFailed(message):
state.isTesting = false
toast(message, .error)
toast(message.trimmingCharacters(in: .whitespacesAndNewlines), .error)
return .none

case .refreshAvailableModelNames:
Expand All @@ -132,6 +140,15 @@ struct ChatModelEdit: ReducerProtocol {
state.suggestedMaxTokens = nil
}
return .none
case .claude:
if let knownModel = ClaudeChatCompletionsService
.KnownModel(rawValue: state.modelName)
{
state.suggestedMaxTokens = knownModel.contextWindow
} else {
state.suggestedMaxTokens = nil
}
return .none
default:
state.suggestedMaxTokens = nil
return .none
Expand Down Expand Up @@ -192,13 +209,12 @@ extension ChatModel {
isFullURL: state.isFullURL,
maxTokens: state.maxTokens,
supportsFunctionCalling: {
if case .googleAI = state.format {
return false
}
if case .ollama = state.format {
switch state.format {
case .googleAI, .ollama, .claude:
return false
case .azureOpenAI, .openAI, .openAICompatible:
return state.supportsFunctionCalling
}
return state.supportsFunctionCalling
}(),
modelName: state.modelName.trimmingCharacters(in: .whitespacesAndNewlines),
ollamaInfo: .init(keepAlive: state.ollamaKeepAlive)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import AIModel
import ComposableArchitecture
import OpenAIService
import Preferences
import SwiftUI

Expand All @@ -26,6 +27,8 @@ struct ChatModelEditView: View {
googleAI
case .ollama:
ollama
case .claude:
claude
}
}
}
Expand Down Expand Up @@ -96,6 +99,8 @@ struct ChatModelEditView: View {
Text("Google Generative AI").tag(format)
case .ollama:
Text("Ollama").tag(format)
case .claude:
Text("Claude").tag(format)
}
}
},
Expand Down Expand Up @@ -348,7 +353,7 @@ struct ChatModelEditView: View {

maxTokensTextField
}

@ViewBuilder
var ollama: some View {
baseURLTextField(prompt: Text("http://127.0.0.1:11434")) {
Expand All @@ -363,7 +368,7 @@ struct ChatModelEditView: View {
}

maxTokensTextField

WithViewStore(
store,
removeDuplicates: { $0.ollamaKeepAlive == $1.ollamaKeepAlive }
Expand All @@ -380,6 +385,51 @@ struct ChatModelEditView: View {
}
.padding(.vertical)
}

@ViewBuilder
var claude: some View {
baseURLTextField(prompt: Text("https://api.anthropic.com")) {
Text("/v1/messages")
}

apiKeyNamePicker

WithViewStore(
store,
removeDuplicates: { $0.modelName == $1.modelName }
) { viewStore in
TextField("Model Name", text: viewStore.$modelName)
.overlay(alignment: .trailing) {
Picker(
"",
selection: viewStore.$modelName,
content: {
if ClaudeChatCompletionsService
.KnownModel(rawValue: viewStore.state.modelName) == nil
{
Text("Custom Model").tag(viewStore.state.modelName)
}
ForEach(
ClaudeChatCompletionsService.KnownModel.allCases,
id: \.self
) { model in
Text(model.rawValue).tag(model.rawValue)
}
}
)
.frame(width: 20)
}
}

maxTokensTextField

VStack(alignment: .leading, spacing: 8) {
Text(Image(systemName: "exclamationmark.triangle.fill")) + Text(
" For more details, please visit [https://anthropic.com](https://anthropic.com)."
)
}
.padding(.vertical)
}
}

#Preview("OpenAI") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ extension ChatModel: ManageableAIModel {
case .openAICompatible: return "OpenAI Compatible"
case .googleAI: return "Google Generative AI"
case .ollama: return "Ollama"
case .claude: return "Claude"
}
}

Expand Down
5 changes: 5 additions & 0 deletions Tool/Sources/AIModel/ChatModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public struct ChatModel: Codable, Equatable, Identifiable {
case openAICompatible
case googleAI
case ollama
case claude
}

public struct Info: Codable, Equatable {
Expand Down Expand Up @@ -107,6 +108,10 @@ public struct ChatModel: Codable, Equatable, Identifiable {
let baseURL = info.baseURL
if baseURL.isEmpty { return "http://localhost:11434/api/chat" }
return "\(baseURL)/api/chat"
case .claude:
let baseURL = info.baseURL
if baseURL.isEmpty { return "https://api.anthropic.com/v1/messages" }
return "\(baseURL)/v1/messages"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI
case .openAI:
if !model.info.openAIInfo.organizationID.isEmpty {
request.setValue(
"OpenAI-Organization",
forHTTPHeaderField: model.info.openAIInfo.organizationID
model.info.openAIInfo.organizationID,
forHTTPHeaderField: "OpenAI-Organization"
)
}
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
Expand All @@ -251,6 +251,8 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI
assertionFailure("Unsupported")
case .ollama:
assertionFailure("Unsupported")
case .claude:
assertionFailure("Unsupported")
}
}

Expand Down Expand Up @@ -319,6 +321,8 @@ actor OpenAIChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI
assertionFailure("Unsupported")
case .ollama:
assertionFailure("Unsupported")
case .claude:
assertionFailure("Unsupported")
}
}

Expand Down Expand Up @@ -376,7 +380,7 @@ extension OpenAIChatCompletionsService.ResponseBody {
),
]
} else {
return []
return nil
}
}()
)
Expand Down
16 changes: 15 additions & 1 deletion Tool/Sources/OpenAIService/ChatGPTService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ public class ChatGPTService: ChatGPTServiceType {
endpoint: endpoint,
requestBody: requestBody
)
case .claude:
return ClaudeChatCompletionsService(
apiKey: apiKey,
model: model,
endpoint: endpoint,
requestBody: requestBody
)
}
}

Expand Down Expand Up @@ -138,6 +145,13 @@ public class ChatGPTService: ChatGPTServiceType {
endpoint: endpoint,
requestBody: requestBody
)
case .claude:
return ClaudeChatCompletionsService(
apiKey: apiKey,
model: model,
endpoint: endpoint,
requestBody: requestBody
)
}
}

Expand Down Expand Up @@ -579,7 +593,7 @@ extension ChatGPTService {
let serviceSupportsFunctionCalling = switch model.format {
case .openAI, .openAICompatible, .azureOpenAI:
model.info.supportsFunctionCalling
case .ollama, .googleAI:
case .ollama, .googleAI, .claude:
false
}

Expand Down