Skip to content

Commit 0e1e9dc

Browse files
committed
Add ClaudeChatCompletionsService
1 parent 1321f93 commit 0e1e9dc

1 file changed

Lines changed: 346 additions & 0 deletions

File tree

Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
import AIModel
2+
import AsyncAlgorithms
3+
import CodableWrappers
4+
import Foundation
5+
import Logger
6+
import Preferences
7+
8+
/// https://docs.anthropic.com/claude/reference/messages_post
9+
public actor ClaudeChatCompletionsService: ChatCompletionsStreamAPI, ChatCompletionsAPI {
10+
public enum KnownModel: String, CaseIterable {
11+
case claude3Opus = "claude-3-opus-20240229"
12+
case claude3Sonnet = "claude-3-sonnet-20240229"
13+
case claude3Haiku = "claude-3-haiku-20240307"
14+
15+
public var contextWindow: Int {
16+
switch self {
17+
case .claude3Opus: return 200_000
18+
case .claude3Sonnet: return 200_000
19+
case .claude3Haiku: return 200_000
20+
}
21+
}
22+
}
23+
24+
struct APIError: Error, Decodable, LocalizedError {
25+
struct ErrorDetail: Decodable {
26+
var message: String?
27+
var type: String?
28+
}
29+
30+
var error: ErrorDetail?
31+
var type: String
32+
33+
var errorDescription: String? {
34+
error?.message ?? "Unknown Error"
35+
}
36+
}
37+
38+
enum MessageRole: String, Codable {
39+
case user
40+
case assistant
41+
42+
var formalized: ChatCompletionsRequestBody.Message.Role {
43+
switch self {
44+
case .user: return .user
45+
case .assistant: return .assistant
46+
}
47+
}
48+
}
49+
50+
struct StreamDataChunk: Decodable {
51+
var type: String
52+
var message: Message?
53+
var index: Int?
54+
var content_block: ContentBlock?
55+
var delta: Delta?
56+
var error: APIError?
57+
58+
struct Message: Decodable {
59+
var id: String
60+
var type: String
61+
var role: MessageRole?
62+
var content: [ContentBlock]?
63+
var model: String
64+
var stop_reason: String?
65+
var stop_sequence: String?
66+
var usage: Usage?
67+
}
68+
69+
struct ContentBlock: Decodable {
70+
var type: String
71+
var text: String?
72+
}
73+
74+
struct Delta: Decodable {
75+
var type: String
76+
var text: String?
77+
var stop_reason: String?
78+
var stop_sequence: String?
79+
var usage: Usage?
80+
}
81+
82+
struct Usage: Decodable {
83+
var input_tokens: Int?
84+
var output_tokens: Int?
85+
}
86+
}
87+
88+
struct ResponseBody: Codable, Equatable {
89+
struct Content: Codable, Equatable {
90+
enum ContentType: String, Codable, FallbackValueProvider {
91+
case text
92+
case unknown
93+
static var defaultValue: ContentType { .unknown }
94+
}
95+
96+
/// The type of the message.
97+
///
98+
/// Currently, the only supported type is `text`.
99+
@FallbackDecoding<ContentType>
100+
var type: ContentType
101+
/// The content of the message.
102+
///
103+
/// If the request input messages ended with an assistant turn,
104+
/// then the response content will continue directly from that last turn.
105+
/// You can use this to constrain the model's output.
106+
var text: String?
107+
}
108+
109+
struct Usage: Codable, Equatable {
110+
var input_tokens: Int?
111+
var output_tokens: Int?
112+
}
113+
114+
var id: String?
115+
var model: String
116+
var type: String
117+
var usage: Usage
118+
var role: MessageRole
119+
var content: [Content]
120+
var stop_reason: String?
121+
var stop_sequence: String?
122+
}
123+
124+
struct RequestBody: Encodable, Equatable {
125+
struct MessageContent: Encodable, Equatable {
126+
enum MessageContentType: String, Encodable, Equatable {
127+
case text
128+
case image
129+
}
130+
131+
struct ImageSource: Encodable, Equatable {
132+
var type: String = "base64"
133+
/// currently support the base64 source type for images,
134+
/// and the image/jpeg, image/png, image/gif, and image/webp media types.
135+
var media_type: String = "image/jpeg"
136+
var data: String
137+
}
138+
139+
var type: MessageContentType
140+
var text: String?
141+
var source: ImageSource?
142+
}
143+
144+
struct Message: Encodable, Equatable {
145+
/// The role of the message.
146+
var role: MessageRole
147+
/// The content of the message.
148+
var content: [MessageContent]
149+
}
150+
151+
var model: String
152+
var system: String
153+
var messages: [Message]
154+
var temperature: Double?
155+
var stream: Bool?
156+
var stop_sequences: [String]?
157+
var max_tokens: Int
158+
}
159+
160+
var apiKey: String
161+
var endpoint: URL
162+
var requestBody: RequestBody
163+
var model: ChatModel
164+
165+
init(
166+
apiKey: String,
167+
model: ChatModel,
168+
endpoint: URL,
169+
requestBody: ChatCompletionsRequestBody
170+
) {
171+
self.apiKey = apiKey
172+
self.endpoint = endpoint
173+
self.requestBody = .init(requestBody)
174+
self.model = model
175+
}
176+
177+
func callAsFunction() async throws
178+
-> AsyncThrowingStream<ChatCompletionsStreamDataChunk, Error>
179+
{
180+
requestBody.stream = true
181+
var request = URLRequest(url: endpoint)
182+
request.httpMethod = "POST"
183+
let encoder = JSONEncoder()
184+
request.httpBody = try encoder.encode(requestBody)
185+
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
186+
request.setValue("2023-06-01", forHTTPHeaderField: "anthropic-version")
187+
if !apiKey.isEmpty {
188+
request.setValue(apiKey, forHTTPHeaderField: "x-api-key")
189+
}
190+
191+
let (result, response) = try await URLSession.shared.bytes(for: request)
192+
guard let response = response as? HTTPURLResponse else {
193+
throw ChatGPTServiceError.responseInvalid
194+
}
195+
196+
guard response.statusCode == 200 else {
197+
let text = try await result.lines.reduce(into: "") { partialResult, current in
198+
partialResult += current
199+
}
200+
guard let data = text.data(using: .utf8)
201+
else { throw ChatGPTServiceError.responseInvalid }
202+
let decoder = JSONDecoder()
203+
let error = try? decoder.decode(APIError.self, from: data)
204+
throw error ?? ChatGPTServiceError.responseInvalid
205+
}
206+
207+
let stream = ResponseStream<StreamDataChunk>(result: result) {
208+
var line = $0
209+
if line.hasPrefix("event:") {
210+
return .init(chunk: nil, done: false)
211+
}
212+
213+
let prefix = "data: "
214+
if line.hasPrefix(prefix) {
215+
line.removeFirst(prefix.count)
216+
}
217+
218+
if line == "[DONE]" { return .init(chunk: nil, done: true) }
219+
220+
do {
221+
let chunk = try JSONDecoder().decode(
222+
StreamDataChunk.self,
223+
from: line.data(using: .utf8) ?? Data()
224+
)
225+
return .init(chunk: chunk, done: chunk.type == "message_stop")
226+
} catch {
227+
Logger.service.error("Error decoding stream data: \(error)")
228+
return .init(chunk: nil, done: false)
229+
}
230+
}
231+
232+
return stream.map { $0.formalized() }.toStream()
233+
}
234+
235+
func callAsFunction() async throws -> ChatCompletionResponseBody {
236+
requestBody.stream = false
237+
var request = URLRequest(url: endpoint)
238+
request.httpMethod = "POST"
239+
let encoder = JSONEncoder()
240+
request.httpBody = try encoder.encode(requestBody)
241+
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
242+
request.setValue("2023-06-01", forHTTPHeaderField: "anthropic-version")
243+
if !apiKey.isEmpty {
244+
request.setValue(apiKey, forHTTPHeaderField: "x-api-key")
245+
}
246+
247+
let (result, response) = try await URLSession.shared.data(for: request)
248+
guard let response = response as? HTTPURLResponse else {
249+
throw ChatGPTServiceError.responseInvalid
250+
}
251+
252+
guard response.statusCode == 200 else {
253+
let error = try? JSONDecoder().decode(APIError.self, from: result)
254+
throw error ?? ChatGPTServiceError
255+
.otherError(String(data: result, encoding: .utf8) ?? "Unknown Error")
256+
}
257+
258+
do {
259+
let body = try JSONDecoder().decode(ResponseBody.self, from: result)
260+
return body.formalized()
261+
} catch {
262+
dump(error)
263+
throw error
264+
}
265+
}
266+
}
267+
268+
extension ClaudeChatCompletionsService.ResponseBody {
269+
func formalized() -> ChatCompletionResponseBody {
270+
return .init(
271+
id: id,
272+
object: "chat.completions",
273+
model: model,
274+
message: .init(
275+
role: role.formalized,
276+
content: content.reduce(into: "") { partialResult, next in
277+
if let text = next.text {
278+
partialResult += text
279+
}
280+
}
281+
),
282+
otherChoices: [],
283+
finishReason: stop_reason ?? ""
284+
)
285+
}
286+
}
287+
288+
extension ClaudeChatCompletionsService.StreamDataChunk {
289+
func formalized() -> ChatCompletionsStreamDataChunk {
290+
return .init(
291+
id: message?.id,
292+
object: "chat.completions",
293+
model: message?.model,
294+
message: {
295+
if let delta {
296+
return .init(content: delta.text)
297+
}
298+
if let message {
299+
return .init(role: message.role?.formalized)
300+
}
301+
return nil
302+
}(),
303+
finishReason: delta?.stop_reason
304+
)
305+
}
306+
}
307+
308+
extension ClaudeChatCompletionsService.RequestBody {
309+
init(_ body: ChatCompletionsRequestBody) {
310+
model = body.model
311+
312+
var systemPrompts = [String]()
313+
var nonSystemMessages = [Message]()
314+
315+
for message in body.messages {
316+
if message.role == .system {
317+
systemPrompts.append(message.content)
318+
} else {
319+
nonSystemMessages.append(.init(
320+
role: {
321+
switch message.role {
322+
case .user:
323+
return .user
324+
case .assistant:
325+
return .assistant
326+
case .system:
327+
return .user
328+
case .tool:
329+
return .assistant
330+
}
331+
}(),
332+
content: [.init(type: .text, text: message.content)]
333+
))
334+
}
335+
}
336+
337+
messages = nonSystemMessages
338+
system = systemPrompts.joined(separator: "\n\n")
339+
.trimmingCharacters(in: .whitespacesAndNewlines)
340+
temperature = body.temperature
341+
stream = body.stream
342+
stop_sequences = body.stop
343+
max_tokens = body.maxTokens ?? 4000
344+
}
345+
}
346+

0 commit comments

Comments
 (0)