-
Notifications
You must be signed in to change notification settings - Fork 1.1k
feat: add onListModels handler to CopilotClientOptions for BYOK mode #730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -274,4 +274,104 @@ public async Task Should_Throw_When_ResumeSession_Called_Without_PermissionHandl | |
| Assert.Contains("OnPermissionRequest", ex.Message); | ||
| Assert.Contains("is required", ex.Message); | ||
| } | ||
|
|
||
| [Fact] | ||
| public async Task ListModels_WithCustomHandler_CallsHandler() | ||
| { | ||
| var customModels = new List<ModelInfo> | ||
| { | ||
| new() | ||
| { | ||
| Id = "my-custom-model", | ||
| Name = "My Custom Model", | ||
| Capabilities = new ModelCapabilities | ||
| { | ||
| Supports = new ModelSupports { Vision = false, ReasoningEffort = false }, | ||
| Limits = new ModelLimits { MaxContextWindowTokens = 128000 } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| var callCount = 0; | ||
| await using var client = new CopilotClient(new CopilotClientOptions | ||
| { | ||
| OnListModels = (ct) => | ||
| { | ||
| callCount++; | ||
| return Task.FromResult(customModels); | ||
| } | ||
| }); | ||
| await client.StartAsync(); | ||
|
|
||
| var models = await client.ListModelsAsync(); | ||
| Assert.Equal(1, callCount); | ||
| Assert.Single(models); | ||
| Assert.Equal("my-custom-model", models[0].Id); | ||
| } | ||
|
Comment on lines
+278
to
+310
|
||
|
|
||
| [Fact] | ||
| public async Task ListModels_WithCustomHandler_CachesResults() | ||
| { | ||
| var customModels = new List<ModelInfo> | ||
| { | ||
| new() | ||
| { | ||
| Id = "cached-model", | ||
| Name = "Cached Model", | ||
| Capabilities = new ModelCapabilities | ||
| { | ||
| Supports = new ModelSupports { Vision = false, ReasoningEffort = false }, | ||
| Limits = new ModelLimits { MaxContextWindowTokens = 128000 } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| var callCount = 0; | ||
| await using var client = new CopilotClient(new CopilotClientOptions | ||
| { | ||
| OnListModels = (ct) => | ||
| { | ||
| callCount++; | ||
| return Task.FromResult(customModels); | ||
| } | ||
| }); | ||
| await client.StartAsync(); | ||
|
|
||
| await client.ListModelsAsync(); | ||
| await client.ListModelsAsync(); | ||
| Assert.Equal(1, callCount); // Only called once due to caching | ||
| } | ||
|
|
||
| [Fact] | ||
| public async Task ListModels_WithCustomHandler_WorksWithoutStart() | ||
| { | ||
| var customModels = new List<ModelInfo> | ||
| { | ||
| new() | ||
| { | ||
| Id = "no-start-model", | ||
| Name = "No Start Model", | ||
| Capabilities = new ModelCapabilities | ||
| { | ||
| Supports = new ModelSupports { Vision = false, ReasoningEffort = false }, | ||
| Limits = new ModelLimits { MaxContextWindowTokens = 128000 } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| var callCount = 0; | ||
| await using var client = new CopilotClient(new CopilotClientOptions | ||
| { | ||
| OnListModels = (ct) => | ||
| { | ||
| callCount++; | ||
| return Task.FromResult(customModels); | ||
| } | ||
| }); | ||
|
|
||
| var models = await client.ListModelsAsync(); | ||
| Assert.Equal(1, callCount); | ||
| Assert.Single(models); | ||
| Assert.Equal("no-start-model", models[0].Id); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -92,6 +92,7 @@ type Client struct { | |
| processErrorPtr *error | ||
| osProcess atomic.Pointer[os.Process] | ||
| negotiatedProtocolVersion int | ||
| onListModels func(ctx context.Context) ([]ModelInfo, error) | ||
|
|
||
| // RPC provides typed server-scoped RPC methods. | ||
| // This field is nil until the client is connected via Start(). | ||
|
|
@@ -188,6 +189,9 @@ func NewClient(options *ClientOptions) *Client { | |
| if options.UseLoggedInUser != nil { | ||
| opts.UseLoggedInUser = options.UseLoggedInUser | ||
| } | ||
| if options.OnListModels != nil { | ||
| client.onListModels = options.OnListModels | ||
| } | ||
| } | ||
|
|
||
| // Default Env to current environment if not set | ||
|
|
@@ -1035,40 +1039,51 @@ func (c *Client) GetAuthStatus(ctx context.Context) (*GetAuthStatusResponse, err | |
| // Results are cached after the first successful call to avoid rate limiting. | ||
| // The cache is cleared when the client disconnects. | ||
| func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) { | ||
| if c.client == nil { | ||
| return nil, fmt.Errorf("client not connected") | ||
| } | ||
|
|
||
| // Use mutex for locking to prevent race condition with concurrent calls | ||
| c.modelsCacheMux.Lock() | ||
| defer c.modelsCacheMux.Unlock() | ||
|
|
||
| // Check cache (already inside lock) | ||
| if c.modelsCache != nil { | ||
| // Return a copy to prevent cache mutation | ||
| result := make([]ModelInfo, len(c.modelsCache)) | ||
| copy(result, c.modelsCache) | ||
| return result, nil | ||
| } | ||
|
|
||
| // Cache miss - fetch from backend while holding lock | ||
| result, err := c.client.Request("models.list", listModelsRequest{}) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| var models []ModelInfo | ||
| if c.onListModels != nil { | ||
| // Use custom handler instead of CLI RPC | ||
| var err error | ||
| models, err = c.onListModels(ctx) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| } else { | ||
| if c.client == nil { | ||
| return nil, fmt.Errorf("client not connected") | ||
| } | ||
| // Cache miss - fetch from backend while holding lock | ||
| result, err := c.client.Request("models.list", listModelsRequest{}) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| var response listModelsResponse | ||
| if err := json.Unmarshal(result, &response); err != nil { | ||
| return nil, fmt.Errorf("failed to unmarshal models response: %w", err) | ||
| var response listModelsResponse | ||
| if err := json.Unmarshal(result, &response); err != nil { | ||
| return nil, fmt.Errorf("failed to unmarshal models response: %w", err) | ||
| } | ||
| models = response.Models | ||
| } | ||
|
|
||
| // Update cache before releasing lock | ||
| c.modelsCache = response.Models | ||
| // Update cache before releasing lock (copy to prevent external mutation) | ||
| cache := make([]ModelInfo, len(models)) | ||
| copy(cache, models) | ||
| c.modelsCache = cache | ||
|
|
||
|
Comment on lines
+1053
to
1082
|
||
| // Return a copy to prevent cache mutation | ||
| models := make([]ModelInfo, len(response.Models)) | ||
| copy(models, response.Models) | ||
| return models, nil | ||
| result := make([]ModelInfo, len(models)) | ||
| copy(result, models) | ||
| return result, nil | ||
| } | ||
|
|
||
| // minProtocolVersion is the minimum protocol version this SDK can communicate with. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the custom handler path, the list returned by
_onListModelsis assigned directly to_modelsCache. Because the caller likely owns thatList<ModelInfo>instance, later mutations can affect cached results unexpectedly. Consider caching a copy (e.g.,models.ToList()) so the cache is isolated from external changes, consistent with the RPC path.