forked from zbirenbaum/copilot.lua
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.lua
More file actions
250 lines (214 loc) · 6.59 KB
/
model.lua
File metadata and controls
250 lines (214 loc) · 6.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
local c = require("copilot.client")
local api = require("copilot.api")
local config = require("copilot.config")
local logger = require("copilot.logger")
local M = {}
--- Runtime override of the model (not persisted to user config)
--- When set, this takes precedence over config.copilot_model
---@type string|nil
M.selected_model = nil
--- Get the currently active model ID
---@return string
function M.get_current_model()
return M.selected_model or config.copilot_model or ""
end
--- Filter models that support completions
---@param models copilot_model[]
---@return copilot_model[]
local function get_completion_models(models)
return vim.tbl_filter(function(m)
return vim.tbl_contains(m.scopes or {}, "completion")
end, models)
end
--- Format a model for display
---@param model copilot_model
---@return string
local function format_model(model, show_id)
local parts = { model.modelName }
if show_id then
table.insert(parts, "[" .. model.id .. "]")
end
local annotations = {}
if model.default then
table.insert(annotations, "default")
end
if model.preview then
table.insert(annotations, "preview")
end
if #annotations > 0 then
table.insert(parts, "(" .. table.concat(annotations, ", ") .. ")")
end
return table.concat(parts, " ")
end
--- Apply the selected model by notifying the LSP server
---@param model_id string
local function apply_model(model_id)
M.selected_model = model_id
local client = c.get()
if client then
local utils = require("copilot.client.utils")
local configurations = utils.get_workspace_configurations()
api.notify_change_configuration(client, configurations)
logger.debug("Model changed to: " .. model_id)
end
end
--- Interactive model selection using vim.ui.select
---@param opts? { force?: boolean, args?: string }
function M.select(opts) -- luacheck: ignore opts
local client = c.get()
if not client then
logger.notify("Copilot client not running")
return
end
coroutine.wrap(function()
local err, models = api.get_models(client)
if err then
logger.notify("Failed to get models: " .. vim.inspect(err))
return
end
if not models or #models == 0 then
logger.notify("No models available")
return
end
local completion_models = get_completion_models(models)
if #completion_models == 0 then
logger.notify("No completion models available")
return
end
local current_model = M.get_current_model()
if #completion_models == 1 then
local model = completion_models[1]
local model_name = format_model(model)
logger.notify("Only one completion model available: " .. model_name)
if model.id ~= current_model then
apply_model(model.id)
logger.notify("Copilot model set to: " .. model_name)
else
logger.notify("Copilot model is already set to: " .. model_name)
end
return
end
-- Sort models: default first, then by name
table.sort(completion_models, function(a, b)
if a.default and not b.default then
return true
end
if b.default and not a.default then
return false
end
return a.modelName < b.modelName
end)
vim.ui.select(completion_models, {
prompt = "Select Copilot completion model:",
format_item = function(model)
local display = format_model(model)
if model.id == current_model then
display = display .. " [current]"
end
return display
end,
}, function(selected)
if not selected then
return
end
apply_model(selected.id)
logger.notify("Copilot model set to: " .. format_model(selected))
end)
end)()
end
--- List available completion models
---@param opts? { force?: boolean, args?: string }
function M.list(opts) -- luacheck: ignore opts
local client = c.get()
if not client then
logger.notify("Copilot client not running")
return
end
coroutine.wrap(function()
local err, models = api.get_models(client)
if err then
logger.notify("Failed to get models: " .. vim.inspect(err))
return
end
if not models or #models == 0 then
logger.notify("No models available")
return
end
local completion_models = get_completion_models(models)
if #completion_models == 0 then
logger.notify("No completion models available")
return
end
local current_model = M.get_current_model()
local lines = { "Available completion models:" }
for _, model in ipairs(completion_models) do
local line = " " .. format_model(model, true)
if model.id == current_model then
line = line .. " <- current"
end
table.insert(lines, line)
end
logger.notify(table.concat(lines, "\n"))
end)()
end
--- Show the current model
---@param opts? { force?: boolean, args?: string }
function M.get(opts) -- luacheck: ignore opts
local current = M.get_current_model()
if current == "" then
logger.notify("No model configured (using server default)")
else
logger.notify("Current model: " .. current)
end
end
--- Set the model programmatically
---@param opts { model?: string, force?: boolean, args?: string }
function M.set(opts)
opts = opts or {}
local model_id = opts.model or opts.args
if not model_id or model_id == "" then
logger.notify("Usage: :Copilot model set <model-id>")
return
end
apply_model(model_id)
logger.notify("Copilot model set to: " .. model_id)
end
--- Validate the currently configured model against available models
--- Called on startup to warn if the configured model is invalid
function M.validate_current()
local configured_model = config.copilot_model
if not configured_model or configured_model == "" then
return -- No model configured, nothing to validate
end
local client = c.get()
if not client then
return
end
coroutine.wrap(function()
local err, models = api.get_models(client)
if err then
logger.debug("Failed to validate model: " .. vim.inspect(err))
return
end
if not models or #models == 0 then
return
end
local completion_models = get_completion_models(models)
local valid_ids = vim.tbl_map(function(m)
return m.id
end, completion_models)
if not vim.tbl_contains(valid_ids, configured_model) then
local valid_list = table.concat(valid_ids, ", ")
logger.warn(
string.format(
"Configured copilot_model '%s' is not a valid completion model. Available: %s",
configured_model,
valid_list
)
)
else
logger.debug("Configured model '" .. configured_model .. "' is valid")
end
end)()
end
return M