Skip to content

Commit 910a84b

Browse files
committed
Fix retrieved content number in rag when context window is not big enough
1 parent 0a2d371 commit 910a84b

3 files changed

Lines changed: 182 additions & 6 deletions

File tree

Tool/Sources/OpenAIService/Memory/AutoManagedChatGPTMemory.swift

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,6 @@ extension AutoManagedChatGPTMemory {
286286
text += """
287287
Here are the information you know about the system and the project, \
288288
separated by \(separator)
289-
290-
291289
"""
292290
}
293291

@@ -302,7 +300,7 @@ extension AutoManagedChatGPTMemory {
302300
{
303301
var right = retrievedContent.count
304302
var left = 0
305-
var retrievedContent = retrievedContent
303+
var gappedRetrievedContent = retrievedContent
306304
var tokenCount: Int?
307305
var proposedMessage = buildMessage(retrievedContent: [])
308306

@@ -340,7 +338,7 @@ extension AutoManagedChatGPTMemory {
340338
let (isValid, _tokenCount) = await checkValid(proposedMessage: _proposedMessage)
341339
if isValid {
342340
proposedMessage = _proposedMessage
343-
retrievedContent = _retrievedContent
341+
gappedRetrievedContent = _retrievedContent
344342
tokenCount = _tokenCount
345343
left = count + 1
346344
} else {
@@ -355,7 +353,7 @@ extension AutoManagedChatGPTMemory {
355353
} else {
356354
await strategy.countToken(proposedMessage)
357355
}
358-
return (proposedMessage, retrievedContent, finalCount)
356+
return (proposedMessage, gappedRetrievedContent, finalCount)
359357
}
360358

361359
let (message, references, tokensCount) = await buildMessageThatFits()
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import Foundation
2+
import XCTest
3+
4+
@testable import OpenAIService
5+
6+
class AutoManagedChatGPTMemoryRetrievedContentTests: XCTestCase {
7+
let separator = String(repeating: "=", count: 32)
8+
9+
func ref(_ text: String) -> ChatMessage.Reference {
10+
.init(
11+
title: "",
12+
subTitle: "",
13+
content: text,
14+
uri: "",
15+
startLine: nil,
16+
endLine: nil,
17+
kind: .text
18+
)
19+
}
20+
21+
func test_retrieved_content_when_the_context_window_is_large_enough() async {
22+
let strategy = Strategy()
23+
24+
let memory = AutoManagedChatGPTMemory(
25+
systemPrompt: "",
26+
configuration: UserPreferenceChatGPTConfiguration(),
27+
functionProvider: EmptyFunctionProvider()
28+
)
29+
30+
await memory.mutateRetrievedContent([
31+
ref("A"), ref("B"), ref("C"), ref("D"), ref("E"),
32+
])
33+
34+
let fullContent = """
35+
Here are the information you know about the system and the project, \
36+
separated by \(separator)
37+
38+
\(separator)[DOCUMENT 0]
39+
40+
A
41+
42+
\(separator)[DOCUMENT 1]
43+
44+
B
45+
46+
\(separator)[DOCUMENT 2]
47+
48+
C
49+
50+
\(separator)[DOCUMENT 3]
51+
52+
D
53+
54+
\(separator)[DOCUMENT 4]
55+
56+
E
57+
"""
58+
59+
let maxTokenCount = await strategy.countToken(.init(role: .user, content: fullContent))
60+
61+
let result = await memory.generateRetrievedContentMessage(
62+
maxTokenCount: maxTokenCount,
63+
strategy: strategy
64+
)
65+
66+
XCTAssertEqual(result.references.count, 5)
67+
XCTAssertEqual(result.retrievedContent.role, .user)
68+
XCTAssertEqual(result.retrievedContent.content, """
69+
Here are the information you know about the system and the project, \
70+
separated by \(separator)
71+
72+
\(separator)[DOCUMENT 0]
73+
74+
A
75+
76+
\(separator)[DOCUMENT 1]
77+
78+
B
79+
80+
\(separator)[DOCUMENT 2]
81+
82+
C
83+
84+
\(separator)[DOCUMENT 3]
85+
86+
D
87+
88+
\(separator)[DOCUMENT 4]
89+
90+
E
91+
""")
92+
}
93+
94+
func test_retrieved_content_when_the_context_window_is_just_not_large_enough() async {
95+
let strategy = Strategy()
96+
97+
let memory = AutoManagedChatGPTMemory(
98+
systemPrompt: "",
99+
configuration: UserPreferenceChatGPTConfiguration(),
100+
functionProvider: EmptyFunctionProvider()
101+
)
102+
103+
await memory.mutateRetrievedContent([
104+
ref("A"), ref("B"), ref("C"), ref("D"), ref("E"),
105+
])
106+
107+
let fullContent = """
108+
Here are the information you know about the system and the project, \
109+
separated by \(separator)
110+
111+
\(separator)[DOCUMENT 0]
112+
113+
A
114+
115+
\(separator)[DOCUMENT 1]
116+
117+
B
118+
119+
\(separator)[DOCUMENT 2]
120+
121+
C
122+
123+
\(separator)[DOCUMENT 3]
124+
125+
D
126+
127+
\(separator)[DOCUMENT 4]
128+
129+
E
130+
"""
131+
132+
let maxTokenCount = await strategy.countToken(.init(role: .user, content: fullContent))
133+
134+
let result = await memory.generateRetrievedContentMessage(
135+
maxTokenCount: maxTokenCount - 1,
136+
strategy: strategy
137+
)
138+
139+
XCTAssertEqual(result.references.count, 4)
140+
XCTAssertEqual(result.retrievedContent.role, .user)
141+
XCTAssertEqual(result.retrievedContent.content, """
142+
Here are the information you know about the system and the project, \
143+
separated by \(separator)
144+
145+
\(separator)[DOCUMENT 0]
146+
147+
A
148+
149+
\(separator)[DOCUMENT 1]
150+
151+
B
152+
153+
\(separator)[DOCUMENT 2]
154+
155+
C
156+
157+
\(separator)[DOCUMENT 3]
158+
159+
D
160+
""")
161+
}
162+
}
163+
164+
private struct EmptyFunctionProvider: ChatGPTFunctionProvider {
165+
var functions: [any ChatGPTFunction] { [] }
166+
var functionCallStrategy: FunctionCallStrategy? { nil }
167+
}
168+
169+
private struct Strategy: AutoManagedChatGPTMemoryStrategy {
170+
func countToken(_ message: OpenAIService.ChatMessage) async -> Int {
171+
message.content?.count ?? 0
172+
}
173+
174+
func countToken<F>(_: F) async -> Int where F: ChatGPTFunction {
175+
0
176+
}
177+
}
178+

Tool/Tests/OpenAIServiceTests/LimitMessagesTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import XCTest
44

55
@testable import OpenAIService
66

7-
final class AutoManagedChatGPTMemoryTests: XCTestCase {
7+
final class AutoManagedChatGPTMemoryLimitTests: XCTestCase {
88
func test_send_all_messages_if_not_reached_token_limit() async {
99
let (messages, memory) = await runService(
1010
systemPrompt: "system",

0 commit comments

Comments
 (0)