"""E2E Tools Tests""" import os import pytest from pydantic import BaseModel, Field from copilot import ToolSet, define_tool from copilot.rpc import ( PermissionDecisionApproveOnce, PermissionDecisionReject, ) from copilot.session import PermissionHandler, PermissionNoResult from copilot.tools import Tool, ToolInvocation, ToolResult from .testharness import E2ETestContext, get_final_assistant_message pytestmark = pytest.mark.asyncio(loop_scope="module") class TestTools: async def test_invokes_built_in_tools(self, ctx: E2ETestContext): readme_path = os.path.join(ctx.work_dir, "README.md") with open(readme_path, "w") as f: f.write("# ELIZA, the only chatbot you'll ever need") session = await ctx.client.create_session( on_permission_request=PermissionHandler.approve_all ) await session.send("What's the first line of README.md in this directory?") assistant_message = await get_final_assistant_message(session) assert "ELIZA" in assistant_message.data.content async def test_invokes_custom_tool(self, ctx: E2ETestContext): class EncryptParams(BaseModel): input: str = Field(description="String to encrypt") @define_tool("encrypt_string", description="Encrypts a string") def encrypt_string(params: EncryptParams, invocation: ToolInvocation) -> str: return params.input.upper() session = await ctx.client.create_session( on_permission_request=PermissionHandler.approve_all, tools=[encrypt_string] ) await session.send("Use encrypt_string to encrypt this string: Hello") assistant_message = await get_final_assistant_message(session) assert "HELLO" in assistant_message.data.content async def test_low_level_tool_definition(self, ctx: E2ETestContext): class PhaseArgs(BaseModel): phase: str = Field( description="Current phase", pattern="^(searching|analyzing|done)$", ) class SearchArgs(BaseModel): keyword: str current_phase = "" @define_tool("set_current_phase", description="Sets the current phase of the agent") def set_current_phase(params: PhaseArgs, invocation: ToolInvocation) -> str: nonlocal current_phase current_phase = params.phase return f"Phase set to {params.phase}" @define_tool("search_items", description="Search for items by keyword") def search_items(params: SearchArgs, invocation: ToolInvocation) -> str: args = invocation.arguments or {} keyword = str(args.get("keyword", "")) assert keyword == "copilot" return "Found: item_alpha, item_beta" session = await ctx.client.create_session( on_permission_request=PermissionHandler.approve_all, available_tools=ToolSet().add_custom("*").add_builtin("web_fetch"), tools=[set_current_phase, search_items], ) prompt = ( "First, set the current phase to 'analyzing'. Then search for items with " "keyword 'copilot'. Report the phase and search results." ) await session.send(prompt) assistant_message = await get_final_assistant_message(session) content = assistant_message.data.content or "" assert content != "" assert "analyzing" in content.lower() assert "item_alpha" in content.lower() or "item_beta" in content.lower() assert current_phase == "analyzing" async def test_handles_tool_calling_errors(self, ctx: E2ETestContext): @define_tool("get_user_location", description="Gets the user's location") def get_user_location() -> str: raise Exception("Melbourne") session = await ctx.client.create_session( on_permission_request=PermissionHandler.approve_all, tools=[get_user_location] ) await session.send("What is my location? If you can't find out, just say 'unknown'.") answer = await get_final_assistant_message(session) # Check the underlying traffic traffic = await ctx.get_exchanges() last_conversation = traffic[-1] tool_calls = [] for msg in last_conversation["request"]["messages"]: if msg.get("role") == "assistant" and "tool_calls" in msg: tool_calls.extend(msg["tool_calls"]) assert len(tool_calls) == 1 tool_call = tool_calls[0] assert tool_call["type"] == "function" assert tool_call["function"]["name"] == "get_user_location" tool_results = [ msg for msg in last_conversation["request"]["messages"] if msg.get("role") == "tool" ] assert len(tool_results) == 1 tool_result = tool_results[0] assert tool_result["tool_call_id"] == tool_call["id"] # The error message "Melbourne" should NOT be exposed to the LLM assert "Melbourne" not in tool_result["content"] # The assistant should not see the exception information assert "Melbourne" not in (answer.data.content or "") assert "unknown" in (answer.data.content or "").lower() async def test_can_receive_and_return_complex_types(self, ctx: E2ETestContext): class DbQuery(BaseModel): table: str ids: list[int] sortAscending: bool class DbQueryParams(BaseModel): query: DbQuery class City(BaseModel): countryId: int cityName: str population: int expected_session_id = None @define_tool("db_query", description="Performs a database query") def db_query(params: DbQueryParams, invocation: ToolInvocation) -> list[City]: assert params.query.table == "cities" assert params.query.ids == [12, 19] assert params.query.sortAscending is True assert invocation.session_id == expected_session_id return [ City(countryId=19, cityName="Passos", population=135460), City(countryId=12, cityName="San Lorenzo", population=204356), ] session = await ctx.client.create_session( on_permission_request=PermissionHandler.approve_all, tools=[db_query] ) expected_session_id = session.session_id await session.send( "Perform a DB query for the 'cities' table using IDs 12 and 19, " "sorting ascending. Reply only with lines of the form: [cityname] [population]" ) assistant_message = await get_final_assistant_message(session) response_content = assistant_message.data.content or "" assert response_content != "" assert "Passos" in response_content assert "San Lorenzo" in response_content assert "135460" in response_content.replace(",", "") assert "204356" in response_content.replace(",", "") async def test_skippermission_sent_in_tool_definition(self, ctx: E2ETestContext): class LookupParams(BaseModel): id: str = Field(description="ID to look up") @define_tool( "safe_lookup", description="A safe lookup that skips permission", skip_permission=True, ) def safe_lookup(params: LookupParams, invocation: ToolInvocation) -> str: return f"RESULT: {params.id}" did_run_permission_request = False def tracking_handler(request, invocation): nonlocal did_run_permission_request did_run_permission_request = True return PermissionNoResult() session = await ctx.client.create_session( on_permission_request=tracking_handler, tools=[safe_lookup] ) await session.send("Use safe_lookup to look up 'test123'") assistant_message = await get_final_assistant_message(session) assert "RESULT: test123" in assistant_message.data.content assert not did_run_permission_request async def test_overrides_built_in_tool_with_custom_tool(self, ctx: E2ETestContext): class GrepParams(BaseModel): query: str = Field(description="Search query") @define_tool( "grep", description="A custom grep implementation that overrides the built-in", overrides_built_in_tool=True, ) def custom_grep(params: GrepParams, invocation: ToolInvocation) -> str: return f"CUSTOM_GREP_RESULT: {params.query}" session = await ctx.client.create_session( on_permission_request=PermissionHandler.approve_all, tools=[custom_grep] ) await session.send("Use grep to search for the word 'hello'") assistant_message = await get_final_assistant_message(session) assert "CUSTOM_GREP_RESULT" in assistant_message.data.content async def test_invokes_custom_tool_with_permission_handler(self, ctx: E2ETestContext): class EncryptParams(BaseModel): input: str = Field(description="String to encrypt") @define_tool("encrypt_string", description="Encrypts a string") def encrypt_string(params: EncryptParams, invocation: ToolInvocation) -> str: return params.input.upper() permission_requests = [] def on_permission_request(request, invocation): permission_requests.append(request) return PermissionDecisionApproveOnce() session = await ctx.client.create_session( on_permission_request=on_permission_request, tools=[encrypt_string] ) await session.send("Use encrypt_string to encrypt this string: Hello") assistant_message = await get_final_assistant_message(session) assert "HELLO" in assistant_message.data.content # Should have received a custom-tool permission request custom_tool_requests = [r for r in permission_requests if r.kind == "custom-tool"] assert len(custom_tool_requests) > 0 assert custom_tool_requests[0].tool_name == "encrypt_string" async def test_denies_custom_tool_when_permission_denied(self, ctx: E2ETestContext): tool_handler_called = False class EncryptParams(BaseModel): input: str = Field(description="String to encrypt") @define_tool("encrypt_string", description="Encrypts a string") def encrypt_string(params: EncryptParams, invocation: ToolInvocation) -> str: nonlocal tool_handler_called tool_handler_called = True return params.input.upper() def on_permission_request(request, invocation): return PermissionDecisionReject() session = await ctx.client.create_session( on_permission_request=on_permission_request, tools=[encrypt_string] ) await session.send("Use encrypt_string to encrypt this string: Hello") await get_final_assistant_message(session) # The tool handler should NOT have been called since permission was denied assert not tool_handler_called async def test_should_execute_multiple_custom_tools_in_parallel_single_turn( self, ctx: E2ETestContext ): """Multiple custom tools invoked in parallel in the same turn.""" import asyncio city_called: asyncio.Future = asyncio.get_event_loop().create_future() country_called: asyncio.Future = asyncio.get_event_loop().create_future() def lookup_city(invocation: ToolInvocation) -> ToolResult: city = (invocation.arguments or {}).get("city", "") if not city_called.done(): city_called.set_result(city) return ToolResult(text_result_for_llm=f"CITY_{city.upper()}") def lookup_country(invocation: ToolInvocation) -> ToolResult: country = (invocation.arguments or {}).get("country", "") if not country_called.done(): country_called.set_result(country) return ToolResult(text_result_for_llm=f"COUNTRY_{country.upper()}") session = await ctx.client.create_session( on_permission_request=PermissionHandler.approve_all, tools=[ Tool( name="lookup_city", description="Looks up city information", parameters={ "type": "object", "properties": {"city": {"type": "string", "description": "City name"}}, "required": ["city"], }, handler=lookup_city, ), Tool( name="lookup_country", description="Looks up country information", parameters={ "type": "object", "properties": { "country": {"type": "string", "description": "Country name"} }, "required": ["country"], }, handler=lookup_country, ), ], ) try: await session.send( "Use lookup_city with 'Paris' and lookup_country with 'France' at the same time," " then combine both results in your reply." ) city_result = await asyncio.wait_for(city_called, timeout=60.0) country_result = await asyncio.wait_for(country_called, timeout=60.0) assert city_result == "Paris" assert country_result == "France" assistant_message = await get_final_assistant_message(session, timeout=60.0) assert assistant_message is not None content = assistant_message.data.content or "" assert "CITY_PARIS" in content assert "COUNTRY_FRANCE" in content finally: await session.disconnect() async def test_should_respect_availabletools_and_excludedtools_combined( self, ctx: E2ETestContext ): """excluded_tools takes precedence over available_tools.""" excluded_tool_called = False def allowed_handler(invocation: ToolInvocation) -> ToolResult: input_val = (invocation.arguments or {}).get("input", "") return ToolResult(text_result_for_llm=f"ALLOWED_{input_val.upper()}") def excluded_handler(invocation: ToolInvocation) -> ToolResult: nonlocal excluded_tool_called excluded_tool_called = True input_val = (invocation.arguments or {}).get("input", "") return ToolResult(text_result_for_llm=f"EXCLUDED_{input_val.upper()}") session = await ctx.client.create_session( on_permission_request=PermissionHandler.approve_all, tools=[ Tool( name="allowed_tool", description="An allowed tool", parameters={ "type": "object", "properties": {"input": {"type": "string", "description": "Input value"}}, "required": ["input"], }, handler=allowed_handler, ), Tool( name="excluded_tool", description="A tool that should be excluded", parameters={ "type": "object", "properties": {"input": {"type": "string", "description": "Input value"}}, "required": ["input"], }, handler=excluded_handler, ), ], available_tools=["allowed_tool", "excluded_tool"], excluded_tools=["excluded_tool"], ) try: result = await session.send_and_wait( "Use the allowed_tool with input 'test'. Do NOT use excluded_tool.", timeout=60.0, ) assert result is not None assert "ALLOWED_TEST" in (result.data.content or "") assert not excluded_tool_called, "Excluded tool should not have been called" finally: await session.disconnect()