Skip to content

Commit 99ce70f

Browse files
committed
Update API to support function calls
1 parent d47b4cf commit 99ce70f

File tree

12 files changed

+207
-43
lines changed

12 files changed

+207
-43
lines changed

Core/Sources/ChatContextCollector/ActiveDocumentChatContextCollector.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import Foundation
2+
import OpenAIService
23
import Preferences
34
import SuggestionModel
45
import XcodeInspector
56

67
public struct ActiveDocumentChatContextCollector: ChatContextCollector {
78
public init() {}
89

9-
public func generateSystemPrompt(history: [String], content prompt: String) -> String {
10+
public func generateSystemPrompt(history: [ChatMessage], content prompt: String) -> String {
1011
let content = getEditorInformation()
1112
let relativePath = content.documentURL.path
1213
.replacingOccurrences(of: content.projectURL.path, with: "")
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import Foundation
2+
import OpenAIService
23

34
public protocol ChatContextCollector {
4-
func generateSystemPrompt(history: [String], content: String) -> String
5+
func generateSystemPrompt(history: [ChatMessage], content: String) -> String
56
}
7+

Core/Sources/ChatService/ChatService.swift

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ public final class ChatService: ObservableObject {
3636
}
3737

3838
public init() {
39-
self.configuration = OverridingUserPreferenceChatGPTConfiguration()
40-
self.memory = AutoManagedChatGPTMemory(systemPrompt: "", configuration: configuration)
39+
configuration = OverridingUserPreferenceChatGPTConfiguration()
40+
memory = AutoManagedChatGPTMemory(systemPrompt: "", configuration: configuration)
4141
chatGPTService = ChatGPTService(memory: memory, configuration: configuration)
4242
pluginController = ChatPluginController(chatGPTService: chatGPTService, plugins: allPlugins)
4343
contextController = DynamicContextController(
@@ -89,14 +89,18 @@ public final class ChatService: ObservableObject {
8989
}
9090

9191
public func resendMessage(id: String) async throws {
92-
if let message = (await memory.history).first(where: { $0.id == id }) {
93-
try await send(content: message.content)
92+
if let message = (await memory.history).first(where: { $0.id == id }),
93+
let content = message.content
94+
{
95+
try await send(content: content)
9496
}
9597
}
9698

9799
public func setMessageAsExtraPrompt(id: String) async {
98-
if let message = (await memory.history).first(where: { $0.id == id }) {
99-
mutateExtraSystemPrompt(message.content)
100+
if let message = (await memory.history).first(where: { $0.id == id }),
101+
let content = message.content
102+
{
103+
mutateExtraSystemPrompt(content)
100104
await mutateHistory { history in
101105
history.append(.init(
102106
role: .assistant,

Core/Sources/ChatService/DynamicContextController.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ final class DynamicContextController {
1515

1616
func updatePromptToMatchContent(systemPrompt: String, content: String) async throws {
1717
let language = UserDefaults.shared.value(for: \.chatGPTLanguage)
18-
let oldMessages = (await memory.history).map(\.content)
18+
let oldMessages = await memory.history
1919
let contextualSystemPrompt = """
2020
\(language.isEmpty ? "" : "You must always reply in \(language)")
2121
\(systemPrompt)

Core/Sources/Service/GUI/ChatProvider+Service.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ extension ChatProvider {
2020
.init(
2121
id: message.id,
2222
isUser: message.role == .user,
23-
text: message.summary ?? message.content
23+
text: message.summary ?? message.content ?? ""
2424
)
2525
}
2626
self.isReceivingMessage = service.isReceivingMessage

Tool/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ public class ChatGPTService: ChatGPTServiceType {
8484
id: uuidGenerator(),
8585
role: .user,
8686
content: content,
87+
name: nil,
88+
functionCall: nil,
8789
summary: summary
8890
)
8991
await memory.appendMessage(newMessage)
@@ -103,7 +105,9 @@ public class ChatGPTService: ChatGPTServiceType {
103105
max_tokens: maxTokenForReply(
104106
model: configuration.model,
105107
remainingTokens: remainingTokens
106-
)
108+
),
109+
function_call: nil,
110+
functions: []
107111
)
108112

109113
let api = buildCompletionStreamAPI(
@@ -118,20 +122,44 @@ public class ChatGPTService: ChatGPTServiceType {
118122
do {
119123
let (trunks, cancel) = try await api()
120124
cancelTask = cancel
125+
var id = ""
126+
var functionCallRawString = ""
121127
for try await trunk in trunks {
128+
id = trunk.id
129+
122130
guard let delta = trunk.choices.first?.delta else { continue }
123131

124132
await memory.streamMessage(
125133
id: trunk.id,
126134
role: delta.role,
127-
content: delta.content
135+
content: delta.content,
136+
functionCall: nil
128137
)
129138

139+
if let call = delta.function_call {
140+
functionCallRawString.append(call)
141+
}
142+
130143
if let content = delta.content {
131144
continuation.yield(content)
132145
}
133146

134-
try await Task.sleep(nanoseconds: 3_500_000)
147+
try await Task.sleep(nanoseconds: 3_000_000)
148+
}
149+
150+
if !functionCallRawString.isEmpty,
151+
let data = functionCallRawString.data(using: .utf8)
152+
{
153+
let function = try JSONDecoder().decode(
154+
ChatMessage.FunctionCall.self,
155+
from: data
156+
)
157+
await memory.streamMessage(
158+
id: id,
159+
role: nil,
160+
content: nil,
161+
functionCall: function
162+
)
135163
}
136164

137165
continuation.finish()
@@ -166,7 +194,7 @@ public class ChatGPTService: ChatGPTServiceType {
166194
)
167195
await memory.appendMessage(newMessage)
168196
}
169-
197+
170198
let messages = await memory.messages.map {
171199
CompletionRequestBody.Message(role: $0.role, content: $0.content)
172200
}
@@ -181,7 +209,9 @@ public class ChatGPTService: ChatGPTServiceType {
181209
max_tokens: maxTokenForReply(
182210
model: configuration.model,
183211
remainingTokens: remainingTokens
184-
)
212+
),
213+
function_call: nil,
214+
functions: []
185215
)
186216

187217
let api = buildCompletionAPI(

Tool/Sources/OpenAIService/CompletionAPI.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@ protocol CompletionAPI {
1010

1111
/// https://platform.openai.com/docs/api-reference/chat/create
1212
struct CompletionResponseBody: Codable, Equatable {
13-
struct Message: Codable, Equatable {
14-
var role: ChatMessage.Role
15-
var content: String
16-
}
13+
typealias Message = CompletionRequestBody.Message
1714

1815
struct Choice: Codable, Equatable {
1916
var message: Message

Tool/Sources/OpenAIService/CompletionStreamAPI.swift

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import AsyncAlgorithms
22
import Foundation
33
import Preferences
4+
import JSONRPC
45

56
typealias CompletionStreamAPIBuilder = (String, ChatFeatureProvider, URL, CompletionRequestBody) -> CompletionStreamAPI
67

@@ -12,10 +13,64 @@ protocol CompletionStreamAPI {
1213
}
1314

1415
/// https://platform.openai.com/docs/api-reference/chat/create
15-
struct CompletionRequestBody: Codable, Equatable {
16+
struct CompletionRequestBody: Encodable, Equatable {
1617
struct Message: Codable, Equatable {
18+
/// The role of the message.
1719
var role: ChatMessage.Role
18-
var content: String
20+
/// The content of the message.
21+
var content: String?
22+
/// When we want to reply to a function call with the result, we have to provide the
23+
/// name of the function call, and include the result in `content`.
24+
///
25+
/// - important: It's required when the role is `function`.
26+
var name: String?
27+
/// When the bot wants to call a function, it will reply with a function call in format:
28+
/// ```json
29+
/// {
30+
/// "name": "weather",
31+
/// "arguments": "{ \"location\": \"earth\" }"
32+
/// }
33+
/// ```
34+
var function_call: MessageFunctionCall?
35+
}
36+
37+
struct MessageFunctionCall: Codable, Equatable {
38+
/// The name of the
39+
var name: String
40+
/// A JSON string.
41+
var arguments: String
42+
}
43+
44+
enum FunctionCallStrategy: Encodable, Equatable {
45+
/// Forbid the bot to call any function.
46+
case none
47+
/// Let the bot choose what function to call.
48+
case auto
49+
/// Force the bot to call a function with the given name.
50+
case name(String)
51+
52+
struct CallFunctionNamed: Codable {
53+
var name: String
54+
}
55+
56+
func encode(to encoder: Encoder) throws {
57+
var container = encoder.singleValueContainer()
58+
switch self {
59+
case .none:
60+
try container.encode("none")
61+
case .auto:
62+
try container.encode("auto")
63+
case let .name(name):
64+
try container.encode(CallFunctionNamed(name: name))
65+
}
66+
}
67+
}
68+
69+
struct Function: Codable {
70+
var name: String
71+
var description: String
72+
/// JSON schema.
73+
var arguments: String
1974
}
2075

2176
var model: String
@@ -30,6 +85,8 @@ struct CompletionRequestBody: Codable, Equatable {
3085
var frequency_penalty: Double?
3186
var logit_bias: [String: Double]?
3287
var user: String?
88+
var function_call: FunctionCallStrategy?
89+
var functions: [Int] = []
3390
}
3491

3592
struct CompletionStreamDataTrunk: Codable {
@@ -47,6 +104,7 @@ struct CompletionStreamDataTrunk: Codable {
47104
struct Delta: Codable {
48105
var role: ChatMessage.Role?
49106
var content: String?
107+
var function_call: String?
50108
}
51109
}
52110
}

Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemory.swift

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
3030
public func mutateSystemPrompt(_ newPrompt: String) {
3131
systemPrompt.content = newPrompt
3232
}
33-
33+
3434
public nonisolated
3535
func observeHistoryChange(_ onChange: @escaping () -> Void) {
3636
Task {
@@ -42,29 +42,31 @@ public actor AutoManagedChatGPTMemory: ChatGPTMemory {
4242
maxNumberOfMessages: Int = UserDefaults.shared.value(for: \.chatGPTMaxMessageCount),
4343
encoder: TokenEncoder = AutoManagedChatGPTMemory.encoder
4444
) -> [ChatMessage] {
45+
func countToken(_ message: inout ChatMessage) -> Int {
46+
if let count = systemPrompt.tokensCount { return count }
47+
let count = encoder.countToken(message: systemPrompt)
48+
message.tokensCount = count
49+
return count
50+
}
51+
4552
var all: [ChatMessage] = []
46-
let systemMessageTokenCount = systemPrompt.tokensCount
47-
?? encoder.encode(text: systemPrompt.content).count
48-
systemPrompt.tokensCount = systemMessageTokenCount
49-
50-
var allTokensCount = systemMessageTokenCount
53+
let systemMessageTokenCount = countToken(&systemPrompt)
54+
var allTokensCount = systemPrompt.isEmpty ? 0 : systemMessageTokenCount
55+
5156
for (index, message) in history.enumerated().reversed() {
52-
var message = message
5357
if maxNumberOfMessages > 0, all.count >= maxNumberOfMessages { break }
54-
if message.content.isEmpty { continue }
55-
let tokensCount = message.tokensCount ?? encoder.encode(text: message.content).count
56-
history[index].tokensCount = tokensCount
58+
if message.isEmpty { continue }
59+
let tokensCount = countToken(&history[index])
5760
if tokensCount + allTokensCount >
5861
configuration.maxTokens - configuration.minimumReplyTokens
5962
{
6063
break
6164
}
62-
message.tokensCount = tokensCount
6365
allTokensCount += tokensCount
6466
all.append(message)
6567
}
6668

67-
if !systemPrompt.content.isEmpty {
69+
if !systemPrompt.isEmpty {
6870
all.append(systemPrompt)
6971
}
7072
return all.reversed()

Tool/Sources/OpenAIService/Memory/ChatGPTMemory.swift

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,20 @@ public extension ChatGPTMemory {
3535
}
3636

3737
/// Stream a message to the history.
38-
func streamMessage(id: String, role: ChatMessage.Role?, content: String?) async {
38+
func streamMessage(
39+
id: String,
40+
role: ChatMessage.Role?,
41+
content: String?,
42+
functionCall: ChatMessage.FunctionCall?
43+
) async {
3944
await mutateHistory { history in
4045
if let index = history.firstIndex(where: { $0.id == id }) {
4146
if let content {
42-
history[index].content.append(content)
47+
if history[index].content == nil {
48+
history[index].content = content
49+
} else {
50+
history[index].content?.append(content)
51+
}
4352
}
4453
if let role {
4554
history[index].role = role
@@ -48,7 +57,9 @@ public extension ChatGPTMemory {
4857
history.append(.init(
4958
id: id,
5059
role: role ?? .system,
51-
content: content ?? ""
60+
content: content,
61+
name: nil,
62+
functionCall: functionCall
5263
))
5364
}
5465
}

0 commit comments

Comments
 (0)