Skip to content

Commit 746cc7a

Browse files
authored
feat(command): add model command to get/set the model (zbirenbaum#643)
1 parent 48c8886 commit 746cc7a

File tree

5 files changed

+272
-2
lines changed

5 files changed

+272
-2
lines changed

lua/copilot/api/init.lua

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,14 @@ end
154154
---@alias copilot_window_show_document { uri: string, external?: boolean, takeFocus?: boolean, selection?: boolean }
155155
---@alias copilot_window_show_document_result { success: boolean }
156156

157+
---@alias copilot_model { id: string, modelName: string, scopes: string[], preview?: boolean, default?: boolean }
158+
---@alias copilot_models_data copilot_model[]
159+
160+
---@return any|nil err
161+
---@return copilot_models_data data
162+
---@return table ctx
163+
function M.get_models(client, callback)
164+
return M.request(client, "copilot/models", {}, callback)
165+
end
166+
157167
return M

lua/copilot/client/config.lua

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ function M.prepare_client_config(overrides, client)
124124
end
125125

126126
require("copilot.nes").setup(lsp_client)
127+
128+
-- Validate configured model on startup
129+
if config.copilot_model and config.copilot_model ~= "" then
130+
require("copilot.model").validate_current()
131+
end
127132
end)
128133
end,
129134
on_exit = function(code, _, client_id)

lua/copilot/client/utils.lua

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ function M.get_workspace_configurations()
2929
filetypes = vim.tbl_deep_extend("keep", filetypes, client_ft.internal_filetypes)
3030
end
3131

32-
local copilot_model = config and config.copilot_model ~= "" and config.copilot_model or ""
32+
-- Use model module to get the current model (supports runtime override)
33+
local model = require("copilot.model")
34+
local copilot_model = model.get_current_model()
3335

