1919import subprocess
2020import threading
2121from dataclasses import asdict , is_dataclass
22- from typing import Any , Dict , List , Optional , cast
22+ from typing import Any , Optional , cast
2323
2424from .generated .session_events import session_event_from_dict
2525from .jsonrpc import JsonRpcClient
2828from .types import (
2929 ConnectionState ,
3030 CopilotClientOptions ,
31+ CustomAgentConfig ,
3132 GetAuthStatusResponse ,
3233 GetStatusResponse ,
3334 ModelInfo ,
35+ ProviderConfig ,
3436 ResumeSessionConfig ,
3537 SessionConfig ,
3638 SessionMetadata ,
@@ -132,7 +134,7 @@ def __init__(self, options: Optional[CopilotClientOptions] = None):
132134 self ._process : Optional [subprocess .Popen ] = None
133135 self ._client : Optional [JsonRpcClient ] = None
134136 self ._state : ConnectionState = "disconnected"
135- self ._sessions : Dict [str , CopilotSession ] = {}
137+ self ._sessions : dict [str , CopilotSession ] = {}
136138 self ._sessions_lock = threading .Lock ()
137139
138140 def _parse_cli_url (self , url : str ) -> tuple [str , int ]:
@@ -218,7 +220,7 @@ async def start(self) -> None:
218220 self ._state = "error"
219221 raise
220222
221- async def stop (self ) -> List [ Dict [str , str ]]:
223+ async def stop (self ) -> list [ dict [str , str ]]:
222224 """
223225 Stop the CLI server and close all active sessions.
224226
@@ -237,7 +239,7 @@ async def stop(self) -> List[Dict[str, str]]:
237239 ... for error in errors:
238240 ... print(f"Cleanup error: {error['message']}")
239241 """
240- errors : List [ Dict [str , str ]] = []
242+ errors : list [ dict [str , str ]] = []
241243
242244 # Atomically take ownership of all sessions and clear the dict
243245 # so no other thread can access them
@@ -357,7 +359,7 @@ async def create_session(self, config: Optional[SessionConfig] = None) -> Copilo
357359 definition ["parameters" ] = tool .parameters
358360 tool_defs .append (definition )
359361
360- payload : Dict [str , Any ] = {}
362+ payload : dict [str , Any ] = {}
361363 if cfg .get ("model" ):
362364 payload ["model" ] = cfg ["model" ]
363365 if cfg .get ("session_id" ):
@@ -482,7 +484,7 @@ async def resume_session(
482484 definition ["parameters" ] = tool .parameters
483485 tool_defs .append (definition )
484486
485- payload : Dict [str , Any ] = {"sessionId" : session_id }
487+ payload : dict [str , Any ] = {"sessionId" : session_id }
486488 if tool_defs :
487489 payload ["tools" ] = tool_defs
488490
@@ -612,7 +614,7 @@ async def get_auth_status(self) -> "GetAuthStatusResponse":
612614
613615 return await self ._client .request ("auth.getStatus" , {})
614616
615- async def list_models (self ) -> List ["ModelInfo" ]:
617+ async def list_models (self ) -> list ["ModelInfo" ]:
616618 """
617619 List available models with their metadata.
618620
@@ -634,7 +636,7 @@ async def list_models(self) -> List["ModelInfo"]:
634636 response = await self ._client .request ("models.list" , {})
635637 return response .get ("models" , [])
636638
637- async def list_sessions (self ) -> List ["SessionMetadata" ]:
639+ async def list_sessions (self ) -> list ["SessionMetadata" ]:
638640 """
639641 List all available sessions known to the server.
640642
@@ -710,7 +712,9 @@ async def _verify_protocol_version(self) -> None:
710712 f"Please update your SDK or server to ensure compatibility."
711713 )
712714
713- def _convert_provider_to_wire_format (self , provider : Dict [str , Any ]) -> Dict [str , Any ]:
715+ def _convert_provider_to_wire_format (
716+ self , provider : ProviderConfig | dict [str , Any ]
717+ ) -> dict [str , Any ]:
714718 """
715719 Convert provider config from snake_case to camelCase wire format.
716720
@@ -720,7 +724,7 @@ def _convert_provider_to_wire_format(self, provider: Dict[str, Any]) -> Dict[str
720724 Returns:
721725 The provider configuration in camelCase wire format.
722726 """
723- wire_provider : Dict [str , Any ] = {"type" : provider .get ("type" )}
727+ wire_provider : dict [str , Any ] = {"type" : provider .get ("type" )}
724728 if "base_url" in provider :
725729 wire_provider ["baseUrl" ] = provider ["base_url" ]
726730 if "api_key" in provider :
@@ -731,14 +735,16 @@ def _convert_provider_to_wire_format(self, provider: Dict[str, Any]) -> Dict[str
731735 wire_provider ["bearerToken" ] = provider ["bearer_token" ]
732736 if "azure" in provider :
733737 azure = provider ["azure" ]
734- wire_azure : Dict [str , Any ] = {}
738+ wire_azure : dict [str , Any ] = {}
735739 if "api_version" in azure :
736740 wire_azure ["apiVersion" ] = azure ["api_version" ]
737741 if wire_azure :
738742 wire_provider ["azure" ] = wire_azure
739743 return wire_provider
740744
741- def _convert_custom_agent_to_wire_format (self , agent : Dict [str , Any ]) -> Dict [str , Any ]:
745+ def _convert_custom_agent_to_wire_format (
746+ self , agent : CustomAgentConfig | dict [str , Any ]
747+ ) -> dict [str , Any ]:
742748 """
743749 Convert custom agent config from snake_case to camelCase wire format.
744750
@@ -748,7 +754,7 @@ def _convert_custom_agent_to_wire_format(self, agent: Dict[str, Any]) -> Dict[st
748754 Returns:
749755 The custom agent configuration in camelCase wire format.
750756 """
751- wire_agent : Dict [str , Any ] = {"name" : agent .get ("name" ), "prompt" : agent .get ("prompt" )}
757+ wire_agent : dict [str , Any ] = {"name" : agent .get ("name" ), "prompt" : agent .get ("prompt" )}
752758 if "display_name" in agent :
753759 wire_agent ["displayName" ] = agent ["display_name" ]
754760 if "description" in agent :
0 commit comments