Skip to content

Commit 10713c8

Browse files
committed
Cache token count
1 parent 4338145 commit 10713c8

4 files changed

Lines changed: 26 additions & 11 deletions

File tree

Core/Sources/OpenAIService/ChatGPTService.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,11 @@ extension ChatGPTService {
270270
{
271271
var all: [CompletionRequestBody.Message] = []
272272
var allTokensCount = encoder.encode(text: systemPrompt).count
273-
for message in history.reversed() {
273+
for (index, message) in history.enumerated().reversed() {
274274
if maxNumberOfMessages > 0, all.count >= maxNumberOfMessages { break }
275275
if message.content.isEmpty { continue }
276-
let tokensCount = encoder.encode(text: message.content).count
276+
let tokensCount = message.tokensCount ?? encoder.encode(text: message.content).count
277+
history[index].tokensCount = tokensCount
277278
if tokensCount + allTokensCount > maxTokens - minimumReplyTokens {
278279
break
279280
}

Core/Sources/OpenAIService/Models.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@ public struct ChatMessage: Equatable, Codable {
1515
}
1616

1717
public var role: Role
18-
public var content: String
18+
public var content: String {
19+
didSet {
20+
tokensCount = nil
21+
}
22+
}
1923
public var summary: String?
2024
public var id: String
25+
public var tokensCount: Int?
2126

2227
public init(
2328
id: String = UUID().uuidString,

Core/Tests/OpenAIServiceTests/ChatGPTServiceTests.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ final class ChatGPTServiceTests: XCTestCase {
6262
.init(role: .user, content: "Hello"),
6363
], "System prompt is included")
6464
XCTAssertEqual(all, ["hello", "my", "friends"], "Text stream is correct")
65-
let history = await service.history
65+
var history = await service.history
66+
for (i, _) in history.enumerated() {
67+
history[i].tokensCount = nil
68+
}
6669
XCTAssertEqual(history, [
6770
.init(id: "0", role: .user, content: "Hello"),
6871
.init(id: "1", role: .assistant, content: "hellomyfriends"),

Core/Tests/OpenAIServiceTests/LimitMessagesTests.swift

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,16 @@ final class LimitMessagesTests: XCTestCase {
2323
"hello",
2424
"world",
2525
])
26-
26+
2727
XCTAssertEqual(remainingTokens, 10000 - 12 - 6)
28+
let history = await service.history
29+
XCTAssertEqual(history.map(\.tokensCount), [
30+
2,
31+
5,
32+
5,
33+
])
2834
}
29-
35+
3036
func test_send_max_message_if_not_reached_token_limit() async {
3137
let service = await createService(systemPrompt: "system", messages: [
3238
"hi",
@@ -45,10 +51,10 @@ final class LimitMessagesTests: XCTestCase {
4551
"hello",
4652
"world",
4753
], "Count from end to start.")
48-
54+
4955
XCTAssertEqual(remainingTokens, 10000 - 10 - 6)
5056
}
51-
57+
5258
func test_reached_token_limit() async {
5359
let service = await createService(systemPrompt: "system", messages: [
5460
"hi",
@@ -66,10 +72,10 @@ final class LimitMessagesTests: XCTestCase {
6672
"system",
6773
"world",
6874
])
69-
75+
7076
XCTAssertEqual(remainingTokens, 201)
7177
}
72-
78+
7379
func test_minimum_reply_tokens_count() async {
7480
let service = await createService(systemPrompt: "system", messages: [
7581
"hi",
@@ -86,7 +92,7 @@ final class LimitMessagesTests: XCTestCase {
8692
XCTAssertEqual(messages, [
8793
"system",
8894
])
89-
95+
9096
XCTAssertEqual(remainingTokens, 200)
9197
}
9298
}

0 commit comments

Comments
 (0)