diff --git a/lua/CopilotChat/init.lua b/lua/CopilotChat/init.lua index daeadc5e..d2e32c32 100644 --- a/lua/CopilotChat/init.lua +++ b/lua/CopilotChat/init.lua @@ -13,11 +13,13 @@ local plugin_name = 'CopilotChat.nvim' --- @field copilot CopilotChat.Copilot? --- @field chat CopilotChat.Chat? --- @field source CopilotChat.config.source? +--- @field config CopilotChat.config? local state = { copilot = nil, chat = nil, window = nil, source = nil, + config = nil, } local function find_lines_between_separator(lines, pattern, at_least_one) @@ -171,16 +173,16 @@ local function complete() vim.fn.complete(cmp_start + 1, items) end -local function get_selection(config) +local function get_selection() local bufnr = state.source.bufnr local winnr = state.source.winnr if - config - and config.selection + state.config + and state.config.selection and vim.api.nvim_buf_is_valid(bufnr) and vim.api.nvim_win_is_valid(winnr) then - return config.selection(state.source) or {} + return state.config.selection(state.source) or {} end return {} end @@ -226,11 +228,15 @@ end function M.open(config, source, no_focus) local should_reset = config and config.window ~= nil and not vim.tbl_isempty(config.window) config = vim.tbl_deep_extend('force', M.config, config or {}) + state.config = config state.source = vim.tbl_extend('keep', source or {}, { bufnr = vim.api.nvim_get_current_buf(), winnr = vim.api.nvim_get_current_win(), }) + -- Exit visual mode if we are in visual mode + vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes('', true, false, true), 'x', false) + local just_created = false if not state.chat or not state.chat:valid() then @@ -265,14 +271,14 @@ function M.open(config, source, no_focus) if line_count == end_line then vim.api.nvim_buf_set_lines(state.chat.bufnr, start_line, end_line, false, { '' }) end - M.ask(input, nil, state.source) + M.ask(input, state.config, state.source) end end, { buffer = state.chat.bufnr }) end if config.mappings.show_diff then vim.keymap.set('n', config.mappings.show_diff, function() - local selection = get_selection(config) + local selection = get_selection() show_diff_between_selection_and_copilot(selection) end, { buffer = state.chat.bufnr, @@ -281,7 +287,7 @@ function M.open(config, source, no_focus) if config.mappings.accept_diff then vim.keymap.set('n', config.mappings.accept_diff, function() - local selection = get_selection(config) + local selection = get_selection() if not selection.start_row or not selection.end_row then return end @@ -410,7 +416,7 @@ function M.ask(prompt, config, source) M.open(config, source, true) config = vim.tbl_deep_extend('force', M.config, config or {}) - local selection = get_selection(config) + local selection = get_selection() vim.api.nvim_set_current_win(state.window) prompt = prompt or '' diff --git a/lua/CopilotChat/select.lua b/lua/CopilotChat/select.lua index 07d199fd..65a9fb84 100644 --- a/lua/CopilotChat/select.lua +++ b/lua/CopilotChat/select.lua @@ -42,34 +42,11 @@ end function M.visual(source) local bufnr = source.bufnr - local full_line = false - local start_line = nil - local start_col = nil - local finish_line = nil - local finish_col = nil - if 'copilot-chat' ~= vim.api.nvim_buf_get_name(vim.api.nvim_get_current_buf()) then - local start = vim.fn.getpos('v') - start_line = start[2] - start_col = start[3] - local finish = vim.fn.getpos('.') - finish_line = finish[2] - finish_col = finish[3] - if vim.fn.mode() == 'V' then - full_line = true - end - end - - -- Exit visual mode - vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes('', true, false, true), 'x', true) - - if start_line == finish_line and start_col == finish_col then - start_line, start_col = unpack(vim.api.nvim_buf_get_mark(bufnr, '<')) - finish_line, finish_col = unpack(vim.api.nvim_buf_get_mark(bufnr, '>')) - start_col = start_col + 1 - finish_col = finish_col + 1 - end - - return get_selection_lines(bufnr, start_line, start_col, finish_line, finish_col, full_line) + local start_line, start_col = unpack(vim.api.nvim_buf_get_mark(bufnr, '<')) + local finish_line, finish_col = unpack(vim.api.nvim_buf_get_mark(bufnr, '>')) + start_col = start_col + 1 + finish_col = finish_col + 1 + return get_selection_lines(bufnr, start_line, start_col, finish_line, finish_col, false) end --- Select and process contents of unnamed register ('"')