Skip to content

Commit e478657

Browse files
patnikoCopilot
andauthored
feat: add onListModels handler to CopilotClientOptions for BYOK mode (#730)
Add an optional onListModels handler to CopilotClientOptions across all 4 SDKs (Node, Python, Go, .NET). When provided, client.listModels() calls the handler instead of sending the models.list RPC to the CLI server. This enables BYOK users to return their provider's available models in the standard ModelInfo format. - Handler completely replaces CLI RPC when set (no fallback) - Results cached identically to CLI path (same locking/thread-safety) - No connection required when handler is provided - Supports both sync and async handlers - 10 new unit tests across all SDKs - Updated BYOK docs with usage examples in all 4 languages Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 6195e3e commit e478657

File tree

13 files changed

+600
-51
lines changed

13 files changed

+600
-51
lines changed

docs/auth/byok.md

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,116 @@ provider: {
306306

307307
> **Note:** The `bearerToken` option accepts a **static token string** only. The SDK does not refresh this token automatically. If your token expires, requests will fail and you'll need to create a new session with a fresh token.
308308
309+
## Custom Model Listing
310+
311+
When using BYOK, the CLI server may not know which models your provider supports. You can supply a custom `onListModels` handler at the client level so that `client.listModels()` returns your provider's models in the standard `ModelInfo` format. This lets downstream consumers discover available models without querying the CLI.
312+
313+
<details open>
314+
<summary><strong>Node.js / TypeScript</strong></summary>
315+
316+
```typescript
317+
import { CopilotClient } from "@github/copilot-sdk";
318+
import type { ModelInfo } from "@github/copilot-sdk";
319+
320+
const client = new CopilotClient({
321+
onListModels: () => [
322+
{
323+
id: "my-custom-model",
324+
name: "My Custom Model",
325+
capabilities: {
326+
supports: { vision: false, reasoningEffort: false },
327+
limits: { max_context_window_tokens: 128000 },
328+
},
329+
},
330+
],
331+
});
332+
```
333+
334+
</details>
335+
336+
<details>
337+
<summary><strong>Python</strong></summary>
338+
339+
```python
340+
from copilot import CopilotClient
341+
from copilot.types import ModelInfo, ModelCapabilities, ModelSupports, ModelLimits
342+
343+
client = CopilotClient({
344+
"on_list_models": lambda: [
345+
ModelInfo(
346+
id="my-custom-model",
347+
name="My Custom Model",
348+
capabilities=ModelCapabilities(
349+
supports=ModelSupports(vision=False, reasoning_effort=False),
350+
limits=ModelLimits(max_context_window_tokens=128000),
351+
),
352+
)
353+
],
354+
})
355+
```
356+
357+
</details>
358+
359+
<details>
360+
<summary><strong>Go</strong></summary>
361+
362+
```go
363+
package main
364+
365+
import (
366+
"context"
367+
copilot "github.com/github/copilot-sdk/go"
368+
)
369+
370+
func main() {
371+
client := copilot.NewClient(&copilot.ClientOptions{
372+
OnListModels: func(ctx context.Context) ([]copilot.ModelInfo, error) {
373+
return []copilot.ModelInfo{
374+
{
375+
ID: "my-custom-model",
376+
Name: "My Custom Model",
377+
Capabilities: copilot.ModelCapabilities{
378+
Supports: copilot.ModelSupports{Vision: false, ReasoningEffort: false},
379+
Limits: copilot.ModelLimits{MaxContextWindowTokens: 128000},
380+
},
381+
},
382+
}, nil
383+
},
384+
})
385+
_ = client
386+
}
387+
```
388+
389+
</details>
390+
391+
<details>
392+
<summary><strong>.NET</strong></summary>
393+
394+
```csharp
395+
using GitHub.Copilot.SDK;
396+
397+
var client = new CopilotClient(new CopilotClientOptions
398+
{
399+
OnListModels = (ct) => Task.FromResult(new List<ModelInfo>
400+
{
401+
new()
402+
{
403+
Id = "my-custom-model",
404+
Name = "My Custom Model",
405+
Capabilities = new ModelCapabilities
406+
{
407+
Supports = new ModelSupports { Vision = false, ReasoningEffort = false },
408+
Limits = new ModelLimits { MaxContextWindowTokens = 128000 }
409+
}
410+
}
411+
})
412+
});
413+
```
414+
415+
</details>
416+
417+
Results are cached after the first call, just like the default behavior. The handler completely replaces the CLI's `models.list` RPC — no fallback to the server occurs.
418+
309419
## Limitations
310420

311421
When using BYOK, be aware of these limitations:

dotnet/src/Client.cs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable
7070
private int? _negotiatedProtocolVersion;
7171
private List<ModelInfo>? _modelsCache;
7272
private readonly SemaphoreSlim _modelsCacheLock = new(1, 1);
73+
private readonly Func<CancellationToken, Task<List<ModelInfo>>>? _onListModels;
7374
private readonly List<Action<SessionLifecycleEvent>> _lifecycleHandlers = [];
7475
private readonly Dictionary<string, List<Action<SessionLifecycleEvent>>> _typedLifecycleHandlers = [];
7576
private readonly object _lifecycleHandlersLock = new();
@@ -136,6 +137,7 @@ public CopilotClient(CopilotClientOptions? options = null)
136137
}
137138

138139
_logger = _options.Logger ?? NullLogger.Instance;
140+
_onListModels = _options.OnListModels;
139141

140142
// Parse CliUrl if provided
141143
if (!string.IsNullOrEmpty(_options.CliUrl))
@@ -624,9 +626,6 @@ public async Task<GetAuthStatusResponse> GetAuthStatusAsync(CancellationToken ca
624626
/// <exception cref="InvalidOperationException">Thrown when the client is not connected or not authenticated.</exception>
625627
public async Task<List<ModelInfo>> ListModelsAsync(CancellationToken cancellationToken = default)
626628
{
627-
var connection = await EnsureConnectedAsync(cancellationToken);
628-
629-
// Use semaphore for async locking to prevent race condition with concurrent calls
630629
await _modelsCacheLock.WaitAsync(cancellationToken);
631630
try
632631
{
@@ -636,14 +635,26 @@ public async Task<List<ModelInfo>> ListModelsAsync(CancellationToken cancellatio
636635
return [.. _modelsCache]; // Return a copy to prevent cache mutation
637636
}
638637

639-
// Cache miss - fetch from backend while holding lock
640-
var response = await InvokeRpcAsync<GetModelsResponse>(
641-
connection.Rpc, "models.list", [], cancellationToken);
638+
List<ModelInfo> models;
639+
if (_onListModels is not null)
640+
{
641+
// Use custom handler instead of CLI RPC
642+
models = await _onListModels(cancellationToken);
643+
}
644+
else
645+
{
646+
var connection = await EnsureConnectedAsync(cancellationToken);
647+
648+
// Cache miss - fetch from backend while holding lock
649+
var response = await InvokeRpcAsync<GetModelsResponse>(
650+
connection.Rpc, "models.list", [], cancellationToken);
651+
models = response.Models;
652+
}
642653

643-
// Update cache before releasing lock
644-
_modelsCache = response.Models;
654+
// Update cache before releasing lock (copy to prevent external mutation)
655+
_modelsCache = [.. models];
645656

646-
return [.. response.Models]; // Return a copy to prevent cache mutation
657+
return [.. models]; // Return a copy to prevent cache mutation
647658
}
648659
finally
649660
{

dotnet/src/Types.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ protected CopilotClientOptions(CopilotClientOptions? other)
6363
Port = other.Port;
6464
UseLoggedInUser = other.UseLoggedInUser;
6565
UseStdio = other.UseStdio;
66+
OnListModels = other.OnListModels;
6667
}
6768

6869
/// <summary>
@@ -136,6 +137,14 @@ public string? GithubToken
136137
/// </summary>
137138
public bool? UseLoggedInUser { get; set; }
138139

140+
/// <summary>
141+
/// Custom handler for listing available models.
142+
/// When provided, <c>ListModelsAsync()</c> calls this handler instead of
143+
/// querying the CLI server. Useful in BYOK mode to return models
144+
/// available from your custom provider.
145+
/// </summary>
146+
public Func<CancellationToken, Task<List<ModelInfo>>>? OnListModels { get; set; }
147+
139148
/// <summary>
140149
/// Creates a shallow clone of this <see cref="CopilotClientOptions"/> instance.
141150
/// </summary>

dotnet/test/ClientTests.cs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,104 @@ public async Task Should_Throw_When_ResumeSession_Called_Without_PermissionHandl
274274
Assert.Contains("OnPermissionRequest", ex.Message);
275275
Assert.Contains("is required", ex.Message);
276276
}
277+
278+
[Fact]
279+
public async Task ListModels_WithCustomHandler_CallsHandler()
280+
{
281+
var customModels = new List<ModelInfo>
282+
{
283+
new()
284+
{
285+
Id = "my-custom-model",
286+
Name = "My Custom Model",
287+
Capabilities = new ModelCapabilities
288+
{
289+
Supports = new ModelSupports { Vision = false, ReasoningEffort = false },
290+
Limits = new ModelLimits { MaxContextWindowTokens = 128000 }
291+
}
292+
}
293+
};
294+
295+
var callCount = 0;
296+
await using var client = new CopilotClient(new CopilotClientOptions
297+
{
298+
OnListModels = (ct) =>
299+
{
300+
callCount++;
301+
return Task.FromResult(customModels);
302+
}
303+
});
304+
await client.StartAsync();
305+
306+
var models = await client.ListModelsAsync();
307+
Assert.Equal(1, callCount);
308+
Assert.Single(models);
309+
Assert.Equal("my-custom-model", models[0].Id);
310+
}
311+
312+
[Fact]
313+
public async Task ListModels_WithCustomHandler_CachesResults()
314+
{
315+
var customModels = new List<ModelInfo>
316+
{
317+
new()
318+
{
319+
Id = "cached-model",
320+
Name = "Cached Model",
321+
Capabilities = new ModelCapabilities
322+
{
323+
Supports = new ModelSupports { Vision = false, ReasoningEffort = false },
324+
Limits = new ModelLimits { MaxContextWindowTokens = 128000 }
325+
}
326+
}
327+
};
328+
329+
var callCount = 0;
330+
await using var client = new CopilotClient(new CopilotClientOptions
331+
{
332+
OnListModels = (ct) =>
333+
{
334+
callCount++;
335+
return Task.FromResult(customModels);
336+
}
337+
});
338+
await client.StartAsync();
339+
340+
await client.ListModelsAsync();
341+
await client.ListModelsAsync();
342+
Assert.Equal(1, callCount); // Only called once due to caching
343+
}
344+
345+
[Fact]
346+
public async Task ListModels_WithCustomHandler_WorksWithoutStart()
347+
{
348+
var customModels = new List<ModelInfo>
349+
{
350+
new()
351+
{
352+
Id = "no-start-model",
353+
Name = "No Start Model",
354+
Capabilities = new ModelCapabilities
355+
{
356+
Supports = new ModelSupports { Vision = false, ReasoningEffort = false },
357+
Limits = new ModelLimits { MaxContextWindowTokens = 128000 }
358+
}
359+
}
360+
};
361+
362+
var callCount = 0;
363+
await using var client = new CopilotClient(new CopilotClientOptions
364+
{
365+
OnListModels = (ct) =>
366+
{
367+
callCount++;
368+
return Task.FromResult(customModels);
369+
}
370+
});
371+
372+
var models = await client.ListModelsAsync();
373+
Assert.Equal(1, callCount);
374+
Assert.Single(models);
375+
Assert.Equal("no-start-model", models[0].Id);
376+
}
277377
}

go/client.go

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ type Client struct {
9292
processErrorPtr *error
9393
osProcess atomic.Pointer[os.Process]
9494
negotiatedProtocolVersion int
95+
onListModels func(ctx context.Context) ([]ModelInfo, error)
9596

9697
// RPC provides typed server-scoped RPC methods.
9798
// This field is nil until the client is connected via Start().
@@ -188,6 +189,9 @@ func NewClient(options *ClientOptions) *Client {
188189
if options.UseLoggedInUser != nil {
189190
opts.UseLoggedInUser = options.UseLoggedInUser
190191
}
192+
if options.OnListModels != nil {
193+
client.onListModels = options.OnListModels
194+
}
191195
}
192196

193197
// Default Env to current environment if not set
@@ -1035,40 +1039,51 @@ func (c *Client) GetAuthStatus(ctx context.Context) (*GetAuthStatusResponse, err
10351039
// Results are cached after the first successful call to avoid rate limiting.
10361040
// The cache is cleared when the client disconnects.
10371041
func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) {
1038-
if c.client == nil {
1039-
return nil, fmt.Errorf("client not connected")
1040-
}
1041-
10421042
// Use mutex for locking to prevent race condition with concurrent calls
10431043
c.modelsCacheMux.Lock()
10441044
defer c.modelsCacheMux.Unlock()
10451045

10461046
// Check cache (already inside lock)
10471047
if c.modelsCache != nil {
1048-
// Return a copy to prevent cache mutation
10491048
result := make([]ModelInfo, len(c.modelsCache))
10501049
copy(result, c.modelsCache)
10511050
return result, nil
10521051
}
10531052

1054-
// Cache miss - fetch from backend while holding lock
1055-
result, err := c.client.Request("models.list", listModelsRequest{})
1056-
if err != nil {
1057-
return nil, err
1058-
}
1053+
var models []ModelInfo
1054+
if c.onListModels != nil {
1055+
// Use custom handler instead of CLI RPC
1056+
var err error
1057+
models, err = c.onListModels(ctx)
1058+
if err != nil {
1059+
return nil, err
1060+
}
1061+
} else {
1062+
if c.client == nil {
1063+
return nil, fmt.Errorf("client not connected")
1064+
}
1065+
// Cache miss - fetch from backend while holding lock
1066+
result, err := c.client.Request("models.list", listModelsRequest{})
1067+
if err != nil {
1068+
return nil, err
1069+
}
10591070

1060-
var response listModelsResponse
1061-
if err := json.Unmarshal(result, &response); err != nil {
1062-
return nil, fmt.Errorf("failed to unmarshal models response: %w", err)
1071+
var response listModelsResponse
1072+
if err := json.Unmarshal(result, &response); err != nil {
1073+
return nil, fmt.Errorf("failed to unmarshal models response: %w", err)
1074+
}
1075+
models = response.Models
10631076
}
10641077

1065-
// Update cache before releasing lock
1066-
c.modelsCache = response.Models
1078+
// Update cache before releasing lock (copy to prevent external mutation)
1079+
cache := make([]ModelInfo, len(models))
1080+
copy(cache, models)
1081+
c.modelsCache = cache
10671082

10681083
// Return a copy to prevent cache mutation
1069-
models := make([]ModelInfo, len(response.Models))
1070-
copy(models, response.Models)
1071-
return models, nil
1084+
result := make([]ModelInfo, len(models))
1085+
copy(result, models)
1086+
return result, nil
10721087
}
10731088

10741089
// minProtocolVersion is the minimum protocol version this SDK can communicate with.

0 commit comments

Comments
 (0)