/*--------------------------------------------------------------------------------------------- * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ import { writeFile } from "fs/promises"; import { join } from "path"; import { assert, describe, expect, it } from "vitest"; import { z } from "zod"; import { defineTool, approveAll } from "../../src/index.js"; import type { PermissionRequest } from "../../src/index.js"; import { createSdkTestContext } from "./harness/sdkTestContext"; describe("Custom tools", async () => { const { copilotClient: client, openAiEndpoint, workDir } = await createSdkTestContext(); it("invokes built-in tools", async () => { await writeFile(join(workDir, "README.md"), "# ELIZA, the only chatbot you'll ever need"); const session = await client.createSession({ onPermissionRequest: approveAll, }); const assistantMessage = await session.sendAndWait({ prompt: "What's the first line of README.md in this directory?", }); expect(assistantMessage?.data.content).toContain("ELIZA"); }); it("invokes custom tool", async () => { const session = await client.createSession({ onPermissionRequest: approveAll, tools: [ defineTool("encrypt_string", { description: "Encrypts a string", parameters: z.object({ input: z.string().describe("String to encrypt"), }), handler: ({ input }) => input.toUpperCase(), }), ], }); const assistantMessage = await session.sendAndWait({ prompt: "Use encrypt_string to encrypt this string: Hello", }); expect(assistantMessage?.data.content).toContain("HELLO"); }); it("handles tool calling errors", async () => { const session = await client.createSession({ onPermissionRequest: approveAll, tools: [ defineTool("get_user_location", { description: "Gets the user's location", handler: () => { throw new Error("Melbourne"); }, }), ], }); const answer = await session.sendAndWait({ prompt: "What is my location? If you can't find out, just say 'unknown'.", }); // Check the underlying traffic const traffic = await openAiEndpoint.getExchanges(); const lastConversation = traffic[traffic.length - 1]; const toolCalls = lastConversation.request.messages.flatMap((m) => m.role === "assistant" ? m.tool_calls : [] ); expect(toolCalls.length).toBe(1); const toolCall = toolCalls[0]!; assert(toolCall.type === "function"); expect(toolCall.function.name).toBe("get_user_location"); const toolResults = lastConversation.request.messages.filter((m) => m.role === "tool"); expect(toolResults.length).toBe(1); const toolResult = toolResults[0]!; expect(toolResult.tool_call_id).toBe(toolCall.id); expect(toolResult.content).not.toContain("Melbourne"); // Importantly, we're checking that the assistant does not see the // exception information as if it was the tool's output. expect(answer?.data.content).not.toContain("Melbourne"); expect(answer?.data.content?.toLowerCase()).toContain("unknown"); }); it("can receive and return complex types", async () => { const session = await client.createSession({ onPermissionRequest: approveAll, tools: [ defineTool("db_query", { description: "Performs a database query", parameters: z.object({ query: z.object({ table: z.string(), ids: z.array(z.number()), sortAscending: z.boolean(), }), }), handler: ({ query }, invocation) => { expect(query.table).toBe("cities"); expect(query.ids).toEqual([12, 19]); expect(query.sortAscending).toBe(true); expect(invocation.sessionId).toBe(session.sessionId); return [ { countryId: 19, cityName: "Passos", population: 135460 }, { countryId: 12, cityName: "San Lorenzo", population: 204356 }, ]; }, }), ], }); const assistantMessage = await session.sendAndWait({ prompt: "Perform a DB query for the 'cities' table using IDs 12 and 19, sorting ascending. " + "Reply only with lines of the form: [cityname] [population]", }); const responseContent = assistantMessage?.data.content!; expect(assistantMessage).not.toBeNull(); expect(responseContent).not.toBe(""); expect(responseContent).toContain("Passos"); expect(responseContent).toContain("San Lorenzo"); expect(responseContent.replace(/,/g, "")).toContain("135460"); expect(responseContent.replace(/,/g, "")).toContain("204356"); }); it("invokes custom tool with permission handler", async () => { const permissionRequests: PermissionRequest[] = []; const session = await client.createSession({ tools: [ defineTool("encrypt_string", { description: "Encrypts a string", parameters: z.object({ input: z.string().describe("String to encrypt"), }), handler: ({ input }) => input.toUpperCase(), }), ], onPermissionRequest: (request) => { permissionRequests.push(request); return { kind: "approve-once" }; }, }); const assistantMessage = await session.sendAndWait({ prompt: "Use encrypt_string to encrypt this string: Hello", }); expect(assistantMessage?.data.content).toContain("HELLO"); // Should have received a custom-tool permission request const customToolRequests = permissionRequests.filter((req) => req.kind === "custom-tool"); expect(customToolRequests.length).toBeGreaterThan(0); expect(customToolRequests[0].toolName).toBe("encrypt_string"); }); it("skipPermission sent in tool definition", async () => { let didRunPermissionRequest = false; const session = await client.createSession({ onPermissionRequest: () => { didRunPermissionRequest = true; return { kind: "no-result" }; }, tools: [ defineTool("safe_lookup", { description: "A safe lookup that skips permission", parameters: z.object({ id: z.string().describe("ID to look up"), }), handler: ({ id }) => `RESULT: ${id}`, skipPermission: true, }), ], }); const assistantMessage = await session.sendAndWait({ prompt: "Use safe_lookup to look up 'test123'", }); expect(assistantMessage?.data.content).toContain("RESULT: test123"); expect(didRunPermissionRequest).toBe(false); }); it("overrides built-in tool with custom tool", async () => { const session = await client.createSession({ onPermissionRequest: approveAll, tools: [ defineTool("grep", { description: "A custom grep implementation that overrides the built-in", parameters: z.object({ query: z.string().describe("Search query"), }), handler: ({ query }) => `CUSTOM_GREP_RESULT: ${query}`, overridesBuiltInTool: true, }), ], }); const assistantMessage = await session.sendAndWait({ prompt: "Use grep to search for the word 'hello'", }); expect(assistantMessage?.data.content).toContain("CUSTOM_GREP_RESULT"); }); it("denies custom tool when permission denied", async () => { let toolHandlerCalled = false; const session = await client.createSession({ tools: [ defineTool("encrypt_string", { description: "Encrypts a string", parameters: z.object({ input: z.string().describe("String to encrypt"), }), handler: ({ input }) => { toolHandlerCalled = true; return input.toUpperCase(); }, }), ], onPermissionRequest: () => { return { kind: "reject" }; }, }); await session.sendAndWait({ prompt: "Use encrypt_string to encrypt this string: Hello", }); // The tool handler should NOT have been called since permission was denied expect(toolHandlerCalled).toBe(false); }); it("should execute multiple custom tools in parallel single turn", async () => { let lookupCityCalled = false; let lookupCountryCalled = false; const session = await client.createSession({ onPermissionRequest: approveAll, tools: [ defineTool("lookup_city", { description: "Looks up city information", parameters: z.object({ city: z.string() }), handler: ({ city }) => { lookupCityCalled = true; return `CITY_${city.toUpperCase()}`; }, }), defineTool("lookup_country", { description: "Looks up country information", parameters: z.object({ country: z.string() }), handler: ({ country }) => { lookupCountryCalled = true; return `COUNTRY_${country.toUpperCase()}`; }, }), ], }); const answer = await session.sendAndWait({ prompt: "Use lookup_city with 'Paris' and lookup_country with 'France' at the same time, then combine both results in your reply.", }); expect(lookupCityCalled).toBe(true); expect(lookupCountryCalled).toBe(true); expect(answer?.data.content).toContain("CITY_PARIS"); expect(answer?.data.content).toContain("COUNTRY_FRANCE"); await session.disconnect(); }); it("should respect availableTools and excludedTools combined", async () => { let allowedToolCalled = false; let excludedToolCalled = false; const session = await client.createSession({ onPermissionRequest: approveAll, tools: [ defineTool("allowed_tool", { description: "A tool that is allowed", parameters: z.object({ input: z.string() }), handler: ({ input }) => { allowedToolCalled = true; return `ALLOWED_${input.toUpperCase()}`; }, }), defineTool("excluded_tool", { description: "A tool that should be excluded", parameters: z.object({}), handler: () => { excludedToolCalled = true; return "EXCLUDED_RESULT"; }, }), ], availableTools: ["allowed_tool", "excluded_tool"], excludedTools: ["excluded_tool"], }); const answer = await session.sendAndWait({ prompt: "Use the allowed_tool with input 'test'. Do NOT use excluded_tool.", }); // allowed_tool should have been called expect(allowedToolCalled).toBe(true); // excluded_tool should NOT have been called expect(excludedToolCalled).toBe(false); expect(answer?.data.content).toContain("ALLOWED_TEST"); await session.disconnect(); }); });