3436
---@type string[]
3537
local disabled_filetypes = vim.tbl_filter(function(ft)

lua/copilot/model.lua

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
local c = require("copilot.client")
2+
local api = require("copilot.api")
3+
local config = require("copilot.config")
4+
local logger = require("copilot.logger")
5+
6+
local M = {}
7+
8+
--- Runtime override of the model (not persisted to user config)
9+
--- When set, this takes precedence over config.copilot_model
10+
---@type string|nil
11+
M.selected_model = nil
12+
13+
--- Get the currently active model ID
14+
---@return string
15+
function M.get_current_model()
16+
return M.selected_model or config.copilot_model or ""
17+
end
18+
19+
--- Filter models that support completions
20+
---@param models copilot_model[]
21+
---@return copilot_model[]
22+
local function get_completion_models(models)
23+
return vim.tbl_filter(function(m)
24+
return vim.tbl_contains(m.scopes or {}, "completion")
25+
end, models)
26+
end
27+
28+
--- Format a model for display
29+
---@param model copilot_model
30+
---@return string
31+
local function format_model(model, show_id)
32+
local parts = { model.modelName }
33+
if show_id then
34+
table.insert(parts, "[" .. model.id .. "]")
35+
end
36+
local annotations = {}
37+
38+
if model.default then
39+
table.insert(annotations, "default")
40+
end
41+
if model.preview then
42+
table.insert(annotations, "preview")
43+
end
44+
45+
if #annotations > 0 then
46+
table.insert(parts, "(" .. table.concat(annotations, ", ") .. ")")
47+
end
48+
49+
return table.concat(parts, " ")
50+
end
51+
52+
--- Apply the selected model by notifying the LSP server
53+
---@param model_id string
54+
local function apply_model(model_id)
55+
M.selected_model = model_id
56+
57+
local client = c.get()
58+
if client then
59+
local utils = require("copilot.client.utils")
60+
local configurations = utils.get_workspace_configurations()
61+
api.notify_change_configuration(client, configurations)
62+
logger.debug("Model changed to: " .. model_id)
63+
end
64+
end
65+
66+
--- Interactive model selection using vim.ui.select
67+
---@param opts? { force?: boolean, args?: string }
68+
function M.select(opts) -- luacheck: ignore opts
69+
local client = c.get()
70+
if not client then
71+
logger.notify("Copilot client not running")
72+
return
73+
end
74+
75+
coroutine.wrap(function()
76+
local err, models = api.get_models(client)
77+
if err then
78+
logger.notify("Failed to get models: " .. vim.inspect(err))
79+
return
80+
end
81+
82+
if not models or #models == 0 then
83+
logger.notify("No models available")
84+
return
85+
end
86+
87+
local completion_models = get_completion_models(models)
88+
if #completion_models == 0 then
89+
logger.notify("No completion models available")
90+
return
91+
end
92+
93+
local current_model = M.get_current_model()
94+
if #completion_models == 1 then
95+
local model = completion_models[1]
96+
local model_name = format_model(model)
97+
logger.notify("Only one completion model available: " .. model_name)
98+
if model.id ~= current_model then
99+
apply_model(model.id)
100+
logger.notify("Copilot model set to: " .. model_name)
101+
else
102+
logger.notify("Copilot model is already set to: " .. model_name)
103+
end
104+
return
105+
end
106+
107+
-- Sort models: default first, then by name
108+
table.sort(completion_models, function(a, b)
109+
if a.default and not b.default then
110+
return true
111+
end
112+
if b.default and not a.default then
113+
return false
114+
end
115+
return a.modelName < b.modelName
116+
end)
117+
118+
vim.ui.select(completion_models, {
119+
prompt = "Select Copilot completion model:",
120+
format_item = function(model)
121+
local display = format_model(model)
122+
if model.id == current_model then
123+
display = display .. " [current]"
124+
end
125+
return display
126+
end,
127+
}, function(selected)
128+
if not selected then
129+
return
130+
end
131+
132+
apply_model(selected.id)
133+
logger.notify("Copilot model set to: " .. format_model(selected))
134+
end)
135+
end)()
136+
end
137+
138+
--- List available completion models
139+
---@param opts? { force?: boolean, args?: string }
140+
function M.list(opts) -- luacheck: ignore opts
141+
local client = c.get()
142+
if not client then
143+
logger.notify("Copilot client not running")
144+
return
145+
end
146+
147+
coroutine.wrap(function()
148+
local err, models = api.get_models(client)
149+
if err then
150+
logger.notify("Failed to get models: " .. vim.inspect(err))
151+
return
152+
end
153+
154+
if not models or #models == 0 then
155+
logger.notify("No models available")
156+
return
157+
end
158+
159+
local completion_models = get_completion_models(models)
160+
if #completion_models == 0 then
161+
logger.notify("No completion models available")
162+
return
163+
end
164+
165+
local current_model = M.get_current_model()
166+
local lines = { "Available completion models:" }
167+
168+
for _, model in ipairs(completion_models) do
169+
local line = " " .. format_model(model, true)
170+
if model.id == current_model then
171+
line = line .. " <- current"
172+
end
173+
table.insert(lines, line)
174+
end
175+
176+
logger.notify(table.concat(lines, "\n"))
177+
end)()
178+
end
179+
180+
--- Show the current model
181+
---@param opts? { force?: boolean, args?: string }
182+
function M.get(opts) -- luacheck: ignore opts
183+
local current = M.get_current_model()
184+
if current == "" then
185+
logger.notify("No model configured (using server default)")
186+
else
187+
logger.notify("Current model: " .. current)
188+
end
189+
end
190+
191+
--- Set the model programmatically
192+
---@param opts { model?: string, force?: boolean, args?: string }
193+
function M.set(opts)
194+
opts = opts or {}
195+
196+
local model_id = opts.model or opts.args
197+
if not model_id or model_id == "" then
198+
logger.notify("Usage: :Copilot model set <model-id>")
199+
return
200+
end
201+
202+
apply_model(model_id)
203+
logger.notify("Copilot model set to: " .. model_id)
204+
end
205+
206+
--- Validate the currently configured model against available models
207+
--- Called on startup to warn if the configured model is invalid
208+
function M.validate_current()
209+
local configured_model = config.copilot_model
210+
if not configured_model or configured_model == "" then
211+
return -- No model configured, nothing to validate
212+
end
213+
214+
local client = c.get()
215+
if not client then
216+
return
217+
end
218+
219+
coroutine.wrap(function()
220+
local err, models = api.get_models(client)
221+
if err then
222+
logger.debug("Failed to validate model: " .. vim.inspect(err))
223+
return
224+
end
225+
226+
if not models or #models == 0 then
227+
return
228+
end
229+
230+
local completion_models = get_completion_models(models)
231+
local valid_ids = vim.tbl_map(function(m)
232+
return m.id
233+
end, completion_models)
234+
235+
if not vim.tbl_contains(valid_ids, configured_model) then
236+
local valid_list = table.concat(valid_ids, ", ")
237+
logger.warn(
238+
string.format(
239+
"Configured copilot_model '%s' is not a valid completion model. Available: %s",
240+
configured_model,
241+
valid_list
242+
)
243+
)
244+
else
245+
logger.debug("Configured model '" .. configured_model .. "' is valid")
246+
end
247+
end)()
248+
end
249+
250+
return M

plugin/copilot.lua

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
local completion_store = {
2-
[""] = { "auth", "attach", "detach", "disable", "enable", "panel", "status", "suggestion", "toggle", "version" },
2+
[""] = { "auth", "attach", "detach", "disable", "enable", "model", "panel", "status", "suggestion", "toggle", "version" },
33
auth = { "signin", "signout", "info" },
4+
model = { "select", "list", "get", "set" },
45
panel = { "accept", "jump_next", "jump_prev", "open", "refresh", "toggle", "close", "is_open" },
56
suggestion = {
67
"accept",
@@ -34,6 +35,8 @@ vim.api.nvim_create_user_command("Copilot", function(opts)
3435
if not action_name then
3536
if mod_name == "auth" then
3637
action_name = "signin"
38+
elseif mod_name == "model" then
39+
action_name = "get"
3740
elseif mod_name == "panel" then
3841
action_name = "open"
3942
elseif mod_name == "suggestion" then

0 commit comments

Comments
 (0)