Skip to content

Commit 17f54a7

Browse files
authored
Re-run uv lock and fix Python tests (github#157)
* Re-run uv lock * Make the lint work * fix(python): fix list_sessions and delete_session e2e tests The tests were failing because: 1. Sessions only persist to disk after a message is sent 2. There's a brief delay before session files are written Changes: - Add send_and_wait() calls to persist sessions before listing - Add small delay before list_sessions() to allow file sync * Fix test
1 parent 8f3e4ba commit 17f54a7

File tree

11 files changed

+155
-353
lines changed

11 files changed

+155
-353
lines changed

python/copilot/client.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import subprocess
2020
import threading
2121
from dataclasses import asdict, is_dataclass
22-
from typing import Any, Dict, List, Optional, cast
22+
from typing import Any, Optional, cast
2323

2424
from .generated.session_events import session_event_from_dict
2525
from .jsonrpc import JsonRpcClient
@@ -28,9 +28,11 @@
2828
from .types import (
2929
ConnectionState,
3030
CopilotClientOptions,
31+
CustomAgentConfig,
3132
GetAuthStatusResponse,
3233
GetStatusResponse,
3334
ModelInfo,
35+
ProviderConfig,
3436
ResumeSessionConfig,
3537
SessionConfig,
3638
SessionMetadata,
@@ -132,7 +134,7 @@ def __init__(self, options: Optional[CopilotClientOptions] = None):
132134
self._process: Optional[subprocess.Popen] = None
133135
self._client: Optional[JsonRpcClient] = None
134136
self._state: ConnectionState = "disconnected"
135-
self._sessions: Dict[str, CopilotSession] = {}
137+
self._sessions: dict[str, CopilotSession] = {}
136138
self._sessions_lock = threading.Lock()
137139

138140
def _parse_cli_url(self, url: str) -> tuple[str, int]:
@@ -218,7 +220,7 @@ async def start(self) -> None:
218220
self._state = "error"
219221
raise
220222

221-
async def stop(self) -> List[Dict[str, str]]:
223+
async def stop(self) -> list[dict[str, str]]:
222224
"""
223225
Stop the CLI server and close all active sessions.
224226
@@ -237,7 +239,7 @@ async def stop(self) -> List[Dict[str, str]]:
237239
... for error in errors:
238240
... print(f"Cleanup error: {error['message']}")
239241
"""
240-
errors: List[Dict[str, str]] = []
242+
errors: list[dict[str, str]] = []
241243

242244
# Atomically take ownership of all sessions and clear the dict
243245
# so no other thread can access them
@@ -357,7 +359,7 @@ async def create_session(self, config: Optional[SessionConfig] = None) -> Copilo
357359
definition["parameters"] = tool.parameters
358360
tool_defs.append(definition)
359361

360-
payload: Dict[str, Any] = {}
362+
payload: dict[str, Any] = {}
361363
if cfg.get("model"):
362364
payload["model"] = cfg["model"]
363365
if cfg.get("session_id"):
@@ -482,7 +484,7 @@ async def resume_session(
482484
definition["parameters"] = tool.parameters
483485
tool_defs.append(definition)
484486

485-
payload: Dict[str, Any] = {"sessionId": session_id}
487+
payload: dict[str, Any] = {"sessionId": session_id}
486488
if tool_defs:
487489
payload["tools"] = tool_defs
488490

@@ -612,7 +614,7 @@ async def get_auth_status(self) -> "GetAuthStatusResponse":
612614

613615
return await self._client.request("auth.getStatus", {})
614616

615-
async def list_models(self) -> List["ModelInfo"]:
617+
async def list_models(self) -> list["ModelInfo"]:
616618
"""
617619
List available models with their metadata.
618620
@@ -634,7 +636,7 @@ async def list_models(self) -> List["ModelInfo"]:
634636
response = await self._client.request("models.list", {})
635637
return response.get("models", [])
636638

637-
async def list_sessions(self) -> List["SessionMetadata"]:
639+
async def list_sessions(self) -> list["SessionMetadata"]:
638640
"""
639641
List all available sessions known to the server.
640642
@@ -710,7 +712,9 @@ async def _verify_protocol_version(self) -> None:
710712
f"Please update your SDK or server to ensure compatibility."
711713
)
712714

713-
def _convert_provider_to_wire_format(self, provider: Dict[str, Any]) -> Dict[str, Any]:
715+
def _convert_provider_to_wire_format(
716+
self, provider: ProviderConfig | dict[str, Any]
717+
) -> dict[str, Any]:
714718
"""
715719
Convert provider config from snake_case to camelCase wire format.
716720
@@ -720,7 +724,7 @@ def _convert_provider_to_wire_format(self, provider: Dict[str, Any]) -> Dict[str
720724
Returns:
721725
The provider configuration in camelCase wire format.
722726
"""
723-
wire_provider: Dict[str, Any] = {"type": provider.get("type")}
727+
wire_provider: dict[str, Any] = {"type": provider.get("type")}
724728
if "base_url" in provider:
725729
wire_provider["baseUrl"] = provider["base_url"]
726730
if "api_key" in provider:
@@ -731,14 +735,16 @@ def _convert_provider_to_wire_format(self, provider: Dict[str, Any]) -> Dict[str
731735
wire_provider["bearerToken"] = provider["bearer_token"]
732736
if "azure" in provider:
733737
azure = provider["azure"]
734-
wire_azure: Dict[str, Any] = {}
738+
wire_azure: dict[str, Any] = {}
735739
if "api_version" in azure:
736740
wire_azure["apiVersion"] = azure["api_version"]
737741
if wire_azure:
738742
wire_provider["azure"] = wire_azure
739743
return wire_provider
740744

741-
def _convert_custom_agent_to_wire_format(self, agent: Dict[str, Any]) -> Dict[str, Any]:
745+
def _convert_custom_agent_to_wire_format(
746+
self, agent: CustomAgentConfig | dict[str, Any]
747+
) -> dict[str, Any]:
742748
"""
743749
Convert custom agent config from snake_case to camelCase wire format.
744750
@@ -748,7 +754,7 @@ def _convert_custom_agent_to_wire_format(self, agent: Dict[str, Any]) -> Dict[st
748754
Returns:
749755
The custom agent configuration in camelCase wire format.
750756
"""
751-
wire_agent: Dict[str, Any] = {"name": agent.get("name"), "prompt": agent.get("prompt")}
757+
wire_agent: dict[str, Any] = {"name": agent.get("name"), "prompt": agent.get("prompt")}
752758
if "display_name" in agent:
753759
wire_agent["displayName"] = agent["display_name"]
754760
if "description" in agent:

python/copilot/jsonrpc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import json
1111
import threading
1212
import uuid
13-
from typing import Any, Awaitable, Callable, Dict, Optional, Union
13+
from collections.abc import Awaitable
14+
from typing import Any, Callable, Optional, Union
1415

1516

1617
class JsonRpcError(Exception):
@@ -41,9 +42,9 @@ def __init__(self, process):
4142
process: subprocess.Popen with stdin=PIPE, stdout=PIPE
4243
"""
4344
self.process = process
44-
self.pending_requests: Dict[str, asyncio.Future] = {}
45+
self.pending_requests: dict[str, asyncio.Future] = {}
4546
self.notification_handler: Optional[Callable[[str, dict], None]] = None
46-
self.request_handlers: Dict[str, RequestHandler] = {}
47+
self.request_handlers: dict[str, RequestHandler] = {}
4748
self._running = False
4849
self._read_thread: Optional[threading.Thread] = None
4950
self._loop: Optional[asyncio.AbstractEventLoop] = None

python/copilot/session.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import asyncio
99
import inspect
1010
import threading
11-
from typing import Any, Callable, Dict, List, Optional, Set
11+
from typing import Any, Callable, Optional
1212

1313
from .generated.session_events import SessionEvent, SessionEventType, session_event_from_dict
1414
from .types import (
@@ -62,9 +62,9 @@ def __init__(self, session_id: str, client: Any):
6262
"""
6363
self.session_id = session_id
6464
self._client = client
65-
self._event_handlers: Set[Callable[[SessionEvent], None]] = set()
65+
self._event_handlers: set[Callable[[SessionEvent], None]] = set()
6666
self._event_handlers_lock = threading.Lock()
67-
self._tool_handlers: Dict[str, ToolHandler] = {}
67+
self._tool_handlers: dict[str, ToolHandler] = {}
6868
self._tool_handlers_lock = threading.Lock()
6969
self._permission_handler: Optional[PermissionHandler] = None
7070
self._permission_handler_lock = threading.Lock()
@@ -220,7 +220,7 @@ def _dispatch_event(self, event: SessionEvent) -> None:
220220
except Exception as e:
221221
print(f"Error in session event handler: {e}")
222222

223-
def _register_tools(self, tools: Optional[List[Tool]]) -> None:
223+
def _register_tools(self, tools: Optional[list[Tool]]) -> None:
224224
"""
225225
Register custom tool handlers for this session.
226226
@@ -307,7 +307,7 @@ async def _handle_permission_request(self, request: dict) -> dict:
307307
# Handler failed, deny permission
308308
return {"kind": "denied-no-approval-rule-and-could-not-request-from-user"}
309309

310-
async def get_messages(self) -> List[SessionEvent]:
310+
async def get_messages(self) -> list[SessionEvent]:
311311
"""
312312
Retrieve all events and messages from this session's history.
313313

python/copilot/tools.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import inspect
1111
import json
12-
from typing import Any, Callable, Type, TypeVar, get_type_hints, overload
12+
from typing import Any, Callable, TypeVar, get_type_hints, overload
1313

1414
from pydantic import BaseModel
1515

@@ -33,7 +33,7 @@ def define_tool(
3333
*,
3434
description: str | None = None,
3535
handler: Callable[[T, ToolInvocation], R],
36-
params_type: Type[T],
36+
params_type: type[T],
3737
) -> Tool: ...
3838

3939

@@ -42,7 +42,7 @@ def define_tool(
4242
*,
4343
description: str | None = None,
4444
handler: Callable[[Any, ToolInvocation], Any] | None = None,
45-
params_type: Type[BaseModel] | None = None,
45+
params_type: type[BaseModel] | None = None,
4646
) -> Tool | Callable[[Callable[[Any, ToolInvocation], Any]], Tool]:
4747
"""
4848
Define a tool with automatic JSON schema generation from Pydantic models.

0 commit comments

Comments
 (0)