Skip to content

Commit ffb6659

Browse files
authored
fix(client): store models cache per provider (CopilotC-Nvim#1291)
Signed-off-by: Tomas Slusny <slusnucky@gmail.com>
1 parent 93d3bb9 commit ffb6659

3 files changed

Lines changed: 30 additions & 24 deletions

File tree

lua/CopilotChat/client.lua

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,12 @@ end
164164
---@class CopilotChat.client.Client : Class
165165
---@field private provider_resolver function():table<string, CopilotChat.config.providers.Provider>
166166
---@field private provider_cache table<string, table>
167-
---@field private model_cache table<string, CopilotChat.client.Model>?
168167
---@field private current_job string?
169168
local Client = class(function(self)
170169
self.provider_resolver = nil
171170
self.provider_cache = vim.defaulttable(function()
172171
return {}
173172
end)
174-
self.model_cache = nil
175173
self.current_job = nil
176174
end)
177175

@@ -211,44 +209,49 @@ end
211209
--- Fetch models from the Copilot API
212210
---@return table<string, CopilotChat.client.Model>
213211
function Client:models()
214-
if self.model_cache then
215-
return self.model_cache
216-
end
217-
218212
local models = {}
219213
local providers = self:get_providers()
220214
local provider_order = vim.tbl_keys(providers)
221215
table.sort(provider_order)
222216
for _, provider_name in ipairs(provider_order) do
223217
local provider = providers[provider_name]
224218
if not provider.disabled and provider.get_models then
225-
notify.publish(notify.STATUS, 'Fetching models from ' .. provider_name)
226-
local ok, headers = pcall(self.authenticate, self, provider_name)
227-
if not ok then
228-
log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers)
229-
goto continue
230-
end
231-
local ok, provider_models = pcall(provider.get_models, headers)
232-
if not ok then
233-
log.warn('Failed to fetch models from ' .. provider_name .. ': ' .. provider_models)
234-
goto continue
219+
local cache = self.provider_cache[provider_name]
220+
local resolved_models = nil
221+
if cache and cache.models then
222+
resolved_models = cache.models
223+
else
224+
notify.publish(notify.STATUS, 'Fetching models from ' .. provider_name)
225+
local ok, headers = pcall(self.authenticate, self, provider_name)
226+
if not ok then
227+
log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers)
228+
goto continue
229+
end
230+
local ok, provider_models = pcall(provider.get_models, headers)
231+
if not ok then
232+
log.warn('Failed to fetch models from ' .. provider_name .. ': ' .. provider_models)
233+
goto continue
234+
end
235+
resolved_models = provider_models
236+
cache.models = resolved_models
235237
end
236238

237-
for _, model in ipairs(provider_models) do
238-
model.provider = provider_name
239-
if models[model.id] then
240-
model.id = model.id .. ':' .. provider_name
239+
if resolved_models then
240+
for _, model in ipairs(resolved_models) do
241+
model.provider = provider_name
242+
if models[model.id] then
243+
model.id = model.id .. ':' .. provider_name
244+
end
245+
models[model.id] = model
241246
end
242-
models[model.id] = model
243247
end
244248

245249
::continue::
246250
end
247251
end
248252

249253
log.debug('Fetched models:', #vim.tbl_keys(models))
250-
self.model_cache = models
251-
return self.model_cache
254+
return models
252255
end
253256

254257
--- Get information about all providers

lua/CopilotChat/config/mappings.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ return {
519519
end
520520

521521
table.insert(lines, header)
522-
table.insert(lines, '```' .. resource.type)
522+
table.insert(lines, '```' .. utils.mimetype_to_filetype(resource.mimetype))
523523
for _, line in ipairs(preview) do
524524
table.insert(lines, line)
525525
end

lua/CopilotChat/utils.lua

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ function M.filetype_to_mimetype(filetype)
240240
if filetype == 'html' or filetype == 'css' then
241241
return 'text/' .. filetype
242242
end
243+
if filetype:find('/') then
244+
return filetype
245+
end
243246
return 'text/x-' .. filetype
244247
end
245248

0 commit comments

Comments
 (0)