Skip to content

Commit a377916

Browse files
committed
Fix unit tests
1 parent 5d6bd73 commit a377916

2 files changed

Lines changed: 223 additions & 105 deletions

File tree

Tool/Tests/OpenAIServiceTests/ChatGPTStreamTests.swift

Lines changed: 181 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import XCTest
55
final class ChatGPTStreamTests: XCTestCase {
66
func test_sending_message() async throws {
77
let memory = ConversationChatGPTMemory(systemPrompt: "system", systemMessageId: "s")
8-
let configuration = UserPreferenceChatGPTConfiguration().overriding()
8+
let configuration = UserPreferenceChatGPTConfiguration().overriding {
9+
$0.model = .init(id: "id", name: "name", format: .openAI, info: .init())
10+
}
911
let functionProvider = NoChatGPTFunctionProvider()
1012
let service = ChatGPTService(
1113
memory: memory,
@@ -64,7 +66,9 @@ final class ChatGPTStreamTests: XCTestCase {
6466

6567
func test_handling_function_call() async throws {
6668
let memory = ConversationChatGPTMemory(systemPrompt: "system", systemMessageId: "s")
67-
let configuration = UserPreferenceChatGPTConfiguration().overriding()
69+
let configuration = UserPreferenceChatGPTConfiguration().overriding {
70+
$0.model = .init(id: "id", name: "name", format: .openAI, info: .init())
71+
}
6872
let functionProvider = FunctionProvider()
6973
let service = ChatGPTService(
7074
memory: memory,
@@ -95,7 +99,7 @@ final class ChatGPTStreamTests: XCTestCase {
9599
"History is not updated"
96100
)
97101
}
98-
102+
99103
XCTAssertEqual(requestBody?.messages, [
100104
.init(role: .system, content: "system"),
101105
.init(role: .user, content: "Hello"),
@@ -105,9 +109,9 @@ final class ChatGPTStreamTests: XCTestCase {
105109
),
106110
.init(role: .function, content: "Function is called.", name: "function"),
107111
], "System prompt is not included")
108-
112+
109113
XCTAssertEqual(all, ["hello", "my", "friends"], "Text stream is not correct")
110-
114+
111115
var history = await memory.history
112116
for (i, _) in history.enumerated() {
113117
history[i].tokensCount = nil
@@ -128,9 +132,13 @@ final class ChatGPTStreamTests: XCTestCase {
128132
name: "function",
129133
summary: nil
130134
),
131-
.init(id: "00000000-0000-0000-0000-0000000000040.0", role: .assistant, content: "hellomyfriends"),
135+
.init(
136+
id: "00000000-0000-0000-0000-0000000000040.0",
137+
role: .assistant,
138+
content: "hellomyfriends"
139+
),
132140
], "History is not updated")
133-
141+
134142
XCTAssertEqual(requestBody?.functions, [
135143
EmptyFunction(),
136144
].map {
@@ -141,7 +149,9 @@ final class ChatGPTStreamTests: XCTestCase {
141149

142150
func test_handling_multiple_function_call() async throws {
143151
let memory = ConversationChatGPTMemory(systemPrompt: "system", systemMessageId: "s")
144-
let configuration = UserPreferenceChatGPTConfiguration().overriding()
152+
let configuration = UserPreferenceChatGPTConfiguration().overriding {
153+
$0.model = .init(id: "id", name: "name", format: .openAI, info: .init())
154+
}
145155
let functionProvider = FunctionProvider()
146156
let service = ChatGPTService(
147157
memory: memory,
@@ -173,7 +183,7 @@ final class ChatGPTStreamTests: XCTestCase {
173183
"History is not updated"
174184
)
175185
}
176-
186+
177187
XCTAssertEqual(requestBody?.messages, [
178188
.init(role: .system, content: "system"),
179189
.init(role: .user, content: "Hello"),
@@ -188,9 +198,9 @@ final class ChatGPTStreamTests: XCTestCase {
188198
),
189199
.init(role: .function, content: "Function is called.", name: "function"),
190200
], "System prompt is not included")
191-
201+
192202
XCTAssertEqual(all, ["hello", "my", "friends"], "Text stream is not correct")
193-
203+
194204
var history = await memory.history
195205
for (i, _) in history.enumerated() {
196206
history[i].tokensCount = nil
@@ -224,106 +234,186 @@ final class ChatGPTStreamTests: XCTestCase {
224234
name: "function",
225235
summary: nil
226236
),
227-
.init(id: "00000000-0000-0000-0000-0000000000070.0", role: .assistant, content: "hellomyfriends"),
237+
.init(
238+
id: "00000000-0000-0000-0000-0000000000070.0",
239+
role: .assistant,
240+
content: "hellomyfriends"
241+
),
228242
], "History is not updated")
229-
243+
230244
XCTAssertEqual(requestBody?.functions, [
231245
EmptyFunction(),
232246
].map {
233247
.init(name: $0.name, description: $0.description, parameters: $0.argumentSchema)
234248
}, "Function schema is not submitted")
235249
}
236250
}
251+
252+
func test_function_calling_unsupported() async throws {
253+
let memory = ConversationChatGPTMemory(systemPrompt: "system", systemMessageId: "s")
254+
let configuration = UserPreferenceChatGPTConfiguration().overriding {
255+
$0.model = .init(
256+
id: "id",
257+
name: "name",
258+
format: .openAI,
259+
info: .init(supportsFunctionCalling: false)
260+
)
261+
}
262+
let functionProvider = FunctionProvider()
263+
let service = ChatGPTService(
264+
memory: memory,
265+
configuration: configuration,
266+
functionProvider: functionProvider
267+
)
268+
var requestBody: CompletionRequestBody?
269+
service.changeBuildCompletionStreamAPI { _, _, _, _requestBody in
270+
requestBody = _requestBody
271+
if _requestBody.messages.count <= 2 {
272+
return MockCompletionStreamAPI_Function()
273+
}
274+
return MockCompletionStreamAPI_Message()
275+
}
276+
277+
try await withDependencies { values in
278+
values.uuid = .incrementing
279+
values.date = .constant(.init(timeIntervalSince1970: 0))
280+
} operation: {
281+
let stream = try await service.send(content: "Hello")
282+
var all = [String]()
283+
for try await text in stream {
284+
all.append(text)
285+
let history = await memory.history
286+
XCTAssertEqual(history.last?.id, "00000000-0000-0000-0000-0000000000040.0")
287+
XCTAssertTrue(
288+
history.last?.content?.hasPrefix(all.joined()) ?? false,
289+
"History is not updated"
290+
)
291+
}
292+
293+
XCTAssertEqual(requestBody?.messages, [
294+
.init(role: .system, content: "system"),
295+
.init(role: .user, content: "Hello"),
296+
.init(
297+
role: .assistant, content: "",
298+
function_call: .init(name: "function", arguments: "{\n\"foo\": 1\n}")
299+
),
300+
.init(role: .function, content: "Function is called.", name: "function"),
301+
], "System prompt is not included")
302+
303+
XCTAssertEqual(all, ["hello", "my", "friends"], "Text stream is not correct")
304+
305+
var history = await memory.history
306+
for (i, _) in history.enumerated() {
307+
history[i].tokensCount = nil
308+
}
309+
XCTAssertEqual(history, [
310+
.init(id: "s", role: .system, content: "system"),
311+
.init(id: "00000000-0000-0000-0000-000000000000", role: .user, content: "Hello"),
312+
.init(
313+
id: "00000000-0000-0000-0000-0000000000010.0",
314+
role: .assistant,
315+
content: nil,
316+
functionCall: .init(name: "function", arguments: "{\n\"foo\": 1\n}")
317+
),
318+
.init(
319+
id: "00000000-0000-0000-0000-000000000003",
320+
role: .function,
321+
content: "Function is called.",
322+
name: "function",
323+
summary: nil
324+
),
325+
.init(
326+
id: "00000000-0000-0000-0000-0000000000040.0",
327+
role: .assistant,
328+
content: "hellomyfriends"
329+
),
330+
], "History is not updated")
331+
332+
XCTAssertEqual(requestBody?.functions, nil, "Functions should be nil")
333+
}
334+
}
237335
}
238336

239337
extension ChatGPTStreamTests {
240338
struct MockCompletionStreamAPI_Message: CompletionStreamAPI {
241339
@Dependency(\.uuid) var uuid
242-
func callAsFunction() async throws -> (
243-
chunkStream: AsyncThrowingStream<CompletionStreamDataChunk, Error>,
244-
cancel: OpenAIService.Cancellable
245-
) {
340+
func callAsFunction() async throws
341+
-> AsyncThrowingStream<OpenAIService.CompletionStreamDataChunk, Error>
342+
{
246343
let id = uuid().uuidString
247-
return (
248-
AsyncThrowingStream<CompletionStreamDataChunk, Error> { continuation in
249-
let chunks: [CompletionStreamDataChunk] = [
250-
.init(id: id, object: "", model: "", choices: [
251-
.init(delta: .init(role: .assistant), index: 0, finish_reason: ""),
252-
]),
253-
.init(id: id, object: "", model: "", choices: [
254-
.init(delta: .init(content: "hello"), index: 0, finish_reason: ""),
255-
]),
256-
.init(id: id, object: "", model: "", choices: [
257-
.init(delta: .init(content: "my"), index: 0, finish_reason: ""),
258-
]),
259-
.init(id: id, object: "", model: "", choices: [
260-
.init(delta: .init(content: "friends"), index: 0, finish_reason: ""),
261-
]),
262-
]
263-
for chunk in chunks {
264-
continuation.yield(chunk)
265-
}
266-
continuation.finish()
267-
},
268-
Cancellable(cancel: {})
269-
)
344+
return AsyncThrowingStream<CompletionStreamDataChunk, Error> { continuation in
345+
let chunks: [CompletionStreamDataChunk] = [
346+
.init(id: id, object: "", model: "", choices: [
347+
.init(delta: .init(role: .assistant), index: 0, finish_reason: ""),
348+
]),
349+
.init(id: id, object: "", model: "", choices: [
350+
.init(delta: .init(content: "hello"), index: 0, finish_reason: ""),
351+
]),
352+
.init(id: id, object: "", model: "", choices: [
353+
.init(delta: .init(content: "my"), index: 0, finish_reason: ""),
354+
]),
355+
.init(id: id, object: "", model: "", choices: [
356+
.init(delta: .init(content: "friends"), index: 0, finish_reason: ""),
357+
]),
358+
]
359+
for chunk in chunks {
360+
continuation.yield(chunk)
361+
}
362+
continuation.finish()
363+
}
270364
}
271365
}
272366

273367
struct MockCompletionStreamAPI_Function: CompletionStreamAPI {
274368
@Dependency(\.uuid) var uuid
275-
func callAsFunction() async throws -> (
276-
chunkStream: AsyncThrowingStream<CompletionStreamDataChunk, Error>,
277-
cancel: OpenAIService.Cancellable
278-
) {
369+
func callAsFunction() async throws
370+
-> AsyncThrowingStream<OpenAIService.CompletionStreamDataChunk, Error>
371+
{
279372
let id = uuid().uuidString
280-
return (
281-
AsyncThrowingStream<CompletionStreamDataChunk, Error> { continuation in
282-
let chunks: [CompletionStreamDataChunk] = [
283-
.init(id: id, object: "", model: "", choices: [
284-
.init(
285-
delta: .init(
286-
role: .assistant,
287-
function_call: .init(name: "function", arguments: "")
288-
),
289-
index: 0,
290-
finish_reason: ""
291-
)]),
292-
.init(id: id, object: "", model: "", choices: [
293-
.init(
294-
delta: .init(
295-
role: .assistant,
296-
function_call: .init(arguments: "{\n")
297-
),
298-
index: 0,
299-
finish_reason: ""
300-
)]),
301-
.init(id: id, object: "", model: "", choices: [
302-
.init(
303-
delta: .init(
304-
role: .assistant,
305-
function_call: .init(arguments: "\"foo\": 1")
306-
),
307-
index: 0,
308-
finish_reason: ""
309-
)]),
310-
.init(id: id, object: "", model: "", choices: [
311-
.init(
312-
delta: .init(
313-
role: .assistant,
314-
function_call: .init(arguments: "\n}")
315-
),
316-
index: 0,
317-
finish_reason: ""
318-
)]),
319-
]
320-
for chunk in chunks {
321-
continuation.yield(chunk)
322-
}
323-
continuation.finish()
324-
},
325-
Cancellable(cancel: {})
326-
)
373+
return AsyncThrowingStream<CompletionStreamDataChunk, Error> { continuation in
374+
let chunks: [CompletionStreamDataChunk] = [
375+
.init(id: id, object: "", model: "", choices: [
376+
.init(
377+
delta: .init(
378+
role: .assistant,
379+
function_call: .init(name: "function", arguments: "")
380+
),
381+
index: 0,
382+
finish_reason: ""
383+
)]),
384+
.init(id: id, object: "", model: "", choices: [
385+
.init(
386+
delta: .init(
387+
role: .assistant,
388+
function_call: .init(arguments: "{\n")
389+
),
390+
index: 0,
391+
finish_reason: ""
392+
)]),
393+
.init(id: id, object: "", model: "", choices: [
394+
.init(
395+
delta: .init(
396+
role: .assistant,
397+
function_call: .init(arguments: "\"foo\": 1")
398+
),
399+
index: 0,
400+
finish_reason: ""
401+
)]),
402+
.init(id: id, object: "", model: "", choices: [
403+
.init(
404+
delta: .init(
405+
role: .assistant,
406+
function_call: .init(arguments: "\n}")
407+
),
408+
index: 0,
409+
finish_reason: ""
410+
)]),
411+
]
412+
for chunk in chunks {
413+
continuation.yield(chunk)
414+
}
415+
continuation.finish()
416+
}
327417
}
328418
}
329419

0 commit comments

Comments
 (0)