forked from intitni/CopilotForXcode
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAutoManagedChatGPTMemoryGoogleAIStrategy.swift
More file actions
121 lines (105 loc) · 4.04 KB
/
AutoManagedChatGPTMemoryGoogleAIStrategy.swift
File metadata and controls
121 lines (105 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import Foundation
import GoogleGenerativeAI
import Logger
extension AutoManagedChatGPTMemory {
struct GoogleAIStrategy: AutoManagedChatGPTMemoryStrategy {
let configuration: ChatGPTConfiguration
func countToken(_ message: ChatMessage) async -> Int {
guard let model = configuration.model else {
return 0
}
let aiModel = GenerativeModel(name: model.info.modelName, apiKey: configuration.apiKey)
if message.isEmpty { return 0 }
let modelMessage = ModelContent(message)
return (try? await aiModel.countTokens([modelMessage]).totalTokens) ?? 0
}
func countToken<F>(_: F) async -> Int where F: ChatGPTFunction {
// function is not supported.
return 0
}
/// Gemini only supports turn-based conversation. A user message must be followed
/// by an model message.
func reformat(_ prompt: ChatGPTPrompt) async -> ChatGPTPrompt {
var history = prompt.history
var reformattedHistory = [ChatMessage]()
// We don't want to combine the new user message with others.
let newUserMessage: ChatMessage? = if history.last?.role == .user {
history.removeLast()
} else {
nil
}
for message in history {
let lastIndex = reformattedHistory.endIndex - 1
guard lastIndex >= 0 else {
reformattedHistory.append(message)
continue
}
let lastMessage = reformattedHistory[lastIndex]
if ModelContent.convertRole(lastMessage.role) == ModelContent
.convertRole(message.role)
{
let newMessage = ChatMessage(
role: message.role == .assistant ? .assistant : .user,
content: """
\(ModelContent.convertContent(of: lastMessage))
======
\(ModelContent.convertContent(of: message))
"""
)
reformattedHistory[lastIndex] = newMessage
} else {
reformattedHistory.append(message)
}
}
if let newUserMessage {
if let last = reformattedHistory.last,
ModelContent.convertRole(last.role) == ModelContent
.convertRole(newUserMessage.role)
{
// Add dummy message
let dummyMessage = ChatMessage(
role: .assistant,
content: "OK"
)
reformattedHistory.append(dummyMessage)
}
reformattedHistory.append(newUserMessage)
}
return .init(
history: reformattedHistory,
references: prompt.references,
remainingTokenCount: prompt.remainingTokenCount
)
}
}
}
extension ModelContent {
static func convertRole(_ role: ChatMessage.Role) -> String {
switch role {
case .user, .system, .function:
return "user"
case .assistant:
return "model"
}
}
static func convertContent(of message: ChatMessage) -> String {
switch message.role {
case .user, .system, .function:
return message.content ?? ""
case .assistant:
if let functionCall = message.functionCall {
return """
call function: \(functionCall.name)
arguments: \(functionCall.arguments)
"""
} else {
return message.content ?? ""
}
}
}
init(_ message: ChatMessage) {
let role = Self.convertRole(message.role)
let parts = [ModelContent.Part.text(Self.convertContent(of: message))]
self = .init(role: role, parts: parts)
}
}