-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy pathtools.test.ts
More file actions
126 lines (110 loc) · 5.35 KB
/
tools.test.ts
File metadata and controls
126 lines (110 loc) · 5.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
*--------------------------------------------------------------------------------------------*/
import { writeFile } from "fs/promises";
import { join } from "path";
import { assert, describe, expect, it } from "vitest";
import { z } from "zod";
import { defineTool } from "../../src/index.js";
import { createSdkTestContext } from "./harness/sdkTestContext";
import { getFinalAssistantMessage } from "./harness/sdkTestHelper";
describe("Custom tools", async () => {
const { copilotClient: client, openAiEndpoint, workDir } = await createSdkTestContext();
it("invokes built-in tools", async () => {
await writeFile(join(workDir, "README.md"), "# ELIZA, the only chatbot you'll ever need");
const session = await client.createSession();
await session.send({ prompt: "What's the first line of README.md in this directory?" });
const assistantMessage = await getFinalAssistantMessage(session);
expect(assistantMessage?.data.content).toContain("ELIZA");
});
it("invokes custom tool", async () => {
const session = await client.createSession({
tools: [
defineTool("encrypt_string", {
description: "Encrypts a string",
parameters: z.object({
input: z.string().describe("String to encrypt"),
}),
handler: ({ input }) => input.toUpperCase(),
}),
],
});
await session.send({ prompt: "Use encrypt_string to encrypt this string: Hello" });
const assistantMessage = await getFinalAssistantMessage(session);
expect(assistantMessage?.data.content).toContain("HELLO");
});
it("handles tool calling errors", async () => {
const session = await client.createSession({
tools: [
defineTool("get_user_location", {
description: "Gets the user's location",
handler: () => {
throw new Error("Melbourne");
},
}),
],
});
await session.send({
prompt: "What is my location? If you can't find out, just say 'unknown'.",
});
const answer = await getFinalAssistantMessage(session);
// Check the underlying traffic
const traffic = await openAiEndpoint.getExchanges();
const lastConversation = traffic[traffic.length - 1];
const toolCalls = lastConversation.request.messages.flatMap((m) =>
m.role === "assistant" ? m.tool_calls : []
);
expect(toolCalls.length).toBe(1);
const toolCall = toolCalls[0]!;
assert(toolCall.type === "function");
expect(toolCall.function.name).toBe("get_user_location");
const toolResults = lastConversation.request.messages.filter((m) => m.role === "tool");
expect(toolResults.length).toBe(1);
const toolResult = toolResults[0]!;
expect(toolResult.tool_call_id).toBe(toolCall.id);
expect(toolResult.content).not.toContain("Melbourne");
// Importantly, we're checking that the assistant does not see the
// exception information as if it was the tool's output.
expect(answer?.data.content).not.toContain("Melbourne");
expect(answer?.data.content?.toLowerCase()).toContain("unknown");
});
it("can receive and return complex types", async () => {
const session = await client.createSession({
tools: [
defineTool("db_query", {
description: "Performs a database query",
parameters: z.object({
query: z.object({
table: z.string(),
ids: z.array(z.number()),
sortAscending: z.boolean(),
}),
}),
handler: ({ query }, invocation) => {
expect(query.table).toBe("cities");
expect(query.ids).toEqual([12, 19]);
expect(query.sortAscending).toBe(true);
expect(invocation.sessionId).toBe(session.sessionId);
return [
{ countryId: 19, cityName: "Passos", population: 135460 },
{ countryId: 12, cityName: "San Lorenzo", population: 204356 },
];
},
}),
],
});
await session.send({
prompt:
"Perform a DB query for the 'cities' table using IDs 12 and 19, sorting ascending. " +
"Reply only with lines of the form: [cityname] [population]",
});
const assistantMessage = await getFinalAssistantMessage(session);
const responseContent = assistantMessage?.data.content!;
expect(assistantMessage).not.toBeNull();
expect(responseContent).not.toBe("");
expect(responseContent).toContain("Passos");
expect(responseContent).toContain("San Lorenzo");
expect(responseContent.replace(/,/g, "")).toContain("135460");
expect(responseContent.replace(/,/g, "")).toContain("204356");
});
});