forked from CopilotC-Nvim/CopilotChat.nvim
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdiff.lua
More file actions
241 lines (214 loc) · 7.56 KB
/
diff.lua
File metadata and controls
241 lines (214 loc) · 7.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
local log = require('plenary.log')
local M = {}
--- Parse unified diff hunks from diff text
---@param diff_text string
---@return table hunks
local function parse_hunks(diff_text)
local hunks = {}
local current_hunk = nil
for _, line in ipairs(vim.split(diff_text, '\n')) do
if line:match('^@@') then
if current_hunk then
table.insert(hunks, current_hunk)
end
local start_old, len_old, start_new, len_new = line:match('@@%s%-(%d+),?(%d*)%s%+(%d+),?(%d*)%s@@')
current_hunk = {
start_old = tonumber(start_old),
len_old = len_old == '' and 1 or tonumber(len_old),
start_new = tonumber(start_new),
len_new = len_new == '' and 1 or tonumber(len_new),
old_snippet = {},
new_snippet = {},
}
elseif current_hunk then
local prefix, rest = line:sub(1, 1), tostring(line:sub(2))
if prefix == '-' then
table.insert(current_hunk.old_snippet, rest)
elseif prefix == '+' then
table.insert(current_hunk.new_snippet, rest)
elseif prefix == ' ' then
table.insert(current_hunk.old_snippet, rest)
table.insert(current_hunk.new_snippet, rest)
end
end
end
if current_hunk then
table.insert(hunks, current_hunk)
end
return hunks
end
--- Try to match old_snippet in lines starting at approximate start_line
---@param lines table
---@param old_snippet table
---@param approx_start number
---@param search_range number
---@return number? matched_start
local function find_best_match(lines, old_snippet, approx_start, search_range)
local best_idx, best_score = nil, -1
local old_len = #old_snippet
if old_len == 0 then
return approx_start
end
local min_start = math.max(1, approx_start - search_range)
local max_start = math.min(#lines - old_len + 1, approx_start + search_range)
for start_idx = min_start, max_start do
local score = 0
for i = 1, old_len do
if vim.trim(lines[start_idx + i - 1] or '') == vim.trim(old_snippet[i] or '') then
score = score + 1
end
end
if score > best_score then
best_score = score
best_idx = start_idx
end
if score == old_len then
return best_idx
end
end
if best_score >= math.ceil(old_len * 0.8) then
return best_idx
end
return nil
end
--- Apply a single hunk to content
---@param hunk table
---@param content string
---@return string patched_content, boolean applied_cleanly
local function apply_hunk(hunk, content)
local lines = vim.split(content, '\n')
local start_idx = hunk.start_old
-- Handle insertions (len_old == 0)
if hunk.len_old == 0 then
-- For insertions, start_old indicates where to insert
-- start_old = 0 means insert at beginning
-- start_old = n means insert after line n
if start_idx == 0 then
start_idx = 1
else
start_idx = start_idx + 1
end
local new_lines = vim.list_slice(lines, 1, start_idx - 1)
vim.list_extend(new_lines, hunk.new_snippet)
vim.list_extend(new_lines, lines, start_idx, #lines)
-- Insertions are always applied cleanly if we reach this point
return table.concat(new_lines, '\n'), true
end
-- Handle replacements and deletions (len_old > 0)
-- If we have a start line hint, try to find best match within +/- 2 lines
if start_idx and start_idx > 0 and start_idx <= #lines then
local match_idx = find_best_match(lines, hunk.old_snippet, start_idx, 2)
if match_idx then
start_idx = match_idx
end
else
-- No valid start line, search for best match in whole content
local match_idx = find_best_match(lines, hunk.old_snippet, 1, #lines)
if match_idx then
start_idx = match_idx
else
start_idx = 1
end
end
-- Replace old lines with new lines
local end_idx = start_idx + #hunk.old_snippet - 1
local new_lines = vim.list_slice(lines, 1, start_idx - 1)
vim.list_extend(new_lines, hunk.new_snippet)
vim.list_extend(new_lines, lines, end_idx + 1, #lines)
-- Check if we matched exactly at the hinted position
local applied_cleanly = find_best_match(lines, hunk.old_snippet, hunk.start_old or start_idx, 0) == start_idx
return table.concat(new_lines, '\n'), applied_cleanly
end
--- Apply unified diff to a table of lines and return new lines
---@param diff_text string
---@param original_content string
---@return table<string>, boolean, integer, integer
function M.apply_unified_diff(diff_text, original_content)
local hunks = parse_hunks(diff_text)
local new_content = original_content
local applied = false
local offset = 0 -- Track cumulative line offset from previous hunks
for _, hunk in ipairs(hunks) do
-- Adjust hunk start position based on accumulated offset
local adjusted_hunk = vim.deepcopy(hunk)
if adjusted_hunk.start_old then
adjusted_hunk.start_old = hunk.start_old + offset
end
local patched, ok = apply_hunk(adjusted_hunk, new_content)
new_content = patched
applied = applied or ok
-- Update offset: (new lines added) - (old lines removed)
offset = offset + (#hunk.new_snippet - #hunk.old_snippet)
end
local new_lines = vim.split(new_content, '\n', { trimempty = true })
local hunks = vim.diff(
original_content,
new_content,
{ algorithm = 'myers', ctxlen = 10, interhunkctxlen = 10, ignore_whitespace_change = true, result_type = 'indices' }
)
if not hunks or #hunks == 0 then
return new_lines, applied, nil, nil
end
local first, last
for _, hunk in ipairs(hunks) do
local hunk_start = hunk[1]
local hunk_end = hunk[1] + hunk[2] - 1
if not first or hunk_start < first then
first = hunk_start
end
if not last or hunk_end > last then
last = hunk_end
end
end
return new_lines, applied, first, last
end
--- Get diff from block content and buffer lines
---@param block CopilotChat.ui.chat.Block Block containing diff info
---@param lines table table of lines
---@return string diff, string content
function M.get_diff(block, lines)
local content = table.concat(lines, '\n')
if block.header.filetype == 'diff' then
return block.content, content
end
local patched_lines = vim.split(block.content, '\n', { trimempty = true })
local start_idx = block.header.start_line
local end_idx = block.header.end_line
local original_lines = lines
if start_idx and end_idx then
local new_lines = vim.list_slice(original_lines, 1, start_idx - 1)
vim.list_extend(new_lines, patched_lines)
vim.list_extend(new_lines, original_lines, end_idx + 1, #original_lines)
patched_lines = new_lines
end
return tostring(
vim.diff(
table.concat(original_lines, '\n'),
table.concat(patched_lines, '\n'),
{ algorithm = 'myers', ctxlen = 10, interhunkctxlen = 10, ignore_whitespace_change = true }
)
),
content
end
--- Apply a diff (unified or indices) to buffer lines
---@param block CopilotChat.ui.chat.Block Block containing diff info
---@param lines table table of lines
---@return table new_lines
function M.apply_diff(block, lines)
local diff, content = M.get_diff(block, lines)
local new_lines, applied, _, _ = M.apply_unified_diff(diff, content)
if not applied then
log.debug('Diff for ' .. block.header.filename .. ' failed to apply cleanly for:\n' .. diff)
end
return new_lines
end
--- Get changed region for diff (unified or indices)
---@param block CopilotChat.ui.chat.Block Block containing diff info
---@param lines table table of lines
---@return number? first, number? last
function M.get_diff_region(block, lines)
local diff, content = M.get_diff(block, lines)
local _, _, first, last = M.apply_unified_diff(diff, content)
return first, last
end
return M