Skip to content

Commit 63f9c23

Browse files
Copilotfriggeri
andauthored
Consistently use Dataclasses in Python SDK (github#216)
* Initial plan * Add TypedDict types for ping() and stop() return values Co-authored-by: friggeri <106686+friggeri@users.noreply.github.com> * Convert all TypedDict response types to dataclasses Co-authored-by: friggeri <106686+friggeri@users.noreply.github.com> * Add proper validation to from_dict methods Co-authored-by: friggeri <106686+friggeri@users.noreply.github.com> * Fix ty type checker errors after dataclass conversion Co-authored-by: friggeri <106686+friggeri@users.noreply.github.com> * fix type error and flaky pyton test --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: friggeri <106686+friggeri@users.noreply.github.com> Co-authored-by: Adrien Friggeri <adrien@friggeri.net>
1 parent 24da763 commit 63f9c23

File tree

6 files changed

+376
-94
lines changed

6 files changed

+376
-94
lines changed

python/copilot/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
PermissionHandler,
2525
PermissionRequest,
2626
PermissionRequestResult,
27+
PingResponse,
2728
ProviderConfig,
2829
ResumeSessionConfig,
2930
SessionConfig,
3031
SessionEvent,
3132
SessionMetadata,
33+
StopError,
3234
Tool,
3335
ToolHandler,
3436
ToolInvocation,
@@ -56,11 +58,13 @@
5658
"PermissionHandler",
5759
"PermissionRequest",
5860
"PermissionRequestResult",
61+
"PingResponse",
5962
"ProviderConfig",
6063
"ResumeSessionConfig",
6164
"SessionConfig",
6265
"SessionEvent",
6366
"SessionMetadata",
67+
"StopError",
6468
"Tool",
6569
"ToolHandler",
6670
"ToolInvocation",

python/copilot/client.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import subprocess
2020
import threading
2121
from dataclasses import asdict, is_dataclass
22-
from typing import Any, Optional, cast
22+
from typing import Any, Optional
2323

2424
from .generated.session_events import session_event_from_dict
2525
from .jsonrpc import JsonRpcClient
@@ -32,10 +32,12 @@
3232
GetAuthStatusResponse,
3333
GetStatusResponse,
3434
ModelInfo,
35+
PingResponse,
3536
ProviderConfig,
3637
ResumeSessionConfig,
3738
SessionConfig,
3839
SessionMetadata,
40+
StopError,
3941
ToolHandler,
4042
ToolInvocation,
4143
ToolResult,
@@ -220,7 +222,7 @@ async def start(self) -> None:
220222
self._state = "error"
221223
raise
222224

223-
async def stop(self) -> list[dict[str, str]]:
225+
async def stop(self) -> list["StopError"]:
224226
"""
225227
Stop the CLI server and close all active sessions.
226228
@@ -230,16 +232,16 @@ async def stop(self) -> list[dict[str, str]]:
230232
3. Terminates the CLI server process (if spawned by this client)
231233
232234
Returns:
233-
A list of errors that occurred during cleanup, each as a dict with
234-
a 'message' key. An empty list indicates all cleanup succeeded.
235+
A list of StopError objects containing error messages that occurred
236+
during cleanup. An empty list indicates all cleanup succeeded.
235237
236238
Example:
237239
>>> errors = await client.stop()
238240
>>> if errors:
239241
... for error in errors:
240-
... print(f"Cleanup error: {error['message']}")
242+
... print(f"Cleanup error: {error.message}")
241243
"""
242-
errors: list[dict[str, str]] = []
244+
errors: list[StopError] = []
243245

244246
# Atomically take ownership of all sessions and clear the dict
245247
# so no other thread can access them
@@ -251,7 +253,9 @@ async def stop(self) -> list[dict[str, str]]:
251253
try:
252254
await session.destroy()
253255
except Exception as e:
254-
errors.append({"message": f"Failed to destroy session {session.session_id}: {e}"})
256+
errors.append(
257+
StopError(message=f"Failed to destroy session {session.session_id}: {e}")
258+
)
255259

256260
# Close client
257261
if self._client:
@@ -570,67 +574,69 @@ def get_state(self) -> ConnectionState:
570574
"""
571575
return self._state
572576

573-
async def ping(self, message: Optional[str] = None) -> dict:
577+
async def ping(self, message: Optional[str] = None) -> "PingResponse":
574578
"""
575579
Send a ping request to the server to verify connectivity.
576580
577581
Args:
578582
message: Optional message to include in the ping.
579583
580584
Returns:
581-
A dict containing the ping response with 'message', 'timestamp',
582-
and 'protocolVersion' keys.
585+
A PingResponse object containing the ping response.
583586
584587
Raises:
585588
RuntimeError: If the client is not connected.
586589
587590
Example:
588591
>>> response = await client.ping("health check")
589-
>>> print(f"Server responded at {response['timestamp']}")
592+
>>> print(f"Server responded at {response.timestamp}")
590593
"""
591594
if not self._client:
592595
raise RuntimeError("Client not connected")
593596

594-
return await self._client.request("ping", {"message": message})
597+
result = await self._client.request("ping", {"message": message})
598+
return PingResponse.from_dict(result)
595599

596600
async def get_status(self) -> "GetStatusResponse":
597601
"""
598602
Get CLI status including version and protocol information.
599603
600604
Returns:
601-
A GetStatusResponse containing version and protocolVersion.
605+
A GetStatusResponse object containing version and protocolVersion.
602606
603607
Raises:
604608
RuntimeError: If the client is not connected.
605609
606610
Example:
607611
>>> status = await client.get_status()
608-
>>> print(f"CLI version: {status['version']}")
612+
>>> print(f"CLI version: {status.version}")
609613
"""
610614
if not self._client:
611615
raise RuntimeError("Client not connected")
612616

613-
return await self._client.request("status.get", {})
617+
result = await self._client.request("status.get", {})
618+
return GetStatusResponse.from_dict(result)
614619

615620
async def get_auth_status(self) -> "GetAuthStatusResponse":
616621
"""
617622
Get current authentication status.
618623
619624
Returns:
620-
A GetAuthStatusResponse containing authentication state.
625+
A GetAuthStatusResponse object containing authentication state.
621626
622627
Raises:
623628
RuntimeError: If the client is not connected.
624629
625630
Example:
626631
>>> auth = await client.get_auth_status()
627-
>>> if auth['isAuthenticated']:
628-
... print(f"Logged in as {auth.get('login')}")
632+
>>> if auth.isAuthenticated:
633+
... print(f"Logged in as {auth.login}")
629634
"""
630635
if not self._client:
631636
raise RuntimeError("Client not connected")
632637

633-
return await self._client.request("auth.getStatus", {})
638+
result = await self._client.request("auth.getStatus", {})
639+
return GetAuthStatusResponse.from_dict(result)
634640

635641
async def list_models(self) -> list["ModelInfo"]:
636642
"""
@@ -646,13 +652,14 @@ async def list_models(self) -> list["ModelInfo"]:
646652
Example:
647653
>>> models = await client.list_models()
648654
>>> for model in models:
649-
... print(f"{model['id']}: {model['name']}")
655+
... print(f"{model.id}: {model.name}")
650656
"""
651657
if not self._client:
652658
raise RuntimeError("Client not connected")
653659

654660
response = await self._client.request("models.list", {})
655-
return response.get("models", [])
661+
models_data = response.get("models", [])
662+
return [ModelInfo.from_dict(model) for model in models_data]
656663

657664
async def list_sessions(self) -> list["SessionMetadata"]:
658665
"""
@@ -661,23 +668,22 @@ async def list_sessions(self) -> list["SessionMetadata"]:
661668
Returns metadata about each session including ID, timestamps, and summary.
662669
663670
Returns:
664-
A list of session metadata dictionaries with keys: sessionId (str),
665-
startTime (str), modifiedTime (str), summary (str, optional),
666-
and isRemote (bool).
671+
A list of SessionMetadata objects.
667672
668673
Raises:
669674
RuntimeError: If the client is not connected.
670675
671676
Example:
672677
>>> sessions = await client.list_sessions()
673678
>>> for session in sessions:
674-
... print(f"Session: {session['sessionId']}")
679+
... print(f"Session: {session.sessionId}")
675680
"""
676681
if not self._client:
677682
raise RuntimeError("Client not connected")
678683

679684
response = await self._client.request("session.list", {})
680-
return response.get("sessions", [])
685+
sessions_data = response.get("sessions", [])
686+
return [SessionMetadata.from_dict(session) for session in sessions_data]
681687

682688
async def delete_session(self, session_id: str) -> None:
683689
"""
@@ -714,7 +720,7 @@ async def _verify_protocol_version(self) -> None:
714720
"""Verify that the server's protocol version matches the SDK's expected version."""
715721
expected_version = get_sdk_protocol_version()
716722
ping_result = await self.ping()
717-
server_version = ping_result.get("protocolVersion")
723+
server_version = ping_result.protocolVersion
718724

719725
if server_version is None:
720726
raise RuntimeError(
@@ -845,11 +851,11 @@ async def read_port():
845851
if not process or not process.stdout:
846852
raise RuntimeError("Process not started or stdout not available")
847853
while True:
848-
line = cast(bytes, await loop.run_in_executor(None, process.stdout.readline))
854+
line = await loop.run_in_executor(None, process.stdout.readline)
849855
if not line:
850856
raise RuntimeError("CLI process exited before announcing port")
851857

852-
line_str = line.decode()
858+
line_str = line.decode() if isinstance(line, bytes) else line
853859
match = re.search(r"listening on port (\d+)", line_str, re.IGNORECASE)
854860
if match:
855861
self._actual_port = int(match.group(1))

python/copilot/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def _normalize_result(result: Any) -> ToolResult:
186186

187187
# ToolResult passes through directly
188188
if isinstance(result, dict) and "resultType" in result and "textResultForLlm" in result:
189-
return result # type: ignore
189+
return result
190190

191191
# Strings pass through directly
192192
if isinstance(result, str):

0 commit comments

Comments
 (0)