diff --git a/lua/CopilotChat/copilot.lua b/lua/CopilotChat/copilot.lua index ce3b2be0..b2c7d99d 100644 --- a/lua/CopilotChat/copilot.lua +++ b/lua/CopilotChat/copilot.lua @@ -33,6 +33,7 @@ ---@field load fun(self: CopilotChat.Copilot, name: string, path: string):table ---@field running fun(self: CopilotChat.Copilot):boolean ---@field list_models fun(self: CopilotChat.Copilot, callback: fun(table):nil):nil +---@field ask_agent fun(self: CopilotChat.Copilot, prompt: string, opts: CopilotChat.copilot.ask.opts):nil local log = require('plenary.log') local curl = require('plenary.curl') @@ -843,4 +844,228 @@ function Copilot:running() return self.current_job ~= nil end +--- Ask a question to Copilot agent +---@param prompt string: The prompt to send to Copilot +---@param opts CopilotChat.copilot.ask.opts: Options for the request +function Copilot:ask_agent(prompt, opts) + opts = opts or {} + local embeddings = opts.embeddings or {} + local filename = opts.filename or '' + local filetype = opts.filetype or '' + local selection = opts.selection or '' + local start_row = opts.start_row or 0 + local end_row = opts.end_row or 0 + local system_prompt = opts.system_prompt or prompts.COPILOT_INSTRUCTIONS + local model = opts.model or 'gpt-4o-2024-05-13' + local temperature = opts.temperature or 0.1 + local on_done = opts.on_done + local on_progress = opts.on_progress + local on_error = opts.on_error + + log.trace('System prompt: ' .. system_prompt) + log.trace('Selection: ' .. selection) + log.debug('Prompt: ' .. prompt) + log.debug('Embeddings: ' .. #embeddings) + log.debug('Filename: ' .. filename) + log.debug('Filetype: ' .. filetype) + log.debug('Model: ' .. model) + log.debug('Temperature: ' .. temperature) + + -- If we already have running job, cancel it and notify the user + if self.current_job then + self:stop() + end + + self:with_auth(function() + self:with_models(function() + local capabilities = self.models[model] and self.models[model].capabilities + or { limits = { max_prompt_tokens = 8192 }, tokenizer = 'cl100k_base' } + local max_tokens = capabilities.limits.max_prompt_tokens -- FIXME: Is max_prompt_tokens the right limit? + local tokenizer = capabilities.tokenizer + log.debug('Max tokens: ' .. max_tokens) + log.debug('Tokenizer: ' .. tokenizer) + + local selection_message = + generate_selection_message(filename, filetype, start_row, end_row, selection) + local embeddings_message = generate_embeddings_message(embeddings) + + tiktoken.load(tokenizer, function() + -- Count required tokens that we cannot reduce + local prompt_tokens = tiktoken.count(prompt) + local system_tokens = tiktoken.count(system_prompt) + local selection_tokens = tiktoken.count(selection_message) + local required_tokens = prompt_tokens + system_tokens + selection_tokens + + -- Reserve space for first embedding if its smaller than half of max tokens + local reserved_tokens = 0 + if #embeddings_message.files > 0 then + local file_tokens = tiktoken.count(embeddings_message.files[1]) + if file_tokens < max_tokens / 2 then + reserved_tokens = tiktoken.count(embeddings_message.header) + file_tokens + end + end + + -- Calculate how many tokens we can use for history + local history_limit = max_tokens - required_tokens - reserved_tokens + local history_tokens = count_history_tokens(self.history) + + -- If we're over history limit, truncate history from the beginning + while history_tokens > history_limit and #self.history > 0 do + local removed = table.remove(self.history, 1) + history_tokens = history_tokens - tiktoken.count(removed.content) + end + + -- Now add as many files as possible with remaining token budget + local remaining_tokens = max_tokens - required_tokens - history_tokens + if #embeddings_message.files > 0 then + remaining_tokens = remaining_tokens - tiktoken.count(embeddings_message.header) + local filtered_files = {} + for _, file in ipairs(embeddings_message.files) do + local file_tokens = tiktoken.count(file) + if remaining_tokens - file_tokens >= 0 then + remaining_tokens = remaining_tokens - file_tokens + table.insert(filtered_files, file) + else + break + end + end + embeddings_message.files = filtered_files + end + + -- Generate the request + local url = 'https://api.githubcopilot.com/agents/perplexityai?chat' + local body = vim.json.encode( + generate_ask_request( + self.history, + prompt, + embeddings_message, + selection_message, + system_prompt, + model, + temperature + ) + ) + + local errored = false + local last_message = nil + local full_response = '' + + local function handle_error(error_msg) + if not errored then + errored = true + log.error(error_msg) + if self.current_job and on_error then + on_error(error_msg) + end + end + end + + local function callback_func(response) + if not response then + handle_error('Failed to get response') + return + end + + if response.status ~= 200 then + handle_error( + 'Failed to get response: ' .. tostring(response.status) .. '\n' .. response.body + ) + return + end + + log.trace('Full response: ' .. full_response) + log.debug('Last message: ' .. vim.inspect(last_message)) + + if on_done then + on_done( + full_response, + last_message and last_message.usage and last_message.usage.total_tokens, + max_tokens + ) + end + + table.insert(self.history, { + content = prompt, + role = 'user', + }) + + table.insert(self.history, { + content = full_response, + role = 'assistant', + }) + end + + local function stream_func(err, line) + if not line or errored or not self.current_job then + return + end + + if err or vim.startswith(line, '{"error"') then + handle_error('Failed to get response: ' .. (err and vim.inspect(err) or line)) + return + end + + line = line:gsub('^%s*data: ', '') + if line == '' or line == '[DONE]' then + return + end + + local ok, content = pcall(vim.json.decode, line, { + luanil = { + object = true, + array = true, + }, + }) + + if not ok then + handle_error('Failed to parse response: ' .. vim.inspect(content) .. '\n' .. line) + return + end + + if not content.choices or #content.choices == 0 then + return + end + + last_message = content + local choice = content.choices[1] + local is_full = choice.message ~= nil + content = is_full and choice.message.content or choice.delta.content + + if not content then + return + end + + if on_progress then + on_progress(content) + end + + full_response = full_response .. content + end + + self:with_claude(model, function() + self.current_job = curl + .post(url, { + timeout = timeout, + headers = generate_headers(self.token.token, self.sessionid, self.machineid), + body = temp_file(body), + proxy = self.proxy, + insecure = self.allow_insecure, + stream = stream_func, + callback = callback_func, + on_error = function(err) + err = 'Failed to get response: ' .. vim.inspect(err) + log.error(err) + if self.current_job and on_error then + on_error(err) + end + end, + }) + :after(function() + self.current_job = nil + end) + end, on_error) + end) + end, on_error) + end, on_error) +end return Copilot diff --git a/lua/CopilotChat/init.lua b/lua/CopilotChat/init.lua index 309fcdd3..7414e7ba 100644 --- a/lua/CopilotChat/init.lua +++ b/lua/CopilotChat/init.lua @@ -931,4 +931,105 @@ function M.setup(config) end, { nargs = '*', force = true, complete = complete_load }) end +--- Ask a question to the Copilot agent. +---@param prompt string +---@param config CopilotChat.config|CopilotChat.config.prompt|nil +---@param source CopilotChat.config.source? +function M.ask_agent(prompt, config, source) + config = vim.tbl_deep_extend('force', M.config, config or {}) + prompt = prompt or '' + local system_prompt, updated_prompt = update_prompts(prompt, config.system_prompt) + updated_prompt = vim.trim(updated_prompt) + if updated_prompt == '' then + M.open(config, source) + return + end + + M.open(config, source) + + if config.clear_chat_on_new_prompt then + M.stop(true, config) + end + + state.last_system_prompt = system_prompt + local selection = get_selection() + local filetype = selection.filetype + or (vim.api.nvim_buf_is_valid(state.source.bufnr) and vim.bo[state.source.bufnr].filetype) + local filename = selection.filename + or ( + vim.api.nvim_buf_is_valid(state.source.bufnr) + and vim.api.nvim_buf_get_name(state.source.bufnr) + ) + if selection.prompt_extra then + updated_prompt = updated_prompt .. ' ' .. selection.prompt_extra + end + + if state.copilot:stop() then + append('\n\n' .. config.question_header .. config.separator .. '\n\n', config) + end + + append(updated_prompt, config) + append('\n\n' .. config.answer_header .. config.separator .. '\n\n', config) + + local selected_context = config.context + if string.find(prompt, '@buffers') then + selected_context = 'buffers' + elseif string.find(prompt, '@buffer') then + selected_context = 'buffer' + end + updated_prompt = string.gsub(updated_prompt, '@buffers?%s*', '') + + local function on_error(err) + vim.schedule(function() + append('\n\n' .. config.error_header .. config.separator .. '\n\n', config) + append('```\n' .. err .. '\n```', config) + append('\n\n' .. config.question_header .. config.separator .. '\n\n', config) + state.chat:finish() + end) + end + + context.find_for_query(state.copilot, { + context = selected_context, + prompt = updated_prompt, + selection = selection.lines, + filename = filename, + filetype = filetype, + bufnr = state.source.bufnr, + on_error = on_error, + on_done = function(embeddings) + state.copilot:ask_agent(updated_prompt, { + selection = selection.lines, + embeddings = embeddings, + filename = filename, + filetype = filetype, + start_row = selection.start_row, + end_row = selection.end_row, + system_prompt = system_prompt, + model = config.model, + temperature = config.temperature, + on_error = on_error, + on_done = function(response, token_count, token_max_count) + vim.schedule(function() + append('\n\n' .. config.question_header .. config.separator .. '\n\n', config) + state.response = response + if token_count and token_max_count and token_count > 0 then + state.chat:finish(token_count .. '/' .. token_max_count .. ' tokens used') + else + state.chat:finish() + end + if config.callback then + config.callback(response, state.source) + end + end) + end, + on_progress = function(token) + vim.schedule(function() + append(token, config) + end) + end, + }) + end, + }) +end + return M