1919import subprocess
2020import threading
2121from dataclasses import asdict , is_dataclass
22- from typing import Any , Optional , cast
22+ from typing import Any , Optional
2323
2424from .generated .session_events import session_event_from_dict
2525from .jsonrpc import JsonRpcClient
3232 GetAuthStatusResponse ,
3333 GetStatusResponse ,
3434 ModelInfo ,
35+ PingResponse ,
3536 ProviderConfig ,
3637 ResumeSessionConfig ,
3738 SessionConfig ,
3839 SessionMetadata ,
40+ StopError ,
3941 ToolHandler ,
4042 ToolInvocation ,
4143 ToolResult ,
@@ -220,7 +222,7 @@ async def start(self) -> None:
220222 self ._state = "error"
221223 raise
222224
223- async def stop (self ) -> list [dict [ str , str ] ]:
225+ async def stop (self ) -> list ["StopError" ]:
224226 """
225227 Stop the CLI server and close all active sessions.
226228
@@ -230,16 +232,16 @@ async def stop(self) -> list[dict[str, str]]:
230232 3. Terminates the CLI server process (if spawned by this client)
231233
232234 Returns:
233- A list of errors that occurred during cleanup, each as a dict with
234- a 'message' key . An empty list indicates all cleanup succeeded.
235+ A list of StopError objects containing error messages that occurred
236+ during cleanup . An empty list indicates all cleanup succeeded.
235237
236238 Example:
237239 >>> errors = await client.stop()
238240 >>> if errors:
239241 ... for error in errors:
240- ... print(f"Cleanup error: {error[' message'] }")
242+ ... print(f"Cleanup error: {error. message}")
241243 """
242- errors : list [dict [ str , str ] ] = []
244+ errors : list [StopError ] = []
243245
244246 # Atomically take ownership of all sessions and clear the dict
245247 # so no other thread can access them
@@ -251,7 +253,9 @@ async def stop(self) -> list[dict[str, str]]:
251253 try :
252254 await session .destroy ()
253255 except Exception as e :
254- errors .append ({"message" : f"Failed to destroy session { session .session_id } : { e } " })
256+ errors .append (
257+ StopError (message = f"Failed to destroy session { session .session_id } : { e } " )
258+ )
255259
256260 # Close client
257261 if self ._client :
@@ -570,67 +574,69 @@ def get_state(self) -> ConnectionState:
570574 """
571575 return self ._state
572576
573- async def ping (self , message : Optional [str ] = None ) -> dict :
577+ async def ping (self , message : Optional [str ] = None ) -> "PingResponse" :
574578 """
575579 Send a ping request to the server to verify connectivity.
576580
577581 Args:
578582 message: Optional message to include in the ping.
579583
580584 Returns:
581- A dict containing the ping response with 'message', 'timestamp',
582- and 'protocolVersion' keys.
585+ A PingResponse object containing the ping response.
583586
584587 Raises:
585588 RuntimeError: If the client is not connected.
586589
587590 Example:
588591 >>> response = await client.ping("health check")
589- >>> print(f"Server responded at {response[' timestamp'] }")
592+ >>> print(f"Server responded at {response. timestamp}")
590593 """
591594 if not self ._client :
592595 raise RuntimeError ("Client not connected" )
593596
594- return await self ._client .request ("ping" , {"message" : message })
597+ result = await self ._client .request ("ping" , {"message" : message })
598+ return PingResponse .from_dict (result )
595599
596600 async def get_status (self ) -> "GetStatusResponse" :
597601 """
598602 Get CLI status including version and protocol information.
599603
600604 Returns:
601- A GetStatusResponse containing version and protocolVersion.
605+ A GetStatusResponse object containing version and protocolVersion.
602606
603607 Raises:
604608 RuntimeError: If the client is not connected.
605609
606610 Example:
607611 >>> status = await client.get_status()
608- >>> print(f"CLI version: {status[' version'] }")
612+ >>> print(f"CLI version: {status. version}")
609613 """
610614 if not self ._client :
611615 raise RuntimeError ("Client not connected" )
612616
613- return await self ._client .request ("status.get" , {})
617+ result = await self ._client .request ("status.get" , {})
618+ return GetStatusResponse .from_dict (result )
614619
615620 async def get_auth_status (self ) -> "GetAuthStatusResponse" :
616621 """
617622 Get current authentication status.
618623
619624 Returns:
620- A GetAuthStatusResponse containing authentication state.
625+ A GetAuthStatusResponse object containing authentication state.
621626
622627 Raises:
623628 RuntimeError: If the client is not connected.
624629
625630 Example:
626631 >>> auth = await client.get_auth_status()
627- >>> if auth[' isAuthenticated'] :
628- ... print(f"Logged in as {auth.get(' login') }")
632+ >>> if auth. isAuthenticated:
633+ ... print(f"Logged in as {auth.login}")
629634 """
630635 if not self ._client :
631636 raise RuntimeError ("Client not connected" )
632637
633- return await self ._client .request ("auth.getStatus" , {})
638+ result = await self ._client .request ("auth.getStatus" , {})
639+ return GetAuthStatusResponse .from_dict (result )
634640
635641 async def list_models (self ) -> list ["ModelInfo" ]:
636642 """
@@ -646,13 +652,14 @@ async def list_models(self) -> list["ModelInfo"]:
646652 Example:
647653 >>> models = await client.list_models()
648654 >>> for model in models:
649- ... print(f"{model['id'] }: {model[' name'] }")
655+ ... print(f"{model.id }: {model. name}")
650656 """
651657 if not self ._client :
652658 raise RuntimeError ("Client not connected" )
653659
654660 response = await self ._client .request ("models.list" , {})
655- return response .get ("models" , [])
661+ models_data = response .get ("models" , [])
662+ return [ModelInfo .from_dict (model ) for model in models_data ]
656663
657664 async def list_sessions (self ) -> list ["SessionMetadata" ]:
658665 """
@@ -661,23 +668,22 @@ async def list_sessions(self) -> list["SessionMetadata"]:
661668 Returns metadata about each session including ID, timestamps, and summary.
662669
663670 Returns:
664- A list of session metadata dictionaries with keys: sessionId (str),
665- startTime (str), modifiedTime (str), summary (str, optional),
666- and isRemote (bool).
671+ A list of SessionMetadata objects.
667672
668673 Raises:
669674 RuntimeError: If the client is not connected.
670675
671676 Example:
672677 >>> sessions = await client.list_sessions()
673678 >>> for session in sessions:
674- ... print(f"Session: {session[' sessionId'] }")
679+ ... print(f"Session: {session. sessionId}")
675680 """
676681 if not self ._client :
677682 raise RuntimeError ("Client not connected" )
678683
679684 response = await self ._client .request ("session.list" , {})
680- return response .get ("sessions" , [])
685+ sessions_data = response .get ("sessions" , [])
686+ return [SessionMetadata .from_dict (session ) for session in sessions_data ]
681687
682688 async def delete_session (self , session_id : str ) -> None :
683689 """
@@ -714,7 +720,7 @@ async def _verify_protocol_version(self) -> None:
714720 """Verify that the server's protocol version matches the SDK's expected version."""
715721 expected_version = get_sdk_protocol_version ()
716722 ping_result = await self .ping ()
717- server_version = ping_result .get ( " protocolVersion" )
723+ server_version = ping_result .protocolVersion
718724
719725 if server_version is None :
720726 raise RuntimeError (
@@ -845,11 +851,11 @@ async def read_port():
845851 if not process or not process .stdout :
846852 raise RuntimeError ("Process not started or stdout not available" )
847853 while True :
848- line = cast ( bytes , await loop .run_in_executor (None , process .stdout .readline ) )
854+ line = await loop .run_in_executor (None , process .stdout .readline )
849855 if not line :
850856 raise RuntimeError ("CLI process exited before announcing port" )
851857
852- line_str = line .decode ()
858+ line_str = line .decode () if isinstance ( line , bytes ) else line
853859 match = re .search (r"listening on port (\d+)" , line_str , re .IGNORECASE )
854860 if match :
855861 self ._actual_port = int (match .group (1 ))
0 commit comments