11import AIModel
2+ import ChatBasic
23import AsyncAlgorithms
34import CodableWrappers
45import Foundation
@@ -124,6 +125,14 @@ public actor ClaudeChatCompletionsService: ChatCompletionsStreamAPI, ChatComplet
124125 }
125126
126127 struct RequestBody : Encodable , Equatable {
128+ struct CacheControl : Encodable , Equatable {
129+ enum CacheControlType : String , Codable , Equatable {
130+ case ephemeral
131+ }
132+
133+ var type : CacheControlType = . ephemeral
134+ }
135+
127136 struct MessageContent : Encodable , Equatable {
128137 enum MessageContentType : String , Encodable , Equatable {
129138 case text
@@ -141,6 +150,7 @@ public actor ClaudeChatCompletionsService: ChatCompletionsStreamAPI, ChatComplet
141150 var type : MessageContentType
142151 var text : String ?
143152 var source : ImageSource ?
153+ var cache_control : CacheControl ?
144154 }
145155
146156 struct Message : Encodable , Equatable {
@@ -169,13 +179,26 @@ public actor ClaudeChatCompletionsService: ChatCompletionsStreamAPI, ChatComplet
169179 }
170180 }
171181
182+ struct SystemPrompt : Encodable , Equatable {
183+ let type = " text "
184+ var text : String
185+ var cache_control : CacheControl ?
186+ }
187+
188+ struct Tool : Encodable , Equatable {
189+ var name : String
190+ var description : String
191+ var input_schema : JSONSchemaValue
192+ }
193+
172194 var model : String
173- var system : String
195+ var system : [ SystemPrompt ]
174196 var messages : [ Message ]
175197 var temperature : Double ?
176198 var stream : Bool ?
177199 var stop_sequences : [ String ] ?
178200 var max_tokens : Int
201+ var tools : [ RequestBody . Tool ] ?
179202 }
180203
181204 var apiKey : String
@@ -261,6 +284,7 @@ public actor ClaudeChatCompletionsService: ChatCompletionsStreamAPI, ChatComplet
261284 request. httpBody = try encoder. encode ( requestBody)
262285 request. setValue ( " application/json " , forHTTPHeaderField: " Content-Type " )
263286 request. setValue ( " 2023-06-01 " , forHTTPHeaderField: " anthropic-version " )
287+ request. setValue ( " prompt-caching-2024-07-31 " , forHTTPHeaderField: " anthropic-beta " )
264288 if !apiKey. isEmpty {
265289 request. setValue ( apiKey, forHTTPHeaderField: " x-api-key " )
266290 }
@@ -330,37 +354,85 @@ extension ClaudeChatCompletionsService.RequestBody {
330354 init ( _ body: ChatCompletionsRequestBody ) {
331355 model = body. model
332356
333- var systemPrompts = [ String ] ( )
357+ var systemPrompts = [ SystemPrompt ] ( )
334358 var nonSystemMessages = [ Message] ( )
335359
360+ enum JoinType {
361+ case joinMessage
362+ case appendToList
363+ case padMessageAndAppendToList
364+ }
365+
366+ func checkJoinType( for message: ChatCompletionsRequestBody . Message ) -> JoinType {
367+ guard let last = nonSystemMessages. last else { return . appendToList }
368+ let newMessageRole : ClaudeChatCompletionsService . MessageRole = message. role == . user
369+ ? . user
370+ : . assistant
371+
372+ if newMessageRole != last. role {
373+ return . appendToList
374+ }
375+
376+ if message. cacheIfPossible != last. content
377+ . contains ( where: { $0. cache_control != nil } )
378+ {
379+ return . padMessageAndAppendToList
380+ }
381+
382+ return . joinMessage
383+ }
384+
336385 for message in body. messages {
337386 switch message. role {
338387 case . system:
339- systemPrompts. append ( message. content)
388+ systemPrompts. append ( . init( text: message. content, cache_control: {
389+ if message. cacheIfPossible {
390+ return . init( )
391+ } else {
392+ return nil
393+ }
394+ } ( ) ) )
340395 case . tool, . assistant:
341- if let last = nonSystemMessages. last, last. role == . assistant {
342- nonSystemMessages [ nonSystemMessages. endIndex - 1 ] . appendText ( message. content)
343- } else {
396+ switch checkJoinType ( for: message) {
397+ case . appendToList:
344398 nonSystemMessages. append ( . init(
345399 role: . assistant,
346400 content: [ . init( type: . text, text: message. content) ]
347401 ) )
402+ case . padMessageAndAppendToList, . joinMessage:
403+ nonSystemMessages [ nonSystemMessages. endIndex - 1 ] . content. append (
404+ . init( type: . text, text: message. content, cache_control: {
405+ if message. cacheIfPossible {
406+ return . init( )
407+ } else {
408+ return nil
409+ }
410+ } ( ) )
411+ )
348412 }
349413 case . user:
350- if let last = nonSystemMessages. last, last. role == . user {
351- nonSystemMessages [ nonSystemMessages. endIndex - 1 ] . appendText ( message. content)
352- } else {
414+ switch checkJoinType ( for: message) {
415+ case . appendToList:
353416 nonSystemMessages. append ( . init(
354417 role: . user,
355418 content: [ . init( type: . text, text: message. content) ]
356419 ) )
420+ case . padMessageAndAppendToList, . joinMessage:
421+ nonSystemMessages [ nonSystemMessages. endIndex - 1 ] . content. append (
422+ . init( type: . text, text: message. content, cache_control: {
423+ if message. cacheIfPossible {
424+ return . init( )
425+ } else {
426+ return nil
427+ }
428+ } ( ) )
429+ )
357430 }
358431 }
359432 }
360433
361434 messages = nonSystemMessages
362- system = systemPrompts. joined ( separator: " \n \n " )
363- . trimmingCharacters ( in: . whitespacesAndNewlines)
435+ system = systemPrompts
364436 temperature = body. temperature
365437 stream = body. stream
366438 stop_sequences = body. stop
0 commit comments