Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 9 additions & 31 deletions lua/CopilotChat/client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,13 @@ end
---@class CopilotChat.client.Client : Class
---@field private providers table<string, CopilotChat.config.providers.Provider>
---@field private provider_cache table<string, table>
---@field private models table<string, CopilotChat.client.Model>?
---@field private model_cache table<string, CopilotChat.client.Model>?
---@field private current_job string?
---@field private headers table<string, string>?
local Client = class(function(self)
self.providers = {}
self.provider_cache = {}
self.models = nil
self.model_cache = nil
self.current_job = nil
self.headers = nil
end)

--- Authenticate with GitHub and get the required headers
Expand All @@ -246,9 +244,9 @@ end

--- Fetch models from the Copilot API
---@return table<string, CopilotChat.client.Model>
function Client:fetch_models()
if self.models then
return self.models
function Client:models()
if self.model_cache then
return self.model_cache
end

local models = {}
Expand Down Expand Up @@ -282,8 +280,8 @@ function Client:fetch_models()
end

log.debug('Fetched models:', #vim.tbl_keys(models))
self.models = models
return self.models
self.model_cache = models
return self.model_cache
end

--- Ask a question to Copilot
Expand All @@ -299,7 +297,7 @@ function Client:ask(prompt, opts)
log.debug('Resources:', #opts.resources)
log.debug('History:', #opts.history)

local models = self:fetch_models()
local models = self:models()
local model_config = models[opts.model]
if not model_config then
error('Model not found: ' .. opts.model)
Expand Down Expand Up @@ -573,26 +571,6 @@ function Client:ask(prompt, opts)
}
end

--- List available models
---@return table<string, table>
function Client:list_models()
local models = self:fetch_models()
local result = vim.tbl_keys(models)

table.sort(result, function(a, b)
a = models[a]
b = models[b]
if a.provider ~= b.provider then
return a.provider < b.provider
end
return a.id < b.id
end)

return vim.tbl_map(function(id)
return models[id]
end, result)
end

--- Generate embeddings for the given inputs
---@param inputs table<CopilotChat.client.Resource>: The inputs to embed
---@param model string
Expand All @@ -603,7 +581,7 @@ function Client:embed(inputs, model)
return inputs
end

local models = self:fetch_models()
local models = self:models()
local ok, provider_name, embed = pcall(resolve_provider_function, 'embed', model, models, self.providers)
if not ok then
---@diagnostic disable-next-line: return-type-mismatch
Expand Down
32 changes: 26 additions & 6 deletions lua/CopilotChat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,26 @@ local function update_highlights()
end
end

--- List available models.
--- @return CopilotChat.client.Model[]
local function list_models()
local models = client:models()
local result = vim.tbl_keys(models)

table.sort(result, function(a, b)
a = models[a]
b = models[b]
if a.provider ~= b.provider then
return a.provider < b.provider
end
return a.id < b.id
end)

return vim.tbl_map(function(id)
return models[id]
end, result)
end

--- Finish writing to chat buffer.
---@param start_of_chat boolean?
local function finish(start_of_chat)
Expand Down Expand Up @@ -284,8 +304,8 @@ function M.resolve_functions(prompt, config)
})
end

-- Resolve each tool reference
local function expand_tool(name, input)
-- Resolve each function reference
local function expand_function(name, input)
notify.publish(notify.STATUS, 'Running function: ' .. name)

local tool_id = nil
Expand Down Expand Up @@ -368,7 +388,7 @@ function M.resolve_functions(prompt, config)
for _, pattern in ipairs(matches:keys()) do
if not utils.empty(pattern) then
local match = matches:get(pattern)
local out = expand_tool(match.word, match.input) or pattern
local out = expand_function(match.word, match.input) or pattern
out = out:gsub('%%', '%%%%') -- Escape percent signs for gsub
prompt = prompt:gsub(vim.pesc(pattern), out, 1)
end
Expand Down Expand Up @@ -440,7 +460,7 @@ function M.resolve_model(prompt, config)

local models = vim.tbl_map(function(model)
return model.id
end, client:list_models())
end, list_models())

local selected_model = config.model or ''
prompt = prompt:gsub('%$' .. WORD, function(match)
Expand Down Expand Up @@ -600,7 +620,7 @@ end
---@return table
---@async
function M.complete_items()
local models = client:list_models()
local models = list_models()
local prompts_to_use = M.prompts()
local items = {}

Expand Down Expand Up @@ -767,7 +787,7 @@ end
--- Select default Copilot GPT model.
function M.select_model()
async.run(function()
local models = client:list_models()
local models = list_models()
local choices = vim.tbl_map(function(model)
return {
id = model.id,
Expand Down
Loading