Skip to content

Commit e0c4ca0

Browse files
committed
refactor(context): optimize file loading with caching
Refactor context.lua to introduce file caching mechanism and improve code organization around file handling. The main changes include: - Add file caching to avoid reprocessing unchanged files - Move outline generation logic into separate build_outline function - Consolidate embed type definitions into context.lua - Remove duplicate type definitions from copilot.lua - Optimize file loading with new get_file helper function Signed-off-by: Tomas Slusny <slusnucky@gmail.com>
1 parent a0b89f0 commit e0c4ca0

8 files changed

Lines changed: 124 additions & 109 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ require('CopilotChat').setup({
719719

720720
## Roadmap (Wishlist)
721721

722-
- Caching for contexts
722+
- Improved caching for context (persistence through restarts/smarter caching)
723723
- General QOL improvements
724724

725725
## Development

lua/CopilotChat/config.lua

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,10 @@ local utils = require('CopilotChat.utils')
77
--- @field bufnr number
88
--- @field winnr number
99

10-
---@class CopilotChat.config.selection.diagnostic
11-
---@field content string
12-
---@field start_line number
13-
---@field end_line number
14-
---@field severity string
15-
16-
---@class CopilotChat.config.selection
17-
---@field content string
18-
---@field start_line number
19-
---@field end_line number
20-
---@field filename string
21-
---@field filetype string
22-
---@field bufnr number
23-
---@field diagnostics table<CopilotChat.config.selection.diagnostic>?
24-
2510
---@class CopilotChat.config.context
2611
---@field description string?
2712
---@field input fun(callback: fun(input: string?), source: CopilotChat.config.source)?
28-
---@field resolve fun(input: string?, source: CopilotChat.config.source):table<CopilotChat.copilot.embed>
13+
---@field resolve fun(input: string?, source: CopilotChat.config.source):table<CopilotChat.context.embed>
2914

3015
---@class CopilotChat.config.prompt : CopilotChat.config.shared
3116
---@field prompt string?
@@ -76,7 +61,7 @@ local utils = require('CopilotChat.utils')
7661
---@field temperature number?
7762
---@field headless boolean?
7863
---@field callback fun(response: string, source: CopilotChat.config.source)?
79-
---@field selection nil|fun(source: CopilotChat.config.source):CopilotChat.config.selection?
64+
---@field selection nil|fun(source: CopilotChat.config.source):CopilotChat.select.selection?
8065
---@field window CopilotChat.config.window?
8166
---@field show_help boolean?
8267
---@field show_folds boolean?

lua/CopilotChat/context.lua

Lines changed: 76 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,18 @@
77
---@field end_row number
88
---@field end_col number
99

10-
---@class CopilotChat.context.outline : CopilotChat.copilot.embed
11-
---@field symbols table<string, CopilotChat.context.symbol>
10+
---@class CopilotChat.context.embed
11+
---@field content string
12+
---@field filename string
13+
---@field filetype string
14+
---@field original string?
15+
---@field symbols table<string, CopilotChat.context.symbol>?
16+
---@field embedding table<number>?
1217

1318
local async = require('plenary.async')
1419
local log = require('plenary.log')
1520
local utils = require('CopilotChat.utils')
21+
local file_cache = {}
1622

1723
local M = {}
1824

@@ -79,10 +85,10 @@ local function spatial_distance_cosine(a, b)
7985
end
8086

8187
--- Rank data by relatedness to the query
82-
---@param query CopilotChat.copilot.embed
83-
---@param data table<CopilotChat.copilot.embed>
88+
---@param query CopilotChat.context.embed
89+
---@param data table<CopilotChat.context.embed>
8490
---@param top_n number
85-
---@return table<CopilotChat.copilot.embed>
91+
---@return table<CopilotChat.context.embed>
8692
local function data_ranked_by_relatedness(query, data, top_n)
8793
data = vim.tbl_map(function(item)
8894
return vim.tbl_extend(
@@ -101,7 +107,7 @@ end
101107

102108
--- Rank data by symbols
103109
---@param query string
104-
---@param data table<CopilotChat.context.outline>
110+
---@param data table<CopilotChat.context.embed>
105111
---@param top_n number
106112
local function data_ranked_by_symbols(query, data, top_n)
107113
local query_terms = {}
@@ -193,22 +199,15 @@ end
193199
--- Build an outline and symbols from a string
194200
---@param content string
195201
---@param filename string
196-
---@param ft string?
197-
---@return CopilotChat.context.outline
198-
function M.outline(content, filename, ft)
199-
ft = ft or 'text'
200-
202+
---@param ft string
203+
---@return CopilotChat.context.embed
204+
local function build_outline(content, filename, ft)
201205
local output = {
202206
filename = filename,
203207
filetype = ft,
204208
content = content,
205-
symbols = {},
206209
}
207210

208-
if ft == 'raw' then
209-
return output
210-
end
211-
212211
local lang = vim.treesitter.language.get_lang(ft)
213212
local ok, parser = false, nil
214213
if lang then
@@ -224,6 +223,7 @@ function M.outline(content, filename, ft)
224223

225224
local root = parser:parse()[1]:root()
226225
local lines = vim.split(content, '\n')
226+
local symbols = {}
227227
local outline_lines = {}
228228
local depth = 0
229229

@@ -239,7 +239,7 @@ function M.outline(content, filename, ft)
239239
table.insert(outline_lines, string.rep(' ', depth) .. signature_start)
240240

241241
-- Store symbol information
242-
table.insert(output.symbols, {
242+
table.insert(symbols, {
243243
name = name,
244244
signature = signature_start,
245245
type = type,
@@ -269,15 +269,45 @@ function M.outline(content, filename, ft)
269269
if #outline_lines > 0 then
270270
output.original = content
271271
output.content = table.concat(outline_lines, '\n')
272+
output.symbols = symbols
272273
end
273274

274275
return output
275276
end
276277

278+
--- Get data for a file
279+
---@param filename string
280+
---@param filetype string
281+
---@return CopilotChat.context.embed?
282+
local function get_file(filename, filetype)
283+
local modified = utils.file_mtime(filename)
284+
if not modified then
285+
return nil
286+
end
287+
288+
local cached = file_cache[filename]
289+
if cached and cached.modified >= modified then
290+
return cached.outline
291+
end
292+
293+
local content = utils.read_file(filename)
294+
if content then
295+
local outline = build_outline(content, filename, filetype)
296+
file_cache[filename] = {
297+
outline = outline,
298+
modified = modified,
299+
}
300+
301+
return outline
302+
end
303+
304+
return nil
305+
end
306+
277307
--- Get list of all files in workspace
278308
---@param winnr number?
279309
---@param with_content boolean?
280-
---@return table<CopilotChat.copilot.embed>
310+
---@return table<CopilotChat.context.embed>
281311
function M.files(winnr, with_content)
282312
local cwd = utils.win_cwd(winnr)
283313
local files = utils.scan_dir(cwd, {
@@ -291,24 +321,22 @@ function M.files(winnr, with_content)
291321
if with_content then
292322
async.util.scheduler()
293323

294-
files = vim.tbl_map(function(file)
295-
return {
296-
name = utils.filepath(file),
297-
ft = utils.filetype(file),
298-
}
299-
end, files)
300-
files = vim.tbl_filter(function(file)
301-
return file.ft ~= nil
302-
end, files)
324+
files = vim.tbl_filter(
325+
function(file)
326+
return file.ft ~= nil
327+
end,
328+
vim.tbl_map(function(file)
329+
return {
330+
name = utils.filepath(file),
331+
ft = utils.filetype(file),
332+
}
333+
end, files)
334+
)
303335

304336
for _, file in ipairs(files) do
305-
local content = utils.read_file(file.name)
306-
if content then
307-
table.insert(out, {
308-
content = content,
309-
filename = file.name,
310-
filetype = file.ft,
311-
})
337+
local file_data = get_file(file.name, file.ft)
338+
if file_data then
339+
table.insert(out, file_data)
312340
end
313341
end
314342

@@ -338,28 +366,20 @@ end
338366

339367
--- Get the content of a file
340368
---@param filename string
341-
---@return CopilotChat.copilot.embed?
369+
---@return CopilotChat.context.embed?
342370
function M.file(filename)
343-
local content = utils.read_file(filename)
344-
if not content then
345-
return nil
346-
end
347-
348371
async.util.scheduler()
349-
if not utils.filetype(filename) then
372+
local ft = utils.filetype(filename)
373+
if not ft then
350374
return nil
351375
end
352376

353-
return {
354-
content = content,
355-
filename = utils.filepath(filename),
356-
filetype = utils.filetype(filename),
357-
}
377+
return get_file(utils.filepath(filename), ft)
358378
end
359379

360380
--- Get the content of a buffer
361381
---@param bufnr? number
362-
---@return CopilotChat.copilot.embed?
382+
---@return CopilotChat.context.embed?
363383
function M.buffer(bufnr)
364384
async.util.scheduler()
365385
bufnr = bufnr or vim.api.nvim_get_current_buf()
@@ -373,17 +393,17 @@ function M.buffer(bufnr)
373393
return nil
374394
end
375395

376-
return {
377-
content = table.concat(content, '\n'),
378-
filename = utils.filepath(vim.api.nvim_buf_get_name(bufnr)),
379-
filetype = vim.bo[bufnr].filetype,
380-
}
396+
return build_outline(
397+
table.concat(content, '\n'),
398+
utils.filepath(vim.api.nvim_buf_get_name(bufnr)),
399+
vim.bo[bufnr].filetype
400+
)
381401
end
382402

383403
--- Get current git diff
384404
---@param type string?
385405
---@param winnr number
386-
---@return CopilotChat.copilot.embed?
406+
---@return CopilotChat.context.embed?
387407
function M.gitdiff(type, winnr)
388408
type = type or 'unstaged'
389409
local cwd = utils.win_cwd(winnr)
@@ -411,7 +431,7 @@ end
411431

412432
--- Return contents of specified register
413433
---@param register string?
414-
---@return CopilotChat.copilot.embed?
434+
---@return CopilotChat.context.embed?
415435
function M.register(register)
416436
register = register or '+'
417437
local lines = vim.fn.getreg(register)
@@ -429,19 +449,14 @@ end
429449
--- Filter embeddings based on the query
430450
---@param copilot CopilotChat.Copilot
431451
---@param prompt string
432-
---@param embeddings table<CopilotChat.copilot.embed>
433-
---@return table<CopilotChat.copilot.embed>
452+
---@param embeddings table<CopilotChat.context.embed>
453+
---@return table<CopilotChat.context.embed>
434454
function M.filter_embeddings(copilot, prompt, embeddings)
435455
-- If we dont need to embed anything, just return directly
436456
if #embeddings < MULTI_FILE_THRESHOLD then
437457
return embeddings
438458
end
439459

440-
-- Map embeddings to outlines
441-
embeddings = vim.tbl_map(function(embed)
442-
return M.outline(embed.content, embed.filename, embed.filetype)
443-
end, embeddings)
444-
445460
-- Rank embeddings by symbols
446461
embeddings = data_ranked_by_symbols(prompt, embeddings, TOP_SYMBOLS)
447462
log.debug('Ranked data:', #embeddings)

lua/CopilotChat/copilot.lua

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,13 @@
1-
---@class CopilotChat.copilot.embed
2-
---@field content string
3-
---@field filename string
4-
---@field filetype string
5-
---@field embedding table<number>
6-
71
---@class CopilotChat.copilot.ask.opts
8-
---@field selection CopilotChat.config.selection?
9-
---@field embeddings table<CopilotChat.copilot.embed>?
2+
---@field selection CopilotChat.select.selection?
3+
---@field embeddings table<CopilotChat.context.embed>?
104
---@field system_prompt string?
115
---@field model string?
126
---@field agent string?
137
---@field temperature number?
148
---@field no_history boolean?
159
---@field on_progress nil|fun(response: string):nil
1610

17-
---@class CopilotChat.copilot.embed.opts
18-
---@field model string?
19-
---@field chunk_size number?
20-
2111
local log = require('plenary.log')
2212
local prompts = require('CopilotChat.prompts')
2313
local tiktoken = require('CopilotChat.tiktoken')
@@ -107,7 +97,7 @@ local function generate_line_numbers(content, start_line)
10797
end
10898

10999
--- Generate messages for the given selection
110-
--- @param selection CopilotChat.config.selection
100+
--- @param selection CopilotChat.select.selection
111101
local function generate_selection_messages(selection)
112102
local filename = selection.filename or 'unknown'
113103
local filetype = selection.filetype or 'text'
@@ -167,7 +157,7 @@ local function generate_selection_messages(selection)
167157
end
168158

169159
--- Generate messages for the given embeddings
170-
--- @param embeddings table<CopilotChat.copilot.embed>
160+
--- @param embeddings table<CopilotChat.context.embed>
171161
local function generate_embeddings_messages(embeddings)
172162
local files = {}
173163
for _, embedding in ipairs(embeddings) do
@@ -295,7 +285,7 @@ end
295285

296286
---@class CopilotChat.Copilot : Class
297287
---@field history table
298-
---@field embedding_cache table<CopilotChat.copilot.embed>
288+
---@field embedding_cache table<CopilotChat.context.embed>
299289
---@field policies table<string, boolean>
300290
---@field models table<string, table>?
301291
---@field agents table<string, table>?
@@ -863,8 +853,8 @@ function Copilot:list_agents()
863853
end
864854

865855
--- Generate embeddings for the given inputs
866-
---@param inputs table<CopilotChat.copilot.embed>: The inputs to embed
867-
---@return table<CopilotChat.copilot.embed>
856+
---@param inputs table<CopilotChat.context.embed>: The inputs to embed
857+
---@return table<CopilotChat.context.embed>
868858
function Copilot:embed(inputs)
869859
if not inputs or #inputs == 0 then
870860
return {}

lua/CopilotChat/init.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ local state = {
4242
}
4343

4444
---@param config CopilotChat.config.shared
45-
---@return CopilotChat.config.selection?
45+
---@return CopilotChat.select.selection?
4646
local function get_selection(config)
4747
local bufnr = state.source and state.source.bufnr
4848
local winnr = state.source and state.source.winnr
@@ -231,7 +231,7 @@ end
231231

232232
---@param prompt string
233233
---@param config CopilotChat.config.shared
234-
---@return table<CopilotChat.copilot.embed>, string
234+
---@return table<CopilotChat.context.embed>, string
235235
local function resolve_embeddings(prompt, config)
236236
local contexts = {}
237237
local function parse_context(prompt_context)

0 commit comments

Comments
 (0)