Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 225 additions & 0 deletions lua/CopilotChat/copilot.lua
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
---@field load fun(self: CopilotChat.Copilot, name: string, path: string):table
---@field running fun(self: CopilotChat.Copilot):boolean
---@field list_models fun(self: CopilotChat.Copilot, callback: fun(table):nil):nil
---@field ask_agent fun(self: CopilotChat.Copilot, prompt: string, opts: CopilotChat.copilot.ask.opts):nil

local log = require('plenary.log')
local curl = require('plenary.curl')
Expand Down Expand Up @@ -843,4 +844,228 @@ function Copilot:running()
return self.current_job ~= nil
end

--- Ask a question to Copilot agent
---@param prompt string: The prompt to send to Copilot
---@param opts CopilotChat.copilot.ask.opts: Options for the request
function Copilot:ask_agent(prompt, opts)
opts = opts or {}
local embeddings = opts.embeddings or {}
local filename = opts.filename or ''
local filetype = opts.filetype or ''
local selection = opts.selection or ''
local start_row = opts.start_row or 0
local end_row = opts.end_row or 0
local system_prompt = opts.system_prompt or prompts.COPILOT_INSTRUCTIONS
local model = opts.model or 'gpt-4o-2024-05-13'
local temperature = opts.temperature or 0.1
local on_done = opts.on_done
local on_progress = opts.on_progress
local on_error = opts.on_error

log.trace('System prompt: ' .. system_prompt)
log.trace('Selection: ' .. selection)
log.debug('Prompt: ' .. prompt)
log.debug('Embeddings: ' .. #embeddings)
log.debug('Filename: ' .. filename)
log.debug('Filetype: ' .. filetype)
log.debug('Model: ' .. model)
log.debug('Temperature: ' .. temperature)

-- If we already have running job, cancel it and notify the user
if self.current_job then
self:stop()
end

self:with_auth(function()
self:with_models(function()
local capabilities = self.models[model] and self.models[model].capabilities
or { limits = { max_prompt_tokens = 8192 }, tokenizer = 'cl100k_base' }
local max_tokens = capabilities.limits.max_prompt_tokens -- FIXME: Is max_prompt_tokens the right limit?
local tokenizer = capabilities.tokenizer
log.debug('Max tokens: ' .. max_tokens)
log.debug('Tokenizer: ' .. tokenizer)

local selection_message =
generate_selection_message(filename, filetype, start_row, end_row, selection)
local embeddings_message = generate_embeddings_message(embeddings)

tiktoken.load(tokenizer, function()
-- Count required tokens that we cannot reduce
local prompt_tokens = tiktoken.count(prompt)
local system_tokens = tiktoken.count(system_prompt)
local selection_tokens = tiktoken.count(selection_message)
local required_tokens = prompt_tokens + system_tokens + selection_tokens

-- Reserve space for first embedding if its smaller than half of max tokens
local reserved_tokens = 0
if #embeddings_message.files > 0 then
local file_tokens = tiktoken.count(embeddings_message.files[1])
if file_tokens < max_tokens / 2 then
reserved_tokens = tiktoken.count(embeddings_message.header) + file_tokens
end
end

-- Calculate how many tokens we can use for history
local history_limit = max_tokens - required_tokens - reserved_tokens
local history_tokens = count_history_tokens(self.history)

-- If we're over history limit, truncate history from the beginning
while history_tokens > history_limit and #self.history > 0 do
local removed = table.remove(self.history, 1)
history_tokens = history_tokens - tiktoken.count(removed.content)
end

-- Now add as many files as possible with remaining token budget
local remaining_tokens = max_tokens - required_tokens - history_tokens
if #embeddings_message.files > 0 then
remaining_tokens = remaining_tokens - tiktoken.count(embeddings_message.header)
local filtered_files = {}
for _, file in ipairs(embeddings_message.files) do
local file_tokens = tiktoken.count(file)
if remaining_tokens - file_tokens >= 0 then
remaining_tokens = remaining_tokens - file_tokens
table.insert(filtered_files, file)
else
break
end
end
embeddings_message.files = filtered_files
end

-- Generate the request
local url = 'https://api.githubcopilot.com/agents/perplexityai?chat'
local body = vim.json.encode(
generate_ask_request(
self.history,
prompt,
embeddings_message,
selection_message,
system_prompt,
model,
temperature
)
)

local errored = false
local last_message = nil
local full_response = ''

local function handle_error(error_msg)
if not errored then
errored = true
log.error(error_msg)
if self.current_job and on_error then
on_error(error_msg)
end
end
end

local function callback_func(response)
if not response then
handle_error('Failed to get response')
return
end

if response.status ~= 200 then
handle_error(
'Failed to get response: ' .. tostring(response.status) .. '\n' .. response.body
)
return
end

log.trace('Full response: ' .. full_response)
log.debug('Last message: ' .. vim.inspect(last_message))

if on_done then
on_done(
full_response,
last_message and last_message.usage and last_message.usage.total_tokens,
max_tokens
)
end

table.insert(self.history, {
content = prompt,
role = 'user',
})

table.insert(self.history, {
content = full_response,
role = 'assistant',
})
end

local function stream_func(err, line)
if not line or errored or not self.current_job then
return
end

if err or vim.startswith(line, '{"error"') then
handle_error('Failed to get response: ' .. (err and vim.inspect(err) or line))
return
end

line = line:gsub('^%s*data: ', '')
if line == '' or line == '[DONE]' then
return
end

local ok, content = pcall(vim.json.decode, line, {
luanil = {
object = true,
array = true,
},
})

if not ok then
handle_error('Failed to parse response: ' .. vim.inspect(content) .. '\n' .. line)
return
end

if not content.choices or #content.choices == 0 then
return
end

last_message = content
local choice = content.choices[1]
local is_full = choice.message ~= nil
content = is_full and choice.message.content or choice.delta.content

if not content then
return
end

if on_progress then
on_progress(content)
end

full_response = full_response .. content
end

self:with_claude(model, function()
self.current_job = curl
.post(url, {
timeout = timeout,
headers = generate_headers(self.token.token, self.sessionid, self.machineid),
body = temp_file(body),
proxy = self.proxy,
insecure = self.allow_insecure,
stream = stream_func,
callback = callback_func,
on_error = function(err)
err = 'Failed to get response: ' .. vim.inspect(err)
log.error(err)
if self.current_job and on_error then
on_error(err)
end
end,
})
:after(function()
self.current_job = nil
end)
end, on_error)
end)
end, on_error)
end, on_error)
end
return Copilot
101 changes: 101 additions & 0 deletions lua/CopilotChat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -931,4 +931,105 @@ function M.setup(config)
end, { nargs = '*', force = true, complete = complete_load })
end

