Skip to content

Commit 3863c2a

Browse files
committed
Support prompt caching in Claude through OpenRouter
1 parent 81c4259 commit 3863c2a

6 files changed

Lines changed: 401 additions & 132 deletions

Tool/Sources/OpenAIService/APIs/BuiltinExtensionChatCompletionsService.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,11 @@ extension BuiltinExtensionChatCompletionsService {
115115
.joined(separator: "\n\n")
116116
let history = Array(messages[0...lastIndexNotUserMessage])
117117
return (message, history.map {
118-
.init(id: UUID().uuidString, role: $0.role.asChatMessageRole, content: $0.content)
118+
.init(
119+
id: UUID().uuidString,
120+
role: $0.role.asChatMessageRole,
121+
content: $0.content
122+
)
119123
})
120124
} else { // everything is user message
121125
let message = messages.map { $0.content }.joined(separator: "\n\n")

Tool/Sources/OpenAIService/APIs/ChatCompletionsAPIDefinition.swift

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import AIModel
2+
import ChatBasic
23
import CodableWrappers
34
import Foundation
45
import Preferences
5-
import ChatBasic
66

7-
struct ChatCompletionsRequestBody: Codable, Equatable {
8-
struct Message: Codable, Equatable {
9-
enum Role: String, Codable, Equatable {
7+
struct ChatCompletionsRequestBody: Equatable {
8+
struct Message: Equatable {
9+
enum Role: String, Equatable {
1010
case system
1111
case user
1212
case assistant
1313
case tool
14-
14+
1515
var asChatMessageRole: ChatMessage.Role {
1616
switch self {
1717
case .system:
@@ -25,6 +25,31 @@ struct ChatCompletionsRequestBody: Codable, Equatable {
2525
}
2626
}
2727
}
28+
29+
struct Image: Equatable {
30+
enum Format: String {
31+
case png = "image/png"
32+
case jpeg = "image/jpeg"
33+
case gif = "image/gif"
34+
}
35+
var data: Data
36+
var format: Format
37+
38+
var dataURLString: String {
39+
let base64 = data.base64EncodedString()
40+
return "data:\(format.rawValue);base64,\(base64)"
41+
}
42+
}
43+
44+
struct Audio: Equatable {
45+
enum Format: String {
46+
case wav
47+
case mp3
48+
}
49+
50+
var data: Data
51+
var format: Format
52+
}
2853

2954
/// The role of the message.
3055
var role: Role
@@ -34,25 +59,29 @@ struct ChatCompletionsRequestBody: Codable, Equatable {
3459
/// name of the function call, and include the result in `content`.
3560
///
3661
/// - important: It's required when the role is `function`.
37-
var name: String?
62+
var name: String? = nil
3863
/// Tool calls in an assistant message.
39-
var toolCalls: [MessageToolCall]?
64+
var toolCalls: [MessageToolCall]? = nil
4065
/// When we want to call a tool, we have to provide the id of the call.
4166
///
4267
/// - important: It's required when the role is `tool`.
43-
var toolCallId: String?
68+
var toolCallId: String? = nil
69+
/// Images to include in the message.
70+
var images: [Image] = []
71+
/// Audios to include in the message.
72+
var audios: [Audio] = []
4473
/// Cache the message if possible.
4574
var cacheIfPossible: Bool = false
4675
}
4776

48-
struct MessageFunctionCall: Codable, Equatable {
77+
struct MessageFunctionCall: Equatable {
4978
/// The name of the
5079
var name: String
5180
/// A JSON string.
5281
var arguments: String?
5382
}
5483

55-
struct MessageToolCall: Codable, Equatable {
84+
struct MessageToolCall: Equatable {
5685
/// The id of the tool call.
5786
var id: String
5887
/// The type of the tool.
@@ -61,7 +90,7 @@ struct ChatCompletionsRequestBody: Codable, Equatable {
6190
var function: MessageFunctionCall
6291
}
6392

64-
struct Tool: Codable, Equatable {
93+
struct Tool: Equatable {
6594
var type: String = "function"
6695
var function: ChatGPTFunctionSchema
6796
}
@@ -182,11 +211,11 @@ struct ChatCompletionsStreamDataChunk {
182211
var content: String?
183212
var toolCalls: [ToolCall]?
184213
}
185-
214+
186215
struct Usage: Codable, Equatable {
187216
var promptTokens: Int?
188217
var completionTokens: Int?
189-
218+
190219
var cachedTokens: Int?
191220
var otherUsage: [String: Int]
192221
}
@@ -205,16 +234,35 @@ protocol ChatCompletionsAPI {
205234
func callAsFunction() async throws -> ChatCompletionResponseBody
206235
}
207236

208-
struct ChatCompletionResponseBody: Codable, Equatable {
209-
typealias Message = ChatCompletionsRequestBody.Message
210-
211-
struct Usage: Codable, Equatable {
237+
struct ChatCompletionResponseBody: Equatable {
238+
struct Message: Equatable {
239+
typealias Role = ChatCompletionsRequestBody.Message.Role
240+
typealias MessageToolCall = ChatCompletionsRequestBody.MessageToolCall
241+
242+
/// The role of the message.
243+
var role: Role
244+
/// The content of the message.
245+
var content: String?
246+
/// When we want to reply to a function call with the result, we have to provide the
247+
/// name of the function call, and include the result in `content`.
248+
///
249+
/// - important: It's required when the role is `function`.
250+
var name: String?
251+
/// Tool calls in an assistant message.
252+
var toolCalls: [MessageToolCall]?
253+
/// When we want to call a tool, we have to provide the id of the call.
254+
///
255+
/// - important: It's required when the role is `tool`.
256+
var toolCallId: String?
257+
}
258+
259+
struct Usage: Equatable {
212260
var promptTokens: Int
213261
var completionTokens: Int
214-
262+
215263
var cachedTokens: Int
216264
var otherUsage: [String: Int]
217-
265+
218266
mutating func merge(with other: ChatCompletionsStreamDataChunk.Usage) {
219267
promptTokens += other.promptTokens ?? 0
220268
completionTokens += other.completionTokens ?? 0
@@ -223,7 +271,7 @@ struct ChatCompletionResponseBody: Codable, Equatable {
223271
otherUsage[key, default: 0] += value
224272
}
225273
}
226-
274+
227275
mutating func merge(with other: Self) {
228276
promptTokens += other.promptTokens
229277
completionTokens += other.completionTokens

Tool/Sources/OpenAIService/APIs/ClaudeChatCompletionsService.swift

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -414,12 +414,6 @@ extension ClaudeChatCompletionsService.RequestBody {
414414
return .appendToList
415415
}
416416

417-
if message.cacheIfPossible != last.content
418-
.contains(where: { $0.cache_control != nil })
419-
{
420-
return .padMessageAndAppendToList
421-
}
422-
423417
return .joinMessage
424418
}
425419

@@ -434,6 +428,41 @@ extension ClaudeChatCompletionsService.RequestBody {
434428
return false
435429
}
436430

431+
func convertMessageContent(
432+
_ message: ChatCompletionsRequestBody.Message
433+
) -> [MessageContent] {
434+
var content = [MessageContent]()
435+
436+
content.append(.init(type: .text, text: message.content, cache_control: {
437+
if message.cacheIfPossible, supportsPromptCache, consumeCacheControl() {
438+
return .init()
439+
} else {
440+
return nil
441+
}
442+
}()))
443+
for image in message.images {
444+
content.append(.init(type: .image, source: .init(
445+
type: "base64",
446+
media_type: image.format.rawValue,
447+
data: image.data.base64EncodedString()
448+
)))
449+
}
450+
451+
return content
452+
}
453+
454+
func convertMessage(_ message: ChatCompletionsRequestBody.Message) -> Message {
455+
let role: ClaudeChatCompletionsService.MessageRole = switch message.role {
456+
case .system: .assistant
457+
case .assistant, .tool: .assistant
458+
case .user: .user
459+
}
460+
461+
let content: [MessageContent] = convertMessageContent(message)
462+
463+
return .init(role: role, content: content)
464+
}
465+
437466
for message in body.messages {
438467
switch message.role {
439468
case .system:
@@ -447,38 +476,18 @@ extension ClaudeChatCompletionsService.RequestBody {
447476
case .tool, .assistant:
448477
switch checkJoinType(for: message) {
449478
case .appendToList:
450-
nonSystemMessages.append(.init(
451-
role: .assistant,
452-
content: [.init(type: .text, text: message.content)]
453-
))
479+
nonSystemMessages.append(convertMessage(message))
454480
case .padMessageAndAppendToList, .joinMessage:
455-
nonSystemMessages[nonSystemMessages.endIndex - 1].content.append(
456-
.init(type: .text, text: message.content, cache_control: {
457-
if message.cacheIfPossible, supportsPromptCache, consumeCacheControl() {
458-
return .init()
459-
} else {
460-
return nil
461-
}
462-
}())
463-
)
481+
nonSystemMessages[nonSystemMessages.endIndex - 1].content
482+
.append(contentsOf: convertMessageContent(message))
464483
}
465484
case .user:
466485
switch checkJoinType(for: message) {
467486
case .appendToList:
468-
nonSystemMessages.append(.init(
469-
role: .user,
470-
content: [.init(type: .text, text: message.content)]
471-
))
487+
nonSystemMessages.append(convertMessage(message))
472488
case .padMessageAndAppendToList, .joinMessage:
473-
nonSystemMessages[nonSystemMessages.endIndex - 1].content.append(
474-
.init(type: .text, text: message.content, cache_control: {
475-
if message.cacheIfPossible, supportsPromptCache, consumeCacheControl() {
476-
return .init()
477-
} else {
478-
return nil
479-
}
480-
}())
481-
)
489+
nonSystemMessages[nonSystemMessages.endIndex - 1].content
490+
.append(contentsOf: convertMessageContent(message))
482491
}
483492
}
484493
}

0 commit comments

Comments
 (0)