Skip to content

Commit bcb0b57

Browse files
committed
Add prompt cache support for ClaudeChatCompletionsService
1 parent 4079a0f commit bcb0b57

1 file changed

Lines changed: 83 additions & 11 deletions

File tree

Tool/Sources/OpenAIService/APIs/ClaudeChatCompletionsService.swift

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import AIModel
2+
import ChatBasic
23
import AsyncAlgorithms
34
import CodableWrappers
45
import Foundation
@@ -124,6 +125,14 @@ public actor ClaudeChatCompletionsService: ChatCompletionsStreamAPI, ChatComplet
124125
}
125126

126127
struct RequestBody: Encodable, Equatable {
128+
struct CacheControl: Encodable, Equatable {
129+
enum CacheControlType: String, Codable, Equatable {
130+
case ephemeral
131+
}
132+
133+
var type: CacheControlType = .ephemeral
134+
}
135+
127136
struct MessageContent: Encodable, Equatable {
128137
enum MessageContentType: String, Encodable, Equatable {
129138
case text
@@ -141,6 +150,7 @@ public actor ClaudeChatCompletionsService: ChatCompletionsStreamAPI, ChatComplet
141150
var type: MessageContentType
142151
var text: String?
143152
var source: ImageSource?
153+
var cache_control: CacheControl?
144154
}
145155

146156
struct Message: Encodable, Equatable {
@@ -169,13 +179,26 @@ public actor ClaudeChatCompletionsService: ChatCompletionsStreamAPI, ChatComplet
169179
}
170180
}
171181

182+
struct SystemPrompt: Encodable, Equatable {
183+
let type = "text"
184+
var text: String
185+
var cache_control: CacheControl?
186+
}
187+
188+
struct Tool: Encodable, Equatable {
189+
var name: String
190+
var description: String
191+
var input_schema: JSONSchemaValue
192+
}
193+
172194
var model: String
173-
var system: String
195+
var system: [SystemPrompt]
174196
var messages: [Message]
175197
var temperature: Double?
176198
var stream: Bool?
177199
var stop_sequences: [String]?
178200
var max_tokens: Int
201+
var tools: [RequestBody.Tool]?
179202
}
180203

181204
var apiKey: String
@@ -261,6 +284,7 @@ public actor ClaudeChatCompletionsService: ChatCompletionsStreamAPI, ChatComplet
261284
request.httpBody = try encoder.encode(requestBody)
262285
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
263286
request.setValue("2023-06-01", forHTTPHeaderField: "anthropic-version")
287+
request.setValue("prompt-caching-2024-07-31", forHTTPHeaderField: "anthropic-beta")
264288
if !apiKey.isEmpty {
265289
request.setValue(apiKey, forHTTPHeaderField: "x-api-key")
266290
}
@@ -330,37 +354,85 @@ extension ClaudeChatCompletionsService.RequestBody {
330354
init(_ body: ChatCompletionsRequestBody) {
331355
model = body.model
332356

333-
var systemPrompts = [String]()
357+
var systemPrompts = [SystemPrompt]()
334358
var nonSystemMessages = [Message]()
335359

360+
enum JoinType {
361+
case joinMessage
362+
case appendToList
363+
case padMessageAndAppendToList
364+
}
365+
366+
func checkJoinType(for message: ChatCompletionsRequestBody.Message) -> JoinType {
367+
guard let last = nonSystemMessages.last else { return .appendToList }
368+
let newMessageRole: ClaudeChatCompletionsService.MessageRole = message.role == .user
369+
? .user
370+
: .assistant
371+
372+
if newMessageRole != last.role {
373+
return .appendToList
374+
}
375+
376+
if message.cacheIfPossible != last.content
377+
.contains(where: { $0.cache_control != nil })
378+
{
379+
return .padMessageAndAppendToList
380+
}
381+
382+
return .joinMessage
383+
}
384+
336385
for message in body.messages {
337386
switch message.role {
338387
case .system:
339-
systemPrompts.append(message.content)
388+
systemPrompts.append(.init(text: message.content, cache_control: {
389+
if message.cacheIfPossible {
390+
return .init()
391+
} else {
392+
return nil
393+
}
394+
}()))
340395
case .tool, .assistant:
341-
if let last = nonSystemMessages.last, last.role == .assistant {
342-
nonSystemMessages[nonSystemMessages.endIndex - 1].appendText(message.content)
343-
} else {
396+
switch checkJoinType(for: message) {
397+
case .appendToList:
344398
nonSystemMessages.append(.init(
345399
role: .assistant,
346400
content: [.init(type: .text, text: message.content)]
347401
))
402+
case .padMessageAndAppendToList, .joinMessage:
403+
nonSystemMessages[nonSystemMessages.endIndex - 1].content.append(
404+
.init(type: .text, text: message.content, cache_control: {
405+
if message.cacheIfPossible {
406+
return .init()
407+
} else {
408+
return nil
409+
}
410+
}())
411+
)
348412
}
349413
case .user:
350-
if let last = nonSystemMessages.last, last.role == .user {
351-
nonSystemMessages[nonSystemMessages.endIndex - 1].appendText(message.content)
352-
} else {
414+
switch checkJoinType(for: message) {
415+
case .appendToList:
353416
nonSystemMessages.append(.init(
354417
role: .user,
355418
content: [.init(type: .text, text: message.content)]
356419
))
420+
case .padMessageAndAppendToList, .joinMessage:
421+
nonSystemMessages[nonSystemMessages.endIndex - 1].content.append(
422+
.init(type: .text, text: message.content, cache_control: {
423+
if message.cacheIfPossible {
424+
return .init()
425+
} else {
426+
return nil
427+
}
428+
}())
429+
)
357430
}
358431
}
359432
}
360433

361434
messages = nonSystemMessages
362-
system = systemPrompts.joined(separator: "\n\n")
363-
.trimmingCharacters(in: .whitespacesAndNewlines)
435+
system = systemPrompts
364436
temperature = body.temperature
365437
stream = body.stream
366438
stop_sequences = body.stop

0 commit comments

Comments
 (0)