--- Ask a question to the Copilot agent.
---@param prompt string
---@param config CopilotChat.config|CopilotChat.config.prompt|nil
---@param source CopilotChat.config.source?
function M.ask_agent(prompt, config, source)
config = vim.tbl_deep_extend('force', M.config, config or {})
prompt = prompt or ''
local system_prompt, updated_prompt = update_prompts(prompt, config.system_prompt)
updated_prompt = vim.trim(updated_prompt)
if updated_prompt == '' then
M.open(config, source)
return
end

M.open(config, source)

if config.clear_chat_on_new_prompt then
M.stop(true, config)
end

state.last_system_prompt = system_prompt
local selection = get_selection()
local filetype = selection.filetype
or (vim.api.nvim_buf_is_valid(state.source.bufnr) and vim.bo[state.source.bufnr].filetype)
local filename = selection.filename
or (
vim.api.nvim_buf_is_valid(state.source.bufnr)
and vim.api.nvim_buf_get_name(state.source.bufnr)
)
if selection.prompt_extra then
updated_prompt = updated_prompt .. ' ' .. selection.prompt_extra
end

if state.copilot:stop() then
append('\n\n' .. config.question_header .. config.separator .. '\n\n', config)
end

append(updated_prompt, config)
append('\n\n' .. config.answer_header .. config.separator .. '\n\n', config)

local selected_context = config.context
if string.find(prompt, '@buffers') then
selected_context = 'buffers'
elseif string.find(prompt, '@buffer') then
selected_context = 'buffer'
end
updated_prompt = string.gsub(updated_prompt, '@buffers?%s*', '')

local function on_error(err)
vim.schedule(function()
append('\n\n' .. config.error_header .. config.separator .. '\n\n', config)
append('```\n' .. err .. '\n```', config)
append('\n\n' .. config.question_header .. config.separator .. '\n\n', config)
state.chat:finish()
end)
end

context.find_for_query(state.copilot, {
context = selected_context,
prompt = updated_prompt,
selection = selection.lines,
filename = filename,
filetype = filetype,
bufnr = state.source.bufnr,
on_error = on_error,
on_done = function(embeddings)
state.copilot:ask_agent(updated_prompt, {
selection = selection.lines,
embeddings = embeddings,
filename = filename,
filetype = filetype,
start_row = selection.start_row,
end_row = selection.end_row,
system_prompt = system_prompt,
model = config.model,
temperature = config.temperature,
on_error = on_error,
on_done = function(response, token_count, token_max_count)
vim.schedule(function()
append('\n\n' .. config.question_header .. config.separator .. '\n\n', config)
state.response = response
if token_count and token_max_count and token_count > 0 then
state.chat:finish(token_count .. '/' .. token_max_count .. ' tokens used')
else
state.chat:finish()
end
if config.callback then
config.callback(response, state.source)
end
end)
end,
on_progress = function(token)
vim.schedule(function()
append(token, config)
end)
end,
})
end,
})
end

return M