diff --git a/lua/CopilotChat/client.lua b/lua/CopilotChat/client.lua index 93143754..b0a48b85 100644 --- a/lua/CopilotChat/client.lua +++ b/lua/CopilotChat/client.lua @@ -216,15 +216,13 @@ end ---@class CopilotChat.client.Client : Class ---@field private providers table ---@field private provider_cache table ----@field private models table? +---@field private model_cache table? ---@field private current_job string? ----@field private headers table? local Client = class(function(self) self.providers = {} self.provider_cache = {} - self.models = nil + self.model_cache = nil self.current_job = nil - self.headers = nil end) --- Authenticate with GitHub and get the required headers @@ -246,9 +244,9 @@ end --- Fetch models from the Copilot API ---@return table -function Client:fetch_models() - if self.models then - return self.models +function Client:models() + if self.model_cache then + return self.model_cache end local models = {} @@ -282,8 +280,8 @@ function Client:fetch_models() end log.debug('Fetched models:', #vim.tbl_keys(models)) - self.models = models - return self.models + self.model_cache = models + return self.model_cache end --- Ask a question to Copilot @@ -299,7 +297,7 @@ function Client:ask(prompt, opts) log.debug('Resources:', #opts.resources) log.debug('History:', #opts.history) - local models = self:fetch_models() + local models = self:models() local model_config = models[opts.model] if not model_config then error('Model not found: ' .. opts.model) @@ -573,26 +571,6 @@ function Client:ask(prompt, opts) } end ---- List available models ----@return table -function Client:list_models() - local models = self:fetch_models() - local result = vim.tbl_keys(models) - - table.sort(result, function(a, b) - a = models[a] - b = models[b] - if a.provider ~= b.provider then - return a.provider < b.provider - end - return a.id < b.id - end) - - return vim.tbl_map(function(id) - return models[id] - end, result) -end - --- Generate embeddings for the given inputs ---@param inputs table: The inputs to embed ---@param model string @@ -603,7 +581,7 @@ function Client:embed(inputs, model) return inputs end - local models = self:fetch_models() + local models = self:models() local ok, provider_name, embed = pcall(resolve_provider_function, 'embed', model, models, self.providers) if not ok then ---@diagnostic disable-next-line: return-type-mismatch diff --git a/lua/CopilotChat/init.lua b/lua/CopilotChat/init.lua index 9900e62f..9ef35e85 100644 --- a/lua/CopilotChat/init.lua +++ b/lua/CopilotChat/init.lua @@ -120,6 +120,26 @@ local function update_highlights() end end +--- List available models. +--- @return CopilotChat.client.Model[] +local function list_models() + local models = client:models() + local result = vim.tbl_keys(models) + + table.sort(result, function(a, b) + a = models[a] + b = models[b] + if a.provider ~= b.provider then + return a.provider < b.provider + end + return a.id < b.id + end) + + return vim.tbl_map(function(id) + return models[id] + end, result) +end + --- Finish writing to chat buffer. ---@param start_of_chat boolean? local function finish(start_of_chat) @@ -284,8 +304,8 @@ function M.resolve_functions(prompt, config) }) end - -- Resolve each tool reference - local function expand_tool(name, input) + -- Resolve each function reference + local function expand_function(name, input) notify.publish(notify.STATUS, 'Running function: ' .. name) local tool_id = nil @@ -368,7 +388,7 @@ function M.resolve_functions(prompt, config) for _, pattern in ipairs(matches:keys()) do if not utils.empty(pattern) then local match = matches:get(pattern) - local out = expand_tool(match.word, match.input) or pattern + local out = expand_function(match.word, match.input) or pattern out = out:gsub('%%', '%%%%') -- Escape percent signs for gsub prompt = prompt:gsub(vim.pesc(pattern), out, 1) end @@ -440,7 +460,7 @@ function M.resolve_model(prompt, config) local models = vim.tbl_map(function(model) return model.id - end, client:list_models()) + end, list_models()) local selected_model = config.model or '' prompt = prompt:gsub('%$' .. WORD, function(match) @@ -600,7 +620,7 @@ end ---@return table ---@async function M.complete_items() - local models = client:list_models() + local models = list_models() local prompts_to_use = M.prompts() local items = {} @@ -767,7 +787,7 @@ end --- Select default Copilot GPT model. function M.select_model() async.run(function() - local models = client:list_models() + local models = list_models() local choices = vim.tbl_map(function(model) return { id = model.id,