Skip to content

Commit 19b33b5

Browse files
committed
Support auto token count management
1 parent 9a955c0 commit 19b33b5

2 files changed

Lines changed: 157 additions & 11 deletions

File tree

Core/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import AsyncAlgorithms
22
import Foundation
3+
import GPTEncoder
34
import Preferences
45

56
public protocol ChatGPTServiceType: ObservableObject {
@@ -112,12 +113,14 @@ public actor ChatGPTService: ChatGPTServiceType {
112113
)
113114
history.append(newMessage)
114115

116+
let (messages, remainingTokens) = combineHistoryWithSystemPrompt()
117+
115118
let requestBody = CompletionRequestBody(
116119
model: model,
117-
messages: combineHistoryWithSystemPrompt(),
120+
messages: messages,
118121
temperature: temperature,
119122
stream: true,
120-
max_tokens: maxToken
123+
max_tokens: remainingTokens
121124
)
122125

123126
isReceivingMessage = true
@@ -190,12 +193,14 @@ public actor ChatGPTService: ChatGPTServiceType {
190193
)
191194
history.append(newMessage)
192195

196+
let (messages, remainingTokens) = combineHistoryWithSystemPrompt()
197+
193198
let requestBody = CompletionRequestBody(
194199
model: model,
195-
messages: combineHistoryWithSystemPrompt(),
200+
messages: messages,
196201
temperature: temperature,
197202
stream: true,
198-
max_tokens: maxToken
203+
max_tokens: remainingTokens
199204
)
200205

201206
isReceivingMessage = true
@@ -210,10 +215,10 @@ public actor ChatGPTService: ChatGPTServiceType {
210215
role: choice.message.role,
211216
content: choice.message.content
212217
))
213-
218+
214219
return choice.message.content
215220
}
216-
221+
217222
return nil
218223
}
219224

@@ -250,17 +255,34 @@ extension ChatGPTService {
250255
uuidGenerator = generator
251256
}
252257

253-
func combineHistoryWithSystemPrompt() -> [CompletionRequestBody.Message] {
258+
func combineHistoryWithSystemPrompt(
259+
minimumReplyTokens: Int = 200,
260+
maxNumberOfMessages: Int = 5,
261+
maxTokens: Int = UserDefaults.shared.value(for: \.chatGPTMaxToken),
262+
encoder: TokenEncoder = GPTEncoder()
263+
)
264+
-> (messages: [CompletionRequestBody.Message], remainingTokens: Int)
265+
{
254266
var all: [CompletionRequestBody.Message] = []
255-
var count = 0
267+
var allTokensCount = encoder.encode(text: systemPrompt).count
256268
for message in history.reversed() {
257-
if count >= 5 { break }
269+
if all.count >= maxNumberOfMessages { break }
258270
if message.content.isEmpty { continue }
271+
let tokensCount = encoder.encode(text: message.content).count
272+
if tokensCount + allTokensCount > maxTokens - minimumReplyTokens {
273+
break
274+
}
275+
allTokensCount += tokensCount
259276
all.append(.init(role: message.role, content: message.content))
260-
count += 1
261277
}
262278

263279
all.append(.init(role: .system, content: systemPrompt))
264-
return all.reversed()
280+
return (all.reversed(), max(minimumReplyTokens, maxTokens - allTokensCount))
265281
}
266282
}
283+
284+
protocol TokenEncoder {
285+
func encode(text: String) -> [Int]
286+
}
287+
288+
extension GPTEncoder: TokenEncoder {}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import Foundation
2+
import XCTest
3+
4+
@testable import OpenAIService
5+
6+
final class LimitMessagesTests: XCTestCase {
7+
func test_send_all_messages_if_not_reached_token_limit() async {
8+
let service = await createService(systemPrompt: "system", messages: [
9+
"hi",
10+
"hello",
11+
"world",
12+
])
13+
14+
let (messages, remainingTokens) = await runService(
15+
service,
16+
minimumReplyTokens: 200,
17+
maxNumberOfMessages: 100,
18+
maxTokens: 10000
19+
)
20+
XCTAssertEqual(messages, [
21+
"system",
22+
"hi",
23+
"hello",
24+
"world",
25+
])
26+
27+
XCTAssertEqual(remainingTokens, 10000 - 12 - 6)
28+
}
29+
30+
func test_send_max_message_if_not_reached_token_limit() async {
31+
let service = await createService(systemPrompt: "system", messages: [
32+
"hi",
33+
"hello",
34+
"world",
35+
])
36+
37+
let (messages, remainingTokens) = await runService(
38+
service,
39+
minimumReplyTokens: 200,
40+
maxNumberOfMessages: 2,
41+
maxTokens: 10000
42+
)
43+
XCTAssertEqual(messages, [
44+
"system",
45+
"hello",
46+
"world",
47+
], "Count from end to start.")
48+
49+
XCTAssertEqual(remainingTokens, 10000 - 10 - 6)
50+
}
51+
52+
func test_reached_token_limit() async {
53+
let service = await createService(systemPrompt: "system", messages: [
54+
"hi",
55+
"hello",
56+
"world",
57+
])
58+
59+
let (messages, remainingTokens) = await runService(
60+
service,
61+
minimumReplyTokens: 200,
62+
maxNumberOfMessages: 100,
63+
maxTokens: 212
64+
)
65+
XCTAssertEqual(messages, [
66+
"system",
67+
"world",
68+
])
69+
70+
XCTAssertEqual(remainingTokens, 201)
71+
}
72+
73+
func test_minimum_reply_tokens_count() async {
74+
let service = await createService(systemPrompt: "system", messages: [
75+
"hi",
76+
"hello",
77+
"world",
78+
])
79+
80+
let (messages, remainingTokens) = await runService(
81+
service,
82+
minimumReplyTokens: 200,
83+
maxNumberOfMessages: 100,
84+
maxTokens: 200
85+
)
86+
XCTAssertEqual(messages, [
87+
"system",
88+
])
89+
90+
XCTAssertEqual(remainingTokens, 200)
91+
}
92+
}
93+
94+
class MockEncoder: TokenEncoder {
95+
func encode(text: String) -> [Int] {
96+
return .init(repeating: 0, count: text.count)
97+
}
98+
}
99+
100+
private func createService(systemPrompt: String, messages: [String]) async -> ChatGPTService {
101+
let service = ChatGPTService(systemPrompt: systemPrompt)
102+
await service.mutateHistory { history in
103+
messages.forEach { message in
104+
history.append(.init(role: .user, content: message))
105+
}
106+
}
107+
return service
108+
}
109+
110+
private func runService(
111+
_ service: ChatGPTService,
112+
minimumReplyTokens: Int,
113+
maxNumberOfMessages: Int,
114+
maxTokens: Int
115+
) async -> (messages: [String], remainingTokens: Int) {
116+
let (messages, remainingTokens) = await service.combineHistoryWithSystemPrompt(
117+
minimumReplyTokens: minimumReplyTokens,
118+
maxNumberOfMessages: maxNumberOfMessages,
119+
maxTokens: maxTokens,
120+
encoder: MockEncoder()
121+
)
122+
123+
return (messages.map(\.content), remainingTokens)
124+
}

0 commit comments

Comments
 (0)