Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions lua/CopilotChat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ local plugin_name = 'CopilotChat.nvim'
--- @field copilot CopilotChat.Copilot?
--- @field chat CopilotChat.Chat?
--- @field source CopilotChat.config.source?
--- @field config CopilotChat.config?
local state = {
copilot = nil,
chat = nil,
window = nil,
source = nil,
config = nil,
}

local function find_lines_between_separator(lines, pattern, at_least_one)
Expand Down Expand Up @@ -171,16 +173,16 @@ local function complete()
vim.fn.complete(cmp_start + 1, items)
end

local function get_selection(config)
local function get_selection()
local bufnr = state.source.bufnr
local winnr = state.source.winnr
if
config
and config.selection
state.config
and state.config.selection
and vim.api.nvim_buf_is_valid(bufnr)
and vim.api.nvim_win_is_valid(winnr)
then
return config.selection(state.source) or {}
return state.config.selection(state.source) or {}
end
return {}
end
Expand Down Expand Up @@ -226,11 +228,15 @@ end
function M.open(config, source, no_focus)
local should_reset = config and config.window ~= nil and not vim.tbl_isempty(config.window)
config = vim.tbl_deep_extend('force', M.config, config or {})
state.config = config
state.source = vim.tbl_extend('keep', source or {}, {
bufnr = vim.api.nvim_get_current_buf(),
winnr = vim.api.nvim_get_current_win(),
})

-- Exit visual mode if we are in visual mode
vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes('<esc>', true, false, true), 'x', false)

local just_created = false

if not state.chat or not state.chat:valid() then
Expand Down Expand Up @@ -265,14 +271,14 @@ function M.open(config, source, no_focus)
if line_count == end_line then
vim.api.nvim_buf_set_lines(state.chat.bufnr, start_line, end_line, false, { '' })
end
M.ask(input, nil, state.source)
M.ask(input, state.config, state.source)
end
end, { buffer = state.chat.bufnr })
end

if config.mappings.show_diff then
vim.keymap.set('n', config.mappings.show_diff, function()
local selection = get_selection(config)
local selection = get_selection()
show_diff_between_selection_and_copilot(selection)
end, {
buffer = state.chat.bufnr,
Expand All @@ -281,7 +287,7 @@ function M.open(config, source, no_focus)

if config.mappings.accept_diff then
vim.keymap.set('n', config.mappings.accept_diff, function()
local selection = get_selection(config)
local selection = get_selection()
if not selection.start_row or not selection.end_row then
return
end
Expand Down Expand Up @@ -410,7 +416,7 @@ function M.ask(prompt, config, source)
M.open(config, source, true)

config = vim.tbl_deep_extend('force', M.config, config or {})
local selection = get_selection(config)
local selection = get_selection()
vim.api.nvim_set_current_win(state.window)

prompt = prompt or ''
Expand Down
33 changes: 5 additions & 28 deletions lua/CopilotChat/select.lua
Original file line number Diff line number Diff line change
Expand Up @@ -42,34 +42,11 @@ end
function M.visual(source)
local bufnr = source.bufnr

local full_line = false
local start_line = nil
local start_col = nil
local finish_line = nil
local finish_col = nil
if 'copilot-chat' ~= vim.api.nvim_buf_get_name(vim.api.nvim_get_current_buf()) then
local start = vim.fn.getpos('v')
start_line = start[2]
start_col = start[3]
local finish = vim.fn.getpos('.')
finish_line = finish[2]
finish_col = finish[3]
if vim.fn.mode() == 'V' then
full_line = true
end
end

-- Exit visual mode
vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes('<esc>', true, false, true), 'x', true)

if start_line == finish_line and start_col == finish_col then
start_line, start_col = unpack(vim.api.nvim_buf_get_mark(bufnr, '<'))
finish_line, finish_col = unpack(vim.api.nvim_buf_get_mark(bufnr, '>'))
start_col = start_col + 1
finish_col = finish_col + 1
end

return get_selection_lines(bufnr, start_line, start_col, finish_line, finish_col, full_line)
local start_line, start_col = unpack(vim.api.nvim_buf_get_mark(bufnr, '<'))
local finish_line, finish_col = unpack(vim.api.nvim_buf_get_mark(bufnr, '>'))
start_col = start_col + 1
finish_col = finish_col + 1
return get_selection_lines(bufnr, start_line, start_col, finish_line, finish_col, false)
end

--- Select and process contents of unnamed register ('"')
Expand Down