diff --git a/.all-contributorsrc b/.all-contributorsrc index 3e77a246..e502fd86 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -431,6 +431,97 @@ "avatar_url": "https://avatars.githubusercontent.com/u/214497460?v=4", "profile": "https://github.com/danilohorta", "contributions": ["code"] + }, + { + "login": "rakotomandimby", + "name": "Mihamina Rakotomandimby", + "avatar_url": "https://avatars.githubusercontent.com/u/488088?v=4", + "profile": "https://mihamina.rktmb.org", + "contributions": ["doc", "code"] + }, + { + "login": "AjmalShajahan", + "name": "Ajmal S", + "avatar_url": "https://avatars.githubusercontent.com/u/23806715?v=4", + "profile": "http://ajmalshajahan.me", + "contributions": ["code"] + }, + { + "login": "samiulsami", + "name": "Samiul Islam", + "avatar_url": "https://avatars.githubusercontent.com/u/33352407?v=4", + "profile": "https://github.com/samiulsami", + "contributions": ["code"] + }, + { + "login": "ruicsh", + "name": "Rui Costa", + "avatar_url": "https://avatars.githubusercontent.com/u/8294038?v=4", + "profile": "https://ruicsh.github.io", + "contributions": ["code"] + }, + { + "login": "ctchen222", + "name": "CTCHEN", + "avatar_url": "https://avatars.githubusercontent.com/u/49014608?v=4", + "profile": "https://github.com/ctchen222", + "contributions": ["code"] + }, + { + "login": "towoe", + "name": "Tobias Wölfel", + "avatar_url": "https://avatars.githubusercontent.com/u/8666134?v=4", + "profile": "https://github.com/towoe", + "contributions": ["code"] + }, + { + "login": "garcia5", + "name": "Alexander Garcia", + "avatar_url": "https://avatars.githubusercontent.com/u/21695295?v=4", + "profile": "https://github.com/garcia5", + "contributions": ["code"] + }, + { + "login": "kharandziuk", + "name": "Max Kharandziuk", + "avatar_url": "https://avatars.githubusercontent.com/u/3404755?v=4", + "profile": "https://github.com/kharandziuk", + "contributions": ["code"] + }, + { + "login": "pxwg", + "name": "Xinyu Xiang", + "avatar_url": "https://avatars.githubusercontent.com/u/149765160?v=4", + "profile": "https://github.com/pxwg", + "contributions": ["code"] + }, + { + "login": "junqizhang", + "name": "junqizhang", + "avatar_url": "https://avatars.githubusercontent.com/u/22600124?v=4", + "profile": "https://github.com/junqizhang", + "contributions": ["code"] + }, + { + "login": "Tlunch", + "name": "Calum Lynch", + "avatar_url": "https://avatars.githubusercontent.com/u/89159592?v=4", + "profile": "http://card.calumhub.xyz", + "contributions": ["code"] + }, + { + "login": "sirjls", + "name": "sirjls", + "avatar_url": "https://avatars.githubusercontent.com/u/270346599?v=4", + "profile": "https://github.com/sirjls", + "contributions": ["code"] + }, + { + "login": "kolchurinvv", + "name": "Vladimir Kolchurin", + "avatar_url": "https://avatars.githubusercontent.com/u/18503099?v=4", + "profile": "https://github.com/kolchurinvv", + "contributions": ["code"] } ], "contributorsPerLine": 7, diff --git a/.emmyrc.json b/.emmyrc.json new file mode 100644 index 00000000..5f9691d9 --- /dev/null +++ b/.emmyrc.json @@ -0,0 +1,9 @@ +{ + "runtime": { + "version": "LuaJIT", + "requirePattern": ["lua/?.lua", "lua/?/init.lua"] + }, + "workspace": { + "library": ["$VIMRUNTIME"] + } +} diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index ce9dcccd..4b3b3fbc 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,3 +1,3 @@ # These are supported funding model platforms -github: [acheong08, jellydn] +github: [deathbeam, jellydn] diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e722714f..9ce3f370 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,8 +64,4 @@ jobs: luarocksVersion: "3.12.2" - name: run test - shell: bash - run: | - luarocks install luacheck - luarocks install vusted - vusted ./test + run: make test diff --git a/.gitignore b/.gitignore index 94e6f763..fc3fe2ac 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,5 @@ cython_debug/ # (neo)vim helptags /doc/tags + +.dependencies/ diff --git a/.luarc.json b/.luarc.json index b97a9f11..c4cebd58 100644 --- a/.luarc.json +++ b/.luarc.json @@ -1,4 +1,14 @@ { - "diagnostics.globals": ["describe", "it"], + "runtime.version": "LuaJIT", + "diagnostics.globals": [ + "describe", + "it", + "pending", + "before_each", + "after_each", + "clear", + "assert", + "print" + ], "diagnostics.disable": ["redefined-local"] } diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a90c2791..ae6b39ed 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,6 @@ repos: hooks: - id: prettier - repo: https://github.com/JohnnyMorganz/StyLua - rev: v2.1.0 + rev: v2.4.1 hooks: - id: stylua-github diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..560393bb --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,75 @@ +# AGENTS.md + +## Overview + +Neovim plugin (pure Lua) providing GitHub Copilot Chat integration. Requires Neovim 0.10.0+, curl 8.0.0+, plenary.nvim. + +## Commands + +```bash +# Run tests (headless Neovim + plenary test harness) +make test + +# Format check (what CI runs) +stylua --check . +``` + +`make test` runs `nvim --headless --clean -u ./scripts/test.lua`, which clones plenary.nvim into `.dependencies/` on first run, then executes all `tests/*_spec.lua` files via plenary's busted-style harness. + +## Project layout + +``` +plugin/CopilotChat.lua — Neovim plugin entry: commands, highlights, autocmds +lua/CopilotChat/ + init.lua — Main module: setup(), ask(), open/close/toggle, save/load + client.lua — Copilot API client (auth, streaming, tool calls) + config.lua — Default configuration schema + config/ — Sub-configs: functions, mappings, prompts, providers + constants.lua — Shared constants (roles, etc.) + completion.lua — Completion source + functions.lua — Built-in functions/tools exposed to the LLM + prompts.lua — Built-in prompt definitions + resources.lua — Resource handling + select.lua — Selection strategies (visual, buffer, diagnostics, git diff) + tiktoken.lua — Token counting via native tiktoken lib + health.lua — :checkhealth integration + notify.lua — Notification utilities + instructions/ — System prompt templates injected into LLM conversations (not agent guidance) + ui/ — Chat window, overlay, spinner + utils.lua — General utilities + utils/ — Utility modules: class, curl, diff, files, orderedmap, stringbuffer +queries/ — Treesitter queries for copilot-chat filetype +tests/ — Plenary busted-style specs (*_spec.lua) +scripts/ + test.lua — Test runner bootstrap (sets up plenary) + minimal.lua — Minimal reproduction config +doc/CopilotChat.txt — Auto-generated vimdoc (do NOT edit; generated from README by panvimdoc in CI) +``` + +## Style and formatting + +- **Lua formatter:** StyLua — 2-space indent, 120 column width, single quotes preferred, Unix line endings. Config in `.stylua.toml`. +- **Pre-commit hooks:** Prettier (markdown/json/yaml) + StyLua (Lua). CI will fail if StyLua check fails. +- **No linter** (no luacheck/selene configured). +- Type annotations use EmmyLua/LuaCATS `---@class`, `---@param`, `---@return` style. + +## Testing + +- Framework: plenary.nvim busted-style (`describe`, `it`, `before_each`, `after_each`, `assert`). +- Test files live in `tests/` and must be named `*_spec.lua`. +- CI runs tests against Neovim nightly with LuaJIT 2.1 and LuaRocks 3.12.2. +- Tests are unit-level (class, diff, utils, orderedmap, stringbuffer, functions, init). No integration tests requiring Copilot auth. + +## CI and releases + +- CI (`ci.yml`): lint (StyLua) + test (plenary) on all PRs; vimdoc generation on main only. +- Releases via release-please (`simple` type). Version tracked in `version.txt`. +- `doc/CopilotChat.txt` is auto-committed by CI — do not edit manually. +- `CHANGELOG.md` is managed by release-please — do not edit manually. + +## Key gotchas + +- The module is loaded as `require('CopilotChat')` (capital C's) — this matches the `lua/CopilotChat/` directory name. Case matters. +- `init.lua` uses lazy self-initialization via `__index` metamethod — accessing any field triggers `setup()` if not already called. +- `.dependencies/` is gitignored and auto-populated by the test runner (plenary clone). +- `build/` is gitignored and holds downloaded tiktoken native libraries. diff --git a/CHANGELOG.md b/CHANGELOG.md index 12e31352..529ec198 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,204 @@ # Changelog +## [4.7.4](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.7.3...v4.7.4) (2025-10-01) + + +### Bug Fixes + +* **url:** ensure main thread scheduling before fetching ([#1453](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1453)) ([7a8e238](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/7a8e238e36ea9e1df9d6309434a37bcdc15a9fae)) + +## [4.7.3](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.7.2...v4.7.3) (2025-09-28) + + +### Bug Fixes + +* **mappings:** make sure function resolution is not ran in fast context ([#1436](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1436)) ([16aa924](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/16aa92419d48957319a3f6b06c9d74ebdcead80c)) +* **os:** use vim.uv.os_uname for OS detection ([#1449](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1449)) ([df8efe9](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/df8efe9d2368c876d607b513bb384eaa8daf1d12)) + +## [4.7.2](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.7.1...v4.7.2) (2025-09-17) + + +### Bug Fixes + +* **chat:** do not create multiple chat isntances ([#1432](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1432)) ([74611b5](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/74611b56e813f50e905122387b92fb832ac9616c)) + +## [4.7.1](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.7.0...v4.7.1) (2025-09-16) + + +### Bug Fixes + +* **chat:** ensure user prompt is wrapped in a list ([#1427](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1427)) ([92dceb4](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/92dceb4ece955deea39fd1d7a57c26e66d5ce38d)), closes [#1426](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1426) +* **ui:** increase separator virt_text priority ([#1424](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1424)) ([9a63e83](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/9a63e83b9fade8e7fa50deb414d58b703352b13a)) + +## [4.7.0](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.6.0...v4.7.0) (2025-09-16) + + +### Features + +* **chat:** switch to treesitter based chat parsing ([#1394](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1394)) ([ba364fe](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/ba364fe04b36121a594435c3f54261c7a8e450a6)) +* **diff:** add experimental unified diff support, refactor handling ([#1392](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1392)) ([9fdf895](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/9fdf8951efff6ab4f46e06945e5d6425bdbf4f80)) +* **diff:** apply all code blocks for a file at once when showing diff ([#1409](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1409)) ([a88874e](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/a88874ef3663aea6bc09eb09c1df4a46ae8577f5)) +* **diff:** use diff-match-patch for better diff handling ([#1407](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1407)) ([35ad8ff](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/35ad8ff61f47c5546c036b9b7310ce0dd87e8d20)) +* **health:** require markdown parser and copilotchat query ([#1401](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1401)) ([f49df19](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/f49df19d5a8925d295ac6472c30b36584bd10d93)) + + +### Bug Fixes + +* **chat:** automatically start treesitter if not started ([#1410](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1410)) ([00d0fb3](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/00d0fb310ad364e76e306a6626a40b85fc5bbd98)) +* **client:** correct history handling for headless ask ([#1416](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1416)) ([d5ea51d](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/d5ea51d3f55dc1941c13cf0c44440de0a7f8019f)), closes [#1415](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1415) +* **provider:** safely call curl.post for model policy ([#1419](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1419)) ([2279dbe](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/2279dbe42702397c969aeaa5aebae475a16bcaa9)) +* **ui:** handle missing filename in chat block header ([#1406](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1406)) ([5c3a558](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/5c3a558f2d740df740735fbb3ea0be822004136d)) +* **ui:** improve help rendering and treesitter usage ([#1411](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1411)) ([559e754](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/559e75423774b3a291a58d33a1144c94444e52ac)) +* **ui:** preserve extra fields in chat messages ([#1399](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1399)) ([f2f523f](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/f2f523fe3fdb855da1b3dcabf4f2981cdc3b2c2d)) + + +### Performance Improvements + +* **chat:** optimize message storage and access ([#1403](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1403)) ([1041ad0](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/1041ad0034e65e4a63859172d31e7045c8975d87)) +* **chat:** simplify last line/column calculation ([#1402](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1402)) ([4a45e69](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/4a45e69de8ad2b72ef62ede5a554c68c9632e718)) +* **core:** do not require calling setup(), add lazy initialization ([#1413](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1413)) ([c15f65e](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/c15f65e5dc5151230c97f9fd4d386e513fc47c63)) + +## [4.6.0](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.5.1...v4.6.0) (2025-08-31) + + +### Features + +* **tiktoken:** improve token counting accuracy ([#1382](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1382)) ([a657694](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/a6576949e821e7abf9d0135e87576a51ec0e2e68)) + + +### Bug Fixes + +* **auth:** improve token saving and polling logic ([#1389](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1389)) ([b7728f4](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/b7728f450bfc95c7c749a322b3f130a16f80e35c)), closes [#1388](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1388) +* **chat:** correct header highlighting for multi-byte characters ([#1385](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1385)) ([f844a68](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/f844a684bd9e59b4bfc8882b4beb9be81cccfe23)), closes [#1384](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1384) +* **utils:** use proper empty check ([#1380](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1380)) ([c4b2e03](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/c4b2e03cd315c3fd9736dcf796cb20f6a4b9f801)) + +## [4.5.1](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.5.0...v4.5.1) (2025-08-28) + + +### Bug Fixes + +* **files:** generate absolute paths in code blocks ([#1378](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1378)) ([0f42bfc](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/0f42bfc44202ac4daa0b0f32e30ee4040f69bf35)), closes [#1377](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1377) + +## [4.5.0](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.4.1...v4.5.0) (2025-08-27) + + +### ⚠ BREAKING CHANGES + +* **select:** remove selection API in favor of resources +* **prompts:** callback receives the full response object instead of just content. + +### Features + +* **config:** add back selection source config option ([#1360](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1360)) ([c37ec3c](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/c37ec3cbdb2c29be73d7d0c48057d64306aa185f)) +* **docs:** add selection source to function table ([#1358](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1358)) ([c7d8547](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/c7d85478f775a65ca777cb9b2f685911cbcd8def)) +* **functions:** add configuration parameter to stop on tool failure ([#1364](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1364)) ([8d8f1e7](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/8d8f1e7ea594b2db3368e1fa62dd7d0d128e8860)) +* **functions:** add scope=selection to diagnostics ([#1351](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1351)) ([7b4a56b](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/7b4a56b29ed926b680ea936bd29fc8568b909d97)) +* **functions:** use cwd for file and grep commands ([#1373](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1373)) ([72216c0](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/72216c06fa2ce82406c3406d898a83c02db412a7)), closes [#1108](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1108) +* **prompts:** add support for providing system prompt as function ([#1318](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1318)) ([33e6ffc](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/33e6ffc63b77b0340731f2b50bd962045adf9366)) +* **prompts:** support buffer replacement in commit messages ([#1370](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1370)) ([afafec5](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/afafec51d2657cdde4fa839bac9cc203037ff60b)) +* **ui:** add auto_fold option for chat messages ([#1354](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1354)) ([80a0994](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/80a0994f01096705e0c24dd7ed09032594689e01)), closes [#1300](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1300) +* **ui:** improve auto folding logic in chat window ([#1356](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1356)) ([a7679e1](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/a7679e118af8038046b2fc4c841406db7fe71216)) + + +### Bug Fixes + +* **completion.lua:** check if window is valid before calling get_cursor ([#1359](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1359)) ([fdac67a](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/fdac67ab62085436b60003f420ae45f104bdf935)) +* **completion:** require tool uri for input completion ([#1328](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1328)) ([76cc416](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/76cc41653d63cfdb653f584624b4bf5e721f9514)) +* **config:** correct system_prompt type and callback usage ([#1325](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1325)) ([f99f1cd](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/f99f1cdef151ac1c950850cdcc0dbeefad00603c)) +* **makefile:** handle MSYS_NT as a valid Windows environment ([#1347](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1347)) ([9769bf9](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/9769bf9a1d215cf0dc22874712d5dcda53a075ee)) +* **prompt:** recursive system prompt expansion ([#1324](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1324)) ([26f7b4f](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/26f7b4f157ec75b168c05dc826b5fa3106cfc351)), closes [#1323](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1323) +* **select:** move config inside of marks function to prevent import loop ([#1361](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1361)) ([19a38dd](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/19a38dd34e1b61c49349552598e43b2559be2fc7)) +* **test:** run tests automatically in test script ([#1334](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1334)) ([c5057d3](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/c5057d3bb6d87e9b117b4f37162409d4c2c74e31)) +* **utils:** always exit insert mode in return_to_normal_mode ([#1313](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1313)) ([957e0a8](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/957e0a88c7d7df706380e09412c0b3f24af534ad)), closes [#1307](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1307) +* **utils:** avoid vim.filetype.match in fast event ([#1344](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1344)) ([7993e6d](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/7993e6d2a97cb851b8b3a4087005cfaf8427dbf3)) + + +### Miscellaneous Chores + +* mark next release as 4.5.0 ([#1315](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1315)) ([d12f6df](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/d12f6dff0e1641f933f9941b843d094bf505a82e)) + + +### Code Refactoring + +* **prompts:** support template substitution in system_prompt ([#1312](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1312)) ([081d4c2](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/081d4c20242140bb185ebee142a65454ad375f7d)) +* **select:** remove selection API in favor of resources ([a2429ed](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/a2429ed44438f694f1fca60429a7984022d4a9f0)) + +## [4.4.1](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.4.0...v4.4.1) (2025-08-12) + + +### Bug Fixes + +* **chat:** schedule chat initialization after window opens ([#1308](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1308)) ([15eebed](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/15eebed57156c3ae6a6bb6f73692dbf0547ba9e4)), closes [#1307](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1307) +* **prompts:** update tool instructions for system prompt ([#1304](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1304)) ([5e091bf](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/5e091bf1bf11827bec5130edc8d4f87fdd243716)) + +## [4.4.0](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.3.1...v4.4.0) (2025-08-09) + + +### Features + +* **completion:** add support for omnifunc and move completion logic to separate module ([1b04ddc](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/1b04ddcfe2d04363a3898998a1005ab2f493dff4)) +* **ui:** show assistant reasoning as virtual text ([#1299](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1299)) ([92777fb](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/92777fb98ad4de7496188f1e9de336d16871ac43)) + + +### Bug Fixes + +* **chat:** correct block selection logic by cursor ([#1301](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1301)) ([7e027df](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/7e027df6e95b622da25282285e84a9fc3806dcf1)) +* **info:** show resource uri instead of name in preview ([#1296](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1296)) ([90c3241](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/90c324177b33aec6d4c2bd5043c26bfc9fbc081f)) + +## [4.3.1](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.3.0...v4.3.1) (2025-08-08) + + +### Bug Fixes + +* **client:** store models cache per provider ([#1291](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1291)) ([ffb6659](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/ffb665919fdafecbfb8dceaf63243d614b50c497)) + +## [4.3.0](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.2.0...v4.3.0) (2025-08-08) + + +### ⚠ BREAKING CHANGES + +* **core:** Resource processing and embeddings support have been removed. Any configuration or usage relying on these features will no longer work. + +### Features + +* **keymap:** switch back to <Tab> for completion, add Copilot conflict note ([#1280](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1280)) ([59f5b43](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/59f5b43cdd3d27ab4e033882179d5cf028cf1302)) +* **setup:** trigger CopilotChatLoaded user autocommand ([#1288](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1288)) ([1189e37](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/1189e376fcad629edf6ffd186aa659f114df0271)) + + +### Bug Fixes + +* **functions:** do not require tool reference in tool prompt, just tool id ([#1273](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1273)) ([4d11c49](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/4d11c49b7a1afb573a3b09be5e10a78a3d41649d)), closes [#1269](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1269) +* **ui:** prevent italics from breaking glob pattern highlights ([#1274](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1274)) ([93110a5](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/93110a5f289aaed20adbbc13ec803f94dc6c63c6)) + + +### Miscellaneous Chores + +* mark next release as 4.3.0 ([#1275](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1275)) ([7576afa](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/7576afad950d4258cc7d455d8d42f7dccac4d19b)) + + +### Code Refactoring + +* **core:** remove resource processing and embeddings ([#1203](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1203)) ([f38319f](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/f38319fd8f3a7aaa1f75b78027032f9c07abc425)) + +## [4.2.0](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.1.0...v4.2.0) (2025-08-03) + + +### Features + +* **chat:** improve error handling ([#1265](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1265)) ([5c8b457](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/5c8b457d617dd1e533b826ff9f9b76ddf988756d)) + +## [4.1.0](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v4.0.0...v4.1.0) (2025-08-03) + + +### Features + +* **ui:** improve keyword highlights accuracy and performance ([#1260](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1260)) ([0d64e26](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/0d64e267a5aef3bd7d580a2c488bcc8b66d374a4)) + + +### Bug Fixes + +* **functions:** do not filter schema enum when entering input ([#1264](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1264)) ([8510f30](https://github.com/CopilotC-Nvim/CopilotChat.nvim/commit/8510f30ff8c338482e7c8a2a7d102519cc57315f)), closes [#1263](https://github.com/CopilotC-Nvim/CopilotChat.nvim/issues/1263) + ## [4.0.0](https://github.com/CopilotC-Nvim/CopilotChat.nvim/compare/v3.12.2...v4.0.0) (2025-08-02) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3be55bac..393f93c9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -50,67 +50,77 @@ Go to the CopilotChat.nvim in your GitHub account, select your branch, and click ![structure.drawio](https://github.com/CopilotC-Nvim/CopilotChat.nvim/assets/5115805/e7517736-0152-47a3-8cb9-36a5dffcb6cc) -### Main components +### Core -- [init.lua](/lua/CopilotChat/init.lua): This file initializes Copilot Chat - plugin. It includes functions for appending to the chat window, showing help, - completing, getting selection, opening and closing the chat window, asking - questions to the Copilot model, resetting the chat window, enabling/disabling - debug, and setting up the plugin. +- [init.lua](/lua/CopilotChat/init.lua): Main module. Plugin initialization + (`setup()`), chat lifecycle (`ask()`, `open()`, `close()`, `toggle()`, + `reset()`), save/load, and sticky prompt processing. -- [config.lua](/lua/CopilotChat/config.lua): This file contains default - configuration for Copilot Chat plugin. +- [client.lua](/lua/CopilotChat/client.lua): Copilot API client. Handles + authentication, model listing, streaming requests, and tool call execution. -- [copilot.lua](/lua/CopilotChat/copilot.lua): This file contains the core - functionality of the Copilot. It includes functions for generating unique IDs, - finding configuration paths, authenticating, asking questions to the Copilot, - generating embeddings, and managing the running job. +- [config.lua](/lua/CopilotChat/config.lua): Default configuration schema. -- [chat.lua](/lua/CopilotChat/chat.lua): This file manages the chat window. It - includes functions for creating, validating, appending to, clearing, opening, - closing, and focusing on the chat window. +- [config/](/lua/CopilotChat/config/): Sub-configs for + [functions](/lua/CopilotChat/config/functions.lua), + [mappings](/lua/CopilotChat/config/mappings.lua), + [prompts](/lua/CopilotChat/config/prompts.lua), and + [providers](/lua/CopilotChat/config/providers.lua). -- [diff.lua](/lua/CopilotChat/diff.lua): This file manages the diff window. It - includes functions for creating, validating, showing, and restoring the diff - window. +- [constants.lua](/lua/CopilotChat/constants.lua): Shared constants (plugin + name, roles). -- [select.lua](/lua/CopilotChat/select.lua): This file contains functions for - selecting and processing different types of data such as visual selection, - unnamed register, whole buffer, current line, diagnostics, and git diff. +### Chat and UI -- [context.lua](/lua/CopilotChat/context.lua): This file is responsible for - building an outline for a buffer and finding items for a query. It uses spatial - distance and relatedness to rank data. +- [ui/chat.lua](/lua/CopilotChat/ui/chat.lua): Chat window management. + Creating, appending to, clearing, opening, closing, and focusing the chat + window. Handles fold expressions and section parsing. -- [actions.lua](/lua/CopilotChat/actions.lua): This file manages the actions - that can be performed. It includes functions for getting help actions, prompt - actions, and picking an action from a list of actions using `vim.ui.select`. +- [ui/overlay.lua](/lua/CopilotChat/ui/overlay.lua): Overlay buffer used for + displaying diff previews and other transient content. -- [tiktoken.lua](/lua/CopilotChat/tiktoken.lua): This file manages integration - with Tiktoken library and is used for counting tokens. It includes functions - for setting up Tiktoken, checking its availability, encoding prompts, and - counting prompts. +- [ui/spinner.lua](/lua/CopilotChat/ui/spinner.lua): Loading spinner indicator + for the chat window. -- [health.lua](/lua/CopilotChat/health.lua): This file checks the health of the - plugin by checking if commands exist, checking if Lua libraries are installed, - and checking if a Treesitter parsers are available. +### Features -- [spinner.lua](/lua/CopilotChat/spinner.lua): This file manages a spinner that - is used for indicating loading status in chat window. +- [prompts.lua](/lua/CopilotChat/prompts.lua): Prompt resolution, custom + instruction loading, system prompt building, and sticky/resource/tool + parsing from user input. -- [utils.lua](/lua/CopilotChat/utils.lua): This file contains utility functions - for creating classes, getting the log file path, checking if the current - version of Neovim is stable, and joining multiple async functions. +- [functions.lua](/lua/CopilotChat/functions.lua): Built-in functions/tools + exposed to the LLM (e.g., file editing, searching). -- [debuginfo.lua](/lua/CopilotChat/debuginfo.lua): This file is used for - creating `:CopilotChatDebugInfo` command. +- [resources.lua](/lua/CopilotChat/resources.lua): Resource handling for file + and URL content retrieval with caching. -### Integrations +- [completion.lua](/lua/CopilotChat/completion.lua): Completion source for the + chat window (`@tools`, `/prompts`, `#resources`, `$models`). -- [telescope.lua](/lua/CopilotChat/integrations/telescope.lua): This file - integrates the Telescope plugin with CopilotChat. It includes a function for - picking an action from a list of actions. +- [select.lua](/lua/CopilotChat/select.lua): Selection strategies for providing + context (visual selection, buffer, diagnostics, git diff, etc.). -- [fzflua.lua](/lua/CopilotChat/integrations/fzflua.lua): This file integrates - the fzf-lua plugin with CopilotChat. It includes a function for picking an - action from a list of actions. +- [tiktoken.lua](/lua/CopilotChat/tiktoken.lua): Token counting via native + tiktoken library. + +- [instructions/](/lua/CopilotChat/instructions/): System prompt templates + injected into LLM conversations (edit formats, tool use instructions, custom + instructions wrapper). + +### Utilities + +- [utils.lua](/lua/CopilotChat/utils.lua): General utility functions. + +- [utils/](/lua/CopilotChat/utils/): Utility modules + [class.lua](/lua/CopilotChat/utils/class.lua) (OOP helper), + [curl.lua](/lua/CopilotChat/utils/curl.lua) (HTTP requests), + [diff.lua](/lua/CopilotChat/utils/diff.lua) (unified diff parsing and application), + [files.lua](/lua/CopilotChat/utils/files.lua) (file I/O and filetype detection), + [notify.lua](/lua/CopilotChat/utils/notify.lua) (pub/sub notification system for status and message events) + [orderedmap.lua](/lua/CopilotChat/utils/orderedmap.lua) (insertion-ordered map), + [stringbuffer.lua](/lua/CopilotChat/utils/stringbuffer.lua) (efficient string concatenation). + +### Other + +- [health.lua](/lua/CopilotChat/health.lua): `:checkhealth` integration. + Verifies commands, libraries, and Treesitter parsers. diff --git a/Makefile b/Makefile index c5d53c52..ebfe5768 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,9 @@ else ifeq ($(UNAME), Darwin) else ifeq ($(UNAME), Windows_NT) OS := windows EXT := dll +else ifneq ($(findstring MSYS_NT,$(UNAME)),) + OS := windows + EXT := dll else $(error Unsupported operating system: $(UNAME)) endif @@ -19,28 +22,12 @@ BUILD_DIR := build .PHONY: help install-cli install-pre-commit install test tiktoken clean -help: - @echo "Available commands:" - @echo " install-cli - Install Lua and Luarocks using Homebrew" - @echo " install-pre-commit - Install pre-commit using pip" - @echo " install - Install vusted using Luarocks" - @echo " test - Run tests using vusted" - @echo " tiktoken - Download tiktoken_core library" - @echo " clean - Remove build directory" - -install-cli: - brew install luarocks - brew install lua - install-pre-commit: pip install pre-commit pre-commit install -install: - luarocks install vusted - test: - vusted test + nvim --headless --clean -u ./scripts/test.lua all: luajit diff --git a/README.md b/README.md index c35320a4..62364cf3 100644 --- a/README.md +++ b/README.md @@ -18,12 +18,13 @@ https://github.com/user-attachments/assets/8cad5643-63b2-4641-a5c4-68bc313f20e6 CopilotChat.nvim brings GitHub Copilot Chat capabilities directly into Neovim with a focus on transparency and user control. -- 🤖 **Multiple AI Models** - GitHub Copilot (GPT-4o, Claude 3.7 Sonnet, Gemini 2.0 Flash) + custom providers (Ollama, Mistral.ai) -- 🔧 **Tool Calling** - LLM can use workspace functions (file reading, git operations, search) with your explicit approval -- 🔒 **Explicit Control** - Only shares what you specifically request - no background data collection -- 📝 **Interactive Chat** - Rich UI with completion, diffs, and quickfix integration +- 🤖 **Multiple AI Models** - GitHub Copilot (including GPT-4o, Gemini 2.5 Pro, Claude 4 Sonnet, Claude 3.7 Sonnet, Claude 3.5 Sonnet, o3-mini, o4-mini) + custom providers (Ollama, Mistral.ai). The exact list of available models depends on your [GitHub Copilot settings](https://github.com/settings/copilot/features) and the models provided by GitHub's API. +- 🔧 **Tool Calling** - LLM can call workspace functions (file reading, git operations, search) with manual approval or automatic execution for trusted tools +- 🔒 **Privacy First** - Only shares what you explicitly request - no background data collection +- 📝 **Interactive Chat** - Interactive UI with completion, diffs, and quickfix integration - 🎯 **Smart Prompts** - Composable templates and sticky prompts for consistent context -- ⚡ **Efficient** - Smart token usage with tiktoken counting and history management +- ⚡ **Token Efficient** - Resource replacement prevents duplicate context, history management via tiktoken counting +- 🔗 **Scriptable** - Comprehensive Lua API for automation and headless mode operation - 🔌 **Extensible** - [Custom functions](https://github.com/CopilotC-Nvim/CopilotChat.nvim/discussions/categories/functions) and [providers](https://github.com/CopilotC-Nvim/CopilotChat.nvim/discussions/categories/providers), plus integrations like [mcphub.nvim](https://github.com/ravitemer/mcphub.nvim) # Installation @@ -91,29 +92,13 @@ EOF # Core Concepts - **Resources** (`#`) - Add specific content (files, git diffs, URLs) to your prompt -- **Tools** (`@`) - Give LLM access to functions it can call with your approval +- **Tools** (`@`) - Give LLM access to functions it can call during the chat, with manual approval by default - **Sticky Prompts** (`> `) - Persist context across single chat session - **Models** (`$`) - Specify which AI model to use for the chat - **Prompts** (`/PromptName`) - Use predefined prompt templates for common tasks -## Examples - -```markdown -# Add specific file to context - -#file:src/main.lua - -# Give LLM access to workspace tools - -@copilot What files are in this project? - -# Sticky prompt that persists - -> #buffer:current -> You are a helpful coding assistant -``` - -When you use `@copilot`, the LLM can call functions like `glob`, `file`, `gitdiff` etc. You'll see the proposed function call and can approve/reject it before execution. +> [!TIP] +> Press `` after typing `#` or `@` to see available options and auto-complete. This is the easiest way to discover what's available! # Usage @@ -135,40 +120,54 @@ When you use `@copilot`, the LLM can call functions like `glob`, `file`, `gitdif ## Chat Key Mappings -| Insert | Normal | Action | -| ----------- | ------- | ------------------------------------------ | -| `` | - | Trigger/accept completion menu for tokens | -| `` | `q` | Close the chat window | -| `` | `` | Reset and clear the chat window | -| `` | `` | Submit the current prompt | -| - | `grr` | Toggle sticky prompt for line under cursor | -| - | `grx` | Clear all sticky prompts in prompt | -| `` | `` | Accept nearest diff | -| - | `gj` | Jump to section of nearest diff | -| - | `gqa` | Add all answers from chat to quickfix list | -| - | `gqd` | Add all diffs from chat to quickfix list | -| - | `gy` | Yank nearest diff to register | -| - | `gd` | Show diff between source and nearest diff | -| - | `gc` | Show info about current chat | -| - | `gh` | Show help message | +| Insert | Normal | Action | +| ------- | ------- | ---------------------------------------------------- | +| `` | - | **Autocomplete resources/files/options** (use this!) | +| `` | `q` | Close the chat window | +| `` | `` | Reset and clear the chat window | +| `` | `` | Submit the current prompt | +| `` | `` | Accept nearest diff | +| - | `gj` | Jump to section of nearest diff | +| - | `gqa` | Add all answers from chat to quickfix | +| - | `gqd` | Add all diffs from chat to quickfix | +| - | `gy` | Yank nearest diff to register | +| - | `gd` | Show diff between source and nearest diff | +| - | `gc` | Show info about current chat | +| - | `gh` | Show help message | + +**💡 Pro tip:** After typing `#`, `@`, `#buffer:`, or `#file:`, press `` to see available options. This is the fastest way to work! + +> [!NOTE] +> **Tab key not working?** Some plugins (e.g. `copilot.vim`) also map `` in insert mode. +> To fix conflicts, disable the other plugin's `` mapping: +> +> ```lua +> -- For copilot.vim +> vim.g.copilot_no_tab_map = true +> vim.keymap.set('i', '', 'copilot#Accept("\\")', { expr = true, replace_keycodes = false }) +> ``` +> +> Or customize CopilotChat keymaps in your config. ## Predefined Functions All predefined functions belong to the `copilot` group. -| Function | Description | Example Usage | -| ------------- | ------------------------------------------------ | ---------------------- | -| `buffer` | Retrieves content from a specific buffer | `#buffer` | -| `buffers` | Fetches content from multiple buffers | `#buffers:visible` | -| `diagnostics` | Collects code diagnostics (errors, warnings) | `#diagnostics:current` | -| `file` | Reads content from a specified file path | `#file:path/to/file` | -| `gitdiff` | Retrieves git diff information | `#gitdiff:staged` | -| `gitstatus` | Retrieves git status information | `#gitstatus` | -| `glob` | Lists filenames matching a pattern in workspace | `#glob:**/*.lua` | -| `grep` | Searches for a pattern across files in workspace | `#grep:TODO` | -| `quickfix` | Includes content of files in quickfix list | `#quickfix` | -| `register` | Provides access to specified Vim register | `#register:+` | -| `url` | Fetches content from a specified URL | `#url:https://...` | +| Function | Manual `#...` | Description | Available Options | +| ----------- | ------------- | ------------------------------------------------------ | --------------------------------------------------------------------- | +| `bash` | No | Executes a bash command and returns output | Tool-only (use `@copilot`) | +| `buffer` | Yes | Retrieves content from buffer(s) with diagnostics | `active`, `visible`, `listed`, `quickfix`, buffer number, or filename | +| `clipboard` | Yes | Provides access to system clipboard content | No options | +| `edit` | No | Applies a unified diff to a file | Tool-only (use `@copilot`) | +| `file` | Yes | Reads content from a specified file path | Any file path (use `` for completion) | +| `gitdiff` | Yes | Retrieves git diff information | `unstaged` (default), `staged`, or commit SHA | +| `glob` | Yes | Lists filenames matching a pattern in workspace | Any glob pattern (default: `**/*`) | +| `grep` | Yes | Searches for a pattern across files in workspace | Any search pattern | +| `selection` | Yes | Includes the current visual selection with diagnostics | No options | +| `url` | Yes | Fetches content from a specified URL | Any HTTPS URL | + +- **`#`** - Embeds output directly in your message (e.g., `#buffer:listed`, `#file:src/main.lua`) +- **`@`** - Makes function(s) available for LLM to call when needed (e.g., `@copilot`, `@file`) ## Predefined Prompts @@ -182,6 +181,55 @@ All predefined functions belong to the `copilot` group. | `Tests` | Generate tests for selected code | | `Commit` | Generate commit message with commitizen convention from staged changes | +## Resource Usage + +```markdown +# Current buffer + +#buffer:active + +# All open buffers (replaces old #buffers) + +#buffer:listed + +# All visible buffers + +#buffer:visible + +# Specific file + +#file:src/main.lua + +# Git changes + +#gitdiff:staged + +# URL content + +#url:https://example.com/docs +``` + +## Tool Usage + +When you use `@copilot`, the LLM can call functions from the `copilot` group such as `bash`, `edit`, `file`, `glob`, `grep`, and `gitdiff`. + +```markdown +# Give LLM access to workspace tools + +@copilot What files are in this project? + +# Sticky context with tools + +> #buffer:listed +> @copilot +> Refactor the authentication code +``` + +By default, tool calls require manual approval. Configure `trusted_tools` to automatically run specific tools (see [Functions](#functions)). + +> [!WARNING] +> `trusted_tools = true` allows the model to run every enabled tool without asking. Only use it if you fully trust the tool set and workspace. + # Configuration For all available configuration options, see [`lua/CopilotChat/config.lua`](lua/CopilotChat/config.lua). @@ -194,6 +242,7 @@ Most users only need to configure a few options: { model = 'gpt-4.1', -- AI model to use temperature = 0.1, -- Lower = focused, higher = creative + trusted_tools = nil, -- Require approval for all tool calls window = { layout = 'vertical', -- 'vertical', 'horizontal', 'float' width = 0.5, -- 50% of screen width @@ -216,21 +265,24 @@ Most users only need to configure a few options: }, headers = { - user = '👤 You: ', - assistant = '🤖 Copilot: ', - tool = '🔧 Tool: ', + user = '👤 You', + assistant = '🤖 Copilot', + tool = '🔧 Tool', }, + separator = '━━', - show_folds = false, -- Disable folding for cleaner look + auto_fold = true, -- Automatically folds non-assistant messages } ``` +`window.layout` also supports `'replace'` to reuse the current window. + ## Buffer Behavior ```lua -- Auto-command to customize chat buffer behavior vim.api.nvim_create_autocmd('BufEnter', { - pattern = 'copilot-*', + pattern = 'copilot-chat', callback = function() vim.opt_local.relativenumber = false vim.opt_local.number = false @@ -247,13 +299,13 @@ You can customize colors by setting highlight groups in your config: -- In your colorscheme or init.lua vim.api.nvim_set_hl(0, 'CopilotChatHeader', { fg = '#7C3AED', bold = true }) vim.api.nvim_set_hl(0, 'CopilotChatSeparator', { fg = '#374151' }) -vim.api.nvim_set_hl(0, 'CopilotChatKeyword', { fg = '#10B981', italic = true }) ``` Types of copilot highlights: - `CopilotChatHeader` - Header highlight in chat buffer - `CopilotChatSeparator` - Separator highlight in chat buffer +- `CopilotChatSelection` - Selection highlight in source buffer - `CopilotChatStatus` - Status and spinner in chat buffer - `CopilotChatHelp` - Help text in chat buffer - `CopilotChatResource` - Resource highlight in chat buffer (e.g. `#file`, `#gitdiff`) @@ -261,8 +313,8 @@ Types of copilot highlights: - `CopilotChatPrompt` - Prompt highlight in chat buffer (e.g. `/Explain`, `/Review`) - `CopilotChatModel` - Model highlight in chat buffer (e.g. `$gpt-4.1`) - `CopilotChatUri` - URI highlight in chat buffer (e.g. `##https://...`) -- `CopilotChatSelection` - Selection highlight in source buffer - `CopilotChatAnnotation` - Annotation highlight in chat buffer (file headers, tool call headers, tool call body) +- `CopilotChatAnnotationHeader` - Annotation header highlight in chat buffer ## Prompts @@ -281,7 +333,7 @@ Define your own prompts in the configuration: system_prompt = 'You are fascinated by pirates, so please respond in pirate speak.', }, NiceInstructions = { - system_prompt = 'You are a nice coding tutor, so please respond in a friendly and helpful manner.' .. require('CopilotChat.config.prompts').COPILOT_BASE.system_prompt, + system_prompt = 'You are a nice coding tutor, so please respond in a friendly and helpful manner.', } } } @@ -289,14 +341,46 @@ Define your own prompts in the configuration: ## Functions +Use `trusted_tools` to control which tool calls are executed automatically: + +```lua +{ + trusted_tools = nil, -- default: require approval for all tool calls + + -- trust all functions in a group + -- trusted_tools = 'copilot', + + -- trust specific functions by name or groups by name + -- trusted_tools = { 'file', 'glob', 'grep' }, + + -- trust every enabled tool call + -- trusted_tools = true, +} +``` + +**How tool trust works:** + +A tool is trusted when any of these match: + +- Its function definition sets `trusted = true` +- Its function name appears in `trusted_tools` +- Its function group appears in `trusted_tools` +- `trusted_tools = true` + +**Recommended setup:** Trust read-only functions like `file`, `glob`, or `grep` for a smoother workflow without compromising safety. + +> [!WARNING] +> Trusted tools run without asking for confirmation. Be especially careful with tools like `bash` and `edit`, which can change your workspace. + Define your own functions in the configuration with input handling and schema: ```lua { functions = { birthday = { - description = "Retrieves birthday information for a person", - uri = "birthday://{name}", + description = 'Retrieves birthday information for a person', + uri = 'birthday://{name}', + trusted = false, schema = { type = 'object', required = { 'name' }, @@ -314,34 +398,15 @@ Define your own functions in the configuration with input handling and schema: uri = 'birthday://' .. input.name, mimetype = 'text/plain', data = input.name .. ' birthday info', - } + }, } - end - } + end, + }, } } ``` -## Selections - -Control what content is automatically included: - -```lua -{ - -- Use visual selection, fallback to current line - selection = function(source) - return require('CopilotChat.select').visual(source) or - require('CopilotChat.select').line(source) - end, -} -``` - -**Available selections:** - -- `require('CopilotChat.select').visual` - Current visual selection -- `require('CopilotChat.select').buffer` - Entire buffer content -- `require('CopilotChat.select').line` - Current line content -- `require('CopilotChat.select').unnamed` - Unnamed register (last deleted/changed/yanked) +If a function has a `uri`, it can be used manually with `#birthday:Alice`. Functions without a `uri` are tool-only and can only be called by the model. ## Providers @@ -351,9 +416,9 @@ Add custom AI providers: { providers = { my_provider = { - get_url = function(opts) return "https://api.example.com/chat" end, - get_headers = function() return { ["Authorization"] = "Bearer " .. api_key } end, - get_models = function() return { { id = "gpt-4.1", name = "GPT-4.1 model" } } end, + get_url = function(opts) return 'https://api.example.com/chat' end, + get_headers = function() return { ['Authorization'] = 'Bearer ' .. api_key } end, + get_models = function() return { { id = 'gpt-4.1', name = 'GPT-4.1 model' } } end, prepare_input = require('CopilotChat.config.providers').copilot.prepare_input, prepare_output = require('CopilotChat.config.providers').copilot.prepare_output, } @@ -368,11 +433,8 @@ Add custom AI providers: -- Optional: Disable provider disabled?: boolean, - -- Optional: Embeddings provider name or function - embed?: string|function, - -- Optional: Extra info about the provider displayed in info panel - get_info?(): string[] + get_info?(headers: table): string[] -- Optional: Get extra request headers with optional expiration time get_headers?(): table, number?, @@ -388,28 +450,26 @@ Add custom AI providers: -- Optional: Get available models get_models?(headers: table): table, + + -- Optional: Resolve a user-facing model id to a provider model id + resolve_model?(headers: table, model: string): string, } ``` **Built-in providers:** - `copilot` - GitHub Copilot (default) -- `github_models` - GitHub Marketplace models (disabled by default) -- `copilot_embeddings` - Copilot embeddings provider +- `github_models` - GitHub Models (disabled by default) # API Reference ## Core ```lua -local chat = require("CopilotChat") +local chat = require('CopilotChat') -- Basic Chat Functions chat.ask(prompt, config) -- Ask a question with optional config -chat.response() -- Get the last response text -chat.resolve_prompt() -- Resolve prompt references -chat.resolve_functions() -- Resolve functions that are available for automatic use by LLM (WARN: async, requires plenary.async.run) -chat.resolve_model() -- Resolve model from prompt (WARN: async, requires plenary.async.run) -- Window Management chat.open(config) -- Open chat window with optional config @@ -418,27 +478,13 @@ chat.toggle(config) -- Toggle chat window visibility with optional con chat.reset() -- Reset the chat chat.stop() -- Stop current output --- Source Management -chat.get_source() -- Get the current source buffer and window -chat.set_source(winnr) -- Set the source window - --- Selection Management -chat.get_selection() -- Get the current selection -chat.set_selection(bufnr, start_line, end_line, clear) -- Set or clear selection - -- Prompt & Model Management chat.select_prompt(config) -- Open prompt selector with optional config chat.select_model() -- Open model selector -chat.prompts() -- Get all available prompts - --- Completion -chat.trigger_complete() -- Trigger completion in chat window -chat.complete_info() -- Get completion info for custom providers -chat.complete_items() -- Get completion items (WARN: async, requires plenary.async.run) -- History Management -chat.save(name, history_path) -- Save chat history chat.load(name, history_path) -- Load chat history +chat.save(name, history_path) -- Save chat history -- Configuration chat.setup(config) -- Update configuration @@ -450,16 +496,17 @@ chat.log_level(level) -- Set log level (debug, info, etc.) You can also access the chat window UI methods through the `chat.chat` object: ```lua -local window = require("CopilotChat").chat +local window = require('CopilotChat').chat -- Chat UI State window:visible() -- Check if chat window is visible window:focused() -- Check if chat window is focused -- Message Management -window:get_message(role) -- Get last chat message by role (user, assistant, tool) +window:get_message(role, cursor) -- Get chat message by role, either last or closest to cursor window:add_message({ role, content }, replace) -- Add or replace a message in chat -window:add_sticky(sticky) -- Add sticky prompt to chat message +window:remove_message(role, cursor) -- Remove chat message by role, either last or closest to cursor +window:get_block(role, cursor) -- Get code block by role, either last or closest to cursor -- Content Management window:append(text) -- Append text to chat window @@ -467,36 +514,53 @@ window:clear() -- Clear chat window content window:start() -- Start writing to chat window window:finish() -- Finish writing to chat window +-- Source Management +window:get_source() -- Get the current source buffer and window +window:set_source(winnr) -- Set the source window + -- Navigation window:follow() -- Move cursor to end of chat content window:focus() -- Focus the chat window -- Advanced Features -window:get_closest_message(role) -- Get message closest to cursor -window:get_closest_block(role) -- Get code block closest to cursor -window:overlay(opts) -- Show overlay with specified options +window:overlay(opts) -- Show overlay with specified options +``` + +## Prompt parser + +```lua +local parser = require('CopilotChat.prompts') + +parser.resolve_prompt() -- Resolve prompt references +parser.resolve_tools() -- Resolve tools shared with the model via @... +parser.resolve_functions() -- Resolve manual function/resource references via #... +parser.resolve_model() -- Resolve model from prompt (WARN: async, requires plenary.async.run) ``` ## Example Usage ```lua -- Open chat, ask a question and handle response -require("CopilotChat").open() -require("CopilotChat").ask("#buffer Explain this code", { +require('CopilotChat').open() +require('CopilotChat').ask('#buffer Explain this code', { callback = function(response) - vim.notify("Got response: " .. response:sub(1, 50) .. "...") - return response + vim.notify('Got response: ' .. vim.trim(response.content):sub(1, 50) .. '...') end, }) -- Save and load chat history -require("CopilotChat").save("my_debugging_session") -require("CopilotChat").load("my_debugging_session") +require('CopilotChat').save('my_debugging_session') +require('CopilotChat').load('my_debugging_session') -- Use custom sticky and model -require("CopilotChat").ask("How can I optimize this?", { - model = "gpt-4.1", - sticky = {"#buffer", "#gitdiff:staged"} +require('CopilotChat').ask('How can I optimize this?', { + model = 'gpt-4.1', + sticky = { '#buffer', '#gitdiff:staged' }, +}) + +-- Automatically trust a small read-only tool set +require('CopilotChat').setup({ + trusted_tools = { 'file', 'glob', 'grep' }, }) ``` @@ -518,7 +582,6 @@ cd CopilotChat.nvim 2. Install development dependencies: ```bash -# Install pre-commit hooks make install-pre-commit ``` @@ -528,6 +591,12 @@ To run tests: make test ``` +To run the same formatting check as CI: + +```bash +stylua --check . +``` + ## Contributing 1. Fork the repository @@ -625,6 +694,23 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d Aaron D Borden
Aaron D Borden

💻 Md. Iftakhar Awal Chowdhury
Md. Iftakhar Awal Chowdhury

💻 📖 Danilo Horta
Danilo Horta

💻 + Mihamina Rakotomandimby
Mihamina Rakotomandimby

📖 💻 + Ajmal S
Ajmal S

💻 + + + Samiul Islam
Samiul Islam

💻 + Rui Costa
Rui Costa

💻 + CTCHEN
CTCHEN

💻 + Tobias Wölfel
Tobias Wölfel

💻 + Alexander Garcia
Alexander Garcia

💻 + Max Kharandziuk
Max Kharandziuk

💻 + Xinyu Xiang
Xinyu Xiang

💻 + + + junqizhang
junqizhang

💻 + Calum Lynch
Calum Lynch

💻 + sirjls
sirjls

💻 + Vladimir Kolchurin
Vladimir Kolchurin

💻 diff --git a/doc/CopilotChat.txt b/doc/CopilotChat.txt index f49db7f1..ec5ad985 100644 --- a/doc/CopilotChat.txt +++ b/doc/CopilotChat.txt @@ -1,4 +1,5 @@ -*CopilotChat.txt* For NVIM v0.8.0 Last change: 2025 August 03 +*CopilotChat.txt* + For NVIM v0.8.0 Last change: 2026 April 26 ============================================================================== Table of Contents *CopilotChat-table-of-contents* @@ -10,12 +11,13 @@ Table of Contents *CopilotChat-table-of-contents* - lazy.nvim |CopilotChat-lazy.nvim| - vim-plug |CopilotChat-vim-plug| 2. Core Concepts |CopilotChat-core-concepts| - - Examples |CopilotChat-examples| 3. Usage |CopilotChat-usage| - Commands |CopilotChat-commands| - Chat Key Mappings |CopilotChat-chat-key-mappings| - Predefined Functions |CopilotChat-predefined-functions| - Predefined Prompts |CopilotChat-predefined-prompts| + - Resource Usage |CopilotChat-resource-usage| + - Tool Usage |CopilotChat-tool-usage| 4. Configuration |CopilotChat-configuration| - Quick Setup |CopilotChat-quick-setup| - Window & Appearance |CopilotChat-window-&-appearance| @@ -23,11 +25,11 @@ Table of Contents *CopilotChat-table-of-contents* - Highlights |CopilotChat-highlights| - Prompts |CopilotChat-prompts| - Functions |CopilotChat-functions| - - Selections |CopilotChat-selections| - Providers |CopilotChat-providers| 5. API Reference |CopilotChat-api-reference| - Core |CopilotChat-core| - Chat Window |CopilotChat-chat-window| + - Prompt parser |CopilotChat-prompt-parser| - Example Usage |CopilotChat-example-usage| 6. Development |CopilotChat-development| - Setup |CopilotChat-setup| @@ -39,12 +41,13 @@ Table of Contents *CopilotChat-table-of-contents* CopilotChat.nvim brings GitHub Copilot Chat capabilities directly into Neovim with a focus on transparency and user control. -- 🤖 **Multiple AI Models** - GitHub Copilot (GPT-4o, Claude 3.7 Sonnet, Gemini 2.0 Flash) + custom providers (Ollama, Mistral.ai) -- 🔧 **Tool Calling** - LLM can use workspace functions (file reading, git operations, search) with your explicit approval -- 🔒 **Explicit Control** - Only shares what you specifically request - no background data collection -- 📝 **Interactive Chat** - Rich UI with completion, diffs, and quickfix integration +- 🤖 **Multiple AI Models** - GitHub Copilot (including GPT-4o, Gemini 2.5 Pro, Claude 4 Sonnet, Claude 3.7 Sonnet, Claude 3.5 Sonnet, o3-mini, o4-mini) + custom providers (Ollama, Mistral.ai). The exact list of available models depends on your GitHub Copilot settings and the models provided by GitHub’s API. +- 🔧 **Tool Calling** - LLM can call workspace functions (file reading, git operations, search) with manual approval or automatic execution for trusted tools +- 🔒 **Privacy First** - Only shares what you explicitly request - no background data collection +- 📝 **Interactive Chat** - Interactive UI with completion, diffs, and quickfix integration - 🎯 **Smart Prompts** - Composable templates and sticky prompts for consistent context -- ⚡ **Efficient** - Smart token usage with tiktoken counting and history management +- ⚡ **Token Efficient** - Resource replacement prevents duplicate context, history management via tiktoken counting +- 🔗 **Scriptable** - Comprehensive Lua API for automation and headless mode operation - 🔌 **Extensible** - Custom functions and providers , plus integrations like mcphub.nvim @@ -124,33 +127,14 @@ VIM-PLUG *CopilotChat-vim-plug* 2. Core Concepts *CopilotChat-core-concepts* - **Resources** (`#`) - Add specific content (files, git diffs, URLs) to your prompt -- **Tools** (`@`) - Give LLM access to functions it can call with your approval +- **Tools** (`@`) - Give LLM access to functions it can call during the chat, with manual approval by default - **Sticky Prompts** (`> `) - Persist context across single chat session - **Models** (`$`) - Specify which AI model to use for the chat - **Prompts** (`/PromptName`) - Use predefined prompt templates for common tasks -EXAMPLES *CopilotChat-examples* - ->markdown - # Add specific file to context - - #file:src/main.lua - - # Give LLM access to workspace tools - - @copilot What files are in this project? - - # Sticky prompt that persists - - > #buffer:current - > You are a helpful coding assistant -< - -When you use `@copilot`, the LLM can call functions like `glob`, `file`, -`gitdiff` etc. You’ll see the proposed function call and can approve/reject -it before execution. - + [!TIP] Press `` after typing `#` or `@` to see available options and + auto-complete. This is the easiest way to discover what’s available! ============================================================================== 3. Usage *CopilotChat-usage* @@ -174,56 +158,90 @@ COMMANDS *CopilotChat-commands* CHAT KEY MAPPINGS *CopilotChat-chat-key-mappings* - Insert Normal Action - ----------- -------- -------------------------------------------- - - Trigger/accept completion menu for tokens - q Close the chat window - Reset and clear the chat window - Submit the current prompt - - grr Toggle sticky prompt for line under cursor - - grx Clear all sticky prompts in prompt - Accept nearest diff - - gj Jump to section of nearest diff - - gqa Add all answers from chat to quickfix list - - gqd Add all diffs from chat to quickfix list - - gy Yank nearest diff to register - - gd Show diff between source and nearest diff - - gc Show info about current chat - - gh Show help message + ------------------------------------------------------------------------- + Insert Normal Action + -------- -------- ------------------------------------------------------- + - Autocomplete resources/files/options (use this!) + + q Close the chat window + + Reset and clear the chat window + + Submit the current prompt + + Accept nearest diff + + - gj Jump to section of nearest diff + + - gqa Add all answers from chat to quickfix + + - gqd Add all diffs from chat to quickfix + + - gy Yank nearest diff to register + + - gd Show diff between source and nearest diff + + - gc Show info about current chat + + - gh Show help message + ------------------------------------------------------------------------- +**💡 Pro tip:** After typing `#`, `@`, `#buffer:`, or `#file:`, press `` +to see available options. This is the fastest way to work! + + + [!NOTE] **Tab key not working?** Some plugins (e.g. `copilot.vim`) also map + `` in insert mode. To fix conflicts, disable the other plugin’s `` + mapping: + >lua + -- For copilot.vim + vim.g.copilot_no_tab_map = true + vim.keymap.set('i', '', 'copilot#Accept("\\")', { expr = true, replace_keycodes = false }) + < + Or customize CopilotChat keymaps in your config. PREDEFINED FUNCTIONS *CopilotChat-predefined-functions* All predefined functions belong to the `copilot` group. - ------------------------------------------------------------------------------ - Function Description Example Usage - ------------- ----------------------------------------- ---------------------- - buffer Retrieves content from a specific buffer #buffer + --------------------------------------------------------------------------------- + Function Manual Description Available Options + #... + ----------- -------- -------------------------- --------------------------------- + bash No Executes a bash command Tool-only (use @copilot) + and returns output - buffers Fetches content from multiple buffers #buffers:visible + buffer Yes Retrieves content from active, visible, listed, + buffer(s) with diagnostics quickfix, buffer number, or + filename - diagnostics Collects code diagnostics (errors, #diagnostics:current - warnings) + clipboard Yes Provides access to system No options + clipboard content - file Reads content from a specified file path #file:path/to/file + edit No Applies a unified diff to Tool-only (use @copilot) + a file - gitdiff Retrieves git diff information #gitdiff:staged + file Yes Reads content from a Any file path (use for + specified file path completion) - gitstatus Retrieves git status information #gitstatus + gitdiff Yes Retrieves git diff unstaged (default), staged, or + information commit SHA - glob Lists filenames matching a pattern in #glob:**/*.lua - workspace + glob Yes Lists filenames matching a Any glob pattern (default: **/*) + pattern in workspace - grep Searches for a pattern across files in #grep:TODO - workspace + grep Yes Searches for a pattern Any search pattern + across files in workspace - quickfix Includes content of files in quickfix #quickfix - list + selection Yes Includes the current No options + visual selection with + diagnostics - register Provides access to specified Vim register #register:+ + url Yes Fetches content from a Any HTTPS URL + specified URL + --------------------------------------------------------------------------------- +- **#** - Embeds output directly in your message (e.g., `#buffer:listed`, `#file:src/main.lua`) +- **@** - Makes function(s) available for LLM to call when needed (e.g., `@copilot`, `@file`) - url Fetches content from a specified URL #url:https://... - ------------------------------------------------------------------------------ PREDEFINED PROMPTS *CopilotChat-predefined-prompts* @@ -246,6 +264,59 @@ PREDEFINED PROMPTS *CopilotChat-predefined-prompts* changes ------------------------------------------------------------------------- +RESOURCE USAGE *CopilotChat-resource-usage* + +>markdown + # Current buffer + + #buffer:active + + # All open buffers (replaces old #buffers) + + #buffer:listed + + # All visible buffers + + #buffer:visible + + # Specific file + + #file:src/main.lua + + # Git changes + + #gitdiff:staged + + # URL content + + #url:https://example.com/docs +< + + +TOOL USAGE *CopilotChat-tool-usage* + +When you use `@copilot`, the LLM can call functions from the `copilot` group +such as `bash`, `edit`, `file`, `glob`, `grep`, and `gitdiff`. + +>markdown + # Give LLM access to workspace tools + + @copilot What files are in this project? + + # Sticky context with tools + + > #buffer:listed + > @copilot + > Refactor the authentication code +< + +By default, tool calls require manual approval. Configure `trusted_tools` to +automatically run specific tools (see |CopilotChat-functions|). + + + [!WARNING] `trusted_tools = true` allows the model to run every enabled tool + without asking. Only use it if you fully trust the tool set and workspace. + ============================================================================== 4. Configuration *CopilotChat-configuration* @@ -261,6 +332,7 @@ Most users only need to configure a few options: { model = 'gpt-4.1', -- AI model to use temperature = 0.1, -- Lower = focused, higher = creative + trusted_tools = nil, -- Require approval for all tool calls window = { layout = 'vertical', -- 'vertical', 'horizontal', 'float' width = 0.5, -- 50% of screen width @@ -284,22 +356,25 @@ WINDOW & APPEARANCE *CopilotChat-window-&-appearance* }, headers = { - user = '👤 You: ', - assistant = '🤖 Copilot: ', - tool = '🔧 Tool: ', + user = '👤 You', + assistant = '🤖 Copilot', + tool = '🔧 Tool', }, + separator = '━━', - show_folds = false, -- Disable folding for cleaner look + auto_fold = true, -- Automatically folds non-assistant messages } < +`window.layout` also supports `'replace'` to reuse the current window. + BUFFER BEHAVIOR *CopilotChat-buffer-behavior* >lua -- Auto-command to customize chat buffer behavior vim.api.nvim_create_autocmd('BufEnter', { - pattern = 'copilot-*', + pattern = 'copilot-chat', callback = function() vim.opt_local.relativenumber = false vim.opt_local.number = false @@ -317,22 +392,22 @@ You can customize colors by setting highlight groups in your config: -- In your colorscheme or init.lua vim.api.nvim_set_hl(0, 'CopilotChatHeader', { fg = '#7C3AED', bold = true }) vim.api.nvim_set_hl(0, 'CopilotChatSeparator', { fg = '#374151' }) - vim.api.nvim_set_hl(0, 'CopilotChatKeyword', { fg = '#10B981', italic = true }) < Types of copilot highlights: - `CopilotChatHeader` - Header highlight in chat buffer - `CopilotChatSeparator` - Separator highlight in chat buffer +- `CopilotChatSelection` - Selection highlight in source buffer - `CopilotChatStatus` - Status and spinner in chat buffer - `CopilotChatHelp` - Help text in chat buffer -- `CopilotChatResource` - Resource highlight in chat buffer (e.g. `#file`, `#gitdiff`) -- `CopilotChatTool` - Tool call highlight in chat buffer (e.g. `@copilot`) -- `CopilotChatPrompt` - Prompt highlight in chat buffer (e.g. `/Explain`, `/Review`) -- `CopilotChatModel` - Model highlight in chat buffer (e.g. `$gpt-4.1`) -- `CopilotChatUri` - URI highlight in chat buffer (e.g. `##https://...`) -- `CopilotChatSelection` - Selection highlight in source buffer +- `CopilotChatResource` - Resource highlight in chat buffer (e.g. `#file`, `#gitdiff`) +- `CopilotChatTool` - Tool call highlight in chat buffer (e.g. `@copilot`) +- `CopilotChatPrompt` - Prompt highlight in chat buffer (e.g. `/Explain`, `/Review`) +- `CopilotChatModel` - Model highlight in chat buffer (e.g. `$gpt-4.1`) +- `CopilotChatUri` - URI highlight in chat buffer (e.g. `##https://...`) - `CopilotChatAnnotation` - Annotation highlight in chat buffer (file headers, tool call headers, tool call body) +- `CopilotChatAnnotationHeader` - Annotation header highlight in chat buffer PROMPTS *CopilotChat-prompts* @@ -352,7 +427,7 @@ Define your own prompts in the configuration: system_prompt = 'You are fascinated by pirates, so please respond in pirate speak.', }, NiceInstructions = { - system_prompt = 'You are a nice coding tutor, so please respond in a friendly and helpful manner.' .. require('CopilotChat.config.prompts').COPILOT_BASE.system_prompt, + system_prompt = 'You are a nice coding tutor, so please respond in a friendly and helpful manner.', } } } @@ -361,14 +436,47 @@ Define your own prompts in the configuration: FUNCTIONS *CopilotChat-functions* +Use `trusted_tools` to control which tool calls are executed automatically: + +>lua + { + trusted_tools = nil, -- default: require approval for all tool calls + + -- trust all functions in a group + -- trusted_tools = 'copilot', + + -- trust specific functions by name or groups by name + -- trusted_tools = { 'file', 'glob', 'grep' }, + + -- trust every enabled tool call + -- trusted_tools = true, + } +< + +**How tool trust works:** + +A tool is trusted when any of these match: + +- Its function definition sets `trusted = true` +- Its function name appears in `trusted_tools` +- Its function group appears in `trusted_tools` +- `trusted_tools = true` + +**Recommended setup:** Trust read-only functions like `file`, `glob`, or `grep` +for a smoother workflow without compromising safety. + + + [!WARNING] Trusted tools run without asking for confirmation. Be especially + careful with tools like `bash` and `edit`, which can change your workspace. Define your own functions in the configuration with input handling and schema: >lua { functions = { birthday = { - description = "Retrieves birthday information for a person", - uri = "birthday://{name}", + description = 'Retrieves birthday information for a person', + uri = 'birthday://{name}', + trusted = false, schema = { type = 'object', required = { 'name' }, @@ -386,35 +494,16 @@ Define your own functions in the configuration with input handling and schema: uri = 'birthday://' .. input.name, mimetype = 'text/plain', data = input.name .. ' birthday info', - } + }, } - end - } + end, + }, } } < - -SELECTIONS *CopilotChat-selections* - -Control what content is automatically included: - ->lua - { - -- Use visual selection, fallback to current line - selection = function(source) - return require('CopilotChat.select').visual(source) or - require('CopilotChat.select').line(source) - end, - } -< - -**Available selections:** - -- `require('CopilotChat.select').visual` - Current visual selection -- `require('CopilotChat.select').buffer` - Entire buffer content -- `require('CopilotChat.select').line` - Current line content -- `require('CopilotChat.select').unnamed` - Unnamed register (last deleted/changed/yanked) +If a function has a `uri`, it can be used manually with `#birthday:Alice`. +Functions without a `uri` are tool-only and can only be called by the model. PROVIDERS *CopilotChat-providers* @@ -425,9 +514,9 @@ Add custom AI providers: { providers = { my_provider = { - get_url = function(opts) return "https://api.example.com/chat" end, - get_headers = function() return { ["Authorization"] = "Bearer " .. api_key } end, - get_models = function() return { { id = "gpt-4.1", name = "GPT-4.1 model" } } end, + get_url = function(opts) return 'https://api.example.com/chat' end, + get_headers = function() return { ['Authorization'] = 'Bearer ' .. api_key } end, + get_models = function() return { { id = 'gpt-4.1', name = 'GPT-4.1 model' } } end, prepare_input = require('CopilotChat.config.providers').copilot.prepare_input, prepare_output = require('CopilotChat.config.providers').copilot.prepare_output, } @@ -442,11 +531,8 @@ Add custom AI providers: -- Optional: Disable provider disabled?: boolean, - -- Optional: Embeddings provider name or function - embed?: string|function, - -- Optional: Extra info about the provider displayed in info panel - get_info?(): string[] + get_info?(headers: table): string[] -- Optional: Get extra request headers with optional expiration time get_headers?(): table, number?, @@ -462,14 +548,16 @@ Add custom AI providers: -- Optional: Get available models get_models?(headers: table): table, + + -- Optional: Resolve a user-facing model id to a provider model id + resolve_model?(headers: table, model: string): string, } < **Built-in providers:** - `copilot` - GitHub Copilot (default) -- `github_models` - GitHub Marketplace models (disabled by default) -- `copilot_embeddings` - Copilot embeddings provider +- `github_models` - GitHub Models (disabled by default) ============================================================================== @@ -479,14 +567,10 @@ Add custom AI providers: CORE *CopilotChat-core* >lua - local chat = require("CopilotChat") + local chat = require('CopilotChat') -- Basic Chat Functions chat.ask(prompt, config) -- Ask a question with optional config - chat.response() -- Get the last response text - chat.resolve_prompt() -- Resolve prompt references - chat.resolve_functions() -- Resolve functions that are available for automatic use by LLM (WARN: async, requires plenary.async.run) - chat.resolve_model() -- Resolve model from prompt (WARN: async, requires plenary.async.run) -- Window Management chat.open(config) -- Open chat window with optional config @@ -495,27 +579,13 @@ CORE *CopilotChat-core* chat.reset() -- Reset the chat chat.stop() -- Stop current output - -- Source Management - chat.get_source() -- Get the current source buffer and window - chat.set_source(winnr) -- Set the source window - - -- Selection Management - chat.get_selection() -- Get the current selection - chat.set_selection(bufnr, start_line, end_line, clear) -- Set or clear selection - -- Prompt & Model Management chat.select_prompt(config) -- Open prompt selector with optional config chat.select_model() -- Open model selector - chat.prompts() -- Get all available prompts - - -- Completion - chat.trigger_complete() -- Trigger completion in chat window - chat.complete_info() -- Get completion info for custom providers - chat.complete_items() -- Get completion items (WARN: async, requires plenary.async.run) -- History Management - chat.save(name, history_path) -- Save chat history chat.load(name, history_path) -- Load chat history + chat.save(name, history_path) -- Save chat history -- Configuration chat.setup(config) -- Update configuration @@ -528,16 +598,17 @@ CHAT WINDOW *CopilotChat-chat-window* You can also access the chat window UI methods through the `chat.chat` object: >lua - local window = require("CopilotChat").chat + local window = require('CopilotChat').chat -- Chat UI State window:visible() -- Check if chat window is visible window:focused() -- Check if chat window is focused -- Message Management - window:get_message(role) -- Get last chat message by role (user, assistant, tool) + window:get_message(role, cursor) -- Get chat message by role, either last or closest to cursor window:add_message({ role, content }, replace) -- Add or replace a message in chat - window:add_sticky(sticky) -- Add sticky prompt to chat message + window:remove_message(role, cursor) -- Remove chat message by role, either last or closest to cursor + window:get_block(role, cursor) -- Get code block by role, either last or closest to cursor -- Content Management window:append(text) -- Append text to chat window @@ -545,14 +616,28 @@ You can also access the chat window UI methods through the `chat.chat` object: window:start() -- Start writing to chat window window:finish() -- Finish writing to chat window + -- Source Management + window:get_source() -- Get the current source buffer and window + window:set_source(winnr) -- Set the source window + -- Navigation window:follow() -- Move cursor to end of chat content window:focus() -- Focus the chat window -- Advanced Features - window:get_closest_message(role) -- Get message closest to cursor - window:get_closest_block(role) -- Get code block closest to cursor - window:overlay(opts) -- Show overlay with specified options + window:overlay(opts) -- Show overlay with specified options +< + + +PROMPT PARSER *CopilotChat-prompt-parser* + +>lua + local parser = require('CopilotChat.prompts') + + parser.resolve_prompt() -- Resolve prompt references + parser.resolve_tools() -- Resolve tools shared with the model via @... + parser.resolve_functions() -- Resolve manual function/resource references via #... + parser.resolve_model() -- Resolve model from prompt (WARN: async, requires plenary.async.run) < @@ -560,22 +645,26 @@ EXAMPLE USAGE *CopilotChat-example-usage* >lua -- Open chat, ask a question and handle response - require("CopilotChat").open() - require("CopilotChat").ask("#buffer Explain this code", { + require('CopilotChat').open() + require('CopilotChat').ask('#buffer Explain this code', { callback = function(response) - vim.notify("Got response: " .. response:sub(1, 50) .. "...") - return response + vim.notify('Got response: ' .. vim.trim(response.content):sub(1, 50) .. '...') end, }) -- Save and load chat history - require("CopilotChat").save("my_debugging_session") - require("CopilotChat").load("my_debugging_session") + require('CopilotChat').save('my_debugging_session') + require('CopilotChat').load('my_debugging_session') -- Use custom sticky and model - require("CopilotChat").ask("How can I optimize this?", { - model = "gpt-4.1", - sticky = {"#buffer", "#gitdiff:staged"} + require('CopilotChat').ask('How can I optimize this?', { + model = 'gpt-4.1', + sticky = { '#buffer', '#gitdiff:staged' }, + }) + + -- Automatically trust a small read-only tool set + require('CopilotChat').setup({ + trusted_tools = { 'file', 'glob', 'grep' }, }) < @@ -601,7 +690,6 @@ To set up the environment: 1. Install development dependencies: >bash - # Install pre-commit hooks make install-pre-commit < @@ -611,6 +699,12 @@ To run tests: make test < +To run the same formatting check as CI: + +>bash + stylua --check . +< + CONTRIBUTING *CopilotChat-contributing* @@ -629,7 +723,7 @@ See CONTRIBUTING.md for detailed guidelines. Thanks goes to these wonderful people (emoji key ): -gptlang💻 📖Dung Duc Huynh (Kaka)💻 📖Ahmed Haracic💻Trí Thiện Nguyễn💻He Zhizhou💻Guruprakash Rajakkannu💻kristofka💻PostCyberPunk📖Katsuhiko Nishimra💻Erno Hopearuoho💻Shaun Garwood💻neutrinoA4💻 📖Jack Muratore💻Adriel Velazquez💻 📖Tomas Slusny💻 📖Nisal📖Tobias Gårdhus📖Petr Dlouhý📖Dylan Madisetti💻Aaron Weisberg💻 📖Jose Tlacuilo💻 📖Kevin Traver💻 📖dTry💻Arata Furukawa💻Ling💻Ivan Frolov💻Folke Lemaitre💻 📖GitMurf💻Dmitrii Lipin💻jinzhongjia📖guill💻Sjon-Paul Brown💻Renzo Mondragón💻 📖fjchen7💻Radosław Woźniak💻JakubPecenka💻thomastthai📖Tomáš Janoušek💻Toddneal Stallworth📖Sergey Alexandrov💻Léopold Mebazaa💻JunKi Jin💻abdennourzahaf📖Josiah💻Tony Fischer💻 📖Kohei Wada💻Sebastian Yaghoubi📖johncming💻Rokas Brazdžionis💻Sola📖 💻Mani Chandra💻Nischal Basuti📖Teo Ljungberg💻Joe Price💻Yufan You📖 💻Manish Kumar💻Anton Ždanov📖 💻Fredrik Averpil💻Aaron D Borden💻Md. Iftakhar Awal Chowdhury💻 📖Danilo Horta💻This project follows the all-contributors +gptlang💻 📖Dung Duc Huynh (Kaka)💻 📖Ahmed Haracic💻Trí Thiện Nguyễn💻He Zhizhou💻Guruprakash Rajakkannu💻kristofka💻PostCyberPunk📖Katsuhiko Nishimra💻Erno Hopearuoho💻Shaun Garwood💻neutrinoA4💻 📖Jack Muratore💻Adriel Velazquez💻 📖Tomas Slusny💻 📖Nisal📖Tobias Gårdhus📖Petr Dlouhý📖Dylan Madisetti💻Aaron Weisberg💻 📖Jose Tlacuilo💻 📖Kevin Traver💻 📖dTry💻Arata Furukawa💻Ling💻Ivan Frolov💻Folke Lemaitre💻 📖GitMurf💻Dmitrii Lipin💻jinzhongjia📖guill💻Sjon-Paul Brown💻Renzo Mondragón💻 📖fjchen7💻Radosław Woźniak💻JakubPecenka💻thomastthai📖Tomáš Janoušek💻Toddneal Stallworth📖Sergey Alexandrov💻Léopold Mebazaa💻JunKi Jin💻abdennourzahaf📖Josiah💻Tony Fischer💻 📖Kohei Wada💻Sebastian Yaghoubi📖johncming💻Rokas Brazdžionis💻Sola📖 💻Mani Chandra💻Nischal Basuti📖Teo Ljungberg💻Joe Price💻Yufan You📖 💻Manish Kumar💻Anton Ždanov📖 💻Fredrik Averpil💻Aaron D Borden💻Md. Iftakhar Awal Chowdhury💻 📖Danilo Horta💻Mihamina Rakotomandimby📖 💻Ajmal S💻Samiul Islam💻Rui Costa💻CTCHEN💻Tobias Wölfel💻Alexander Garcia💻Max Kharandziuk💻Xinyu Xiang💻junqizhang💻Calum Lynch💻sirjls💻Vladimir Kolchurin💻This project follows the all-contributors specification. Contributions of any kind are welcome! diff --git a/lua/CopilotChat/client.lua b/lua/CopilotChat/client.lua index 11c353b6..7bbc65d8 100644 --- a/lua/CopilotChat/client.lua +++ b/lua/CopilotChat/client.lua @@ -1,17 +1,17 @@ ---@class CopilotChat.client.AskOptions ---@field headless boolean ---@field history table ----@field selection CopilotChat.select.Selection? ---@field tools table? ---@field resources table? ---@field system_prompt string ---@field model string ---@field temperature number ----@field on_progress? fun(response: string):nil +---@field on_progress fun(response: CopilotChat.client.Message)? ---@class CopilotChat.client.Message ---@field role string ---@field content string +---@field reasoning string? ---@field tool_call_id string? ---@field tool_calls table? @@ -31,16 +31,16 @@ ---@field description string description of the tool ---@field schema table? schema of the tool ----@class CopilotChat.client.Embed ----@field index number ----@field embedding table +---@class CopilotChat.client.ResourceAnnotations +---@field start_line number? +---@field end_line number? ---@class CopilotChat.client.Resource ----@field name string ----@field type string ---@field data string - ----@class CopilotChat.client.EmbeddedResource : CopilotChat.client.Resource, CopilotChat.client.Embed +---@field name string? +---@field mimetype string? +---@field uri string? +---@field annotations CopilotChat.client.ResourceAnnotations? ---@class CopilotChat.client.Model ---@field provider string? @@ -51,96 +51,67 @@ ---@field max_output_tokens number? ---@field streaming boolean? ---@field tools boolean? +---@field reasoning boolean? local log = require('plenary.log') +local constants = require('CopilotChat.constants') +local notify = require('CopilotChat.utils.notify') local tiktoken = require('CopilotChat.tiktoken') -local notify = require('CopilotChat.notify') local utils = require('CopilotChat.utils') -local class = utils.class +local curl = require('CopilotChat.utils.curl') +local class = require('CopilotChat.utils.class') +local files = require('CopilotChat.utils.files') +local orderedmap = require('CopilotChat.utils.orderedmap') +local stringbuffer = require('CopilotChat.utils.stringbuffer') --- Constants -local RESOURCE_FORMAT = '# %s\n```%s\n%s\n```' -local LINE_CHARACTERS = 100 -local BIG_EMBED_THRESHOLD = 200 * LINE_CHARACTERS - ---- Resolve provider function ----@param model string ----@param models table ----@param providers table ----@return string, function -local function resolve_provider_function(name, model, models, providers) - local model_config = models[model] - if not model_config then - error('Model not found: ' .. model) - end - - local provider_name = model_config.provider - if not provider_name then - error('Provider not found for model: ' .. model) - end - local provider = providers[provider_name] - if not provider then - error('Provider not found: ' .. provider_name) - end - - local func = provider[name] - if type(func) == 'string' then - provider_name = func - provider = providers[provider_name] - if not provider then - error('Provider not found: ' .. provider_name) - end - func = provider[name] - end - if not func then - error('Function not found: ' .. name) +local RESOURCE_SHORT_FORMAT = '# %s\n```%s start_line=%s end_line=%s\n%s\n```' +local RESOURCE_LONG_FORMAT = '# %s\n```%s path=%s start_line=%s end_line=%s\n%s\n```' +local CACHE_TTL = 300 -- 5 minutes + +--- Get a cached value or fill it if not present +--- @param cache table: The cache table to use +--- @param key string: The key to look up in the cache +--- @param filler function: A function that returns the value to cache if not present +local function get_cached(cache, key, filler) + local now = math.floor(os.time()) + if cache and cache[key] and cache[key .. '_expires_at'] > now then + return cache[key] end - return provider_name, func + local value = filler() + cache[key] = value + cache[key .. '_expires_at'] = now + CACHE_TTL + return value end ---- Generate content block with line numbers, truncating if necessary +--- Generate resource block with line numbers, truncating if necessary ---@param content string ----@param start_line number?: The starting line number +---@param start_line number: The starting line number ---@return string -local function generate_content_block(content, start_line) - if start_line ~= nil then - local lines = vim.split(content, '\n') - local total_lines = #lines - local max_length = #tostring(total_lines) - for i, line in ipairs(lines) do - local formatted_line_number = string.format('%' .. max_length .. 'd', i - 1 + (start_line or 1)) - lines[i] = formatted_line_number .. ': ' .. line - end - - return table.concat(lines, '\n') +local function generate_resource_block(content, mimetype, name, path, start_line, end_line) + local lines = vim.split(content, '\n') + local total_lines = #lines + local max_length = #tostring(total_lines) + for i, line in ipairs(lines) do + local formatted_line_number = string.format('%' .. max_length .. 'd', i - 1 + (start_line or 1)) + lines[i] = formatted_line_number .. ': ' .. line end - return content -end - ---- Generate messages for the given selection ---- @param selection CopilotChat.select.Selection ---- @return CopilotChat.client.Message? -local function generate_selection_message(selection) - local filename = selection.filename or 'unknown' - local filetype = selection.filetype or 'text' - local content = selection.content - - if not content or content == '' then - return nil + local updated_content = table.concat(lines, '\n') + local filetype = files.mimetype_to_filetype(mimetype or 'text') + if not start_line then + start_line = 1 end - - local out = "User's active selection:\n" - if selection.start_line and selection.end_line then - out = out .. string.format('Excerpt from %s, lines %s to %s:\n', filename, selection.start_line, selection.end_line) + if not end_line then + end_line = start_line and (start_line + total_lines - 1) or 1 end - out = out .. string.format('```%s\n%s\n```', filetype, generate_content_block(content, selection.start_line)) - return { - content = out, - role = 'user', - } + if path then + return string.format(RESOURCE_LONG_FORMAT, name, filetype, path, start_line, end_line, updated_content) + else + return string.format(RESOURCE_SHORT_FORMAT, name, filetype, start_line, end_line, updated_content) + end end --- Generate messages for the given resources @@ -153,22 +124,28 @@ local function generate_resource_messages(resources) return resource.data and resource.data ~= '' end) :map(function(resource) - local content = generate_content_block(resource.data, 1) - + local start_line = resource.annotations and resource.annotations.start_line or 1 + local end_line = resource.annotations and resource.annotations.end_line or nil return { - content = string.format(RESOURCE_FORMAT, resource.name, resource.type, content), - role = 'user', + content = generate_resource_block( + resource.data, + resource.mimetype, + resource.uri, + resource.name, + start_line, + end_line + ), + role = constants.ROLE.USER, } end) :totable() end --- Generate ask request ---- @param prompt string --- @param system_prompt string --- @param history table --- @param generated_messages table -local function generate_ask_request(prompt, system_prompt, history, generated_messages) +local function generate_ask_request(system_prompt, history, generated_messages) local messages = {} system_prompt = vim.trim(system_prompt) @@ -177,59 +154,62 @@ local function generate_ask_request(prompt, system_prompt, history, generated_me if not utils.empty(system_prompt) then table.insert(messages, { content = system_prompt, - role = 'system', + role = constants.ROLE.SYSTEM, }) end -- Include generated messages and history - for _, message in ipairs(generated_messages) do - table.insert(messages, { - content = message.content, - role = message.role, - }) - end - for _, message in ipairs(history) do - table.insert(messages, message) - end - if not utils.empty(prompt) and utils.empty(history) then - -- Include user prompt if we have no history - table.insert(messages, { - content = prompt, - role = 'user', - }) - end - + vim.list_extend(messages, generated_messages) + vim.list_extend(messages, history) return messages end ---- Generate embedding request ---- @param inputs table ---- @param threshold number ---- @return table -local function generate_embedding_request(inputs, threshold) - return vim.tbl_map(function(embedding) - local content = generate_content_block(embedding.data, threshold) - return string.format(RESOURCE_FORMAT, embedding.name, embedding.type, content) - end, inputs) -end - ---@class CopilotChat.client.Client : Class ----@field private providers table +---@field private provider_resolver function():table ---@field private provider_cache table ----@field private model_cache table? ---@field private current_job string? local Client = class(function(self) - self.providers = {} - self.provider_cache = {} - self.model_cache = nil + self.provider_resolver = nil + self.provider_cache = vim.defaulttable(function() + return {} + end) self.current_job = nil end) +--- Get all providers from the client +---@param supported_method? string: The method to filter providers by (optional) +---@return OrderedMap +function Client:get_providers(supported_method) + local out = orderedmap() + + if not self.provider_resolver then + return out + end + + local providers = self.provider_resolver() + local provider_names = vim.tbl_keys(providers) + table.sort(provider_names) + + for _, provider_name in ipairs(provider_names) do + local provider = providers[provider_name] + if provider and not provider.disabled and (not supported_method or provider[supported_method]) then + out:set(provider_name, provider) + end + end + return out +end + +--- Set a provider resolver on the client +---@param resolver function: A function that returns a table of providers +function Client:set_providers(resolver) + self.provider_resolver = resolver +end + --- Authenticate with GitHub and get the required headers ---@param provider_name string: The provider to authenticate with ---@return table function Client:authenticate(provider_name) - local provider = self.providers[provider_name] + local provider = self:get_providers():get(provider_name) local headers = self.provider_cache[provider_name].headers local expires_at = self.provider_cache[provider_name].expires_at @@ -245,81 +225,77 @@ end --- Fetch models from the Copilot API ---@return table function Client:models() - if self.model_cache then - return self.model_cache - end - - local models = {} - local provider_order = vim.tbl_keys(self.providers) - table.sort(provider_order) - for _, provider_name in ipairs(provider_order) do - local provider = self.providers[provider_name] - if not provider.disabled and provider.get_models then - notify.publish(notify.STATUS, 'Fetching models from ' .. provider_name) - local ok, headers = pcall(self.authenticate, self, provider_name) - if not ok then - log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers) - goto continue - end - local ok, provider_models = pcall(provider.get_models, headers) - if not ok then - log.warn('Failed to fetch models from ' .. provider_name .. ': ' .. provider_models) - goto continue - end + local out = {} + local providers = self:get_providers('get_models') + + for _, provider_name in ipairs(providers:keys()) do + local provider = providers:get(provider_name) + for _, model in + ipairs(get_cached(self.provider_cache[provider_name], 'models', function() + notify.publish(notify.STATUS, 'Fetching models from ' .. provider_name) + + local ok, headers = pcall(self.authenticate, self, provider_name) + if not ok then + log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers) + return {} + end - for _, model in ipairs(provider_models) do - model.provider = provider_name - if models[model.id] then - model.id = model.id .. ':' .. provider_name + local ok, models = pcall(provider.get_models, headers) + if not ok then + log.warn('Failed to fetch models from ' .. provider_name .. ': ' .. models) + return {} end - models[model.id] = model - end - ::continue:: + return models or {} + end)) + do + model.provider = provider_name + if out[model.id] then + model.id = model.id .. ':' .. provider_name + end + out[model.id] = model end end - log.debug('Fetched models:', #vim.tbl_keys(models)) - self.model_cache = models - return self.model_cache + log.debug('Fetched models:', #vim.tbl_keys(out)) + return out end --- Get information about all providers ---@return table function Client:info() - local infos = {} - local now = math.floor(os.time()) - local CACHE_TTL = 300 -- 5 minutes + local out = {} + local providers = self:get_providers('get_info') - for provider_name, provider in pairs(self.providers) do - if not provider.disabled and provider.get_info then - local cache = self.provider_cache[provider_name] - if cache and cache.info and cache.info_expires_at and cache.info_expires_at > now then - infos[provider_name] = cache.info - else - local ok, info = pcall(provider.get_info, self:authenticate(provider_name)) - if ok then - infos[provider_name] = info - if cache then - cache.info = info - cache.info_expires_at = now + CACHE_TTL - end - else - log.warn('Failed to get info for provider ' .. provider_name .. ': ' .. info) - end + for _, provider_name in ipairs(providers:keys()) do + local provider = providers:get(provider_name) + out[provider_name] = get_cached(self.provider_cache[provider_name], 'infos', function() + notify.publish(notify.STATUS, 'Fetching info from ' .. provider_name) + + local ok, headers = pcall(self.authenticate, self, provider_name) + if not ok then + log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers) + return {} end - end + + local ok, infos = pcall(provider.get_info, headers) + if not ok then + log.warn('Failed to fetch info from ' .. provider_name .. ': ' .. infos) + return {} + end + + return infos or {} + end) end - log.debug('Fetched provider infos:', #vim.tbl_keys(infos)) - return infos + log.debug('Fetched provider infos:', #vim.tbl_keys(out)) + return out end --- Ask a question to Copilot ----@param prompt string: The prompt to send to Copilot ---@param opts CopilotChat.client.AskOptions: Options for the request ---@return CopilotChat.client.AskResponse? -function Client:ask(prompt, opts) +function Client:ask(opts) opts = opts or {} local job_id = utils.uuid() @@ -338,11 +314,21 @@ function Client:ask(prompt, opts) if not provider_name then error('Provider not found for model: ' .. opts.model) end - local provider = self.providers[provider_name] + local provider = self:get_providers():get(provider_name) if not provider then error('Provider not found: ' .. provider_name) end + if provider.resolve_model then + local headers = self:authenticate(provider_name) + local resolved_model = provider.resolve_model(headers, opts.model) + opts.model = resolved_model + model_config = models[opts.model] + if not model_config then + error('Resolved model not found: ' .. opts.model) + end + end + local options = { model = vim.tbl_extend('force', model_config, { id = opts.model:gsub(':' .. provider_name .. '$', ''), @@ -356,48 +342,46 @@ function Client:ask(prompt, opts) log.debug('Tokenizer:', tokenizer) if max_tokens and tokenizer then - tiktoken.load(tokenizer) + tiktoken:load(tokenizer) end if not opts.headless then notify.publish(notify.STATUS, 'Generating request') end - local history = not opts.headless and vim.deepcopy(opts.history) or {} - local tool_calls = utils.ordered_map() + local history = vim.deepcopy(opts.history) + local tool_calls = orderedmap() local generated_messages = {} - local selection_message = opts.selection and generate_selection_message(opts.selection) local resource_messages = generate_resource_messages(opts.resources) - if selection_message then - table.insert(generated_messages, selection_message) - end - if max_tokens then -- Count required tokens that we cannot reduce - local selection_tokens = selection_message and tiktoken.count(selection_message.content) or 0 - local prompt_tokens = tiktoken.count(prompt) - local system_tokens = tiktoken.count(opts.system_prompt) - local resource_tokens = #resource_messages > 0 and tiktoken.count(resource_messages[1].content) or 0 - local required_tokens = prompt_tokens + system_tokens + selection_tokens + resource_tokens + local system_tokens = tiktoken:count(opts.system_prompt) + local prompt_tokens = #history > 0 and tiktoken:count(history[#history].content) or 0 + local resource_tokens = #resource_messages > 0 and tiktoken:count(resource_messages[1].content) or 0 + local required_tokens = prompt_tokens + system_tokens + resource_tokens + + log.debug('System tokens:', system_tokens) + log.debug('Prompt tokens:', prompt_tokens) + log.debug('Resource tokens:', resource_tokens) -- Calculate how many tokens we can use for history local history_limit = max_tokens - required_tokens local history_tokens = 0 for _, msg in ipairs(history) do - history_tokens = history_tokens + tiktoken.count(msg.content) + history_tokens = history_tokens + tiktoken:count(msg.content) end - -- Remove history messages until we are under the limit - while history_tokens > history_limit and #history > 0 do + -- Remove history messages except prompt until we are under the limit + while history_tokens > history_limit and #history > 1 do local entry = table.remove(history, 1) - history_tokens = history_tokens - tiktoken.count(entry.content) + history_tokens = history_tokens - tiktoken:count(entry.content) end -- Now add as many files as possible with remaining token budget local remaining_tokens = max_tokens - required_tokens - history_tokens for _, message in ipairs(resource_messages) do - local tokens = tiktoken.count(message.content) + local tokens = tiktoken:count(message.content) if remaining_tokens - tokens >= 0 then remaining_tokens = remaining_tokens - tokens table.insert(generated_messages, message) @@ -412,15 +396,16 @@ function Client:ask(prompt, opts) end end - local errored = false + local errored = nil local finished = false local token_count = 0 - local response_buffer = utils.string_buffer() + local out_model = nil + local response_content_buffer = stringbuffer() + local response_reasoning_buffer = stringbuffer() local function finish_stream(err, job) if err then - errored = true - response_buffer:set(err) + errored = err end log.debug('Finishing stream', err) @@ -460,20 +445,44 @@ function Client:ask(prompt, opts) if out.tool_calls then for _, tool_call in ipairs(out.tool_calls) do - local val = tool_calls:get(tool_call.index) - if not val then - tool_calls:set(tool_call.index, tool_call) + local key = tostring(tool_call.index or tool_call.id or tool_call.name or #tool_calls:values() + 1) + local existing = tool_calls:get(key) + + if not existing then + tool_calls:set(key, tool_call) else - val.arguments = val.arguments .. tool_call.arguments + existing.arguments = existing.arguments .. tool_call.arguments + if tool_call.id then + existing.id = tool_call.id + end + if tool_call.index then + existing.index = tool_call.index + end + if tool_call.name then + existing.name = tool_call.name + end end end end if out.content then - response_buffer:add(out.content) - if opts.on_progress then - opts.on_progress(out.content) - end + response_content_buffer:put(out.content) + end + + if out.reasoning then + response_reasoning_buffer:put(out.reasoning) + end + + if out.model then + out_model = out.model + end + + if opts.on_progress then + opts.on_progress({ + role = constants.ROLE.ASSISTANT, + content = out.content or '', + reasoning = out.reasoning or '', + }) end if out.finish_reason then @@ -528,8 +537,14 @@ function Client:ask(prompt, opts) end local headers = self:authenticate(provider_name) - local request = - provider.prepare_input(generate_ask_request(prompt, opts.system_prompt, history, generated_messages), options) + + local request, extra_headers = + provider.prepare_input(generate_ask_request(opts.system_prompt, history, generated_messages), options) + + if extra_headers then + headers = vim.tbl_extend('force', headers, extra_headers) + end + local is_stream = request.stream local args = { @@ -541,7 +556,7 @@ function Client:ask(prompt, opts) args.stream = stream_func end - local response, err = utils.curl_post(provider.get_url(options), args) + local response, err = curl.post(provider.get_url(options), args) if not opts.headless then if self.current_job ~= job_id then @@ -569,15 +584,15 @@ function Client:ask(prompt, opts) end error(error_msg) - return end - local response_text = response_buffer:tostring() if errored then - error(response_text) - return + error(errored) end + local response_text = response_content_buffer:tostring() + local response_reasoning = response_reasoning_buffer:tostring() + if response then if is_stream then if utils.empty(response_text) and not finished then @@ -588,101 +603,28 @@ function Client:ask(prompt, opts) else parse_line(response.body) end - response_text = response_buffer:tostring() + response_text = response_content_buffer:tostring() + response_reasoning = response_reasoning_buffer:tostring() end + -- Filter out tool calls that don't have names (streaming deltas used only for accumulation) + local final_tool_calls = vim.tbl_filter(function(tc) + return tc.name ~= nil + end, tool_calls:values()) + return { message = { - role = 'assistant', + role = constants.ROLE.ASSISTANT, content = response_text, - tool_calls = #tool_calls:values() > 0 and tool_calls:values() or nil, + reasoning = response_reasoning, + tool_calls = #final_tool_calls > 0 and final_tool_calls or nil, + model = out_model, }, token_count = token_count, token_max_count = max_tokens, } end ---- Generate embeddings for the given inputs ----@param inputs table: The inputs to embed ----@param model string ----@return table -function Client:embed(inputs, model) - if not inputs or #inputs == 0 then - ---@diagnostic disable-next-line: return-type-mismatch - return inputs - end - - local models = self:models() - local ok, provider_name, embed = pcall(resolve_provider_function, 'embed', model, models, self.providers) - if not ok then - ---@diagnostic disable-next-line: return-type-mismatch - return inputs - end - - notify.publish(notify.STATUS, 'Generating embeddings for ' .. #inputs .. ' inputs') - - -- Initialize essentials - local to_process = inputs - local results = {} - local initial_chunk_size = 10 - - -- Process inputs in batches with adaptive chunk size - while #to_process > 0 do - local chunk_size = initial_chunk_size -- Reset chunk size for each new batch - local threshold = BIG_EMBED_THRESHOLD -- Reset threshold for each new batch - local last_error = nil - - -- Take next chunk - local batch = {} - for _ = 1, math.min(chunk_size, #to_process) do - table.insert(batch, table.remove(to_process, 1)) - end - - -- Try to get embeddings for batch - local success = false - local attempts = 0 - while not success and attempts < 5 do -- Limit total attempts to 5 - local ok, data = pcall(embed, generate_embedding_request(batch, threshold), self:authenticate(provider_name)) - - if not ok then - log.debug('Failed to get embeddings: ', data) - last_error = data - attempts = attempts + 1 - -- If we have few items and the request failed, try reducing threshold first - if #batch <= 5 then - threshold = math.max(5 * LINE_CHARACTERS, math.floor(threshold / 2)) - log.debug(string.format('Reducing threshold to %d and retrying...', threshold)) - else - -- Otherwise reduce batch size first - chunk_size = math.max(1, math.floor(chunk_size / 2)) - -- Put items back in to_process - for i = #batch, 1, -1 do - table.insert(to_process, 1, table.remove(batch, i)) - end - -- Take new smaller batch - batch = {} - for _ = 1, math.min(chunk_size, #to_process) do - table.insert(batch, table.remove(to_process, 1)) - end - log.debug(string.format('Reducing batch size to %d and retrying...', chunk_size)) - end - else - success = true - for _, embedding in ipairs(data) do - local result = vim.tbl_extend('force', batch[embedding.index + 1], embedding) - table.insert(results, result) - end - end - end - - if not success then - error(last_error) - end - end - - return results -end - --- Stop the running job ---@return boolean function Client:stop() @@ -700,13 +642,5 @@ function Client:running() return self.current_job ~= nil end ---- Load providers to client -function Client:load_providers(providers) - self.providers = providers - for provider_name, _ in pairs(providers) do - self.provider_cache[provider_name] = {} - end -end - --- @type CopilotChat.client.Client return Client() diff --git a/lua/CopilotChat/completion.lua b/lua/CopilotChat/completion.lua new file mode 100644 index 00000000..fdb509de --- /dev/null +++ b/lua/CopilotChat/completion.lua @@ -0,0 +1,239 @@ +local async = require('plenary.async') +local client = require('CopilotChat.client') +local constants = require('CopilotChat.constants') +local config = require('CopilotChat.config') +local functions = require('CopilotChat.functions') +local utils = require('CopilotChat.utils') + +local M = {} + +--- Get the completion info for the chat window, for use with custom completion providers +---@return table +function M.info() + return { + triggers = { '@', '/', '#', '$' }, + pattern = [[\%(@\|/\|#\|\$\)\S*]], + } +end + +--- Get the completion items for the chat window, for use with custom completion providers +---@return table +---@async +function M.items() + local models = client:models() + local prompts = config.prompts or {} + local items = {} + + for name, prompt in pairs(prompts) do + if type(prompt) == 'string' then + prompt = { + prompt = prompt, + } + end + + local kind = '' + local info = '' + if prompt.prompt then + kind = constants.ROLE.USER + info = prompt.prompt + elseif prompt.system_prompt then + kind = constants.ROLE.SYSTEM + info = prompt.system_prompt + end + + items[#items + 1] = { + word = '/' .. name, + abbr = name, + kind = kind, + info = info, + menu = prompt.description or '', + icase = 1, + dup = 0, + empty = 0, + } + end + + for id, model in pairs(models) do + items[#items + 1] = { + word = '$' .. id, + abbr = id, + kind = model.provider, + menu = model.name, + icase = 1, + dup = 0, + empty = 0, + } + end + + local groups = {} + for name, tool in pairs(config.functions) do + if tool.group then + groups[tool.group] = groups[tool.group] or {} + groups[tool.group][name] = tool + end + end + for name, group in pairs(groups) do + local group_tools = vim.tbl_keys(group) + items[#items + 1] = { + word = '@' .. name, + abbr = name, + kind = 'group', + info = table.concat(group_tools, '\n'), + menu = string.format('%s tools', #group_tools), + icase = 1, + dup = 0, + empty = 0, + } + end + for name, tool in pairs(config.functions) do + items[#items + 1] = { + word = '@' .. name, + abbr = name, + kind = constants.ROLE.TOOL, + info = tool.description, + menu = tool.group or '', + icase = 1, + dup = 0, + empty = 0, + } + end + + local tools_to_use = functions.parse_tools(config.functions) + for _, tool in pairs(tools_to_use) do + local uri = config.functions[tool.name].uri + if uri then + local info = + string.format('%s\n\n%s', tool.description, tool.schema and vim.inspect(tool.schema, { indent = ' ' }) or '') + + items[#items + 1] = { + word = '#' .. tool.name, + abbr = tool.name, + kind = config.functions[tool.name].group or 'resource', + info = info, + menu = uri, + icase = 1, + dup = 0, + empty = 0, + } + end + end + + table.sort(items, function(a, b) + if a.kind == b.kind then + return a.word < b.word + end + return a.kind < b.kind + end) + + return items +end + +--- Trigger the completion for the chat window. +---@param without_input boolean? +function M.complete(without_input) + local source = require('CopilotChat').chat:get_source() + local info = M.info() + local bufnr = vim.api.nvim_get_current_buf() + local line = vim.api.nvim_get_current_line() + local win = vim.api.nvim_get_current_win() + local row, col = unpack(vim.api.nvim_win_get_cursor(win)) + + local prefix, cmp_start = unpack(vim.fn.matchstrpos(line:sub(1, col), info.pattern)) + if not prefix then + return + end + + if not without_input and vim.startswith(prefix, '#') and vim.endswith(prefix, ':') then + local found_tool = config.functions[prefix:sub(2, -2)] + local found_schema = found_tool and functions.parse_schema(found_tool) + if found_tool and found_schema and found_tool.uri then + async.run(function() + local value = functions.enter_input(found_schema, source) + if not value then + return + end + + utils.schedule_main() + vim.api.nvim_buf_set_text(bufnr, row - 1, col, row - 1, col, { value }) + vim.api.nvim_win_set_cursor(0, { row, col + #value }) + end) + end + + return + end + + utils.debounce('copilot_chat_complete', function() + async.run(function() + local items = M.items() + utils.schedule_main() + + if not vim.api.nvim_win_is_valid(win) then + return + end + + local row_changed = vim.api.nvim_win_get_cursor(win)[1] ~= row + local mode = vim.api.nvim_get_mode().mode + if row_changed or not (mode == 'i' or mode == 'ic') then + return + end + + vim.fn.complete( + cmp_start + 1, + vim.tbl_filter(function(item) + return vim.startswith(item.word:lower(), prefix:lower()) + end, items) + ) + end) + end, 100) +end + +--- Omnifunc for the chat window completion. +---@param findstart integer 0 or 1, decides behavior +---@param base integer findstart=0, text to match against +---@return number|table +function M.omnifunc(findstart, base) + assert(base) + local bufnr = vim.api.nvim_get_current_buf() + local ft = vim.bo[bufnr].filetype + + if ft ~= 'copilot-chat' then + return findstart == 1 and -1 or {} + end + + M.complete(true) + return -2 -- Return -2 to indicate that we are handling the completion asynchronously +end + +--- Enable the completion for specific buffer. +---@param bufnr number: the buffer number to enable completion for +---@param autocomplete boolean: whether to enable autocomplete +function M.enable(bufnr, autocomplete) + if autocomplete then + vim.api.nvim_create_autocmd('TextChangedI', { + buffer = bufnr, + callback = function() + local completeopt = vim.opt.completeopt:get() + if not vim.tbl_contains(completeopt, 'noinsert') and not vim.tbl_contains(completeopt, 'noselect') then + -- Don't trigger completion if completeopt is not set to noinsert or noselect + return + end + + M.complete(true) + end, + }) + + -- Add noinsert completeopt if not present + if vim.fn.has('nvim-0.11.0') == 1 then + local completeopt = vim.opt.completeopt:get() + if not vim.tbl_contains(completeopt, 'noinsert') then + table.insert(completeopt, 'noinsert') + vim.bo[bufnr].completeopt = table.concat(completeopt, ',') + end + end + else + -- Just set the omnifunc for the buffer + vim.bo[bufnr].omnifunc = [[v:lua.require'CopilotChat.completion'.omnifunc]] + end +end + +return M diff --git a/lua/CopilotChat/config.lua b/lua/CopilotChat/config.lua index 30ad2bd8..a54a66f3 100644 --- a/lua/CopilotChat/config.lua +++ b/lua/CopilotChat/config.lua @@ -14,17 +14,18 @@ ---@field blend number? ---@class CopilotChat.config.Shared ----@field system_prompt string? +---@field system_prompt nil|string ---@field model string? ---@field tools string|table|nil +---@field resources string|table|nil ---@field sticky string|table|nil +---@field trusted_tools boolean|string|table|nil +---@field diff 'block'|'unified'? ---@field language string? ----@field resource_processing boolean? ---@field temperature number? ---@field headless boolean? ----@field callback nil|fun(response: string, source: CopilotChat.source) +---@field callback nil|fun(response: CopilotChat.client.Message, source: CopilotChat.ui.chat.Source) ---@field remember_as_sticky boolean? ----@field selection false|nil|fun(source: CopilotChat.source):CopilotChat.select.Selection? ---@field window CopilotChat.config.Window? ---@field show_help boolean? ---@field show_folds boolean? @@ -32,8 +33,10 @@ ---@field highlight_headers boolean? ---@field auto_follow_cursor boolean? ---@field auto_insert_mode boolean? +---@field auto_fold boolean? ---@field insert_at_end boolean? ---@field clear_chat_on_new_prompt boolean? +---@field stop_on_function_failure boolean? --- CopilotChat default configuration ---@class CopilotChat.config.Config : CopilotChat.config.Shared @@ -41,6 +44,8 @@ ---@field log_level 'trace'|'debug'|'info'|'warn'|'error'|'fatal'? ---@field proxy string? ---@field allow_insecure boolean? +---@field instruction_files table? +---@field selection 'visual'|'unnamed'|nil ---@field chat_autocomplete boolean? ---@field log_path string? ---@field history_path string? @@ -54,23 +59,21 @@ return { -- Shared config starts here (can be passed to functions at runtime and configured via setup function) - system_prompt = 'COPILOT_INSTRUCTIONS', -- System prompt to use (can be specified manually in prompt via /). + system_prompt = require('CopilotChat.config.prompts').COPILOT_INSTRUCTIONS.system_prompt, -- System prompt to use (can be specified manually in prompt via /). model = 'gpt-4.1', -- Default model to use, see ':CopilotChatModels' for available models (can be specified manually in prompt via $). tools = nil, -- Default tool or array of tools (or groups) to share with LLM (can be specified manually in prompt via @). + resources = 'selection', -- Default resources to share with LLM (can be specified manually in prompt via #). sticky = nil, -- Default sticky prompt or array of sticky prompts to use at start of every new chat (can be specified manually in prompt via >). + trusted_tools = nil, -- Trust tool calls from specific functions or groups, or all trusted tools when true (e.g., {'buffer', 'file'} or 'copilot'). + diff = 'block', -- Default diff format to use, 'block' or 'unified'. language = 'English', -- Default language to use for answers - resource_processing = false, -- Enable intelligent resource processing (skips unnecessary resources to save tokens) - temperature = 0.1, -- Result temperature headless = false, -- Do not write to chat buffer and use history (useful for using custom processing) callback = nil, -- Function called when full response is received remember_as_sticky = true, -- Remember config as sticky prompts when asking questions - -- default selection - selection = require('CopilotChat.select').visual, - -- default window options window = { layout = 'vertical', -- 'vertical', 'horizontal', 'float', 'replace', or a function that returns the layout @@ -89,12 +92,14 @@ return { show_help = true, -- Shows help message as virtual lines when waiting for user input show_folds = true, -- Shows folds for sections in chat + auto_fold = false, -- Automatically non-assistant messages in chat (requires 'show_folds' to be true) highlight_selection = true, -- Highlight selection highlight_headers = true, -- Highlight headers in chat auto_follow_cursor = true, -- Auto-follow cursor in chat auto_insert_mode = false, -- Automatically enter insert mode when opening window and on new prompt insert_at_end = false, -- Move cursor to end of buffer when inserting text clear_chat_on_new_prompt = false, -- Clears chat on every new prompt + stop_on_function_failure = false, -- Stop processing prompt if any function fails (preserves quota) -- Static config starts here (can be configured only via setup function) @@ -103,15 +108,22 @@ return { proxy = nil, -- [protocol://]host[:port] Use this proxy allow_insecure = false, -- Allow insecure server connections + -- Instruction files to look for in current working directory + instruction_files = { + '.github/copilot-instructions.md', + 'AGENTS.md', + }, + + selection = 'visual', -- Selection source chat_autocomplete = true, -- Enable chat autocompletion (when disabled, requires manual `mappings.complete` trigger) log_path = vim.fn.stdpath('state') .. '/CopilotChat.log', -- Default path to log file history_path = vim.fn.stdpath('data') .. '/copilotchat_history', -- Default path to stored history headers = { - user = '## User ', -- Header to use for user questions - assistant = '## Copilot ', -- Header to use for AI answers - tool = '## Tool ', -- Header to use for tool calls + user = 'User', -- Header to use for user questions + assistant = 'Copilot', -- Header to use for AI answers + tool = 'Tool', -- Header to use for tool calls }, separator = '───', -- Separator to use in chat diff --git a/lua/CopilotChat/config/functions.lua b/lua/CopilotChat/config/functions.lua index 25c3f6c5..0ec22b10 100644 --- a/lua/CopilotChat/config/functions.lua +++ b/lua/CopilotChat/config/functions.lua @@ -1,17 +1,52 @@ local resources = require('CopilotChat.resources') local utils = require('CopilotChat.utils') - ----@class CopilotChat.config.functions.Result ----@field data string ----@field mimetype string? ----@field uri string? +local files = require('CopilotChat.utils.files') + +--- Get diagnostics for a buffer and format them as text +---@param bufnr number +---@param start_line number? +---@param end_line number? +---@return string +local function get_diagnostics_text(bufnr, start_line, end_line) + local diagnostics = vim.diagnostic.get(bufnr, { + severity = { min = vim.diagnostic.severity.HINT }, + }) + + if #diagnostics == 0 then + return '' + end + + local diag_lines = { '\n--- Diagnostics ---' } + for _, diag in ipairs(diagnostics) do + local diag_lnum = diag.lnum + 1 + -- If range is specified, filter diagnostics within range + if not start_line or (diag_lnum >= start_line and diag_lnum <= end_line) then + local severity = vim.diagnostic.severity[diag.severity] or 'UNKNOWN' + local line_text = vim.api.nvim_buf_get_lines(bufnr, diag.lnum, diag.lnum + 1, false)[1] or '' + table.insert( + diag_lines, + string.format( + '%s line=%d-%d: %s\n > %s', + severity, + diag.lnum + 1, + diag.end_lnum and (diag.end_lnum + 1) or (diag.lnum + 1), + diag.message, + line_text + ) + ) + end + end + + return #diag_lines > 1 and table.concat(diag_lines, '\n') or '' +end ---@class CopilotChat.config.functions.Function ---@field description string? ---@field schema table? ---@field group string? +---@field trusted boolean? ---@field uri string? ----@field resolve fun(input: table, source: CopilotChat.source, prompt: string):table +---@field resolve fun(input: table, source: CopilotChat.ui.chat.Source):CopilotChat.client.Resource[] ---@type table return { @@ -28,8 +63,9 @@ return { type = 'string', description = 'Path to file to include in chat context.', enum = function(source) - return utils.glob(source.cwd(), { + return files.glob(source.cwd(), { max_count = 0, + hidden = true, }) end, }, @@ -46,6 +82,7 @@ return { return { { uri = 'file://' .. input.path, + name = input.path, mimetype = mimetype, data = data, }, @@ -53,64 +90,38 @@ return { end, }, - glob = { + url = { group = 'copilot', - uri = 'files://glob/{pattern}', - description = 'Lists filenames matching a pattern in your workspace. Useful for discovering relevant files or understanding the project structure.', + uri = 'https://{url}', + description = 'Fetches content from a specified URL. Useful for referencing documentation, examples, or other online resources.', schema = { type = 'object', - required = { 'pattern' }, + required = { 'url' }, properties = { - pattern = { + url = { type = 'string', - description = 'Glob pattern to match files.', - default = '**/*', + description = 'URL to include in chat context.', }, }, }, - resolve = function(input, source) - local files = utils.glob(source.cwd(), { - pattern = input.pattern, - }) - - return { - { - uri = 'files://glob/' .. input.pattern, - mimetype = 'text/plain', - data = table.concat(files, '\n'), - }, - } - end, - }, - - grep = { - group = 'copilot', - uri = 'files://grep/{pattern}', - description = 'Searches for a pattern across files in your workspace. Helpful for finding specific code elements or patterns.', - - schema = { - type = 'object', - required = { 'pattern' }, - properties = { - pattern = { - type = 'string', - description = 'Pattern to search for.', - }, - }, - }, + resolve = function(input) + if not input.url:match('^https?://') then + input.url = 'https://' .. input.url + end - resolve = function(input, source) - local files = utils.grep(source.cwd(), { - pattern = input.pattern, - }) + utils.schedule_main() + local data, mimetype = resources.get_url(input.url) + if not data then + error('URL not found: ' .. input.url) + end return { { - uri = 'files://grep/' .. input.pattern, - mimetype = 'text/plain', - data = table.concat(files, '\n'), + uri = input.url, + mimetype = mimetype, + data = data, }, } end, @@ -118,288 +129,217 @@ return { buffer = { group = 'copilot', - uri = 'buffer://{name}', - description = 'Retrieves content from a specific buffer. Useful for discussing or analyzing code from a particular file that is currently loaded.', + uri = 'neovim://buffer/{scope}', + description = 'Retrieves content from buffer(s) with diagnostics. Scope can be a buffer number, filename, or one of: active, visible, listed, quickfix.', schema = { type = 'object', - required = { 'name' }, + required = { 'scope' }, properties = { - name = { + scope = { type = 'string', - description = 'Buffer filename to include in chat context.', + description = 'Buffer scope: active (current), visible (shown in windows), listed (all listed buffers), quickfix (buffers in quickfix list), or a specific buffer number/filename.', enum = function() - return vim - .iter(vim.api.nvim_list_bufs()) - :filter(function(buf) - return buf and utils.buf_valid(buf) and vim.fn.buflisted(buf) == 1 - end) - :map(function(buf) - return vim.api.nvim_buf_get_name(buf) - end) - :totable() + local opts = { + { display = 'active (current buffer)', value = 'active' }, + { display = 'visible (all visible buffers)', value = 'visible' }, + { display = 'listed (all listed buffers)', value = 'listed' }, + { display = 'quickfix (buffers in quickfix)', value = 'quickfix' }, + } + + for _, buf in ipairs(vim.api.nvim_list_bufs()) do + if utils.buf_valid(buf) and vim.fn.buflisted(buf) == 1 then + local name = vim.api.nvim_buf_get_name(buf) + if name and name ~= '' then + local display_name = vim.fn.fnamemodify(name, ':~:.') + table.insert(opts, { display = display_name, value = tostring(buf) }) + end + end + end + return opts end, + default = 'active', }, }, }, resolve = function(input, source) utils.schedule_main() - local name = input.name or vim.api.nvim_buf_get_name(source.bufnr) - local found_buf = nil - for _, buf in ipairs(vim.api.nvim_list_bufs()) do - if vim.api.nvim_buf_get_name(buf) == name then - found_buf = buf - break + local scope = input.scope or 'active' + local buffers = {} + + -- Determine which buffers to include based on scope + if scope == 'active' then + if source and source.bufnr and utils.buf_valid(source.bufnr) then + buffers = { source.bufnr } + end + elseif scope == 'visible' then + buffers = vim.tbl_filter(function(b) + return utils.buf_valid(b) and vim.fn.buflisted(b) == 1 and #vim.fn.win_findbuf(b) > 0 + end, vim.api.nvim_list_bufs()) + elseif scope == 'listed' then + buffers = vim.tbl_filter(function(b) + return utils.buf_valid(b) and vim.fn.buflisted(b) == 1 + end, vim.api.nvim_list_bufs()) + elseif scope == 'quickfix' then + local items = vim.fn.getqflist() + local file_to_bufnr = {} + for _, item in ipairs(items) do + local filename = item.filename or vim.api.nvim_buf_get_name(item.bufnr) + if filename and item.bufnr and utils.buf_valid(item.bufnr) then + file_to_bufnr[filename] = item.bufnr + end + end + buffers = vim.tbl_values(file_to_bufnr) + elseif tonumber(scope) then + local bufnr = tonumber(scope) + if utils.buf_valid(bufnr) then + buffers = { bufnr } end end - if not found_buf then - error('Buffer not found: ' .. name) + + if #buffers == 0 then + error('No buffers found for input: ' .. scope) end - local data, mimetype = resources.get_buffer(found_buf) - if not data then - error('Buffer not found: ' .. name) + + local results = {} + for _, bufnr in ipairs(buffers) do + local name = vim.api.nvim_buf_get_name(bufnr) + local data, mimetype = resources.get_buffer(bufnr) + if data then + local diag_text = get_diagnostics_text(bufnr) + if diag_text ~= '' then + data = data .. diag_text + end + + table.insert(results, { + uri = 'buffer://' .. bufnr, + name = name, + mimetype = mimetype, + data = data, + }) + end end - return { - { - uri = 'buffer://' .. name, - mimetype = mimetype, - data = data, - }, - } + + return results end, }, - buffers = { + selection = { group = 'copilot', - uri = 'buffers://{scope}', - description = 'Fetches content from multiple buffers. Helps with discussing or analyzing code across multiple files simultaneously.', - - schema = { - type = 'object', - required = { 'scope' }, - properties = { - scope = { - type = 'string', - description = 'Scope of buffers to include in chat context.', - enum = { 'listed', 'visible' }, - default = 'listed', - }, - }, - }, + uri = 'neovim://selection', + description = 'Includes the content of the current visual selection with diagnostics. Useful for discussing specific code snippets or text blocks.', - resolve = function(input) + resolve = function(_, source) utils.schedule_main() - return vim - .iter(vim.api.nvim_list_bufs()) - :filter(function(bufnr) - return utils.buf_valid(bufnr) - and vim.fn.buflisted(bufnr) == 1 - and (input.scope == 'listed' or #vim.fn.win_findbuf(bufnr) > 0) - end) - :map(function(bufnr) - local name = vim.api.nvim_buf_get_name(bufnr) - local data, mimetype = resources.get_buffer(bufnr) - if not data then - return nil - end - return { - uri = 'buffer://' .. name, - mimetype = mimetype, - data = data, - } - end) - :filter(function(file_data) - return file_data ~= nil - end) - :totable() + + local select = require('CopilotChat.select') + local selection = select.get(source.bufnr) + if not selection then + return {} + end + + local data = selection.content + local diag_text = get_diagnostics_text(source.bufnr, selection.start_line, selection.end_line) + if diag_text ~= '' then + data = data .. diag_text + end + + return { + { + uri = 'neovim://selection', + name = selection.filename, + mimetype = files.filetype_to_mimetype(selection.filetype), + data = data, + annotations = { + start_line = selection.start_line, + end_line = selection.end_line, + }, + }, + } end, }, - quickfix = { + clipboard = { group = 'copilot', - uri = 'neovim://quickfix', - description = 'Includes the content of all files referenced in the current quickfix list. Useful for discussing compilation errors, search results, or other collected locations.', + uri = 'neovim://clipboard', + description = 'Provides access to the system clipboard content. Useful for discussing copied text or code snippets.', resolve = function() utils.schedule_main() - - local items = vim.fn.getqflist() - if not items or #items == 0 then + local lines = vim.fn.getreg('+') + if not lines or lines == '' then return {} end - local file_to_bufnr = {} - for _, item in ipairs(items) do - local filename = item.filename or vim.api.nvim_buf_get_name(item.bufnr) - if filename then - if item.bufnr and utils.buf_valid(item.bufnr) then - file_to_bufnr[filename] = item.bufnr - else - file_to_bufnr[filename] = false - end - end - end - - return vim - .iter(vim.tbl_keys(file_to_bufnr)) - :map(function(file) - local bufnr = file_to_bufnr[file] - local data, mimetype, uri - if bufnr and bufnr ~= false then - data, mimetype = resources.get_buffer(bufnr) - uri = 'buffer://' .. file - else - data, mimetype = resources.get_file(file) - uri = 'file://' .. file - end - if not data then - return nil - end - return { - uri = uri, - mimetype = mimetype, - data = data, - } - end) - :filter(function(file_data) - return file_data ~= nil - end) - :totable() + return { + { + uri = 'neovim://clipboard', + mimetype = 'text/plain', + data = lines, + }, + } end, }, - diagnostics = { + glob = { group = 'copilot', - uri = 'neovim://diagnostics/{scope}/{severity}', - description = 'Collects code diagnostics (errors, warnings, etc.) from specified buffers. Helpful for troubleshooting and fixing code issues.', + uri = 'files://glob/{pattern}', + description = 'Lists filenames matching a pattern in your workspace. Useful for discovering relevant files or understanding the project structure.', schema = { type = 'object', - required = { 'scope', 'severity' }, + required = { 'pattern' }, properties = { - scope = { - type = 'string', - description = 'Scope of buffers to use for retrieving diagnostics.', - enum = { 'current', 'listed', 'visible' }, - default = 'current', - }, - severity = { + pattern = { type = 'string', - description = 'Minimum severity level of diagnostics to include.', - enum = { 'error', 'warn', 'info', 'hint' }, - default = 'warn', + description = 'Glob pattern to match files.', + default = '**/*', }, }, }, resolve = function(input, source) - utils.schedule_main() - local out = {} - local scope = input.scope or 'current' - local buffers = {} - - -- Get buffers based on scope - if scope == 'current' then - if source and source.bufnr and utils.buf_valid(source.bufnr) then - buffers = { source.bufnr } - end - elseif scope == 'listed' then - buffers = vim.tbl_filter(function(b) - return utils.buf_valid(b) and vim.fn.buflisted(b) == 1 - end, vim.api.nvim_list_bufs()) - elseif scope == 'visible' then - buffers = vim.tbl_filter(function(b) - return utils.buf_valid(b) and vim.fn.buflisted(b) == 1 and #vim.fn.win_findbuf(b) > 0 - end, vim.api.nvim_list_bufs()) - else - buffers = vim.tbl_filter(function(b) - return utils.buf_valid(b) and vim.api.nvim_buf_get_name(b) == input.scope - end, vim.api.nvim_list_bufs()) - end - - -- Collect diagnostics for each buffer - for _, bufnr in ipairs(buffers) do - local name = vim.api.nvim_buf_get_name(bufnr) - local diagnostics = vim.diagnostic.get(bufnr, { - severity = { - min = vim.diagnostic.severity[input.severity:upper()], - }, - }) - - if #diagnostics > 0 then - local diag_lines = {} - for _, diag in ipairs(diagnostics) do - local severity = vim.diagnostic.severity[diag.severity] or 'UNKNOWN' - local line_text = vim.api.nvim_buf_get_lines(bufnr, diag.lnum, diag.lnum + 1, false)[1] or '' - - table.insert( - diag_lines, - string.format( - '%s line=%d-%d: %s\n > %s', - severity, - diag.lnum + 1, - diag.end_lnum and (diag.end_lnum + 1) or (diag.lnum + 1), - diag.message, - line_text - ) - ) - end - - table.insert(out, { - uri = 'neovim://diagnostics/' .. name, - mimetype = 'text/plain', - data = table.concat(diag_lines, '\n'), - }) - end - end + local out = files.glob(source.cwd(), { + pattern = input.pattern, + }) - return out + return { + { + uri = 'files://glob/' .. input.pattern, + mimetype = 'text/plain', + data = table.concat(out, '\n'), + }, + } end, }, - register = { + grep = { group = 'copilot', - uri = 'neovim://register/{register}', - description = 'Provides access to the content of a specified Vim register. Useful for discussing yanked text, clipboard content, or previously executed commands.', + uri = 'files://grep/{pattern}', + description = 'Searches for a pattern across files in your workspace. Helpful for finding specific code elements or patterns.', schema = { type = 'object', - required = { 'register' }, + required = { 'pattern' }, properties = { - register = { + pattern = { type = 'string', - description = 'Register to include in chat context.', - enum = { - '+', - '*', - '"', - '0', - '-', - '.', - '%', - ':', - '#', - '=', - '/', - }, - default = '+', + description = 'Pattern to search for.', }, }, }, - resolve = function(input) - utils.schedule_main() - local lines = vim.fn.getreg(input.register) - if not lines or lines == '' then - return {} - end + resolve = function(input, source) + local out = files.grep(source.cwd(), { + pattern = input.pattern, + }) return { { - uri = 'neovim://register/' .. input.register, + uri = 'files://grep/' .. input.pattern, mimetype = 'text/plain', - data = lines, + data = table.concat(out, '\n'), }, } end, @@ -453,64 +393,104 @@ return { end, }, - gitstatus = { + bash = { group = 'copilot', - uri = 'git://status', - description = 'Retrieves the status of the current git repository. Useful for discussing changes, commits, and other git-related tasks.', + description = 'Executes a bash command and returns its output. Useful for running shell commands, checking file contents, or gathering system information.', - resolve = function(_, source) - local cmd = { - 'git', - '-C', - source.cwd(), - 'status', - } + schema = { + type = 'object', + required = { 'command' }, + properties = { + command = { + type = 'string', + description = 'Bash command to execute.', + }, + }, + }, - local out = utils.system(cmd) + resolve = function(input, source) + local cmd = { 'bash', '-c', input.command } + local out = utils.system(cmd, source.cwd()) return { { - uri = 'git://status', - mimetype = 'text/plain', data = out.stdout, }, } end, }, - url = { + edit = { group = 'copilot', - uri = 'https://{url}', - description = 'Fetches content from a specified URL. Useful for referencing documentation, examples, or other online resources.', + description = 'Applies a unified diff to a file. The diff should be in unified diff format (similar to diff -U0 output).', schema = { type = 'object', - required = { 'url' }, + required = { 'filename', 'diff' }, properties = { - url = { + filename = { type = 'string', - description = 'URL to include in chat context.', + description = 'Path to file to edit.', + }, + diff = { + type = 'string', + description = 'Unified diff content to apply to the file.', }, }, }, - resolve = function(input) - if not input.url:match('^https?://') then - input.url = 'https://' .. input.url + resolve = function(input, source) + utils.schedule_main() + + local select = require('CopilotChat.select') + local diff = require('CopilotChat.utils.diff') + + -- Find or create the buffer for the file + local filename = input.filename + local diff_bufnr = nil + + -- Try to find matching buffer first + for _, buf in ipairs(vim.api.nvim_list_bufs()) do + if files.filename_same(vim.api.nvim_buf_get_name(buf), filename) then + diff_bufnr = buf + break + end end - local data, mimetype = resources.get_url(input.url) - if not data then - error('URL not found: ' .. input.url) + -- If still not found, try to load or create buffer + if not diff_bufnr then + diff_bufnr = vim.fn.bufadd(filename) + vim.fn.bufload(diff_bufnr) end - return { - { - uri = input.url, - mimetype = mimetype, - data = data, - }, - } + -- Get current buffer content + local lines = vim.api.nvim_buf_get_lines(diff_bufnr, 0, -1, false) + local content = table.concat(lines, '\n') + + -- Apply the unified diff + local new_lines, applied, first, last = diff.apply_unified_diff(input.diff, content) + + if applied then + -- Apply changes to buffer + vim.api.nvim_buf_set_lines(diff_bufnr, 0, -1, false, new_lines) + + -- If source window is valid, switch to the edited buffer and highlight changes + if source and source.winnr and vim.api.nvim_win_is_valid(source.winnr) then + vim.api.nvim_win_set_buf(source.winnr, diff_bufnr) + if first and last then + select.set(diff_bufnr, source.winnr, first, last) + select.highlight(diff_bufnr) + end + end + + return { + { + data = string.format('Successfully applied diff to %s (lines %d-%d)', filename, first or 0, last or 0), + }, + } + else + error('Failed to apply diff to ' .. filename) + end end, }, } diff --git a/lua/CopilotChat/config/mappings.lua b/lua/CopilotChat/config/mappings.lua index aaaa0ea6..c3028f8a 100644 --- a/lua/CopilotChat/config/mappings.lua +++ b/lua/CopilotChat/config/mappings.lua @@ -1,104 +1,32 @@ local async = require('plenary.async') -local copilot = require('CopilotChat') local client = require('CopilotChat.client') +local constants = require('CopilotChat.constants') +local select = require('CopilotChat.select') local utils = require('CopilotChat.utils') +local files = require('CopilotChat.utils.files') ----@class CopilotChat.config.mappings.Diff ----@field change string ----@field reference string ----@field filename string ----@field filetype string ----@field start_line number ----@field end_line number ----@field bufnr number? - ---- Get diff data from a block ----@param block CopilotChat.ui.chat.Block? ----@return CopilotChat.config.mappings.Diff? -local function get_diff(block) - -- If no block found, return nil - if not block then - return nil +--- Prepare a buffer for applying a diff +---@param filename string? +---@param source CopilotChat.ui.chat.Source +---@return integer +local function prepare_diff_buffer(filename, source) + if not filename then + filename = vim.api.nvim_buf_get_name(source.bufnr) end - -- Initialize variables with selection if available - local header = block.header - local selection = copilot.get_selection() - local reference = selection and selection.content - local start_line = selection and selection.start_line - local end_line = selection and selection.end_line - local filename = selection and selection.filename - local filetype = selection and selection.filetype - local bufnr = selection and selection.bufnr - - -- If we have header info, use it as source of truth - if header.start_line and header.end_line then - -- Try to find matching buffer and window - bufnr = nil - for _, win in ipairs(vim.api.nvim_list_wins()) do - local win_buf = vim.api.nvim_win_get_buf(win) - if utils.filename_same(vim.api.nvim_buf_get_name(win_buf), header.filename) then - bufnr = win_buf - break - end - end - - filename = header.filename - filetype = header.filetype or utils.filetype(filename) - start_line = header.start_line - end_line = header.end_line - - -- If we found a valid buffer, get the reference content - if bufnr and utils.buf_valid(bufnr) then - reference = table.concat(vim.api.nvim_buf_get_lines(bufnr, start_line - 1, end_line, false), '\n') - filetype = vim.bo[bufnr].filetype + -- Try to find matching buffer first + local diff_bufnr = nil + for _, buf in ipairs(vim.api.nvim_list_bufs()) do + if files.filename_same(vim.api.nvim_buf_get_name(buf), filename) then + diff_bufnr = buf + break end end - -- If we are missing info, there is no diff to be made - if not start_line or not end_line or not filename then - return nil - end - - return { - change = block.content, - reference = reference or '', - filetype = filetype or '', - filename = filename, - start_line = start_line, - end_line = end_line, - bufnr = bufnr, - } -end - ---- Prepare a buffer for applying a diff ----@param diff CopilotChat.config.mappings.Diff? ----@param source CopilotChat.source? ----@return CopilotChat.config.mappings.Diff? -local function prepare_diff_buffer(diff, source) - if not diff then - return diff - end - - local diff_bufnr = diff.bufnr - - -- If buffer is not found, try to load it + -- If not found, create a new buffer if not diff_bufnr then - -- Try to find matching buffer first - for _, buf in ipairs(vim.api.nvim_list_bufs()) do - if utils.filename_same(vim.api.nvim_buf_get_name(buf), diff.filename) then - diff_bufnr = buf - break - end - end - - -- If still not found, create a new buffer - if not diff_bufnr then - diff_bufnr = vim.fn.bufadd(diff.filename) - vim.fn.bufload(diff_bufnr) - end - - diff.bufnr = diff_bufnr + diff_bufnr = vim.fn.bufadd(filename) + vim.fn.bufload(diff_bufnr) end -- If source exists, update it to point to the diff buffer @@ -107,39 +35,35 @@ local function prepare_diff_buffer(diff, source) vim.api.nvim_win_set_buf(source.winnr, diff_bufnr) end - return diff + return diff_bufnr end ---@class CopilotChat.config.mapping ---@field normal string? ---@field insert string? ----@field callback fun(source: CopilotChat.source) +---@field callback fun(source: CopilotChat.ui.chat.Source) ---@class CopilotChat.config.mapping.yank_diff : CopilotChat.config.mapping ---@field register string? ----@class CopilotChat.config.mapping.show_diff : CopilotChat.config.mapping ----@field full_diff boolean? - ---@class CopilotChat.config.mappings ---@field complete CopilotChat.config.mapping|false|nil ---@field close CopilotChat.config.mapping|false|nil ---@field reset CopilotChat.config.mapping|false|nil ---@field submit_prompt CopilotChat.config.mapping|false|nil ----@field toggle_sticky CopilotChat.config.mapping|false|nil ---@field accept_diff CopilotChat.config.mapping|false|nil ---@field jump_to_diff CopilotChat.config.mapping|false|nil ---@field quickfix_diffs CopilotChat.config.mapping|false|nil +---@field quickfix_answers CopilotChat.config.mapping|false|nil ---@field yank_diff CopilotChat.config.mapping.yank_diff|false|nil ----@field show_diff CopilotChat.config.mapping.show_diff|false|nil +---@field show_diff CopilotChat.config.mapping|false|nil ---@field show_info CopilotChat.config.mapping|false|nil ----@field show_context CopilotChat.config.mapping|false|nil ---@field show_help CopilotChat.config.mapping|false|nil return { complete = { - insert = '', + insert = '', callback = function() - copilot.trigger_complete() + require('CopilotChat.completion').complete() end, }, @@ -147,7 +71,7 @@ return { normal = 'q', insert = '', callback = function() - copilot.close() + require('CopilotChat').close() end, }, @@ -155,7 +79,7 @@ return { normal = '', insert = '', callback = function() - copilot.reset() + require('CopilotChat').reset() end, }, @@ -163,7 +87,8 @@ return { normal = '', insert = '', callback = function() - local message = copilot.chat:get_closest_message('user') + local copilot = require('CopilotChat') + local message = copilot.chat:get_message(constants.ROLE.USER, true) if not message then return end @@ -172,284 +97,210 @@ return { end, }, - toggle_sticky = { - normal = 'grr', - callback = function() - local message = copilot.chat:get_message('user') - local section = message and message.section - if not section then - return - end + accept_diff = { + normal = '', + insert = '', + callback = function(source) + local chat = require('CopilotChat').chat + local diff = require('CopilotChat.utils.diff') - local cursor = vim.api.nvim_win_get_cursor(copilot.chat.winnr) - if cursor[1] < section.start_line or cursor[1] > section.end_line then + local block = chat:get_block(constants.ROLE.ASSISTANT, true) + if not block then return end - local current_line = vim.trim(vim.api.nvim_get_current_line()) - if current_line == '' then - return + local path = block.header.filename + local bufnr = prepare_diff_buffer(path, source) + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local new_lines = diff.apply_diff(block, lines) + vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, new_lines) + local first, last = diff.get_diff_region(block, lines) + if first and last then + select.set(bufnr, source.winnr, first, last) + select.highlight(bufnr) end + end, + }, - local cur_line = cursor[1] - vim.api.nvim_buf_set_lines(copilot.chat.bufnr, cur_line - 1, cur_line, false, {}) + jump_to_diff = { + normal = 'gj', + callback = function(source) + local chat = require('CopilotChat').chat + local diff = require('CopilotChat.utils.diff') - if vim.startswith(current_line, '> ') then + local block = chat:get_block(constants.ROLE.ASSISTANT, true) + if not block then return end - copilot.chat:add_sticky(current_line) - vim.api.nvim_win_set_cursor(copilot.chat.winnr, cursor) + local path = block.header.filename + local bufnr = prepare_diff_buffer(path, source) + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local first, last = diff.get_diff_region(block, lines) + if first and last and bufnr then + select.set(bufnr, source.winnr, first, last) + select.highlight(bufnr) + end end, }, - clear_stickies = { - normal = 'grx', + yank_diff = { + normal = 'gy', + register = '"', -- Default register to use for yanking callback = function() - local message = copilot.chat:get_message('user') - local section = message and message.section - if not section then + local config = require('CopilotChat.config') + local chat = require('CopilotChat').chat + local block = chat:get_block(constants.ROLE.ASSISTANT, true) + if not block then return end - local lines = vim.split(message.content, '\n') - local new_lines = {} - local changed = false - - for _, line in ipairs(lines) do - if not vim.startswith(vim.trim(line), '> ') then - table.insert(new_lines, line) - else - changed = true - end - end - - if changed then - message.content = table.concat(new_lines, '\n') - copilot.chat:add_message(message, true) - end + vim.fn.setreg(config.mappings.yank_diff.register, block.content) end, }, - accept_diff = { - normal = '', - insert = '', + show_diff = { + normal = 'gd', callback = function(source) - local diff = get_diff(copilot.chat:get_closest_block()) - diff = prepare_diff_buffer(diff, source) - if not diff then + local chat = require('CopilotChat').chat + local diff = require('CopilotChat.utils.diff') + + local block = chat:get_block(constants.ROLE.ASSISTANT, true) + if not block then return end - local lines = vim.split(diff.change, '\n', { trimempty = false }) - vim.api.nvim_buf_set_lines(diff.bufnr, diff.start_line - 1, diff.end_line, false, lines) - copilot.set_selection(diff.bufnr, diff.start_line, diff.start_line + #lines - 1) - end, - }, + local path = block.header.filename + local bufnr = prepare_diff_buffer(path, source) - jump_to_diff = { - normal = 'gj', - callback = function(source) - local diff = get_diff(copilot.chat:get_closest_block()) - diff = prepare_diff_buffer(diff, source) - if not diff then - return + -- Collect all blocks for the same filename + local message = chat:get_message(constants.ROLE.ASSISTANT, true) + local blocks = {} + if message and message.section and message.section.blocks then + for _, b in ipairs(message.section.blocks) do + if b.header.filename == path then + table.insert(blocks, b) + end + end + else + blocks = { block } end - copilot.set_selection(diff.bufnr, diff.start_line, diff.end_line) - end, - }, + -- Apply all diffs for the filename + local new_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + for i = #blocks, 1, -1 do + new_lines = diff.apply_diff(blocks[i], new_lines) + end - quickfix_answers = { - normal = 'gqa', - callback = function() - local items = {} - for i, message in ipairs(copilot.chat.messages) do - if message.section and message.role == 'assistant' then - local prev_message = copilot.chat.messages[i - 1] - local text = '' - if prev_message then - text = prev_message.content - end + local opts = { + filetype = vim.bo[bufnr].filetype, + text = table.concat(new_lines, '\n'), + } - table.insert(items, { - bufnr = copilot.chat.bufnr, - lnum = message.section.start_line, - end_lnum = message.section.end_line, - text = text, - }) - end + opts.on_show = function() + vim.api.nvim_win_call(source.winnr, function() + vim.cmd('diffthis') + end) + + vim.api.nvim_win_call(chat.winnr, function() + vim.cmd('diffthis') + end) end - vim.fn.setqflist(items) - vim.cmd('copen') + opts.on_hide = function() + vim.api.nvim_win_call(chat.winnr, function() + vim.cmd('diffoff') + end) + end + + chat:overlay(opts) end, }, quickfix_diffs = { normal = 'gqd', callback = function() - local selection = copilot.get_selection() + local chat = require('CopilotChat').chat local items = {} - - for _, message in ipairs(copilot.chat.messages) do + local messages = chat:get_messages() + for _, message in ipairs(messages) do if message.section then for _, block in ipairs(message.section.blocks) do - local header = block.header - - if not header.start_line and selection then - header.filename = selection.filename .. ' (selection)' - header.start_line = selection.start_line - header.end_line = selection.end_line - end - - local text = string.format('%s (%s)', header.filename, header.filetype) - if header.start_line and header.end_line then - text = text .. string.format(' [lines %d-%d]', header.start_line, header.end_line) + local text = string.format('%s (%s)', block.header.filename, block.header.filetype) + if block.header.start_line and block.header.end_line then + text = text .. string.format(' [lines %d-%d]', block.header.start_line, block.header.end_line) end table.insert(items, { - bufnr = copilot.chat.bufnr, + bufnr = chat.bufnr, lnum = block.start_line, end_lnum = block.end_line, text = text, }) end end - - vim.fn.setqflist(items) - vim.cmd('copen') - end - end, - }, - - yank_diff = { - normal = 'gy', - register = '"', -- Default register to use for yanking - callback = function() - local block = copilot.chat:get_closest_block() - if not block then - return end - vim.fn.setreg(copilot.config.mappings.yank_diff.register, block.content) + vim.fn.setqflist(items) + vim.cmd('copen') end, }, - show_diff = { - normal = 'gd', - full_diff = false, -- Show full diff instead of unified diff when showing diff window - callback = function(source) - local diff = get_diff(copilot.chat:get_closest_block()) - diff = prepare_diff_buffer(diff, source) - if not diff then - return - end - - local opts = { - filetype = diff.filetype, - syntax = 'diff', - } - - if copilot.config.mappings.show_diff.full_diff then - local modified = utils.buf_valid(diff.bufnr) and vim.api.nvim_buf_get_lines(diff.bufnr, 0, -1, false) or {} - - -- Apply all diffs from same file - if #modified > 0 then - -- Find all diffs from the same file in this section - local message = copilot.chat:get_closest_message('assistant') - local section = message and message.section - local same_file_diffs = {} - if section then - for _, block in ipairs(section.blocks) do - local block_diff = get_diff(block) - if block_diff and block_diff.bufnr == diff.bufnr then - table.insert(same_file_diffs, block_diff) - end - end - - -- Sort diffs bottom to top to preserve line numbering - table.sort(same_file_diffs, function(a, b) - return a.start_line > b.start_line - end) - end - - for _, file_diff in ipairs(same_file_diffs) do - local start_idx = file_diff.start_line - local end_idx = file_diff.end_line - for _ = start_idx, end_idx do - table.remove(modified, start_idx) - end - local change_lines = vim.split(file_diff.change, '\n') - for i, line in ipairs(change_lines) do - table.insert(modified, start_idx + i, line) - end + quickfix_answers = { + normal = 'gqa', + callback = function() + local chat = require('CopilotChat').chat + local items = {} + local messages = chat:get_messages() + for i, message in ipairs(messages) do + if message.section and message.role == constants.ROLE.ASSISTANT then + local prev_message = messages[i - 1] + local text = '' + if prev_message then + text = prev_message.content end - modified = vim.tbl_filter(function(line) - return line ~= nil - end, modified) - - opts.text = table.concat(modified, '\n') - else - opts.text = diff.change - end - - opts.on_show = function() - vim.api.nvim_win_call(vim.fn.bufwinid(diff.bufnr), function() - vim.cmd('diffthis') - end) - - vim.api.nvim_win_call(copilot.chat.winnr, function() - vim.cmd('diffthis') - end) - end - - opts.on_hide = function() - vim.api.nvim_win_call(copilot.chat.winnr, function() - vim.cmd('diffoff') - end) + table.insert(items, { + bufnr = chat.bufnr, + lnum = message.section.start_line, + end_lnum = message.section.end_line, + text = text, + }) end - else - opts.text = tostring(vim.diff(diff.reference, diff.change, { - result_type = 'unified', - ignore_blank_lines = true, - ignore_whitespace = true, - ignore_whitespace_change = true, - ignore_whitespace_change_at_eol = true, - ignore_cr_at_eol = true, - algorithm = 'myers', - ctxlen = #diff.reference, - })) end - copilot.chat:overlay(opts) + vim.fn.setqflist(items) + vim.cmd('copen') end, }, show_info = { normal = 'gc', callback = function(source) - local message = copilot.chat:get_closest_message('user') + local chat = require('CopilotChat').chat + local prompts = require('CopilotChat.prompts') + + local message = chat:get_message(constants.ROLE.USER, true) if not message then return end local lines = {} - local config, prompt = copilot.resolve_prompt(message.content) - local system_prompt = config.system_prompt async.run(function() + local config, prompt = prompts.resolve_prompt(message.content) + local system_prompt = config.system_prompt + local selected_tools = prompts.resolve_tools(prompt, config) + local selected_model = prompts.resolve_model(prompt, config) local infos = client:info() - local selected_model = copilot.resolve_model(prompt, config) - local selected_tools, resolved_resources = copilot.resolve_functions(prompt, config) + selected_tools = vim.tbl_map(function(tool) return tool.name end, selected_tools) utils.schedule_main() - table.insert(lines, '**Logs**: `' .. copilot.config.log_path .. '`') - table.insert(lines, '**History**: `' .. copilot.config.history_path .. '`') + table.insert(lines, '**Logs**: `' .. config.log_path .. '`') + table.insert(lines, '**History**: `' .. config.history_path .. '`') table.insert(lines, '') for provider, infolines in pairs(infos) do @@ -489,7 +340,7 @@ return { table.insert(lines, '') end - local selection = copilot.get_selection() + local selection = select.get(source.bufnr) if selection then table.insert(lines, '**Selection**') table.insert(lines, '') @@ -505,29 +356,7 @@ return { table.insert(lines, '') end - if not utils.empty(resolved_resources) then - table.insert(lines, '**Resources**') - table.insert(lines, '') - end - - for _, resource in ipairs(resolved_resources) do - local resource_lines = vim.split(resource.data, '\n') - local preview = vim.list_slice(resource_lines, 1, math.min(10, #resource_lines)) - local header = string.format('**%s** (%s lines)', resource.name, #resource_lines) - if #resource_lines > 10 then - header = header .. ' (truncated)' - end - - table.insert(lines, header) - table.insert(lines, '```' .. resource.type) - for _, line in ipairs(preview) do - table.insert(lines, line) - end - table.insert(lines, '```') - table.insert(lines, '') - end - - copilot.chat:overlay({ + chat:overlay({ text = vim.trim(table.concat(lines, '\n')) .. '\n', }) end) @@ -537,6 +366,9 @@ return { show_help = { normal = 'gh', callback = function() + local config = require('CopilotChat.config') + local chat = require('CopilotChat').chat + local chat_help = '**`Special tokens`**\n' chat_help = chat_help .. '`@` to share function\n' chat_help = chat_help .. '`#` to add resource\n' @@ -546,22 +378,22 @@ return { chat_help = chat_help .. '`> ` to make a sticky prompt (copied to next prompt)\n' chat_help = chat_help .. '\n**`Mappings`**\n' - local chat_keys = vim.tbl_keys(copilot.config.mappings) + local chat_keys = vim.tbl_keys(config.mappings) table.sort(chat_keys, function(a, b) - a = copilot.config.mappings[a] + a = config.mappings[a] a = a and (a.normal or a.insert) or '' - b = copilot.config.mappings[b] + b = config.mappings[b] b = b and (b.normal or b.insert) or '' return a < b end) for _, name in ipairs(chat_keys) do - local info = utils.key_to_info(name, copilot.config.mappings[name], '`') + local info = utils.key_to_info(name, config.mappings[name], '`') if info ~= '' then chat_help = chat_help .. info .. '\n' end end - copilot.chat:overlay({ + chat:overlay({ text = chat_help, }) end, diff --git a/lua/CopilotChat/config/prompts.lua b/lua/CopilotChat/config/prompts.lua index 8764f914..53baa21f 100644 --- a/lua/CopilotChat/config/prompts.lua +++ b/lua/CopilotChat/config/prompts.lua @@ -1,5 +1,13 @@ -local COPILOT_BASE = [[ -When asked for your name, you must respond with "GitHub Copilot". +---@class CopilotChat.config.prompts.Prompt : CopilotChat.config.Shared +---@field prompt string? +---@field description string? +---@field mapping string? + +---@type table +return { + COPILOT_BASE = { + system_prompt = [[ +When asked for your name, you must respond with "Copilot". Follow the user's requirements carefully & to the letter. Keep your answers short and impersonal. Always answer in {LANGUAGE} unless explicitly asked otherwise. @@ -13,74 +21,40 @@ The user works in editor called Neovim which has these core concepts: - Normal/Insert/Visual/Command modes: Different interaction states - LSP (Language Server Protocol): Provides code intelligence features like completion, diagnostics, and code actions - Treesitter: Provides syntax highlighting, code folding, and structural text editing based on syntax tree parsing +- Visual selection: Text selected in visual mode that can be shared as context The user is working on a {OS_NAME} machine. Please respond with system specific commands if applicable. -The user is currently in workspace directory {DIR} (typically the project root). Current file paths will be relative to this directory. +The user is currently in workspace directory {DIR} (project root). File paths are relative to this directory. + +Context is provided to you in several ways: +- Resources: Contextual data shared via "# " headers and referenced via "##" links +- Code blocks with file path labels and line numbers (e.g., ```lua path=/file.lua start_line=1 end_line=10```) + Note: Each line in code block can be prefixed with : for your reference only. NEVER include these line numbers in your responses. +- Visual selections: Text selected in visual mode that can be shared as context +- Diffs: Changes shown in unified diff format (+, -, etc.) +- Conversation history +When resources (like buffers, files, or diffs) change, their content in the chat history is replaced with the latest version rather than appended as new data. + The user will ask a question or request a task that may require analysis to answer correctly. If you can infer the project type (languages, frameworks, libraries) from context, consider them when making changes. For implementing features, break down the request into concepts and provide a clear solution. Think creatively to provide complete solutions based on the information available. -Never fabricate or hallucinate file contents you haven't actually seen. +Never fabricate or hallucinate file contents you haven't actually seen in the provided context. +When outputting code or diffs, NEVER include line number prefixes - they are only for reference when analyzing the provided context. - -If tools are explicitly defined in your system context: -- Follow JSON schema precisely when using tools, including all required properties and outputting valid JSON. -- Use appropriate tools for tasks rather than asking for manual actions. -- Execute actions directly when you indicate you'll do so, without asking for permission. -- Only use tools that exist and use proper invocation procedures - no multi_tool_use.parallel. -- Before using tools to retrieve information, check if it's already available in context: - 1. Resources shared via "# " headers and referenced via "##" links - 2. Code blocks with file path labels - 3. Other contextual sharing like selected text or conversation history -- If you don't have explicit tool definitions in your system context, assume NO tools are available and clearly state this limitation when asked. NEVER pretend to retrieve content you cannot access. - - -You will receive code snippets that include line number prefixes - use these to maintain correct position references but remove them when generating output. -Always use code blocks to present code changes, even if the user doesn't ask for it. - -When presenting code changes: -1. For each change, use the following markdown code block format with triple backticks: - ``` path= start_line= end_line= - - ``` - - Examples: - - ```lua path=lua/CopilotChat/init.lua start_line=40 end_line=50 - local function example() - print("This is an example function.") - end - ``` - - ```python path=scripts/example.py start_line=10 end_line=15 - def example_function(): - print("This is an example function.") - ``` - - ```json path=config/settings.json start_line=5 end_line=8 - { - "setting": "value", - "enabled": true - } - ``` -2. Keep changes minimal and focused to produce short diffs. -3. Include complete replacement code for the specified line range with: - - Proper indentation matching the source - - All necessary lines (no eliding with comments) - - No line number prefixes in the code -4. Address any diagnostics issues when fixing code. -5. If multiple changes are needed, present them as separate code blocks. - -]] - -local COPILOT_INSTRUCTIONS = [[ +]], + }, + + COPILOT_INSTRUCTIONS = { + system_prompt = [[ You are a code-focused AI programming assistant that specializes in practical software engineering solutions. -]] .. COPILOT_BASE +]], + }, -local COPILOT_EXPLAIN = [[ + COPILOT_EXPLAIN = { + system_prompt = [[ You are a programming instructor focused on clear, practical explanations. -]] .. COPILOT_BASE .. [[ When explaining code: - Provide concise high-level overview first @@ -90,11 +64,12 @@ When explaining code: - Focus on complex parts rather than basic syntax - Use short paragraphs with clear structure - Mention performance considerations where relevant -]] +]], + }, -local COPILOT_REVIEW = [[ + COPILOT_REVIEW = { + system_prompt = [[ You are a code reviewer focused on improving code quality and maintainability. -]] .. COPILOT_BASE .. [[ Format each issue you find precisely as: line=: @@ -117,29 +92,7 @@ Multiple issues on one line should be separated by semicolons. End with: "**`To clear buffer highlights, please ask a different question.`**" If no issues found, confirm the code is well-written and explain why. -]] - ----@class CopilotChat.config.prompts.Prompt : CopilotChat.config.Shared ----@field prompt string? ----@field description string? ----@field mapping string? - ----@type table -return { - COPILOT_BASE = { - system_prompt = COPILOT_BASE, - }, - - COPILOT_INSTRUCTIONS = { - system_prompt = COPILOT_INSTRUCTIONS, - }, - - COPILOT_EXPLAIN = { - system_prompt = COPILOT_EXPLAIN, - }, - - COPILOT_REVIEW = { - system_prompt = COPILOT_REVIEW, +]], }, Explain = { @@ -152,7 +105,7 @@ return { system_prompt = 'COPILOT_REVIEW', callback = function(response, source) local diagnostics = {} - for line in response:gmatch('[^\r\n]+') do + for line in response.content:gmatch('[^\r\n]+') do if line:find('^line=') then local start_line = nil local end_line = nil @@ -205,6 +158,8 @@ return { Commit = { prompt = 'Write commit message for the change with commitizen convention. Keep the title under 50 characters and wrap message at 72 characters. Format as a gitcommit code block.', - sticky = '#gitdiff:staged', + resources = { + 'gitdiff:staged', + }, }, } diff --git a/lua/CopilotChat/config/providers.lua b/lua/CopilotChat/config/providers.lua index d2ac7976..44233f3d 100644 --- a/lua/CopilotChat/config/providers.lua +++ b/lua/CopilotChat/config/providers.lua @@ -1,6 +1,10 @@ -local notify = require('CopilotChat.notify') -local utils = require('CopilotChat.utils') +local log = require('plenary.log') local plenary_utils = require('plenary.async.util') +local constants = require('CopilotChat.constants') +local notify = require('CopilotChat.utils.notify') +local utils = require('CopilotChat.utils') +local curl = require('CopilotChat.utils.curl') +local files = require('CopilotChat.utils.files') local EDITOR_VERSION = 'Neovim/' .. vim.version().major .. '.' .. vim.version().minor .. '.' .. vim.version().patch @@ -13,7 +17,7 @@ local function load_tokens() local config_path = vim.fs.normalize(vim.fn.stdpath('data') .. '/copilot_chat') local cache_file = config_path .. '/tokens.json' - local file = utils.read_file(cache_file) + local file = files.read_file(cache_file) if file then token_cache = vim.json.decode(file) else @@ -38,10 +42,14 @@ local function set_token(tag, token, save) return token end + utils.schedule_main() local tokens = load_tokens() tokens[tag] = token local config_path = vim.fs.normalize(vim.fn.stdpath('data') .. '/copilot_chat') - utils.write_file(config_path .. '/tokens.json', vim.json.encode(tokens)) + local file_path = config_path .. '/tokens.json' + vim.fn.mkdir(vim.fn.fnamemodify(file_path, ':p:h'), 'p') + files.write_file(file_path, vim.json.encode(tokens)) + log.info('Token for ' .. tag .. ' saved to ' .. file_path) return token end @@ -49,7 +57,7 @@ end ---@return string local function github_device_flow(tag, client_id, scope) local function request_device_code() - local res = utils.curl_post('https://github.com/login/device/code', { + local res = curl.post('https://github.com/login/device/code', { body = { client_id = client_id, scope = scope, @@ -62,23 +70,25 @@ local function github_device_flow(tag, client_id, scope) end local function poll_for_token(device_code, interval) - while true do - plenary_utils.sleep(interval * 1000) - - local res = utils.curl_post('https://github.com/login/oauth/access_token', { - body = { - client_id = client_id, - device_code = device_code, - grant_type = 'urn:ietf:params:oauth:grant-type:device_code', - }, - headers = { ['Accept'] = 'application/json' }, - }) - local data = vim.json.decode(res.body) - if data.access_token then - return data.access_token - elseif data.error ~= 'authorization_pending' then - error('Auth error: ' .. (data.error or 'unknown')) - end + plenary_utils.sleep(interval * 1000) + + local res = curl.post('https://github.com/login/oauth/access_token', { + json_response = true, + body = { + client_id = client_id, + device_code = device_code, + grant_type = 'urn:ietf:params:oauth:grant-type:device_code', + }, + headers = { ['Accept'] = 'application/json' }, + }) + + local data = res.body + if data.access_token then + return data.access_token + elseif data.error ~= 'authorization_pending' then + error('Auth error: ' .. (data.error or 'unknown')) + else + return poll_for_token(device_code, interval) end end @@ -94,6 +104,8 @@ local function github_device_flow(tag, client_id, scope) ) notify.publish(notify.STATUS, '[' .. tag .. '] Waiting for authorization...') token = poll_for_token(code_data.device_code, code_data.interval) + notify.publish(notify.MESSAGE, '') + notify.publish(notify.STATUS, '') return set_token(tag, token, true) end @@ -140,7 +152,7 @@ local function get_github_copilot_token(tag) } for _, file_path in ipairs(file_paths) do - local file_data = utils.read_file(file_path) + local file_data = files.read_file(file_path) if file_data then local parsed_data = utils.json_decode(file_data) if parsed_data then @@ -171,7 +183,7 @@ local function get_github_models_token(tag) end -- loading token from gh cli if available - if vim.fn.executable('gh') == 0 then + if vim.fn.executable('gh') == 1 then local result = utils.system({ 'gh', 'auth', 'token', '-h', 'github.com' }) if result and result.code == 0 and result.stdout then local gh_token = vim.trim(result.stdout) @@ -184,6 +196,322 @@ local function get_github_models_token(tag) return github_device_flow(tag, '178c6fc778ccc68e1d6a', 'read:user copilot') end +--- Resolve the Copilot API base URL from token endpoint response. +--- Falls back to the default api.githubcopilot.com if no business endpoint is found. +---@param token_body table The decoded JSON body from the token endpoint +---@return string base_url The base URL (no trailing slash) +local function resolve_copilot_base_url(token_body) + -- The token response may include an `endpoints` table with an `api` field + -- pointing to the correct base URL for business/enterprise accounts, + -- e.g. https://api.business.githubcopilot.com + if token_body and token_body.endpoints and token_body.endpoints.api then + local url = token_body.endpoints.api + -- Strip trailing slash if present + return url:gsub('/$', '') + end + return 'https://api.githubcopilot.com' +end + +--- Prepare input for Responses API +---@param inputs CopilotChat.client.Message[] +---@param opts CopilotChat.config.providers.Options +---@return table +local function prepare_responses_input(inputs, opts) + local instructions = nil + local input_messages = {} + + for _, msg in ipairs(inputs) do + if msg.role == constants.ROLE.SYSTEM then + instructions = instructions and (instructions .. '\n\n' .. msg.content) or msg.content + elseif msg.role == constants.ROLE.TOOL then + table.insert(input_messages, { + type = 'function_call_output', + call_id = msg.tool_call_id, + output = msg.content, + }) + else + table.insert(input_messages, { + role = msg.role, + content = msg.content, + }) + + if msg.tool_calls then + for _, tool_call in ipairs(msg.tool_calls) do + table.insert(input_messages, { + type = 'function_call', + call_id = tool_call.id, + name = tool_call.name, + arguments = tool_call.arguments or '', + }) + end + end + end + end + + local out = { + model = opts.model.id, + stream = opts.model.streaming ~= false, + input = input_messages, + } + + if instructions then + out.instructions = instructions + end + + if opts.tools and opts.model.tools then + out.tools = vim.tbl_map(function(tool) + return { + type = 'function', + name = tool.name, + description = tool.description, + parameters = tool.schema, + } + end, opts.tools) + end + + return out +end + +--- Prepare input for Chat Completions API +---@param inputs CopilotChat.client.Message[] +---@param opts CopilotChat.config.providers.Options +---@return table +local function prepare_chat_input(inputs, opts) + local is_o1 = vim.startswith(opts.model.id, 'o1') + local is_codex = opts.model.id:find('codex') ~= nil + + inputs = vim.tbl_map(function(input) + local output = { + role = (is_o1 and input.role == constants.ROLE.SYSTEM) and constants.ROLE.USER or input.role, + content = input.content, + } + + if input.tool_call_id then + output.tool_call_id = input.tool_call_id + end + + if input.tool_calls then + output.tool_calls = vim.tbl_map(function(tool_call) + return { + id = tool_call.id, + type = 'function', + ['function'] = { + name = tool_call.name, + arguments = tool_call.arguments or nil, + }, + } + end, input.tool_calls) + end + + return output + end, inputs) + + local out = { + messages = inputs, + model = opts.model.id, + stream = opts.model.streaming or false, + } + + if opts.tools and opts.model.tools then + out.tools = vim.tbl_map(function(tool) + return { + type = 'function', + ['function'] = { + name = tool.name, + description = tool.description, + parameters = tool.schema, + }, + } + end, opts.tools) + end + + if not is_o1 and not is_codex then + out.n = 1 + out.top_p = 1 + out.temperature = opts.temperature + end + + if opts.model.max_output_tokens then + out.max_tokens = opts.model.max_output_tokens + end + + return out +end +---@param parts table Array of content parts +---@return string The concatenated text content +local function extract_text_from_parts(parts) + if not parts or type(parts) ~= 'table' then + return '' + end + + local content = '' + for _, part in ipairs(parts) do + if type(part) == 'string' then + content = content .. part + elseif type(part) == 'table' then + -- Responses API: parts have type field + if part.type == 'text' or part.type == 'output_text' or part.type == 'input_text' then + content = content .. (part.text or '') + -- Fallback for simpler structures + elseif part.text then + content = content .. part.text + end + end + end + return content +end + +--- Parse Responses API output (both streaming and non-streaming) +---@param output table Raw API response +---@return CopilotChat.config.providers.Output +local function prepare_responses_output(output) + local content = '' + local reasoning = '' + local finish_reason = nil + local total_tokens = nil + local tool_calls = {} + local model = nil + + -- Handle errors + local error_msg = output.error or (output.response and output.response.error) + if error_msg then + if type(error_msg) == 'table' then + error_msg = error_msg.message or vim.inspect(error_msg) + end + return { + content = '', + reasoning = '', + finish_reason = 'error: ' .. tostring(error_msg), + total_tokens = nil, + tool_calls = {}, + model = nil, + } + end + + -- Handle streaming events + if output.type then + if output.type == 'response.output_text.delta' then + -- Streaming text delta + if output.delta and type(output.delta) == 'string' then + content = output.delta + elseif output.delta and output.delta.text then + content = output.delta.text + end + elseif output.type == 'response.output_item.done' then + local item = output.item + if item and item.type == 'function_call' then + table.insert(tool_calls, { + id = item.call_id, + index = output.output_index, + name = item.name, + arguments = item.arguments or '', + }) + end + elseif output.type == 'response.completed' or output.type == 'response.done' then + local response = output.response + if response then + if response.reasoning and response.reasoning.summary then + reasoning = response.reasoning.summary + end + if response.usage then + total_tokens = response.usage.total_tokens + end + if response.model then + model = response.model + end + finish_reason = 'stop' + end + elseif output.type == 'response.failed' then + finish_reason = 'error: ' .. (output.error and output.error.message or 'unknown error') + end + -- Handle non-streaming response + elseif output.response then + local response = output.response + if response.output and #response.output > 0 then + for _, msg in ipairs(response.output) do + if msg.content then + content = content .. extract_text_from_parts(msg.content) + end + if msg.tool_calls then + for i, tool_call in ipairs(msg.tool_calls) do + table.insert(tool_calls, { + id = tool_call.call_id, + index = i, + name = tool_call.name, + arguments = tool_call.arguments or '', + }) + end + end + end + end + if response.reasoning and response.reasoning.summary then + reasoning = response.reasoning.summary + end + if response.usage then + total_tokens = response.usage.total_tokens + end + if response.model then + model = response.model + end + finish_reason = response.status == 'completed' and 'stop' or nil + end + + return { + content = content, + reasoning = reasoning, + finish_reason = finish_reason, + total_tokens = total_tokens, + tool_calls = tool_calls, + model = model, + } +end + +--- Parse Chat Completions API output (both streaming and non-streaming) +---@param output table Raw API response +---@return CopilotChat.config.providers.Output +local function prepare_chat_output(output) + local tool_calls = {} + + local choice + if output.choices and #output.choices > 0 then + for _, c in ipairs(output.choices) do + local message = c.message or c.delta + if message and message.tool_calls then + for i, tool_call in ipairs(message.tool_calls) do + local fn = tool_call['function'] + if fn then + table.insert(tool_calls, { + id = tool_call.id, + index = tool_call.index or i, + name = fn.name, + arguments = fn.arguments or '', + }) + end + end + end + end + choice = output.choices[1] + else + choice = output + end + + local message = choice.message or choice.delta + local content = message and message.content + local reasoning = message and (message.reasoning or message.reasoning_content) + local usage = choice.usage and choice.usage.total_tokens or output.usage and output.usage.total_tokens + local finish_reason = choice.finish_reason or choice.done_reason or output.finish_reason or output.done_reason + local model = choice.model or output.model + + return { + content = content, + reasoning = reasoning, + finish_reason = finish_reason, + total_tokens = usage, + tool_calls = tool_calls, + model = model, + } +end + ---@class CopilotChat.config.providers.Options ---@field model CopilotChat.client.Model ---@field temperature number? @@ -191,17 +519,19 @@ end ---@class CopilotChat.config.providers.Output ---@field content string +---@field reasoning string? ---@field finish_reason string? ---@field total_tokens number? ---@field tool_calls table +---@field model string? ---@class CopilotChat.config.providers.Provider ---@field disabled nil|boolean ---@field get_headers nil|fun():table,number? ---@field get_info nil|fun(headers:table):string[] ---@field get_models nil|fun(headers:table):table ----@field embed nil|string|fun(inputs:table, headers:table):table ----@field prepare_input nil|fun(inputs:table, opts:CopilotChat.config.providers.Options):table +---@field resolve_model nil|fun(headers:table, model: string):string +---@field prepare_input nil|fun(inputs:CopilotChat.client.Message[], opts:CopilotChat.config.providers.Options):table,table? ---@field prepare_output nil|fun(output:table, opts:CopilotChat.config.providers.Options):CopilotChat.config.providers.Output ---@field get_url nil|fun(opts:CopilotChat.config.providers.Options):string @@ -209,10 +539,8 @@ end local M = {} M.copilot = { - embed = 'copilot_embeddings', - get_headers = function() - local response, err = utils.curl_get('https://api.github.com/copilot_internal/v2/token', { + local response, err = curl.get('https://api.github.com/copilot_internal/v2/token', { json_response = true, headers = { ['Authorization'] = 'Token ' .. get_github_copilot_token('github_copilot'), @@ -223,17 +551,25 @@ M.copilot = { error(err) end + -- Resolve the base URL from the token response so that business/enterprise + -- accounts using *.business.githubcopilot.com are handled automatically. + local base_url = resolve_copilot_base_url(response.body) + return { ['Authorization'] = 'Bearer ' .. response.body.token, ['Editor-Version'] = EDITOR_VERSION, ['Editor-Plugin-Version'] = 'CopilotChat.nvim/*', ['Copilot-Integration-Id'] = 'vscode-chat', + ['x-github-api-version'] = '2025-10-01', + -- Store the resolved base URL in a custom header so that get_models, + -- resolve_model, and get_url can read it without making another request. + ['x-copilot-base-url'] = base_url, }, response.body.expires_at end, - get_info = function(headers) - local response, err = utils.curl_get('https://api.github.com/copilot_internal/user', { + get_info = function() + local response, err = curl.get('https://api.github.com/copilot_internal/user', { json_response = true, headers = { ['Authorization'] = 'Token ' .. get_github_copilot_token('github_copilot'), @@ -283,9 +619,16 @@ M.copilot = { end, get_models = function(headers) - local response, err = utils.curl_get('https://api.githubcopilot.com/models', { + -- Use the resolved base URL carried in the custom header, falling back to + -- the default if it is absent (e.g. during tests or manual calls). + local base_url = headers['x-copilot-base-url'] or 'https://api.githubcopilot.com' + + -- Build request headers without our internal routing header. + local request_headers = vim.tbl_extend('force', headers, { ['x-copilot-base-url'] = nil }) + + local response, err = curl.get(base_url .. '/models', { json_response = true, - headers = headers, + headers = request_headers, }) if err then @@ -298,6 +641,10 @@ M.copilot = { return model.capabilities.type == 'chat' and model.model_picker_enabled end) :map(function(model) + local supported_endpoints = model.supported_endpoints or {} + -- Pre-compute whether this model uses the Responses API + local use_responses = vim.tbl_contains(supported_endpoints, '/responses') + return { id = model.id, name = model.name, @@ -308,6 +655,10 @@ M.copilot = { tools = model.capabilities.supports.tool_calls, policy = not model['policy'] or model['policy']['state'] == 'enabled', version = model.version, + use_responses = use_responses, + -- Carry the base URL into the model so get_url and resolve_model + -- can use it without needing access to the headers again. + base_url = base_url, } end) :totable() @@ -323,137 +674,87 @@ M.copilot = { for _, model in ipairs(models) do if not model.policy then - utils.curl_post('https://api.githubcopilot.com/models/' .. model.id .. '/policy', { - headers = headers, + pcall(curl.post, base_url .. '/models/' .. model.id .. '/policy', { + headers = request_headers, json_request = true, body = { state = 'enabled' }, }) end end + -- Auto model selector + table.insert(models, { + id = 'auto', + name = 'Auto (Copilot)', + description = 'Auto selects the best model for your request.', + base_url = base_url, + }) + return models end, - prepare_input = function(inputs, opts) - local is_o1 = vim.startswith(opts.model.id, 'o1') - - inputs = vim.tbl_map(function(input) - local output = { - role = input.role, - content = input.content, - } - - if is_o1 then - if input.role == 'system' then - output.role = 'user' - end - end - - if input.tool_call_id then - output.tool_call_id = input.tool_call_id - end - - if input.tool_calls then - output.tool_calls = vim.tbl_map(function(tool_call) - return { - id = tool_call.id, - type = 'function', - ['function'] = { - name = tool_call.name, - arguments = tool_call.arguments or nil, - }, - } - end, input.tool_calls) - end - - return output - end, inputs) - - local out = { - messages = inputs, - model = opts.model.id, - stream = opts.model.streaming or false, - } - - if opts.tools and opts.model.tools then - out.tools = vim.tbl_map(function(tool) - return { - type = 'function', - ['function'] = { - name = tool.name, - description = tool.description, - parameters = tool.schema, - }, - } - end, opts.tools) + resolve_model = function(headers, model) + if model ~= 'auto' then + return model end - if not is_o1 then - out.n = 1 - out.top_p = 1 - out.temperature = opts.temperature - end + local base_url = headers['x-copilot-base-url'] or 'https://api.githubcopilot.com' + local request_headers = vim.tbl_extend('force', headers, { ['x-copilot-base-url'] = nil }) + + local url = base_url .. '/models/session' + local response, err = curl.post(url, { + headers = request_headers, + body = { auto_mode = { model_hints = { 'auto' } } }, + json_response = true, + json_request = true, + }) - if opts.model.max_output_tokens then - out.max_tokens = opts.model.max_output_tokens + if err then + error(err) end - return out + return response.body.selected_model end, - prepare_output = function(output) - local tool_calls = {} - - local choice - if output.choices and #output.choices > 0 then - for _, choice in ipairs(output.choices) do - local message = choice.message or choice.delta - if message and message.tool_calls then - for i, tool_call in ipairs(message.tool_calls) do - local fn = tool_call['function'] - if fn then - local index = tool_call.index or i - local id = utils.empty(tool_call.id) and ('tooluse_' .. index) or tool_call.id - table.insert(tool_calls, { - id = id, - index = index, - name = fn.name, - arguments = fn.arguments or '', - }) - end - end - end - end - - choice = output.choices[1] + prepare_input = function(inputs, opts) + local request + if opts.model.use_responses then + request = prepare_responses_input(inputs, opts) else - choice = output + request = prepare_chat_input(inputs, opts) end - local message = choice.message or choice.delta - local content = message and message.content - local usage = choice.usage and choice.usage.total_tokens - if not usage then - usage = output.usage and output.usage.total_tokens + if inputs and #inputs > 0 then + local last_msg = inputs[#inputs] + if last_msg.role == constants.ROLE.TOOL then + return request, { ['x-initiator'] = 'agent' } + end end - local finish_reason = choice.finish_reason or choice.done_reason or output.finish_reason or output.done_reason - return { - content = content, - finish_reason = finish_reason, - total_tokens = usage, - tool_calls = tool_calls, - } + return request end, - get_url = function() - return 'https://api.githubcopilot.com/chat/completions' + prepare_output = function(output, opts) + if opts and opts.model and opts.model.use_responses then + return prepare_responses_output(output) + end + return prepare_chat_output(output) + end, + + get_url = function(opts) + -- Use the base URL stored on the model (populated by get_models), falling + -- back to the default for backwards compatibility. + local base_url = (opts and opts.model and opts.model.base_url) or 'https://api.githubcopilot.com' + + if opts and opts.model and opts.model.use_responses then + return base_url .. '/responses' + end + return base_url .. '/chat/completions' end, } M.github_models = { disabled = true, - embed = 'copilot_embeddings', get_headers = function() return { @@ -462,7 +763,7 @@ M.github_models = { end, get_models = function(headers) - local response, err = utils.curl_get('https://models.github.ai/catalog/models', { + local response, err = curl.get('https://models.github.ai/catalog/models', { json_response = true, headers = headers, }) @@ -474,16 +775,15 @@ M.github_models = { return vim .iter(response.body) :map(function(model) - local max_output_tokens = model.limits.max_output_tokens - local max_input_tokens = model.limits.max_input_tokens return { id = model.id, name = model.name, - tokenizer = 'o200k_base', - max_input_tokens = max_input_tokens, - max_output_tokens = max_output_tokens, - streaming = vim.tbl_contains(model.capabilities, 'streaming'), - tools = vim.tbl_contains(model.capabilities, 'tool-calling'), + tokenizer = 'o200k_base', -- GitHub Models doesn't expose tokenizer info + max_input_tokens = model.limits and model.limits.max_input_tokens, + max_output_tokens = model.limits and model.limits.max_output_tokens, + streaming = model.capabilities and vim.tbl_contains(model.capabilities, 'streaming') or false, + tools = model.capabilities and vim.tbl_contains(model.capabilities, 'tool-calling') or false, + reasoning = model.capabilities and vim.tbl_contains(model.capabilities, 'reasoning') or false, version = model.version, } end) @@ -498,27 +798,4 @@ M.github_models = { end, } -M.copilot_embeddings = { - get_headers = M.copilot.get_headers, - - embed = function(inputs, headers) - local response, err = utils.curl_post('https://api.githubcopilot.com/embeddings', { - headers = headers, - json_request = true, - json_response = true, - body = { - dimensions = 512, - input = inputs, - model = 'text-embedding-3-small', - }, - }) - - if err then - error(err) - end - - return response.body.data - end, -} - return M diff --git a/lua/CopilotChat/constants.lua b/lua/CopilotChat/constants.lua new file mode 100644 index 00000000..7c6f7561 --- /dev/null +++ b/lua/CopilotChat/constants.lua @@ -0,0 +1,10 @@ +return { + PLUGIN_NAME = 'CopilotChat', + + ROLE = { + USER = 'user', + ASSISTANT = 'assistant', + SYSTEM = 'system', + TOOL = 'tool', + }, +} diff --git a/lua/CopilotChat/functions.lua b/lua/CopilotChat/functions.lua index 6e936a3d..bed78c8a 100644 --- a/lua/CopilotChat/functions.lua +++ b/lua/CopilotChat/functions.lua @@ -57,7 +57,14 @@ local function filter_schema(tbl, root) for k, v in pairs(tbl) do if not utils.empty(v) then if type(v) ~= 'function' and k ~= 'examples' then - result[k] = type(v) == 'table' and filter_schema(v) or v + if k == 'enum' and type(v) == 'table' and type(v[1]) == 'table' and v[1].value then + -- If enum contains objects with value/display, extract just the values + result[k] = vim.tbl_map(function(item) + return item.value + end, v) + else + result[k] = type(v) == 'table' and filter_schema(v) or v + end end end end @@ -114,15 +121,16 @@ function M.match_uri(uri, pattern) return result end ----@param tool CopilotChat.config.functions.Function -function M.parse_schema(tool) - local schema = tool.schema +--- Parse function schema and return a JSON schema object +---@param fn CopilotChat.config.functions.Function +function M.parse_schema(fn) + local schema = fn.schema -- If schema is missing but uri is present, generate a default schema from uri - if not schema and tool.uri then + if not schema and fn.uri then -- Extract parameter names from the uri pattern, e.g. file://{path} local param_names = {} - for param in tool.uri:gmatch(URI_PARAM_PATTERN) do + for param in fn.uri:gmatch(URI_PARAM_PATTERN) do table.insert(param_names, param) end if #param_names > 0 then @@ -138,26 +146,22 @@ function M.parse_schema(tool) end end - if schema then - schema = filter_schema(schema, true) - end - return schema end ---- Prepare the schema for use ----@param tools table +--- Prepare functions for tool use +---@param functions table ---@return table -function M.parse_tools(tools) - local tool_names = vim.tbl_keys(tools) +function M.parse_tools(functions) + local tool_names = vim.tbl_keys(functions) table.sort(tool_names) return vim.tbl_map(function(name) - local tool = tools[name] + local tool = functions[name] return { name = name, description = tool.description, - schema = M.parse_schema(tool), + schema = filter_schema(M.parse_schema(tool), true), } end, tool_names) end @@ -195,7 +199,7 @@ end --- Get input from the user based on the schema ---@param schema table? ----@param source CopilotChat.source +---@param source CopilotChat.ui.chat.Source ---@return string? function M.enter_input(schema, source) if not schema or not schema.properties then @@ -214,11 +218,30 @@ function M.enter_input(schema, source) if #choices == 0 then choice = nil elseif #choices == 1 then - choice = choices[1] + -- Handle both string and table choices + choice = type(choices[1]) == 'table' and choices[1].value or choices[1] else - choice = utils.select(choices, { - prompt = string.format('Select %s> ', prop_name), - }) + -- Check if choices are objects with display/value + local has_display = type(choices[1]) == 'table' and choices[1].display ~= nil + local selected + + if has_display then + -- Use format_item to display the display field + selected = utils.select(choices, { + prompt = string.format('Select %s> ', prop_name), + format_item = function(item) + return item.display + end, + }) + -- Extract the value from the selected item + choice = selected and selected.value or nil + else + -- Regular string choices + selected = utils.select(choices, { + prompt = string.format('Select %s> ', prop_name), + }) + choice = selected + end end table.insert(out, choice or '') diff --git a/lua/CopilotChat/health.lua b/lua/CopilotChat/health.lua index 1c8bc3b4..3a67e706 100644 --- a/lua/CopilotChat/health.lua +++ b/lua/CopilotChat/health.lua @@ -38,6 +38,15 @@ local function treesitter_parser_available(ft) return res and parser ~= nil end +--- Check if a treesitter query is available +---@param ft string +---@param query_name string +---@return boolean +local function treesitter_query_available(ft, query_name) + local query = vim.treesitter.query.get(ft, query_name) + return query ~= nil +end + function M.check() start('CopilotChat.nvim [core]') @@ -48,11 +57,11 @@ function M.check() error('nvim: unsupported, please upgrade to 0.10.0 or later. See "https://neovim.io/".') end - local setup_called = require('CopilotChat').config ~= nil - if setup_called then - ok('setup: called') + local initialized = require('CopilotChat').initialized + if initialized then + ok('initialized: true') else - error('setup: not called, required for plugin to work. See `:h CopilotChat-installation`.') + error('initialized: false, something went wrong. See `:h CopilotChat-installation`.') end local testfile = os.tmpname() @@ -145,8 +154,16 @@ function M.check() if treesitter_parser_available('markdown') then ok('treesitter[markdown]: installed') else - warn( - 'treesitter[markdown]: missing, optional for better chat highlighting. Install `nvim-treesitter/nvim-treesitter` plugin and run `:TSInstall markdown`.' + error( + 'treesitter[markdown]: missing, required for chat parsing. Install `nvim-treesitter/nvim-treesitter` plugin and run `:TSInstall markdown`.' + ) + end + + if treesitter_query_available('markdown', 'copilotchat') then + ok('treesitter[markdown/copilotchat]: found') + else + error( + 'treesitter[markdown/copilotchat]: missing, required for chat parsing. See `:h CopilotChat-installation` for instructions.' ) end diff --git a/lua/CopilotChat/init.lua b/lua/CopilotChat/init.lua index 4e81c934..e4dfdc3a 100644 --- a/lua/CopilotChat/init.lua +++ b/lua/CopilotChat/init.lua @@ -1,48 +1,79 @@ local async = require('plenary.async') local log = require('plenary.log') -local functions = require('CopilotChat.functions') -local resources = require('CopilotChat.resources') local client = require('CopilotChat.client') -local notify = require('CopilotChat.notify') +local constants = require('CopilotChat.constants') +local functions = require('CopilotChat.functions') +local prompts = require('CopilotChat.prompts') +local select = require('CopilotChat.select') local utils = require('CopilotChat.utils') +local curl = require('CopilotChat.utils.curl') +local orderedmap = require('CopilotChat.utils.orderedmap') -local PLUGIN_NAME = 'CopilotChat' -local WORD = '([^%s:]+)' -local WORD_NO_INPUT = '([^%s]+)' -local WORD_WITH_INPUT_QUOTED = WORD .. ':`([^`]+)`' -local WORD_WITH_INPUT_UNQUOTED = WORD .. ':?([^%s`]*)' local BLOCK_OUTPUT_FORMAT = '```%s\n%s\n```' ---@class CopilotChat ---@field config CopilotChat.config.Config ---@field chat CopilotChat.ui.chat.Chat -local M = {} +local M = setmetatable({}, { + __index = function(t, key) + if key == 'config' then + return require('CopilotChat.config') + end + + -- Lazy initialize + local initialized = rawget(t, 'initialized') + if not initialized then + rawset(t, 'initialized', true) + rawget(t, 'setup')() + end + + return rawget(t, key) + end, +}) ---- @class CopilotChat.source ---- @field bufnr number ---- @field winnr number ---- @field cwd fun():string +---@param config CopilotChat.config.Shared +---@param tool_name string +---@return boolean +local function is_trusted_tool(config, tool_name) + local tool_spec = config.functions[tool_name] + if not tool_spec then + return false + end ---- @class CopilotChat.state ---- @field source CopilotChat.source? ---- @field sticky string[]? -local state = { - -- Current state tracking - source = nil, + if tool_spec.trusted then + return true + end - -- Last state tracking - sticky = nil, -} + local trusted_tools = config.trusted_tools + if trusted_tools == true then + return true + end ---- Insert sticky values from config into prompt + for _, trusted_pattern in ipairs(utils.to_table(trusted_tools)) do + if tool_name == trusted_pattern then + return true + end + + if tool_spec.group == trusted_pattern then + return true + end + end + + return false +end + +--- Process sticky values from prompt and config +--- Extracts stickies from prompt, adds config-based stickies, stores them, returns clean prompt ---@param prompt string ---@param config CopilotChat.config.Shared -local function insert_sticky(prompt, config) - local existing_prompt = M.chat:get_message('user') +---@return string clean_prompt The prompt without sticky prefixes +local function process_sticky(prompt, config) + local existing_prompt = M.chat:get_message(constants.ROLE.USER) local combined_prompt = (existing_prompt and existing_prompt.content or '') .. '\n' .. (prompt or '') local lines = vim.split(prompt or '', '\n') - local stickies = utils.ordered_map() + local stickies = orderedmap() + -- Extract existing stickies from combined prompt local sticky_indices = {} local in_code_block = false for _, line in ipairs(vim.split(combined_prompt, '\n')) do @@ -53,8 +84,14 @@ local function insert_sticky(prompt, config) stickies:set(vim.trim(line:sub(3)), true) end end + + -- Find sticky lines in new prompt to remove them + in_code_block = false for i, line in ipairs(lines) do - if vim.startswith(line, '> ') then + if line:match('^```') then + in_code_block = not in_code_block + end + if vim.startswith(line, '> ') and not in_code_block then table.insert(sticky_indices, i) end end @@ -64,6 +101,7 @@ local function insert_sticky(prompt, config) lines = vim.split(vim.trim(table.concat(lines, '\n')), '\n') + -- Add config-based stickies if config.remember_as_sticky and config.model and config.model ~= M.config.model then stickies:set('$' .. config.model, true) end @@ -74,6 +112,12 @@ local function insert_sticky(prompt, config) end end + if config.remember_as_sticky and config.resources and not vim.deep_equal(config.resources, M.config.resources) then + for _, resource in ipairs(utils.to_table(config.resources)) do + stickies:set('#' .. resource, true) + end + end + if config.remember_as_sticky and config.system_prompt @@ -89,84 +133,22 @@ local function insert_sticky(prompt, config) end end - -- Insert stickies at start of prompt - local prompt_lines = {} + -- Store stickies + local sticky_array = {} for _, sticky in ipairs(stickies:keys()) do if sticky ~= '' then - table.insert(prompt_lines, '> ' .. sticky) + table.insert(sticky_array, sticky) end end - if #prompt_lines > 0 then - table.insert(prompt_lines, '') - end - for _, line in ipairs(lines) do - table.insert(prompt_lines, line) - end - if #lines == 0 then - table.insert(prompt_lines, '') - end + M.chat:set_sticky(sticky_array) - return table.concat(prompt_lines, '\n') -end - -local function store_sticky(prompt) - local sticky = {} - local in_code_block = false - for _, line in ipairs(vim.split(prompt, '\n')) do - if line:match('^```') then - in_code_block = not in_code_block - end - if vim.startswith(line, '> ') and not in_code_block then - table.insert(sticky, line:sub(3)) - end - end - state.sticky = sticky -end - ---- Update the highlights for chat buffer -local function update_highlights() - local selection_ns = vim.api.nvim_create_namespace('copilot-chat-selection') - for _, buf in ipairs(vim.api.nvim_list_bufs()) do - vim.api.nvim_buf_clear_namespace(buf, selection_ns, 0, -1) - end - - if M.chat.config.highlight_selection and M.chat:focused() then - local selection = M.get_selection() - if not selection or not utils.buf_valid(selection.bufnr) or not selection.start_line or not selection.end_line then - return - end - - vim.api.nvim_buf_set_extmark(selection.bufnr, selection_ns, selection.start_line - 1, 0, { - hl_group = 'CopilotChatSelection', - end_row = selection.end_line, - strict = false, - }) - end -end - ---- List available models. ---- @return CopilotChat.client.Model[] -local function list_models() - local models = client:models() - local result = vim.tbl_keys(models) - - table.sort(result, function(a, b) - a = models[a] - b = models[b] - if a.provider ~= b.provider then - return a.provider < b.provider - end - return a.id < b.id - end) - - return vim.tbl_map(function(id) - return models[id] - end, result) + -- Return clean prompt + return table.concat(lines, '\n') end --- Finish writing to chat buffer. ---@param start_of_chat boolean? -local function finish(start_of_chat) +local function finish(start_of_chat, remaining_tool_calls) if start_of_chat then local sticky = {} if M.config.sticky then @@ -174,15 +156,19 @@ local function finish(start_of_chat) table.insert(sticky, sticky_line) end end - state.sticky = sticky + M.chat:set_sticky(sticky) end local prompt_content = '' - local last_message = M.chat.messages[#M.chat.messages] - local tool_calls = last_message and last_message.tool_calls or {} + local tool_calls = remaining_tool_calls + if not tool_calls then + local assistant_message = M.chat:get_message(constants.ROLE.ASSISTANT) + tool_calls = assistant_message and assistant_message.tool_calls or {} + end - if not utils.empty(state.sticky) then - for _, sticky in ipairs(state.sticky) do + local current_sticky = M.chat:get_sticky() + if not utils.empty(current_sticky) then + for _, sticky in ipairs(current_sticky) do prompt_content = prompt_content .. '> ' .. sticky .. '\n' end prompt_content = prompt_content .. '\n' @@ -196,7 +182,7 @@ local function finish(start_of_chat) end M.chat:add_message({ - role = 'user', + role = constants.ROLE.USER, content = prompt_content, }) @@ -204,17 +190,38 @@ local function finish(start_of_chat) end --- Show an error in the chat window. ----@param err string|table|nil -local function show_error(err) - err = err or 'Unknown error' - err = utils.make_string(err) +---@param config CopilotChat.config.Shared +---@param cb function +---@return any +local function handle_error(config, cb) + return function() + local function error_handler(err) + return { + err = utils.make_string(err), + traceback = debug.traceback(), + } + end - M.chat:add_message({ - role = 'assistant', - content = '\n' .. string.format(BLOCK_OUTPUT_FORMAT, 'error', err) .. '\n', - }) + local ok, out = xpcall(cb, error_handler) + if ok then + return out + end + log.error(out.err .. '\n' .. out.traceback) + + if config.headless then + return + end + + utils.schedule_main() + out = out.err + + M.chat:add_message({ + role = constants.ROLE.ASSISTANT, + content = '\n' .. string.format(BLOCK_OUTPUT_FORMAT, 'error', out) .. '\n', + }) - finish() + finish() + end end --- Map a key to a function. @@ -229,7 +236,7 @@ local function map_key(name, bufnr, fn) if not fn then fn = function() - key.callback(state.source) + key.callback(M.chat:get_source()) end end @@ -238,7 +245,7 @@ local function map_key(name, bufnr, fn) 'n', key.normal, fn, - { buffer = bufnr, nowait = true, desc = PLUGIN_NAME .. ' ' .. name:gsub('_', ' ') } + { buffer = bufnr, nowait = true, desc = constants.PLUGIN_NAME .. ' ' .. name:gsub('_', ' ') } ) end if key.insert and key.insert ~= '' then @@ -252,532 +259,14 @@ local function map_key(name, bufnr, fn) else fn() end - end, { buffer = bufnr, desc = PLUGIN_NAME .. ' ' .. name:gsub('_', ' ') }) + end, { buffer = bufnr, desc = constants.PLUGIN_NAME .. ' ' .. name:gsub('_', ' ') }) end end --- Updates the source buffer based on previous or current window. local function update_source() local use_prev_window = M.chat:focused() - M.set_source(use_prev_window and vim.fn.win_getid(vim.fn.winnr('#')) or vim.api.nvim_get_current_win()) -end - ---- Call and resolve function calls from the prompt. ----@param prompt string? ----@param config CopilotChat.config.Shared? ----@return table, table, table, string ----@async -function M.resolve_functions(prompt, config) - config, prompt = M.resolve_prompt(prompt, config) - - local tools = {} - for _, tool in ipairs(functions.parse_tools(M.config.functions)) do - tools[tool.name] = tool - end - - local enabled_tools = {} - local resolved_resources = {} - local resolved_tools = {} - local matches = utils.to_table(config.tools) - local tool_calls = {} - for _, message in ipairs(M.chat.messages) do - if message.tool_calls then - for _, tool_call in ipairs(message.tool_calls) do - table.insert(tool_calls, tool_call) - end - end - end - - -- Check for @tool pattern to find enabled tools - prompt = prompt:gsub('@' .. WORD, function(match) - for name, tool in pairs(M.config.functions) do - if name == match or tool.group == match then - table.insert(matches, match) - return '' - end - end - return '@' .. match - end) - for _, match in ipairs(matches) do - for name, tool in pairs(M.config.functions) do - if name == match or tool.group == match then - enabled_tools[name] = true - end - end - end - - local matches = utils.ordered_map() - - -- Check for #word:`input` pattern - for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_QUOTED) do - local pattern = string.format('#%s:`%s`', word, input) - matches:set(pattern, { - word = word, - input = input, - }) - end - - -- Check for #word:input pattern - for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_UNQUOTED) do - local pattern = utils.empty(input) and string.format('#%s', word) or string.format('#%s:%s', word, input) - matches:set(pattern, { - word = word, - input = input, - }) - end - - -- Check for ##word:input pattern - for word in prompt:gmatch('##' .. WORD_NO_INPUT) do - local pattern = string.format('##%s', word) - matches:set(pattern, { - word = word, - }) - end - - -- Resolve each function reference - local function expand_function(name, input) - notify.publish(notify.STATUS, 'Running function: ' .. name) - - local tool_id = nil - if not utils.empty(tool_calls) then - for _, tool_call in ipairs(tool_calls) do - if tool_call.name == name and vim.trim(tool_call.id) == vim.trim(input) then - input = utils.empty(tool_call.arguments) and {} or utils.json_decode(tool_call.arguments) - tool_id = tool_call.id - break - end - end - end - - local tool = M.config.functions[name] - if not tool then - -- Check if input matches uri - for tool_name, tool_spec in pairs(M.config.functions) do - if tool_spec.uri then - local match = functions.match_uri(name, tool_spec.uri) - if match then - name = tool_name - tool = tool_spec - input = match - break - end - end - end - end - if not tool then - return nil - end - if tool_id and not enabled_tools[name] and not tool.uri then - return nil - end - - local schema = tools[name] and tools[name].schema or nil - local result = '' - local ok, output = pcall(tool.resolve, functions.parse_input(input, schema), state.source or {}, prompt) - if not ok then - result = string.format(BLOCK_OUTPUT_FORMAT, 'error', utils.make_string(output)) - else - for _, content in ipairs(output) do - if content then - local content_out = nil - if content.uri then - content_out = '##' .. content.uri - table.insert(resolved_resources, resources.to_resource(content)) - if tool_id then - table.insert(state.sticky, content_out) - end - else - content_out = string.format(BLOCK_OUTPUT_FORMAT, utils.mimetype_to_filetype(content.mimetype), content.data) - end - - if not utils.empty(result) then - result = result .. '\n' - end - result = result .. content_out - end - end - end - - if tool_id then - table.insert(resolved_tools, { - id = tool_id, - result = result, - }) - - return nil - end - - return result - end - - -- Resolve and process all tools - for _, pattern in ipairs(matches:keys()) do - if not utils.empty(pattern) then - local match = matches:get(pattern) - local out = expand_function(match.word, match.input) or pattern - out = out:gsub('%%', '%%%%') -- Escape percent signs for gsub - prompt = prompt:gsub(vim.pesc(pattern), out, 1) - end - end - - return vim.tbl_map(function(name) - return tools[name] - end, vim.tbl_keys(enabled_tools)), - resolved_resources, - resolved_tools, - prompt -end - ---- Resolve the final prompt and config from prompt template. ----@param prompt string? ----@param config CopilotChat.config.Shared? ----@return CopilotChat.config.prompts.Prompt, string -function M.resolve_prompt(prompt, config) - if not prompt then - local message = M.chat:get_message('user') - if message then - prompt = message.content - end - end - - local prompts_to_use = M.prompts() - local depth = 0 - local MAX_DEPTH = 10 - - local function resolve(inner_config, inner_prompt) - if depth >= MAX_DEPTH then - return inner_config, inner_prompt - end - depth = depth + 1 - - inner_prompt = string.gsub(inner_prompt, '/' .. WORD, function(match) - local p = prompts_to_use[match] - if p then - local resolved_config, resolved_prompt = resolve(p, p.prompt or '') - inner_config = vim.tbl_deep_extend('force', inner_config, resolved_config) - return resolved_prompt - end - - return '/' .. match - end) - - depth = depth - 1 - return inner_config, inner_prompt - end - - config = vim.tbl_deep_extend('force', M.config, config or {}) - config, prompt = resolve(config, prompt or '') - if prompts_to_use[config.system_prompt] then - config.system_prompt = prompts_to_use[config.system_prompt].system_prompt - end - - if config.system_prompt then - config.system_prompt = config.system_prompt:gsub('{OS_NAME}', jit.os) - config.system_prompt = config.system_prompt:gsub('{LANGUAGE}', config.language) - if state.source then - config.system_prompt = config.system_prompt:gsub('{DIR}', state.source.cwd()) - end - end - - return config, prompt -end - ---- Resolve the model from the prompt. ----@param prompt string? ----@param config CopilotChat.config.Shared? ----@return string, string ----@async -function M.resolve_model(prompt, config) - config, prompt = M.resolve_prompt(prompt, config) - - local models = vim.tbl_map(function(model) - return model.id - end, list_models()) - - local selected_model = config.model or '' - prompt = prompt:gsub('%$' .. WORD, function(match) - if vim.tbl_contains(models, match) then - selected_model = match - return '' - end - return '$' .. match - end) - - return selected_model, prompt -end - ---- Get the current source buffer and window. -function M.get_source() - return state.source -end - ---- Sets the source to the given window. ----@param source_winnr number ----@return boolean if the source was set -function M.set_source(source_winnr) - local source_bufnr = vim.api.nvim_win_get_buf(source_winnr) - - -- Check if the window is valid to use as a source - if source_winnr ~= M.chat.winnr and source_bufnr ~= M.chat.bufnr and vim.fn.win_gettype(source_winnr) == '' then - state.source = { - bufnr = source_bufnr, - winnr = source_winnr, - cwd = function() - local ok, dir = pcall(function() - return vim.w[source_winnr].cchat_cwd - end) - if not ok or not dir or dir == '' then - return '.' - end - return dir - end, - } - - return true - end - - return false -end - ---- Get the selection from the source buffer. ----@return CopilotChat.select.Selection? -function M.get_selection() - local config = vim.tbl_deep_extend('force', M.config, M.chat.config) - local selection = config.selection - local bufnr = state.source and state.source.bufnr - local winnr = state.source and state.source.winnr - - if selection and utils.buf_valid(bufnr) and winnr and vim.api.nvim_win_is_valid(winnr) then - return selection(state.source) - end - - return nil -end - ---- Sets the selection to specific lines in buffer. ----@param bufnr number ----@param start_line number ----@param end_line number ----@param clear boolean? -function M.set_selection(bufnr, start_line, end_line, clear) - if not utils.buf_valid(bufnr) then - return - end - - if clear then - for _, mark in ipairs({ '<', '>', '[', ']' }) do - pcall(vim.api.nvim_buf_del_mark, bufnr, mark) - end - update_highlights() - return - end - - local winnr = vim.fn.win_findbuf(bufnr)[1] - if not winnr and state.source then - winnr = state.source.winnr - end - if not winnr then - return - end - - pcall(vim.api.nvim_buf_set_mark, bufnr, '<', start_line, 0, {}) - pcall(vim.api.nvim_buf_set_mark, bufnr, '>', end_line, 0, {}) - pcall(vim.api.nvim_buf_set_mark, bufnr, '[', start_line, 0, {}) - pcall(vim.api.nvim_buf_set_mark, bufnr, ']', end_line, 0, {}) - pcall(vim.api.nvim_win_set_cursor, winnr, { start_line, 0 }) - update_highlights() -end - ---- Trigger the completion for the chat window. ----@param without_input boolean? -function M.trigger_complete(without_input) - local info = M.complete_info() - local bufnr = vim.api.nvim_get_current_buf() - local line = vim.api.nvim_get_current_line() - local cursor = vim.api.nvim_win_get_cursor(0) - local row = cursor[1] - local col = cursor[2] - if col == 0 or #line == 0 then - return - end - - local prefix, cmp_start = unpack(vim.fn.matchstrpos(line:sub(1, col), info.pattern)) - if not prefix then - return - end - - if not without_input and vim.startswith(prefix, '#') and vim.endswith(prefix, ':') then - local found_tool = M.config.functions[prefix:sub(2, -2)] - local found_schema = found_tool and functions.parse_schema(found_tool) - if found_tool and found_schema then - async.run(function() - local value = functions.enter_input(found_schema, state.source) - if not value then - return - end - - utils.schedule_main() - vim.api.nvim_buf_set_text(bufnr, row - 1, col, row - 1, col, { value }) - vim.api.nvim_win_set_cursor(0, { row, col + #value }) - end) - end - - return - end - - async.run(function() - local items = M.complete_items() - utils.schedule_main() - - if vim.fn.mode() ~= 'i' then - return - end - - vim.fn.complete( - cmp_start + 1, - vim.tbl_filter(function(item) - return vim.startswith(item.word:lower(), prefix:lower()) - end, items) - ) - end) -end - ---- Get the completion info for the chat window, for use with custom completion providers ----@return table -function M.complete_info() - return { - triggers = { '@', '/', '#', '$' }, - pattern = [[\%(@\|/\|#\|\$\)\S*]], - } -end - ---- Get the completion items for the chat window, for use with custom completion providers ----@return table ----@async -function M.complete_items() - local models = list_models() - local prompts_to_use = M.prompts() - local items = {} - - for name, prompt in pairs(prompts_to_use) do - local kind = '' - local info = '' - if prompt.prompt then - kind = 'user' - info = prompt.prompt - elseif prompt.system_prompt then - kind = 'system' - info = prompt.system_prompt - end - - items[#items + 1] = { - word = '/' .. name, - abbr = name, - kind = kind, - info = info, - menu = prompt.description or '', - icase = 1, - dup = 0, - empty = 0, - } - end - - for _, model in ipairs(models) do - items[#items + 1] = { - word = '$' .. model.id, - abbr = model.id, - kind = model.provider, - menu = model.name, - icase = 1, - dup = 0, - empty = 0, - } - end - - local groups = {} - for name, tool in pairs(M.config.functions) do - if tool.group then - groups[tool.group] = groups[tool.group] or {} - groups[tool.group][name] = tool - end - end - for name, group in pairs(groups) do - local group_tools = vim.tbl_keys(group) - items[#items + 1] = { - word = '@' .. name, - abbr = name, - kind = 'group', - info = table.concat(group_tools, '\n'), - menu = string.format('%s tools', #group_tools), - icase = 1, - dup = 0, - empty = 0, - } - end - for name, tool in pairs(M.config.functions) do - items[#items + 1] = { - word = '@' .. name, - abbr = name, - kind = 'tool', - info = tool.description, - menu = tool.group or '', - icase = 1, - dup = 0, - empty = 0, - } - end - - local tools_to_use = functions.parse_tools(M.config.functions) - for _, tool in pairs(tools_to_use) do - local uri = M.config.functions[tool.name].uri - if uri then - local info = - string.format('%s\n\n%s', tool.description, tool.schema and vim.inspect(tool.schema, { indent = ' ' }) or '') - - items[#items + 1] = { - word = '#' .. tool.name, - abbr = tool.name, - kind = M.config.functions[tool.name].group or 'resource', - info = info, - menu = uri, - icase = 1, - dup = 0, - empty = 0, - } - end - end - - table.sort(items, function(a, b) - if a.kind == b.kind then - return a.word < b.word - end - return a.kind < b.kind - end) - - return items -end - ---- Get the prompts to use. ----@return table -function M.prompts() - local prompts_to_use = {} - - for name, prompt in pairs(M.config.prompts) do - local val = prompt - if type(prompt) == 'string' then - val = { - prompt = prompt, - } - end - - if val.system_prompt and M.config.prompts[val.system_prompt] then - val.system_prompt = M.config.prompts[val.system_prompt].system_prompt - end - - prompts_to_use[name] = val - end - - return prompts_to_use + M.chat:set_source(use_prev_window and vim.fn.win_getid(vim.fn.winnr('#')) or vim.api.nvim_get_current_win()) end --- Open the chat window. @@ -789,13 +278,21 @@ function M.open(config) M.chat:open(config) -- Add sticky values from provided config when opening the chat - local message = M.chat:get_message('user') + local message = M.chat:get_message(constants.ROLE.USER) if message then - local prompt = insert_sticky(message.content, config) - if prompt then + local clean_prompt = process_sticky(message.content, config) + local stickies = M.chat:get_sticky() + local content = '' + if not vim.tbl_isempty(stickies) then + content = '\n> ' .. table.concat(stickies, '\n> ') .. '\n\n' + end + if clean_prompt and clean_prompt ~= '' then + content = content .. clean_prompt + end + if content ~= '' then M.chat:add_message({ - role = 'user', - content = '\n' .. prompt, + role = constants.ROLE.USER, + content = content, }, true) end end @@ -806,7 +303,7 @@ end --- Close the chat window. function M.close() - M.chat:close(state.source and state.source.bufnr or nil) + M.chat:close() end --- Toggle the chat window. @@ -822,7 +319,22 @@ end --- Select default Copilot GPT model. function M.select_model() async.run(function() - local models = list_models() + local models = client:models() + local result = vim.tbl_keys(models) + + table.sort(result, function(a, b) + a = models[a] + b = models[b] + if a.provider ~= b.provider then + return a.provider < b.provider + end + return a.id < b.id + end) + + models = vim.tbl_map(function(id) + return models[id] + end, result) + local choices = vim.tbl_map(function(model) return { id = model.id, @@ -830,6 +342,7 @@ function M.select_model() provider = model.provider, streaming = model.streaming, tools = model.tools, + reasoning = model.reasoning, selected = model.id == M.config.model, } end, models) @@ -854,6 +367,9 @@ function M.select_model() if item.tools then table.insert(indicators, 'tools') end + if item.reasoning then + table.insert(indicators, 'reasoning') + end if #indicators > 0 then out = out .. ' [' .. table.concat(indicators, ', ') .. ']' @@ -872,8 +388,8 @@ end --- Select a prompt template to use. ---@param config CopilotChat.config.Shared? function M.select_prompt(config) - local prompts = M.prompts() - local keys = vim.tbl_keys(prompts) + local prompt_list = prompts.list_prompts() + local keys = vim.tbl_keys(prompt_list) table.sort(keys) local choices = vim @@ -881,8 +397,8 @@ function M.select_prompt(config) :map(function(name) return { name = name, - description = prompts[name].description, - prompt = prompts[name].prompt, + description = prompt_list[name].description, + prompt = prompt_list[name].prompt, } end) :filter(function(choice) @@ -897,7 +413,7 @@ function M.select_prompt(config) end, }, function(choice) if choice then - M.ask(prompts[choice.name].prompt, vim.tbl_extend('force', prompts[choice.name], config or {})) + M.ask(prompt_list[choice.name].prompt, vim.tbl_extend('force', prompt_list[choice.name], config or {})) end end) end @@ -913,6 +429,9 @@ function M.ask(prompt, config) vim.diagnostic.reset(vim.api.nvim_create_namespace('copilot-chat-diagnostics')) config = vim.tbl_deep_extend('force', M.config, config or {}) + local schedule = function(cb) + return cb() + end -- Stop previous conversation and open window if not config.headless then @@ -923,172 +442,230 @@ function M.ask(prompt, config) end if not M.chat:focused() then M.open(config) + schedule = vim.schedule end else update_source() end -- Resolve prompt after window is opened - prompt = insert_sticky(prompt, config) + prompt = process_sticky(prompt, config) prompt = vim.trim(prompt) + prompt = table.concat(M.chat:get_sticky(), '\n') .. '\n\n' .. prompt - -- Prepare chat - if not config.headless then - store_sticky(prompt) - M.chat:start() - M.chat:append('\n') - end - - -- Resolve prompt references - config, prompt = M.resolve_prompt(prompt, config) - local system_prompt = config.system_prompt or '' - - -- Remove sticky prefix - prompt = table.concat( - vim.tbl_map(function(l) - return l:gsub('^>%s+', '') - end, vim.split(prompt, '\n')), - '\n' - ) - - -- Retrieve the selection - local selection = M.get_selection() - - local ok, err = pcall(async.run, function() - local selected_tools, resolved_resources, resolved_tools, prompt = M.resolve_functions(prompt, config) - local selected_model, prompt = M.resolve_model(prompt, config) - - if config.resource_processing then - local query_ok, processed_resources = - pcall(resources.process_resources, prompt, selected_model, resolved_resources) - if query_ok then - resolved_resources = processed_resources - else - log.warn('Failed to process resources', processed_resources) - end + -- After opening window we need to schedule to next cycle so everything properly resolves + schedule(function() + if not config.headless then + -- Prepare chat + M.chat:start() + M.chat:append('\n') end - prompt = vim.trim(prompt) - utils.schedule_main() + async.run(handle_error(config, function() + config, prompt = prompts.resolve_prompt(prompt, config) + local system_prompt = config.system_prompt or '' + local selected_tools, prompt = prompts.resolve_tools(prompt, config) + local resolved_resources, resolved_tools, prompt = prompts.resolve_functions(prompt, config) + local selected_model, prompt = prompts.resolve_model(prompt, config) - if not config.headless then - local assistant_message = M.chat:get_message('assistant') - if assistant_message and assistant_message.tool_calls then - local handled_ids = {} - for _, tool in ipairs(resolved_tools) do - handled_ids[tool.id] = true + prompt = vim.trim(prompt) + + if not config.headless then + utils.schedule_main() + local assistant_message = M.chat:get_message(constants.ROLE.ASSISTANT) + if assistant_message and assistant_message.tool_calls then + local handled_ids = {} + for _, tool in ipairs(resolved_tools) do + handled_ids[tool.id] = true + end + + -- If we skipped any tool calls, send that as result + for _, tool_call in ipairs(assistant_message.tool_calls) do + if not handled_ids[tool_call.id] then + table.insert(resolved_tools, { + id = tool_call.id, + result = 'User skipped this function call.', + }) + handled_ids[tool_call.id] = true + end + end end - -- If we skipped any tool calls, send that as result - for _, tool_call in ipairs(assistant_message.tool_calls) do - if not handled_ids[tool_call.id] then - table.insert(resolved_tools, { - id = tool_call.id, - result = string.format(BLOCK_OUTPUT_FORMAT, 'error', 'User skipped this function call.'), + if not utils.empty(resolved_tools) then + -- If we are handling tools, replace user message with tool results + M.chat:remove_message(constants.ROLE.USER) + for _, tool in ipairs(resolved_tools) do + M.chat:add_message({ + id = tool.id, + role = constants.ROLE.TOOL, + tool_call_id = tool.id, + content = '\n' .. tool.result .. '\n', }) - handled_ids[tool_call.id] = true end + else + -- Otherwise just replace the user message with resolved prompt + M.chat:add_message({ + role = constants.ROLE.USER, + content = '\n' .. prompt .. '\n', + }, true) end end - if not utils.empty(resolved_tools) then - -- If we are handling tools, replace user message with tool results - M.chat:remove_message('user') - for _, tool in ipairs(resolved_tools) do - M.chat:add_message({ - id = tool.id, - role = 'tool', - tool_call_id = tool.id, - content = '\n' .. tool.result .. '\n', - }) + if utils.empty(prompt) and utils.empty(resolved_tools) then + if not config.headless then + M.chat:remove_message(constants.ROLE.USER) + finish() end - else - -- Otherwise just replace the user message with resolved prompt - M.chat:add_message({ - role = 'user', - content = '\n' .. prompt .. '\n', - }, true) + return end - end - if utils.empty(prompt) and utils.empty(resolved_tools) then + -- Build history, when in headless mode its just current prompt + local history if not config.headless then - M.chat:remove_message('user') - finish() + history = M.chat:get_messages() + else + history = { + { + content = prompt, + role = constants.ROLE.USER, + }, + } end - return - end - local ask_ok, ask_response = pcall(client.ask, client, prompt, { - headless = config.headless, - history = M.chat.messages, - selection = selection, - resources = resolved_resources, - tools = selected_tools, - system_prompt = system_prompt, - model = selected_model, - temperature = config.temperature, - on_progress = vim.schedule_wrap(function(token) - if not config.headless then - M.chat:add_message({ - content = token, - role = 'assistant', - }) - end - end), - }) + local ask_response = client:ask({ + headless = config.headless, + history = history, + resources = resolved_resources, + tools = selected_tools, + system_prompt = system_prompt, + model = selected_model, + temperature = config.temperature, + on_progress = vim.schedule_wrap(function(message) + if not config.headless then + M.chat:add_message(message) + end + end), + }) - utils.schedule_main() + -- If there was no error and no response, it means job was canceled + if ask_response == nil then + return + end - if not ask_ok then - log.error(ask_response) - if not config.headless then - show_error(ask_response) + local response = ask_response.message + local token_count = ask_response.token_count + local token_max_count = ask_response.token_max_count + + -- Call the callback function + if config.callback then + utils.schedule_main() + config.callback(response, M.chat:get_source()) end - return - end - -- If there was no error and no response, it means job was cancelled - if ask_response == nil then - return - end + if not config.headless then + response.content = vim.trim(response.content) + if utils.empty(response.content) then + response.content = '' + else + response.content = '\n' .. response.content .. '\n' + end - local response = ask_response.message - local token_count = ask_response.token_count - local token_max_count = ask_response.token_max_count + utils.schedule_main() + M.chat:add_message(response, true) + M.chat.token_count = token_count + M.chat.token_max_count = token_max_count + + -- Execute trusted tool calls automatically + if response.tool_calls and #response.tool_calls > 0 then + local trusted_tool_calls = {} + local untrusted_tool_calls = {} + + for _, tool_call in ipairs(response.tool_calls) do + if is_trusted_tool(config, tool_call.name) then + table.insert(trusted_tool_calls, tool_call) + else + table.insert(untrusted_tool_calls, tool_call) + end + end - -- Call the callback function - if config.callback then - local callback_ok, callback_response = pcall(config.callback, response.content, state.source) - if not callback_ok then - log.error('Callback error: ' .. callback_response) - if not config.headless then - show_error(callback_response) + if #trusted_tool_calls > 0 then + async.run(handle_error(config, function() + local trusted_tool_results = {} + local source = M.chat:get_source() + + for _, tool_call in ipairs(trusted_tool_calls) do + local input = {} + if not utils.empty(tool_call.arguments) then + input = utils.json_decode(tool_call.arguments) + end + + local ok, output = prompts.execute_tool_call(tool_call.name, input, config, source) + local result = prompts.format_tool_output(ok, output) + + table.insert(trusted_tool_results, { + id = tool_call.id, + result = result, + }) + end + + if not utils.empty(trusted_tool_results) then + utils.schedule_main() + for _, tool in ipairs(trusted_tool_results) do + M.chat:add_message({ + id = tool.id, + role = constants.ROLE.TOOL, + tool_call_id = tool.id, + content = '\n' .. tool.result .. '\n', + }) + end + + if #untrusted_tool_calls > 0 then + finish(nil, untrusted_tool_calls) + else + local continue_response = client:ask({ + headless = config.headless, + history = M.chat:get_messages(), + resources = resolved_resources, + tools = selected_tools, + system_prompt = system_prompt, + model = selected_model, + temperature = config.temperature, + on_progress = vim.schedule_wrap(function(message) + if not config.headless then + M.chat:add_message(message) + end + end), + }) + + if continue_response then + local continue_message = continue_response.message + continue_message.content = vim.trim(continue_message.content) + if utils.empty(continue_message.content) then + continue_message.content = '' + else + continue_message.content = '\n' .. continue_message.content .. '\n' + end + + utils.schedule_main() + M.chat:add_message(continue_message, true) + M.chat.token_count = continue_response.token_count + M.chat.token_max_count = continue_response.token_max_count + end + + finish() + end + else + finish() + end + end)) + return + end end - return - end - end - if not config.headless then - response.content = vim.trim(response.content) - if utils.empty(response.content) then - response.content = '' - else - response.content = '\n' .. response.content .. '\n' + finish() end - M.chat:add_message(response, true) - M.chat.token_count = token_count - M.chat.token_max_count = token_max_count - finish() - end + end)) end) - - if not ok then - log.error(err) - if not config.headless then - show_error(err) - end - end end --- Stop current copilot output and optionally reset the chat ten show the help message. @@ -1099,11 +676,7 @@ function M.stop(reset) if reset then M.chat:clear() vim.diagnostic.reset(vim.api.nvim_create_namespace('copilot-chat-diagnostics')) - - -- Clear the selection - if state.source then - M.set_selection(state.source.bufnr, 0, 0, true) - end + select.set(M.chat:get_source().bufnr) end if stopped or reset then @@ -1129,7 +702,7 @@ function M.save(name, history_path) return end - local history = vim.deepcopy(M.chat.messages) + local history = vim.deepcopy(M.chat:get_messages()) for _, message in ipairs(history) do message.section = nil end @@ -1190,64 +763,71 @@ function M.log_level(level) M.config.log_level = level M.config.debug = level == 'debug' - log.new({ - plugin = PLUGIN_NAME, - level = level, - outfile = M.config.log_path, - fmt_msg = function(is_console, mode_name, src_path, src_line, msg) - local nameupper = mode_name:upper() - if is_console then - return string.format('[%s] %s', nameupper, msg) - else - local lineinfo = src_path .. ':' .. src_line - return string.format('[%-6s%s] %s: %s\n', nameupper, os.date(), lineinfo, msg) - end - end, - }, true) + if level ~= log.level then + log.new({ + plugin = constants.PLUGIN_NAME, + level = level, + outfile = M.config.log_path, + fmt_msg = function(is_console, mode_name, src_path, src_line, msg) + local nameupper = mode_name:upper() + if is_console then + return string.format('[%s] %s', nameupper, msg) + else + local lineinfo = src_path .. ':' .. src_line + return string.format('[%-6s%s] %s: %s\n', nameupper, os.date(), lineinfo, msg) + end + end, + }, true) + log.level = level + end end --- Set up the plugin ---@param config CopilotChat.config.Config? function M.setup(config) - local default_config = require('CopilotChat.config') - M.config = vim.tbl_deep_extend('force', default_config, config or {}) - state.highlights_loaded = false - - -- Save proxy and insecure settings - utils.curl_store_args({ - insecure = M.config.allow_insecure, - proxy = M.config.proxy, - }) + for k, v in pairs(vim.tbl_deep_extend('force', M.config, config or {})) do + M.config[k] = v + end - -- Load the providers - client:stop() - client:load_providers(M.config.providers) + if not M.config.separator or M.config.separator == '' then + log.warn( + 'Empty separator is not allowed, using default separator instead. Set `separator` in config to change this.' + ) + M.config.separator = '---' + end + -- Set log level if M.config.debug then M.log_level('debug') else M.log_level(M.config.log_level) end - if not M.config.separator or M.config.separator == '' then - log.warn( - 'Empty separator is not allowed, using default separator instead. Set `separator` in config to change this.' - ) - M.config.separator = default_config.separator - end + -- Save proxy and insecure settings + curl.store_args({ + insecure = M.config.allow_insecure, + proxy = M.config.proxy, + }) + -- Load the providers + client:stop() + client:set_providers(function() + return M.config.providers + end) + + -- Initialize chat + require('CopilotChat.utils.notify').clear() if M.chat then - M.chat:close(state.source and state.source.bufnr or nil) + M.chat:close() M.chat:delete() - end - M.chat = require('CopilotChat.ui.chat')( - M.config, - utils.key_to_info('show_help', M.config.mappings.show_help), - function(bufnr) + else + M.chat = require('CopilotChat.ui.chat')(M.config, function(bufnr) for name, _ in pairs(M.config.mappings) do map_key(name, bufnr) end + require('CopilotChat.completion').enable(bufnr, M.config.chat_autocomplete) + vim.api.nvim_create_autocmd({ 'BufEnter', 'BufLeave' }, { buffer = bufnr, callback = function(ev) @@ -1255,7 +835,9 @@ function M.setup(config) update_source() end - vim.schedule(update_highlights) + vim.schedule(function() + select.highlight(M.chat:get_source().bufnr, not (M.config.highlight_selection and M.chat:focused())) + end) end, }) @@ -1270,44 +852,11 @@ function M.setup(config) }) end - if M.config.chat_autocomplete then - vim.api.nvim_create_autocmd('TextChangedI', { - buffer = bufnr, - callback = function() - local completeopt = vim.opt.completeopt:get() - if not vim.tbl_contains(completeopt, 'noinsert') and not vim.tbl_contains(completeopt, 'noselect') then - -- Don't trigger completion if completeopt is not set to noinsert or noselect - return - end - - local line = vim.api.nvim_get_current_line() - local cursor = vim.api.nvim_win_get_cursor(0) - local col = cursor[2] - local char = line:sub(col, col) - - if vim.tbl_contains(M.complete_info().triggers, char) then - utils.debounce('complete', function() - M.trigger_complete(true) - end, 100) - end - end, - }) - - -- Add noinsert completeopt if not present - if vim.fn.has('nvim-0.11.0') == 1 then - local completeopt = vim.opt.completeopt:get() - if not vim.tbl_contains(completeopt, 'noinsert') then - table.insert(completeopt, 'noinsert') - vim.bo[bufnr].completeopt = table.concat(completeopt, ',') - end - end - end - finish(true) - end - ) + end) + end - for name, prompt in pairs(M.prompts()) do + for name, prompt in pairs(prompts.list_prompts()) do if prompt.prompt then vim.api.nvim_create_user_command('CopilotChat' .. name, function(args) local input = prompt.prompt @@ -1321,13 +870,13 @@ function M.setup(config) nargs = '*', force = true, range = true, - desc = prompt.description or (PLUGIN_NAME .. ' ' .. name), + desc = prompt.description or (constants.PLUGIN_NAME .. ' ' .. name), }) if prompt.mapping then vim.keymap.set({ 'n', 'v' }, prompt.mapping, function() M.ask(prompt.prompt, prompt) - end, { desc = prompt.description or (PLUGIN_NAME .. ' ' .. name) }) + end, { desc = prompt.description or (constants.PLUGIN_NAME .. ' ' .. name) }) end end end diff --git a/lua/CopilotChat/instructions/custom_instructions.lua b/lua/CopilotChat/instructions/custom_instructions.lua new file mode 100644 index 00000000..57b1ba44 --- /dev/null +++ b/lua/CopilotChat/instructions/custom_instructions.lua @@ -0,0 +1,6 @@ +return [[ + +Custom instructions from user's `{FILENAME}`: +{CONTENT} + +]] diff --git a/lua/CopilotChat/instructions/edit_file_block.lua b/lua/CopilotChat/instructions/edit_file_block.lua new file mode 100644 index 00000000..f5f9bf9e --- /dev/null +++ b/lua/CopilotChat/instructions/edit_file_block.lua @@ -0,0 +1,26 @@ +return [[ + +Use these instructions when editing files via code blocks. Present changes as clear, minimal, and precise file edits. + +For each change, use this markdown code block format: +``` path= start_line= end_line= + +``` + +Example: +```lua path={DIR}/lua/CopilotChat/init.lua start_line=40 end_line=50 +local function example() + print("This is an example function.") +end +``` + +Code content requirements: +Always use absolute file paths in headers. Convert relative paths to absolute by prefixing with {DIR}. +Keep changes minimal and focused. Include complete replacement code for the specified line range. +Use proper indentation matching the source file. Include all necessary lines without eliding code. +NEVER include line number prefixes in output code blocks - output only valid code as it should appear in the file. +Address any diagnostics issues when fixing code. + +Present multiple changes as separate code blocks. + +]] diff --git a/lua/CopilotChat/instructions/edit_file_unified.lua b/lua/CopilotChat/instructions/edit_file_unified.lua new file mode 100644 index 00000000..9eb8f56f --- /dev/null +++ b/lua/CopilotChat/instructions/edit_file_unified.lua @@ -0,0 +1,34 @@ +return [[ + +Return edits similar to unified diffs that `diff -U0` would produce. + +Make sure you include the first 2 lines with the file paths. +Don't include timestamps with the file paths. +Do not use any file path prefixes, just use --- path/to/file and +++ path/to/file. + +Start each hunk of changes with a `@@` line. + +The user's patch tool needs CORRECT patches that apply cleanly against the current contents of the file! +Code can start with line number prefixes for reference (e.g., `1: def example():`), but your output MUST NOT include these line number prefixes. +Think carefully and make sure you include and mark all lines that need to be removed or changed as `-` lines. +Make sure you mark all new or modified lines with `+`. +Don't leave out any lines or the diff patch won't apply correctly. + +Indentation matters in the diffs! + +Start a new hunk for each section of the file that needs changes. + +Only output hunks that specify changes with `+` or `-` lines. + +Output hunks in whatever order makes the most sense. +Hunks don't need to be in any particular order. + +When editing a function, method, loop, etc use a hunk to replace the *entire* code block. +Delete the entire existing version with `-` lines and then add a new, updated version with `+` lines. +This will help you generate correct code and correct diffs. + +To move code within a file, use 2 hunks: 1 to delete it from its current location, 1 to insert it in the new location. + +To make a new file, show a diff from `--- /dev/null` to `+++ path/to/new/file.ext`. + +]] diff --git a/lua/CopilotChat/instructions/tool_use.lua b/lua/CopilotChat/instructions/tool_use.lua new file mode 100644 index 00000000..989bf209 --- /dev/null +++ b/lua/CopilotChat/instructions/tool_use.lua @@ -0,0 +1,12 @@ +return [[ + +If tools are available for a requested action (such as file edit, read, search, diagnostics, etc.), you MUST use the tool to perform the action. Only provide manual code or instructions if no tool exists for that purpose. +- Always prefer tool usage over manual edits or suggestions. +- Follow JSON schema precisely when using tools, including all required properties and outputting valid JSON. +- Use appropriate tools for tasks rather than asking for manual actions or generating code for actions you can perform directly. +- Execute actions directly when you indicate you'll do so, without asking for permission. +- Only use tools that exist and use proper invocation procedures - no multi_tool_use.parallel unless specified. +- Before using tools to retrieve information, check if context is already available as described in the context instructions above. +- If you don't have explicit tool definitions in your system prompt, clearly state this limitation when asked. NEVER pretend to have tool capabilities you don't possess. + +]] diff --git a/lua/CopilotChat/prompts.lua b/lua/CopilotChat/prompts.lua new file mode 100644 index 00000000..7c4e60ce --- /dev/null +++ b/lua/CopilotChat/prompts.lua @@ -0,0 +1,418 @@ +local client = require('CopilotChat.client') +local constants = require('CopilotChat.constants') +local functions = require('CopilotChat.functions') +local notify = require('CopilotChat.utils.notify') +local files = require('CopilotChat.utils.files') +local orderedmap = require('CopilotChat.utils.orderedmap') +local utils = require('CopilotChat.utils') + +local WORD = '([^%s:]+)' +local WORD_NO_INPUT = '([^%s]+)' +local WORD_WITH_INPUT_QUOTED = WORD .. ':`([^`]+)`' +local WORD_WITH_INPUT_UNQUOTED = WORD .. ':?([^%s`]*)' + +--- Find custom instructions in the current working directory. +---@param cwd string +---@param config CopilotChat.config.Config +---@return table +local function find_custom_instructions(cwd, config) + local out = {} + local files_to_check = {} + for _, relpath in ipairs(config.instruction_files or {}) do + table.insert(files_to_check, vim.fs.joinpath(cwd, relpath)) + end + for _, path in ipairs(files_to_check) do + local content = files.read_file(path) + if content then + table.insert(out, { + filename = path, + content = vim.trim(content), + }) + end + end + return out +end + +local M = {} + +--- List available prompts. +---@return table +function M.list_prompts() + local config = require('CopilotChat.config') + local prompts_to_use = {} + + for name, prompt in pairs(config.prompts) do + local val = prompt + if type(prompt) == 'string' then + val = { + prompt = prompt, + } + end + + prompts_to_use[name] = val + end + + return prompts_to_use +end + +--- Resolve enabled tools from the prompt. +---@param prompt string? +---@param config CopilotChat.config.Shared? +---@return table, string +function M.resolve_tools(prompt, config) + config, prompt = M.resolve_prompt(prompt, config) + + local tools = {} + for _, tool in ipairs(functions.parse_tools(config.functions)) do + tools[tool.name] = tool + end + + local enabled_tools = orderedmap() + local tool_matches = utils.to_table(config.tools) + + -- Check for @tool pattern to find enabled tools + prompt = prompt:gsub('@' .. WORD, function(match) + for name, tool in pairs(config.functions) do + if name == match or tool.group == match then + table.insert(tool_matches, match) + return '' + end + end + return '@' .. match + end) + for _, match in ipairs(tool_matches) do + for name, tool in pairs(config.functions) do + if name == match or tool.group == match then + enabled_tools:set(name, tools[name]) + end + end + end + + return enabled_tools:values(), prompt +end + +--- Execute a tool call and return the raw output. +---@param name string Tool name +---@param input table|string Input arguments +---@param config CopilotChat.config.Shared +---@param source CopilotChat.client.Source +---@return boolean ok +---@return any output +---@async +function M.execute_tool_call(name, input, config, source) + local tool = config.functions[name] + if not tool or not tool.resolve then + return false, 'Tool not found: ' .. name + end + + local schema = nil + for _, t in ipairs(functions.parse_tools(config.functions)) do + if t.name == name then + schema = t.schema + break + end + end + + local ok, output + if config.stop_on_function_failure then + output = tool.resolve(functions.parse_input(input, schema), source) + ok = true + else + ok, output = pcall(tool.resolve, functions.parse_input(input, schema), source) + end + + return ok, output +end + +--- Format tool output as plain text. +---@param ok boolean +---@param output any +---@return string +function M.format_tool_output(ok, output) + local result = '' + if not ok then + result = utils.make_string(output) + elseif type(output) ~= 'table' then + result = utils.make_string(output) + else + for _, content in ipairs(output) do + if content then + local data = content.data or content.uri + if data then + result = result .. (utils.empty(result) and '' or '\n') .. data + end + end + end + end + + return result +end + +--- Call and resolve function calls from the prompt. +---@param prompt string? +---@param config CopilotChat.config.Shared? +---@return table, table, string +---@async +function M.resolve_functions(prompt, config) + config, prompt = M.resolve_prompt(prompt, config) + + local chat = require('CopilotChat').chat + local source = chat:get_source() + + if config.resources then + local resources = utils.to_table(config.resources) + local lines = utils.split_lines(prompt) + for i = #resources, 1, -1 do + local resource = resources[i] + table.insert(lines, 1, '#' .. resource) + end + prompt = table.concat(lines, '\n') + end + + local resolved_resources = {} + local resolved_tools = {} + local tool_calls = {} + + utils.schedule_main() + for _, message in ipairs(chat:get_messages()) do + if message.tool_calls then + for _, tool_call in ipairs(message.tool_calls) do + table.insert(tool_calls, tool_call) + end + end + end + + local resource_matches = {} + + -- Check for #word:`input` pattern + for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_QUOTED) do + local pattern = string.format('#%s:`%s`', word, input) + table.insert(resource_matches, { + pattern = pattern, + word = word, + input = input, + }) + end + + -- Check for #word:input pattern + for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_UNQUOTED) do + local pattern = utils.empty(input) and string.format('#%s', word) or string.format('#%s:%s', word, input) + table.insert(resource_matches, { + pattern = pattern, + word = word, + input = input, + }) + end + + -- Check for ##word:input pattern + for word in prompt:gmatch('##' .. WORD_NO_INPUT) do + local pattern = string.format('##%s', word) + table.insert(resource_matches, { + pattern = pattern, + word = word, + }) + end + + -- Resolve each function reference + local function expand_function(name, input) + notify.publish(notify.STATUS, 'Running function: ' .. name) + + local tool_id = nil + if not utils.empty(tool_calls) then + for _, tool_call in ipairs(tool_calls) do + if tool_call.name == name and vim.trim(tool_call.id) == vim.trim(input) then + input = utils.empty(tool_call.arguments) and {} or utils.json_decode(tool_call.arguments) + tool_id = tool_call.id + break + end + end + end + + local tool = config.functions[name] + if not tool then + -- Check if input matches uri + for tool_name, tool_spec in pairs(config.functions) do + if tool_spec.uri then + local match = functions.match_uri(name, tool_spec.uri) + if match then + name = tool_name + tool = tool_spec + input = match + break + end + end + end + end + if not tool then + return nil + end + if not tool_id and not tool.uri then + return nil + end + + local ok, output = M.execute_tool_call(name, input, config, source) + + if tool_id then + table.insert(resolved_tools, { + id = tool_id, + result = M.format_tool_output(ok, output), + }) + + return '' + end + + if not ok then + return utils.make_string(output) + end + + if type(output) ~= 'table' then + return utils.make_string(output) + end + + local result = '' + for _, content in ipairs(output) do + if content then + local content_out = nil + if content.uri then + if + not vim.tbl_contains(resolved_resources, function(resource) + return resource.uri == content.uri + end, { predicate = true }) + then + content_out = '##' .. content.uri + table.insert(resolved_resources, content) + end + else + content_out = content.data + end + + if content_out then + if not utils.empty(result) then + result = result .. '\n' + end + result = result .. content_out + end + end + end + + return result + end + + -- Resolve and process all tools + for _, match in ipairs(resource_matches) do + if not utils.empty(match.pattern) then + local out = expand_function(match.word, match.input) + if out == nil then + out = match.pattern + end + out = out:gsub('%%', '%%%%') -- Escape percent signs for gsub + prompt = prompt:gsub(vim.pesc(match.pattern), out, 1) + end + end + + return resolved_resources, resolved_tools, prompt +end + +--- Resolve the final prompt and config from prompt template. +---@param prompt string? +---@param config CopilotChat.config.Shared? +---@return CopilotChat.config.prompts.Prompt, string +---@async +function M.resolve_prompt(prompt, config) + local chat = require('CopilotChat').chat + local source = chat:get_source() + + if prompt == nil then + utils.schedule_main() + local message = chat:get_message(constants.ROLE.USER) + if message then + prompt = message.content + end + end + + local prompts_to_use = M.list_prompts() + local depth = 0 + local MAX_DEPTH = 10 + + local function resolve(inner_config, inner_prompt) + if depth >= MAX_DEPTH then + return inner_config, inner_prompt + end + depth = depth + 1 + + inner_prompt = string.gsub(inner_prompt, '/' .. WORD, function(match) + local p = prompts_to_use[match] + if p then + local resolved_config, resolved_prompt = resolve(p, p.prompt or '') + inner_config = vim.tbl_deep_extend('force', inner_config, resolved_config) + return resolved_prompt + end + + return '/' .. match + end) + + depth = depth - 1 + return inner_config, inner_prompt + end + + config = vim.tbl_deep_extend('force', require('CopilotChat.config'), config or {}) + config, prompt = resolve(config, prompt or '') + + if config.system_prompt then + if config.prompts[config.system_prompt] then + -- Name references are good for making system prompt auto sticky + config.system_prompt = config.prompts[config.system_prompt].system_prompt + end + + local custom_instructions = vim.trim(require('CopilotChat.instructions.custom_instructions')) + for _, instruction in ipairs(find_custom_instructions(source.cwd(), config)) do + config.system_prompt = vim.trim(config.system_prompt) + .. '\n' + .. custom_instructions:gsub('{FILENAME}', instruction.filename):gsub('{CONTENT}', instruction.content) + end + + config.system_prompt = vim.trim(config.system_prompt) .. '\n' .. config.prompts.COPILOT_BASE.system_prompt + config.system_prompt = vim.trim(config.system_prompt) + .. '\n' + .. vim.trim(require('CopilotChat.instructions.tool_use')) + + if config.diff == 'unified' then + config.system_prompt = vim.trim(config.system_prompt) + .. '\n' + .. vim.trim(require('CopilotChat.instructions.edit_file_unified')) + else + config.system_prompt = vim.trim(config.system_prompt) + .. '\n' + .. vim.trim(require('CopilotChat.instructions.edit_file_block')) + end + + config.system_prompt = config.system_prompt:gsub('{OS_NAME}', vim.uv.os_uname().sysname) + config.system_prompt = config.system_prompt:gsub('{LANGUAGE}', config.language) + config.system_prompt = config.system_prompt:gsub('{DIR}', source.cwd) + end + + return config, prompt +end + +--- Resolve the model from the prompt. +---@param prompt string? +---@param config CopilotChat.config.Shared? +---@return string, string +---@async +function M.resolve_model(prompt, config) + config, prompt = M.resolve_prompt(prompt, config) + local models = vim.tbl_keys(client:models()) + + local selected_model = config.model or '' + prompt = prompt:gsub('%$' .. WORD, function(match) + if vim.tbl_contains(models, match) then + selected_model = match + return '' + end + return '$' .. match + end) + + return selected_model, prompt +end + +return M diff --git a/lua/CopilotChat/resources.lua b/lua/CopilotChat/resources.lua index da79d6ec..22c97c4b 100644 --- a/lua/CopilotChat/resources.lua +++ b/lua/CopilotChat/resources.lua @@ -1,411 +1,36 @@ ----@class CopilotChat.resources.Symbol ----@field name string? ----@field signature string ----@field type string ----@field start_row number ----@field start_col number ----@field end_row number ----@field end_col number - local async = require('plenary.async') -local log = require('plenary.log') -local client = require('CopilotChat.client') -local notify = require('CopilotChat.notify') local utils = require('CopilotChat.utils') +local curl = require('CopilotChat.utils.curl') +local files = require('CopilotChat.utils.files') local file_cache = {} local url_cache = {} -local embedding_cache = {} -local outline_cache = {} local M = {} -local OUTLINE_TYPES = { - 'local_function', - 'function_item', - 'arrow_function', - 'function_definition', - 'function_declaration', - 'method_definition', - 'method_declaration', - 'proc_declaration', - 'template_declaration', - 'macro_declaration', - 'constructor_declaration', - 'field_declaration', - 'class_definition', - 'class_declaration', - 'interface_definition', - 'interface_declaration', - 'record_declaration', - 'type_alias_declaration', - 'import_statement', - 'import_from_statement', - 'atx_heading', - 'list_item', -} - -local NAME_TYPES = { - 'name', - 'identifier', - 'heading_content', -} - -local OFF_SIDE_RULE_LANGUAGES = { - 'python', - 'coffeescript', - 'nim', - 'elm', - 'curry', - 'fsharp', -} - -local MULTI_FILE_THRESHOLD = 5 - ---- Compute the cosine similarity between two vectors ----@param a table ----@param b table ----@return number -local function spatial_distance_cosine(a, b) - if not a or not b then - return 0 - end - - local dot_product = 0 - local magnitude_a = 0 - local magnitude_b = 0 - for i = 1, #a do - dot_product = dot_product + a[i] * b[i] - magnitude_a = magnitude_a + a[i] * a[i] - magnitude_b = magnitude_b + b[i] * b[i] - end - magnitude_a = math.sqrt(magnitude_a) - magnitude_b = math.sqrt(magnitude_b) - return dot_product / (magnitude_a * magnitude_b) -end - ---- Rank data by relatedness to the query ----@param query CopilotChat.client.EmbeddedResource ----@param data table ----@return table -local function data_ranked_by_relatedness(query, data) - for _, item in ipairs(data) do - local score = spatial_distance_cosine(item.embedding, query.embedding) - item.score = score or item.score or 0 - end - - table.sort(data, function(a, b) - return a.score > b.score - end) - - -- Apply dynamic filtering for embedding-based ranking - local filtered = {} - - if #data > 0 then - -- Calculate statistics for score distribution - local sum = 0 - local max_score = data[1].score - - for _, item in ipairs(data) do - sum = sum + item.score - end - - local mean = sum / #data - - -- Calculate standard deviation - local sum_squared_diff = 0 - for _, item in ipairs(data) do - sum_squared_diff = sum_squared_diff + ((item.score - mean) * (item.score - mean)) - end - local std_dev = math.sqrt(sum_squared_diff / #data) - - -- Calculate z-scores and use them to determine significance - -- Include items with z-score > -0.5 (meaning within 0.5 std dev below mean) - -- This is a statistical approach to find "significantly" related items - for _, result in ipairs(data) do - local z_score = (result.score - mean) / std_dev - if z_score > -0.5 then - table.insert(filtered, result) - end - end - - -- If we didn't get enough results or the distribution is very tight, - -- use a percentage of max score as fallback - if #filtered < MULTI_FILE_THRESHOLD then - filtered = {} - local adaptive_threshold = max_score * 0.6 -- 60% of max score - - for i, result in ipairs(data) do - if i <= MULTI_FILE_THRESHOLD or result.score >= adaptive_threshold then - table.insert(filtered, result) - end - end - end - end - - return filtered -end - --- Create trigrams from text (e.g., "hello" -> {"hel", "ell", "llo"}) -local function get_trigrams(text) - local trigrams = {} - text = text:lower() - for i = 1, #text - 2 do - trigrams[text:sub(i, i + 2)] = true - end - return trigrams -end - --- Calculate Jaccard similarity between two trigram sets -local function trigram_similarity(set1, set2) - local intersection = 0 - local union = 0 - - -- Count intersection and union - for trigram in pairs(set1) do - if set2[trigram] then - intersection = intersection + 1 - end - union = union + 1 - end - - for trigram in pairs(set2) do - if not set1[trigram] then - union = union + 1 - end - end - - return intersection / union -end - ---- Rank data by symbols and filenames ----@param query string ----@param data table ----@return table -local function data_ranked_by_symbols(query, data) - -- Get query trigrams including compound versions - local query_trigrams = {} - - -- Add trigrams for each word - for term in query:gmatch('%w+') do - for trigram in pairs(get_trigrams(term)) do - query_trigrams[trigram] = true - end - end - - -- Add trigrams for compound query - local compound_query = query:gsub('[^%w]', '') - for trigram in pairs(get_trigrams(compound_query)) do - query_trigrams[trigram] = true - end - - local max_score = 0 - - for _, entry in ipairs(data) do - local basename = utils.filename(entry.name):gsub('%..*$', '') - - -- Get trigrams for basename and compound version - local file_trigrams = get_trigrams(basename) - local compound_trigrams = get_trigrams(basename:gsub('[^%w]', '')) - - -- Calculate similarities - local name_sim = trigram_similarity(query_trigrams, file_trigrams) - local compound_sim = trigram_similarity(query_trigrams, compound_trigrams) - - -- Take best match - local score = (entry.score or 0) + math.max(name_sim, compound_sim) - - -- Add symbol matches - if entry.symbols then - local symbol_score = 0 - for _, symbol in ipairs(entry.symbols) do - if symbol.name then - local symbol_trigrams = get_trigrams(symbol.name) - local sym_sim = trigram_similarity(query_trigrams, symbol_trigrams) - symbol_score = math.max(symbol_score, sym_sim) - end - end - score = score + (symbol_score * 0.5) -- Weight symbol matches less - end - - max_score = math.max(max_score, score) - entry.score = score - end - - -- Normalize scores - for _, entry in ipairs(data) do - entry.score = entry.score / max_score - end - - -- Sort by score first - table.sort(data, function(a, b) - return a.score > b.score - end) - - -- Use elbow method to find natural cutoff point for symbol-based ranking - local filtered_results = {} - - if #data > 0 then - -- Always include at least the top result - table.insert(filtered_results, data[1]) - - -- Find the point of maximum drop-off (the "elbow") - local max_drop = 0 - local cutoff_index = math.min(MULTI_FILE_THRESHOLD, #data) - - for i = 2, math.min(20, #data) do - local drop = data[i - 1].score - data[i].score - if drop > max_drop then - max_drop = drop - cutoff_index = i - end - end - - -- Include everything up to the cutoff point - for i = 2, cutoff_index do - table.insert(filtered_results, data[i]) - end - - -- Also include any remaining items that have scores close to the cutoff - local cutoff_score = data[cutoff_index].score - local threshold = cutoff_score * 0.8 -- Within 80% of the cutoff score - - for i = cutoff_index + 1, #data do - if data[i].score >= threshold then - table.insert(filtered_results, data[i]) - end - end - end - - return filtered_results -end - ---- Get the full signature of a declaration ----@param start_row number ----@param start_col number ----@param lines table ----@return string -local function get_full_signature(start_row, start_col, lines) - local start_line = lines[start_row + 1] - local signature = vim.trim(start_line:sub(start_col + 1)) - - -- Look ahead for opening brace on next line - if not signature:match('{') and (start_row + 2) <= #lines then - local next_line = vim.trim(lines[start_row + 2]) - if next_line:match('^{') then - signature = signature .. ' {' - end - end - - return signature -end - ---- Get the name of a node ----@param node table ----@param content string ----@return string? -local function get_node_name(node, content) - for _, name_type in ipairs(NAME_TYPES) do - local name_field = node:field(name_type) - if name_field and #name_field > 0 then - return vim.treesitter.get_node_text(name_field[1], content) - end - end - - return nil -end - ---- Build an outline and symbols from a string ----@param content string ----@param ft string ----@return string?, table? -local function get_outline(content, ft) - if not ft or ft == '' then - return nil - end - - local lang = vim.treesitter.language.get_lang(ft) - local ok, parser = false, nil - if lang then - ok, parser = pcall(vim.treesitter.get_string_parser, content, lang) - end - if not ok or not parser then - ft = string.gsub(ft, 'react', '') - ok, parser = pcall(vim.treesitter.get_string_parser, content, ft) - if not ok or not parser then - return nil - end - end - - local root = utils.ts_parse(parser) - local lines = vim.split(content, '\n') - local symbols = {} - local outline_lines = {} - local depth = 0 - - local function parse_node(node) - local type = node:type() - local is_outline = vim.tbl_contains(OUTLINE_TYPES, type) - local start_row, start_col, end_row, end_col = node:range() - - if is_outline then - depth = depth + 1 - local name = get_node_name(node, content) - local signature_start = get_full_signature(start_row, start_col, lines) - table.insert(outline_lines, string.rep(' ', depth) .. signature_start) - - -- Store symbol information - table.insert(symbols, { - name = name, - signature = signature_start, - type = type, - start_row = start_row + 1, - start_col = start_col + 1, - end_row = end_row, - end_col = end_col, - }) - end - - for child in node:iter_children() do - parse_node(child) - end - - if is_outline then - if not vim.tbl_contains(OFF_SIDE_RULE_LANGUAGES, ft) then - local end_line = lines[end_row + 1] - local signature_end = vim.trim(end_line:sub(1, end_col)) - table.insert(outline_lines, string.rep(' ', depth) .. signature_end) - end - depth = depth - 1 - end - end - - parse_node(root) - - if #outline_lines == 0 then - return nil - end - return table.concat(outline_lines, '\n'), symbols -end - --- Get data for a file ---@param filename string ---@return string?, string? function M.get_file(filename) - local filetype = utils.filetype(filename) + local filetype = files.filetype(filename) if not filetype then return nil end - local modified = utils.file_mtime(filename) - if not modified then + local err, stat = async.uv.fs_stat(filename) + if err or not stat then return nil end + local modified = stat.mtime.sec local data = file_cache[filename] if not data or data._modified < modified then - local content = utils.read_file(filename) + local content = files.read_file(filename) if not content or content == '' then return nil end + -- Simple binary detection: reject files with null bytes + if content:find('\0') then + return nil + end data = { content = content, _modified = modified, @@ -413,7 +38,7 @@ function M.get_file(filename) file_cache[filename] = data end - return data.content, utils.filetype_to_mimetype(filetype) + return data.content, files.filetype_to_mimetype(filetype) end --- Get data for a buffer @@ -429,7 +54,7 @@ function M.get_buffer(bufnr) return nil end - return table.concat(content, '\n'), utils.filetype_to_mimetype(vim.bo[bufnr].filetype) + return table.concat(content, '\n'), files.filetype_to_mimetype(vim.bo[bufnr].filetype) end --- Get the content of an URL @@ -440,7 +65,7 @@ function M.get_url(url) return nil end - local ft = utils.filetype(url) + local ft = files.filetype(url) local content = url_cache[url] if not content then local ok, out = async.util.apcall(utils.system, { 'lynx', '-dump', url }) @@ -449,7 +74,7 @@ function M.get_url(url) content = out.stdout else -- Fallback to curl if lynx fails - local response = utils.curl_get(url, { raw = { '-L' } }) + local response = curl.get(url, { raw = { '-L' } }) if not response or not response.body then return nil end @@ -478,105 +103,7 @@ function M.get_url(url) url_cache[url] = content end - return content, utils.filetype_to_mimetype(ft) -end - ---- Transform a resource into a format suitable for the client ----@param resource CopilotChat.config.functions.Result ----@return CopilotChat.client.Resource -function M.to_resource(resource) - return { - name = utils.uri_to_filename(resource.uri), - type = utils.mimetype_to_filetype(resource.mimetype), - data = resource.data, - } -end - ---- Process resources based on the query ----@param prompt string ----@param model string ----@param resources table ----@return table -function M.process_resources(prompt, model, resources) - -- If we dont need to embed anything, just return directly - if #resources < MULTI_FILE_THRESHOLD then - return resources - end - - notify.publish(notify.STATUS, 'Preparing embedding outline') - - -- Get the outlines for each resource - for _, input in ipairs(resources) do - local hash = input.name .. utils.quick_hash(input.data) - input._hash = hash - - local outline = outline_cache[hash] - if not outline then - local outline_text, symbols = get_outline(input.data, input.type) - if outline_text then - outline = { - outline = outline_text, - symbols = symbols, - } - - outline_cache[hash] = outline - end - end - - if outline then - input.outline = outline.outline - input.symbols = outline.symbols - end - end - - notify.publish(notify.STATUS, 'Ranking embeddings') - - -- Build query from history and prompt - local query = prompt - - -- Rank embeddings by symbols - resources = data_ranked_by_symbols(query, resources) - log.debug('Ranked data:', #resources) - for i, item in ipairs(resources) do - log.debug(string.format('%s: %s - %s', i, item.score, item.name)) - end - - -- Prepare embeddings for processing - local to_process = {} - local results = {} - for _, input in ipairs(resources) do - local hash = input._hash - local embed = embedding_cache[hash] - if embed then - input.embedding = embed - table.insert(results, input) - else - table.insert(to_process, input) - end - end - table.insert(to_process, { - type = 'text', - data = query, - }) - - -- Embed the data and process the results - for _, input in ipairs(client:embed(to_process, model)) do - if input._hash then - embedding_cache[input._hash] = input.embedding - end - table.insert(results, input) - end - - -- Rate embeddings by relatedness to the query - local embedded_query = table.remove(results, #results) - log.debug('Embedded query:', embedded_query.content) - results = data_ranked_by_relatedness(embedded_query, results) - log.debug('Ranked embeddings:', #results) - for i, item in ipairs(results) do - log.debug(string.format('%s: %s - %s', i, item.score, item.filename)) - end - - return results + return content, files.filetype_to_mimetype(ft) end return M diff --git a/lua/CopilotChat/select.lua b/lua/CopilotChat/select.lua index 8bef366c..84722e9d 100644 --- a/lua/CopilotChat/select.lua +++ b/lua/CopilotChat/select.lua @@ -1,95 +1,91 @@ ---@class CopilotChat.select.Selection ---@field content string ----@field start_line number ----@field end_line number +---@field start_line integer +---@field end_line integer ---@field filename string ---@field filetype string ----@field bufnr number +---@field bufnr integer + +local log = require('plenary.log') +local utils = require('CopilotChat.utils') local M = {} ---- Select and process current visual selection ---- @param source CopilotChat.source ---- @return CopilotChat.select.Selection|nil -function M.visual(source) - local bufnr = source.bufnr - local start_line = unpack(vim.api.nvim_buf_get_mark(bufnr, '<')) - local finish_line = unpack(vim.api.nvim_buf_get_mark(bufnr, '>')) - if start_line == 0 or finish_line == 0 then - return nil - end - if start_line > finish_line then - start_line, finish_line = finish_line, start_line - end +--- Use #selection instead +---@deprecated +function M.visual(_) + log.warn('CopilotChat.select.visual is deprecated, use #selection instead') + return nil +end - local ok, lines = pcall(vim.api.nvim_buf_get_lines, bufnr, start_line - 1, finish_line, false) - if not ok then - return nil +--- Use #selection instead +---@deprecated use #selection instead +function M.buffer(_) + log.warn('CopilotChat.select.buffer is deprecated, use #selection instead') + return nil +end + +--- Use #selection instead +---@deprecated use #selection instead +function M.line(_) + log.warn('CopilotChat.select.line is deprecated, use #selection instead') + return nil +end + +--- Use #selection instead +---@deprecated use #selection instead +function M.unnamed(_) + log.warn('CopilotChat.select.unnamed is deprecated, use #selection instead') + return nil +end + +--- Get the marks used for selection +---@return string[] +function M.marks() + local config = require('CopilotChat.config') + local marks = { '<', '>' } + if config.selection == 'unnamed' then + marks = { '[', ']' } end - local lines_content = table.concat(lines, '\n') - if vim.trim(lines_content) == '' then - return nil + return marks +end + +--- Highlight selection in target buffer or clear it +---@param bufnr integer +---@param clear boolean? +function M.highlight(bufnr, clear) + local selection_ns = vim.api.nvim_create_namespace('copilot-chat-selection') + for _, buf in ipairs(vim.api.nvim_list_bufs()) do + vim.api.nvim_buf_clear_namespace(buf, selection_ns, 0, -1) end - return { - content = lines_content, - filename = vim.api.nvim_buf_get_name(bufnr), - filetype = vim.bo[bufnr].filetype, - start_line = start_line, - end_line = finish_line, - bufnr = bufnr, - } -end + if clear then + return + end ---- Select and process whole buffer ---- @param source CopilotChat.source ---- @return CopilotChat.select.Selection|nil -function M.buffer(source) - local bufnr = source.bufnr - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - if not lines or #lines == 0 then - return nil + local selection = M.get(bufnr) + if not selection then + return end - return { - content = table.concat(lines, '\n'), - filename = vim.api.nvim_buf_get_name(bufnr), - filetype = vim.bo[bufnr].filetype, - start_line = 1, - end_line = #lines, - bufnr = bufnr, - } + vim.api.nvim_buf_set_extmark(selection.bufnr, selection_ns, selection.start_line - 1, 0, { + hl_group = 'CopilotChatSelection', + end_row = selection.end_line, + strict = false, + }) end ---- Select and process current line ---- @param source CopilotChat.source ---- @return CopilotChat.select.Selection|nil -function M.line(source) - local bufnr = source.bufnr - local winnr = source.winnr - local cursor = vim.api.nvim_win_get_cursor(winnr) - local line = vim.api.nvim_buf_get_lines(bufnr, cursor[1] - 1, cursor[1], false)[1] - if not line then +--- Get the selection from the target buffer +---@param bufnr integer +---@return CopilotChat.select.Selection? +function M.get(bufnr) + if not utils.buf_valid(bufnr) then return nil end - return { - content = line, - filename = vim.api.nvim_buf_get_name(bufnr), - filetype = vim.bo[bufnr].filetype, - start_line = cursor[1], - end_line = cursor[1], - bufnr = bufnr, - } -end - ---- Select and process contents of unnamed register ("). This register contains last deleted, changed or yanked content. ---- @param source CopilotChat.source ---- @return CopilotChat.select.Selection|nil -function M.unnamed(source) - local bufnr = source.bufnr - local start_line = unpack(vim.api.nvim_buf_get_mark(bufnr, '[')) - local finish_line = unpack(vim.api.nvim_buf_get_mark(bufnr, ']')) + local marks = M.marks() + local start_line = unpack(vim.api.nvim_buf_get_mark(bufnr, marks[1])) + local finish_line = unpack(vim.api.nvim_buf_get_mark(bufnr, marks[2])) if start_line == 0 or finish_line == 0 then return nil end @@ -116,4 +112,31 @@ function M.unnamed(source) } end +--- Sets the selection to specific lines in buffer or clears it +---@param bufnr integer +---@param winnr integer? +---@param start_line integer? +---@param end_line integer? +function M.set(bufnr, winnr, start_line, end_line) + if not utils.buf_valid(bufnr) then + return + end + + local marks = M.marks() + + if not start_line or not end_line then + for _, mark in ipairs(marks) do + pcall(vim.api.nvim_buf_del_mark, bufnr, mark) + end + return + end + + pcall(vim.api.nvim_buf_set_mark, bufnr, marks[1], start_line, 0, {}) + pcall(vim.api.nvim_buf_set_mark, bufnr, marks[2], end_line, 0, {}) + + if winnr and vim.api.nvim_win_is_valid(winnr) then + pcall(vim.api.nvim_win_set_cursor, winnr, { start_line, 0 }) + end +end + return M diff --git a/lua/CopilotChat/tiktoken.lua b/lua/CopilotChat/tiktoken.lua index dde3d2b5..f7ea0de7 100644 --- a/lua/CopilotChat/tiktoken.lua +++ b/lua/CopilotChat/tiktoken.lua @@ -1,35 +1,25 @@ -local notify = require('CopilotChat.notify') +local notify = require('CopilotChat.utils.notify') local utils = require('CopilotChat.utils') -local current_tokenizer = nil +local curl = require('CopilotChat.utils.curl') +local class = require('CopilotChat.utils.class') +--- Get the library extension based on the operating system --- @return string local function get_lib_extension() - if jit.os:lower() == 'mac' or jit.os:lower() == 'osx' then + local os_name = vim.uv.os_uname().sysname:lower() + if os_name:find('darwin') then return '.dylib' - end - if jit.os:lower() == 'windows' then + elseif os_name:find('windows') then return '.dll' + else + return '.so' end - return '.so' -end - -package.cpath = package.cpath - .. ';' - .. debug.getinfo(1).source:match('@?(.*/)') - .. '../../build/?' - .. get_lib_extension() - -local tiktoken_ok, tiktoken_core = pcall(require, 'tiktoken_core') -if not tiktoken_ok then - tiktoken_core = nil end --- Load tiktoken data from cache or download it ---@param tokenizer string The tokenizer to load ---@async local function load_tiktoken_data(tokenizer) - utils.schedule_main() - local tiktoken_url = 'https://openaipublic.blob.core.windows.net/encodings/' .. tokenizer .. '.tiktoken' local cache_dir = vim.fn.stdpath('cache') @@ -42,27 +32,41 @@ local function load_tiktoken_data(tokenizer) notify.publish(notify.STATUS, 'Downloading tiktoken data from ' .. tiktoken_url) - utils.curl_get(tiktoken_url, { + curl.get(tiktoken_url, { output = cache_path, }) return cache_path end -local M = {} +---@class CopilotChat.tiktoken.Tiktoken : Class +---@field private tiktoken_core table? +---@field private tokenizer string? +local Tiktoken = class(function(self) + package.cpath = package.cpath + .. ';' + .. debug.getinfo(1).source:match('@?(.*/)') + .. '../../build/?' + .. get_lib_extension() + + local tiktoken_ok, tiktoken_core = pcall(require, 'tiktoken_core') + self.tiktoken_core = tiktoken_ok and tiktoken_core or nil + self.tokenizer = nil +end) --- Load the tiktoken module ---@param tokenizer string The tokenizer to load ---@async -M.load = function(tokenizer) - if not tiktoken_core then +function Tiktoken:load(tokenizer) + if not self.tiktoken_core then return end - if tokenizer == current_tokenizer then + if tokenizer == self.tokenizer then return end + utils.schedule_main() local path = load_tiktoken_data(tokenizer) local special_tokens = {} special_tokens['<|endoftext|>'] = 100257 @@ -74,26 +78,22 @@ M.load = function(tokenizer) "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" utils.schedule_main() - tiktoken_core.new(path, special_tokens, pat_str) - current_tokenizer = tokenizer + self.tiktoken_core.new(path, special_tokens, pat_str) + self.tokenizer = tokenizer end --- Encode a prompt ---@param prompt string The prompt to encode ---@return table? -function M.encode(prompt) - if not tiktoken_core then +function Tiktoken:encode(prompt) + if not self.tiktoken_core then return nil end - if not prompt or prompt == '' then + if not prompt or prompt == '' or type(prompt) ~= 'string' then return nil end - -- Check if prompt is a string - if type(prompt) ~= 'string' then - error('Prompt must be a string') - end - local ok, result = pcall(tiktoken_core.encode, prompt) + local ok, result = pcall(self.tiktoken_core.encode, prompt) if not ok then return nil end @@ -104,16 +104,16 @@ end --- Count the tokens in a prompt ---@param prompt string The prompt to count ---@return number -function M.count(prompt) - if not tiktoken_core then - return math.ceil(#prompt * 0.5) -- Fallback to 1/2 character count +function Tiktoken:count(prompt) + if not self.tiktoken_core then + return math.ceil(#prompt / 4) end - local tokens = M.encode(prompt) + local tokens = self:encode(prompt) if not tokens then - return math.ceil(#prompt * 0.5) -- Fallback to 1/2 character count + return math.ceil(#prompt / 4) end return #tokens end -return M +return Tiktoken() diff --git a/lua/CopilotChat/ui/chat.lua b/lua/CopilotChat/ui/chat.lua index 8b22c32e..7f14475e 100644 --- a/lua/CopilotChat/ui/chat.lua +++ b/lua/CopilotChat/ui/chat.lua @@ -1,8 +1,10 @@ local Overlay = require('CopilotChat.ui.overlay') local Spinner = require('CopilotChat.ui.spinner') -local notify = require('CopilotChat.notify') +local constants = require('CopilotChat.constants') +local notify = require('CopilotChat.utils.notify') local utils = require('CopilotChat.utils') -local class = utils.class +local class = require('CopilotChat.utils.class') +local orderedmap = require('CopilotChat.utils.orderedmap') function CopilotChatFoldExpr(lnum, separator) local to_match = separator .. '$' @@ -14,86 +16,161 @@ function CopilotChatFoldExpr(lnum, separator) return '=' end -local HEADER_PATTERNS = { - '^```?(%w+)%s+path=(%S+)%s+start_line=(%d+)%s+end_line=(%d+)$', - '^```(%w+)$', -} +---@param headers table? +---@return string?, string? +local function match_section_header(headers, separator, line) + if not headers then + return + end + + for header_name, header_value in pairs(headers) do + local id = line:match('^' .. vim.pesc(header_value) .. ' %(([^)]+)%) ' .. vim.pesc(separator) .. '$') + if id then + return id, header_name + end + end +end ---@param header? string ---@return string?, string?, number?, number? -local function match_header(header) +local function match_block_header(header) if not header then return end - for _, pattern in ipairs(HEADER_PATTERNS) do + local patterns = { + '^(%w+)%s+path=(.-)%s+start_line=(%d+)%s+end_line=(%d+)$', + '^(%w+)%s+path=(%S+)%s+start_line=(%d+)%s+end_line=(%d+)$', + '^(%w+)$', + } + + for _, pattern in ipairs(patterns) do local type, path, start_line, end_line = header:match(pattern) if path then return type, path, tonumber(start_line) or 1, tonumber(end_line) or tonumber(start_line) or 1 elseif type then - return type, 'block' + return type, nil + end + end +end + +---@param header? CopilotChat.ui.chat.Header +---@param content? string +---@return string? +local function match_block_content(header, content) + if not header or header.filetype ~= 'diff' or not content then + return + end + + local lines = vim.split(content, '\n') + for _, line in ipairs(lines) do + local diff_filename = line:match('^%+%+%+%s+(.*)') + if diff_filename then + return vim.trim(diff_filename) end end end +--- Get the last line and column of the chat window. +---@param bufnr number +---@return number, number +---@protected +local function last(bufnr) + local line_count = vim.api.nvim_buf_line_count(bufnr) + if line_count == 0 then + return 0, 0 + end + local last_line = line_count - 1 + local last_line_content = vim.api.nvim_buf_get_lines(bufnr, last_line, last_line + 1, false) + local last_column = last_line_content[1] and #last_line_content[1] or 0 + return last_line, last_column +end + ---@class CopilotChat.ui.chat.Header ----@field filename string ----@field start_line number ----@field end_line number ---@field filetype string +---@field filename string +---@field start_line number? +---@field end_line number? ---@class CopilotChat.ui.chat.Block ---@field header CopilotChat.ui.chat.Header ---@field start_line number ---@field end_line number ----@field content string? +---@field content string ---@class CopilotChat.ui.chat.Section ----@field start_line number ----@field end_line number ----@field blocks table +---@field start_line integer +---@field end_line integer +---@field blocks CopilotChat.ui.chat.Block[] ---@class CopilotChat.ui.chat.Message : CopilotChat.client.Message ----@field id string +---@field id string? ---@field section CopilotChat.ui.chat.Section? +--- @class CopilotChat.ui.chat.Source +--- @field bufnr integer? +--- @field winnr integer? +--- @field cwd fun():string + ---@class CopilotChat.ui.chat.Chat : CopilotChat.ui.overlay.Overlay ----@field winnr number? +---@field winnr integer? ---@field config CopilotChat.config.Shared ---@field token_count number? ---@field token_max_count number? ----@field messages table +---@field private messages OrderedMap ---@field private layout CopilotChat.config.Layout? ---@field private headers table ---@field private separator string ---@field private spinner CopilotChat.ui.spinner.Spinner ---@field private chat_overlay CopilotChat.ui.overlay.Overlay -local Chat = class(function(self, config, help, on_buf_create) - Overlay.init(self, 'copilot-chat', help, on_buf_create) +---@field private last_changedtick number? +---@field private source CopilotChat.ui.chat.Source +---@field private sticky string[] +local Chat = class(function(self, config, on_buf_create) + Overlay.init(self, 'copilot-chat', utils.key_to_info('show_help', config.mappings.show_help), on_buf_create) self.winnr = nil self.config = config self.token_count = nil self.token_max_count = nil - self.messages = {} + self.messages = orderedmap() + + self.source = { + bufnr = nil, + winnr = nil, + cwd = function() + return '.' + end, + } + + self.sticky = {} self.layout = nil - self.headers = config.headers + self.headers = {} + for k, v in pairs(config.headers or {}) do + self.headers[k] = v:gsub('^#+', ''):gsub('^%s+', '') + end self.separator = config.separator self.spinner = Spinner() - self.chat_overlay = Overlay('copilot-overlay', 'q to close', function(bufnr) - vim.keymap.set('n', 'q', function() - self.chat_overlay:restore(self.winnr, self.bufnr) - end) - - vim.api.nvim_create_autocmd({ 'BufHidden', 'BufDelete' }, { - buffer = bufnr, - callback = function() + self.chat_overlay = Overlay( + 'copilot-overlay', + utils.key_to_info('close', { + normal = config.mappings.close.normal, + }), + function(bufnr) + vim.keymap.set('n', config.mappings.close.normal, function() self.chat_overlay:restore(self.winnr, self.bufnr) - end, - }) - end) + end, { buffer = bufnr }) + + vim.api.nvim_create_autocmd({ 'BufHidden', 'BufDelete' }, { + buffer = bufnr, + callback = function() + self.chat_overlay:restore(self.winnr, self.bufnr) + end, + }) + end + ) notify.listen(notify.MESSAGE, function(msg) utils.schedule_main() @@ -102,7 +179,11 @@ local Chat = class(function(self, config, help, on_buf_create) self:open(self.config) end - self:overlay({ text = msg }) + if not msg or msg == '' then + self.chat_overlay:restore(self.winnr, self.bufnr) + else + self.chat_overlay:show(msg, self.winnr) + end end) end, Overlay) @@ -119,121 +200,113 @@ function Chat:focused() return self:visible() and vim.api.nvim_get_current_win() == self.winnr end ---- Get the closest message to the cursor. +--- Get the closest code block to the cursor. ---@param role string? If specified, only considers sections of the given role ----@return CopilotChat.ui.chat.Message? -function Chat:get_closest_message(role) - if not self:visible() then - return nil - end - - self:render() - local cursor_pos = vim.api.nvim_win_get_cursor(self.winnr) - local cursor_line = cursor_pos[1] - local closest_message = nil - local max_line_below_cursor = -1 +---@param cursor boolean? If true, returns the block closest to the cursor position +---@return CopilotChat.ui.chat.Block? +function Chat:get_block(role, cursor) + local messages = self:get_messages() - for _, message in ipairs(self.messages) do - local section = message.section - local matches_role = not role or message.role == role - if matches_role and section.start_line <= cursor_line and section.start_line > max_line_below_cursor then - max_line_below_cursor = section.start_line - closest_message = message + if cursor then + if not self:visible() then + return nil end - end - - return closest_message -end - ---- Get the closest code block to the cursor. ----@return CopilotChat.ui.chat.Block? -function Chat:get_closest_block() - if not self:visible() then - return nil - end - self:render() - local cursor_pos = vim.api.nvim_win_get_cursor(self.winnr) - local cursor_line = cursor_pos[1] - local closest_block = nil - local max_line_below_cursor = -1 + local cursor_pos = vim.api.nvim_win_get_cursor(self.winnr) + local cursor_line = cursor_pos[1] + local closest_block = nil + local max_line_below_cursor = -1 - for _, message in pairs(self.messages) do - local section = message.section - for _, block in ipairs(section.blocks) do - if block.start_line <= cursor_line and block.start_line > max_line_below_cursor then - max_line_below_cursor = block.start_line - closest_block = block + for _, message in ipairs(messages) do + local section = message.section + local matches_role = not role or message.role == role + if matches_role and section and section.blocks then + for _, block in ipairs(section.blocks) do + if block.start_line <= cursor_line and block.start_line > max_line_below_cursor then + max_line_below_cursor = block.start_line + closest_block = block + end + end end end - end - - return closest_block -end ---- Get last message by role in the chat window. ----@return CopilotChat.ui.chat.Message? -function Chat:get_message(role) - if not self:visible() then - return + return closest_block end - for i = #self.messages, 1, -1 do - local message = self.messages[i] + for i = #messages, 1, -1 do + local message = messages[i] local matches_role = not role or message.role == role - if matches_role then - return message + if matches_role and message.section and message.section.blocks and #message.section.blocks > 0 then + return message.section.blocks[#message.section.blocks] end end end ---- Add a sticky line to the prompt in the chat window. ----@param sticky string -function Chat:add_sticky(sticky) - if not self:visible() then - return - end +--- Get list of all chat messages +---@return CopilotChat.ui.chat.Message[] +function Chat:get_messages() + self:parse() + return self.messages:values() +end - local prompt = self:get_message('user') - if not prompt or not prompt.section then - return - end +--- Get last message by role in the chat window. +---@param role string? If specified, only considers sections of the given role +---@param cursor boolean? If true, returns the message closest to the cursor position +---@return CopilotChat.ui.chat.Message? +function Chat:get_message(role, cursor) + local messages = self:get_messages() - local lines = vim.split(prompt.content, '\n') - local insert_line = 1 - local first_one = true - local found = false + if cursor then + if not self:visible() then + return nil + end - for i = insert_line, #lines do - local line = lines[i] - if line and line ~= '' then - if vim.startswith(line, '> ') then - if line:sub(3) == sticky then - found = true - break - end + local cursor_pos = vim.api.nvim_win_get_cursor(self.winnr) + local cursor_line = cursor_pos[1] + local closest_message = nil + local max_line_below_cursor = -1 - first_one = false - else - break + for _, message in ipairs(messages) do + local section = message.section + local matches_role = not role or message.role == role + if + matches_role + and section + and section.start_line <= cursor_line + and section.start_line > max_line_below_cursor + then + max_line_below_cursor = section.start_line + closest_message = message end - elseif i >= 2 then - break end - insert_line = insert_line + 1 + return closest_message end - if found then - return + for i = #messages, 1, -1 do + local message = messages[i] + local matches_role = not role or message.role == role + if matches_role then + return message + end end +end - insert_line = prompt.section.start_line + insert_line - 1 - local to_insert = first_one and { '> ' .. sticky, '' } or { '> ' .. sticky } - local modifiable = vim.bo[self.bufnr].modifiable - vim.bo[self.bufnr].modifiable = true - vim.api.nvim_buf_set_lines(self.bufnr, insert_line - 1, insert_line - 1, false, to_insert) - vim.bo[self.bufnr].modifiable = modifiable +--- Get the current sticky array. +---@return string[] +function Chat:get_sticky() + return self.sticky +end + +--- Set the sticky array. +---@param sticky string[] +function Chat:set_sticky(sticky) + self.sticky = sticky +end + +--- Clear the sticky array. +function Chat:clear_sticky() + self.sticky = {} end ---@class CopilotChat.ui.Chat.show_overlay @@ -341,15 +414,14 @@ function Chat:open(config) end local ns = vim.api.nvim_create_namespace('copilot-chat-local-hl') - vim.api.nvim_set_hl(ns, '@markup.quote.markdown', {}) + vim.api.nvim_set_hl(ns, '@markup.quote.markdown', {}) -- disable quote block overriding chat keywords + vim.api.nvim_set_hl(ns, '@markup.italic.markdown_inline', {}) -- disable italic messing up glob patterns vim.api.nvim_win_set_hl_ns(self.winnr, ns) vim.api.nvim_win_set_buf(self.winnr, self.bufnr) - self:render() end --- Close the chat window. ----@param bufnr number? -function Chat:close(bufnr) +function Chat:close() if not self:visible() then return end @@ -359,8 +431,8 @@ function Chat:close(bufnr) end if self.layout == 'replace' then - if bufnr then - self:restore(self.winnr, bufnr) + if self.source.bufnr then + self:restore(self.winnr, self.source.bufnr) end else vim.api.nvim_win_close(self.winnr, true) @@ -387,11 +459,7 @@ function Chat:follow() return end - local last_line, last_column, line_count = self:last() - if line_count == 0 then - return - end - + local last_line, last_column = last(self.bufnr) vim.api.nvim_win_set_cursor(self.winnr, { last_line + 1, last_column }) end @@ -403,19 +471,12 @@ function Chat:start() utils.return_to_normal_mode() end - if self.spinner then - self.spinner:start() - end - + self.spinner:start() vim.bo[self.bufnr].modifiable = false end --- Finish writing to the chat window. function Chat:finish() - if not self.spinner then - return - end - self.spinner:finish() vim.bo[self.bufnr].modifiable = true if self.config.auto_insert_mode and self:focused() then @@ -423,8 +484,11 @@ function Chat:finish() end end +--- Add a message to the chat window. +---@param message CopilotChat.ui.chat.Message +---@param replace boolean? If true, replaces the last message if it has same role function Chat:add_message(message, replace) - local current_message = self.messages[#self.messages] + local current_message = self:get_message() local is_new = not current_message or current_message.role ~= message.role or (message.id and current_message.id ~= message.id) @@ -433,17 +497,15 @@ function Chat:add_message(message, replace) -- Add appropriate header based on role and generate a new ID if not provided message.id = message.id or utils.uuid() local header = self.headers[message.role] + self.messages:set(message.id, message) + if current_message then - header = '\n' .. header + self:append('\n') end - - table.insert(self.messages, message) - self:append(header .. '(' .. message.id .. ')' .. self.separator .. '\n\n') + self:append('# ' .. header .. ' (' .. message.id .. ') ' .. self.separator .. '\n\n') self:append(message.content) elseif replace and current_message then -- Replace the content of the current message - self:render() - for k, v in pairs(message) do current_message[k] = v end @@ -470,13 +532,11 @@ function Chat:add_message(message, replace) end end -function Chat:remove_message(role) - if not self:visible() then - return - end - - self:render() - local message = self:get_closest_message(role) +--- Remove a message from the chat window by role. +---@param role string? If specified, only considers sections of the given role +---@param cursor boolean? If true, removes the message closest to the cursor position +function Chat:remove_message(role, cursor) + local message = self:get_message(role, cursor) if not message then return end @@ -493,14 +553,7 @@ function Chat:remove_message(role) vim.bo[self.bufnr].modifiable = modifiable -- Remove the message from the messages list - for i, msg in ipairs(self.messages) do - if msg.id == message.id then - table.remove(self.messages, i) - break - end - end - - self:render() + self.messages:remove(message.id) end --- Append text to the chat window. @@ -517,7 +570,7 @@ function Chat:append(str) should_follow_cursor = current_pos[1] >= line_count - 1 end - local last_line, last_column, _ = self:last() + local last_line, last_column, _ = last(self.bufnr) local modifiable = vim.bo[self.bufnr].modifiable vim.bo[self.bufnr].modifiable = true @@ -534,7 +587,7 @@ function Chat:clear() self:validate() self.token_count = nil self.token_max_count = nil - self.messages = {} + self.messages = orderedmap() local modifiable = vim.bo[self.bufnr].modifiable vim.bo[self.bufnr].modifiable = true @@ -548,17 +601,25 @@ function Chat:create() local bufnr = Overlay.create(self) vim.bo[bufnr].syntax = 'markdown' vim.bo[bufnr].textwidth = 0 + vim.bo[bufnr].undolevels = 10 + self.spinner.bufnr = bufnr + + vim.schedule(function() + if not vim.treesitter.get_parser(bufnr, 'markdown', {}) then + pcall(vim.treesitter.start, bufnr) + end + end) vim.api.nvim_create_autocmd({ 'TextChanged', 'InsertLeave' }, { buffer = bufnr, callback = function() - utils.debounce(self.name, function() + utils.debounce('chat-parse-' .. bufnr, function() + self:parse() self:render() - end, 100) + end, 150) end, }) - self.spinner.bufnr = bufnr return bufnr end @@ -571,149 +632,203 @@ function Chat:validate() end end ---- Render the chat window. +--- Parse the chat window buffer into structured messages. ---@protected -function Chat:render() +function Chat:parse() self:validate() - local highlight_ns = vim.api.nvim_create_namespace('copilot-chat-headers') - vim.api.nvim_buf_clear_namespace(self.bufnr, highlight_ns, 0, -1) + -- Skip parsing if buffer hasn't changed + local changedtick = vim.api.nvim_buf_get_changedtick(self.bufnr) + if self.last_changedtick == changedtick then + return false + end + self.last_changedtick = changedtick + + local parser = vim.treesitter.get_parser(self.bufnr, 'markdown') + if not parser then + return + end - local lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) + local query = vim.treesitter.query.get('markdown', 'copilotchat') + if not query then + return + end + local root = parser:parse()[1]:root() local new_messages = {} - local current_message = nil - local current_block = nil - - local function parse_header(header, line) - return line:match('^' .. vim.pesc(header) .. '%(([^)]+)%)' .. vim.pesc(self.separator) .. '$') - end - - for l, line in ipairs(lines) do - -- Detect section header with ID - for header_name, header_value in pairs(self.headers) do - local id = parse_header(header_value, line) - if id then - -- Draw the separator as virtual text over the header line, hiding the id and anything after the header - if self.config.highlight_headers then - local sep_col = vim.fn.strwidth(header_value) - vim.api.nvim_buf_set_extmark(self.bufnr, highlight_ns, l - 1, sep_col, { - virt_text = { - { string.rep(self.separator, vim.go.columns), 'CopilotChatSeparator' }, + local current_message = { + content = {}, + section = { + blocks = {}, + }, + } + + local current_block = { + content = {}, + } + + for id, node in query:iter_captures(root, self.bufnr, 0, -1) do + local name = query.captures[id] + local start_row, _, end_row, _ = node:range() + + -- Convert 0 based to 1 based indexing + start_row = start_row + 1 + end_row = end_row + 1 + + -- Skip header line at start of the section + start_row = start_row + 1 + + if name == 'section_header' then + local header_text = vim.treesitter.get_node_text(node, self.bufnr) + local id, role = match_section_header(self.headers, self.separator, header_text) + if role and id ~= current_message.id then + current_message.section.end_line = start_row - 2 + + current_message = { + id = id, + role = role, + content = {}, + section = { + blocks = {}, + start_line = start_row, + }, + } + table.insert(new_messages, current_message) + end + elseif name == 'section_content' then + local content = vim.treesitter.get_node_text(node, self.bufnr) + current_message.section.end_line = end_row + table.insert(current_message.content, content) + elseif current_message.role == constants.ROLE.ASSISTANT then + if name == 'block_header' then + local header_text = vim.treesitter.get_node_text(node, self.bufnr) + local filetype, filename, start_line, end_line = match_block_header(header_text) + + if filetype then + current_block = { + header = { + filetype = filetype, + filename = filename, + start_line = start_line, + end_line = end_line, }, - virt_text_win_col = sep_col, - priority = 200, - strict = false, - }) - vim.api.nvim_buf_set_extmark(self.bufnr, highlight_ns, l - 1, 0, { - end_col = sep_col, - hl_group = 'CopilotChatHeader', - priority = 100, - strict = false, - }) + start_line = start_row, + content = {}, + } + table.insert(current_message.section.blocks, current_block) end + elseif name == 'block_content' then + local content = vim.treesitter.get_node_text(node, self.bufnr) + current_block.end_line = end_row - -- Finish previous message - if current_message then - current_message.section.end_line = l - 1 - current_message.content = vim.trim( - table.concat( - vim.list_slice(lines, current_message.section.start_line, current_message.section.end_line), - '\n' - ) - ) + local filename = match_block_content(current_block.header, content) + if filename then + current_block.header.filename = filename end - -- Find existing message by id or create new - local old_msg = nil - for _, msg in ipairs(self.messages) do - if msg.id == id then - old_msg = msg - break - end - end - if not old_msg then - old_msg = { id = id, role = header_name } - end + table.insert(current_block.content, content) + end + end + end - -- Attach section info - old_msg.section = { - role = header_name, - start_line = l + 1, - blocks = {}, - } - table.insert(new_messages, old_msg) - current_message = old_msg - current_block = nil - break + -- Finish last message + current_message.section.end_line = vim.api.nvim_buf_line_count(self.bufnr) + + -- Format new messages and preserve extra fields from old messages + local messages = orderedmap() + for _, message in ipairs(new_messages) do + message.content = vim.trim(table.concat(message.content, '\n')) + if message.section then + for _, block in ipairs(message.section.blocks) do + block.content = table.concat(block.content, '\n') end end - -- Code blocks - if current_message and current_message.role == 'assistant' then - local filetype, filename, start_line, end_line = match_header(line) - if filetype and filename and not current_block then - current_block = { - header = { - filename = filename, - start_line = start_line, - end_line = end_line, - filetype = filetype, - }, - start_line = l + 1, - } - local text = string.format('[%s] %s', filetype, filename) - if start_line and end_line then - text = text .. string.format(' lines %d-%d', start_line, end_line) + local old = self.messages:get(message.id) + if old then + for k, v in pairs(old) do + if message[k] == nil then + message[k] = v end - vim.api.nvim_buf_set_extmark(self.bufnr, highlight_ns, l, 0, { + end + end + + messages:set(message.id, message) + end + + -- Update messages + self.messages = messages +end + +--- Render the chat window. +---@protected +function Chat:render() + self:validate() + + local highlight_ns = vim.api.nvim_create_namespace('copilot-chat-headers') + vim.api.nvim_buf_clear_namespace(self.bufnr, highlight_ns, 0, -1) -- Clear previous highlights + self:show_help() -- Clear previous help + + local messages = self:get_messages() + + for i, message in ipairs(messages) do + if self.config.highlight_headers then + -- Overlay section header with nice display + local header_value = self.headers[message.role] + local header_line = message.section.start_line - 2 + if message.model then + header_value = header_value .. ' (' .. message.model .. ')' + end + + vim.api.nvim_buf_set_extmark(self.bufnr, highlight_ns, header_line, 0, { + conceal = '', + virt_text = { + { ' ' .. header_value .. ' ', 'CopilotChatHeader' }, + { string.rep(self.separator, vim.go.columns - #header_value - 1), 'CopilotChatSeparator' }, + }, + virt_text_pos = 'overlay', + priority = 2000, -- High priority to override other plugins if enabled + strict = false, + }) + + -- Highlight code block headers and show file info as virtual lines + for _, block in ipairs(message.section.blocks) do + local header = block.header + local filetype = header.filetype + local filename = header.filename + local text = string.format('[%s] %s', filetype, filename or 'block') + if header.start_line and header.end_line then + text = text .. string.format(' lines %d-%d', header.start_line, header.end_line) + end + vim.api.nvim_buf_set_extmark(self.bufnr, highlight_ns, block.start_line - 1, 0, { virt_lines_above = true, virt_lines = { { { text, 'CopilotChatAnnotationHeader' } } }, priority = 100, strict = false, }) - elseif line == '```' and current_block then - current_block.end_line = l - 1 - current_block.content = - table.concat(vim.list_slice(lines, current_block.start_line, current_block.end_line), '\n') - table.insert(current_message.section.blocks, current_block) - current_block = nil end end - -- If last line, finish last message - if l == #lines and current_message then - current_message.section.end_line = l - current_message.content = vim.trim( - table.concat(vim.list_slice(lines, current_message.section.start_line, current_message.section.end_line), '\n') - ) - end - - -- Highlight response calls - for _, message in ipairs(self.messages) do - for _, tool_call in ipairs(message.tool_calls or {}) do - if line:match(string.format('#%s:%s', tool_call.name, vim.pesc(tool_call.id))) then - vim.api.nvim_buf_add_highlight(self.bufnr, highlight_ns, 'CopilotChatAnnotationHeader', l - 1, 0, #line) - if not utils.empty(tool_call.arguments) then - vim.api.nvim_buf_set_extmark(self.bufnr, highlight_ns, l - 1, 0, { - virt_lines = vim.tbl_map(function(json_line) - return { { json_line, 'CopilotChatAnnotation' } } - end, vim.split(vim.inspect(utils.json_decode(tool_call.arguments)), '\n')), - priority = 100, - strict = false, - }) - end - break - end + -- Show reasoning as virtual text above assistant messages + if + message.role == constants.ROLE.ASSISTANT + and not utils.empty(message.reasoning) + and message.section + and message.section.start_line + then + local virt_lines = {} + for _, line in ipairs(vim.split(message.reasoning, '\n')) do + table.insert(virt_lines, { { 'Reasoning: ' .. line, 'CopilotChatAnnotation' } }) end + vim.api.nvim_buf_set_extmark(self.bufnr, highlight_ns, message.section.start_line - 1, 0, { + virt_lines = virt_lines, + virt_lines_above = true, + priority = 100, + strict = false, + }) end - end - - -- Replace self.messages with new_messages (preserving tool_calls, etc.) - self.messages = new_messages - -- Show tool call details as virt lines - for _, message in ipairs(self.messages) do + -- Show tool call details as virt lines in assistant messages if message.tool_calls and #message.tool_calls > 0 then local section = message.section if section and section.end_line then @@ -733,13 +848,14 @@ function Chat:render() end end + -- Highlight tool calls in tool messages if message.tool_call_id then local section = message.section if section and section.start_line then local virt_lines = { { { 'Tool: ' .. message.tool_call_id, 'CopilotChatAnnotationHeader' } }, } - vim.api.nvim_buf_set_extmark(self.bufnr, highlight_ns, section.start_line, 0, { + vim.api.nvim_buf_set_extmark(self.bufnr, highlight_ns, section.start_line - 1, 0, { virt_lines = virt_lines, virt_lines_above = true, priority = 100, @@ -747,40 +863,86 @@ function Chat:render() }) end end - end - -- Show help as before, using last user message - local last_message = self.messages[#self.messages] - if last_message and last_message.role == 'user' then - local msg = self.config.show_help and self.help or '' - if self.token_count and self.token_max_count then - if msg ~= '' then - msg = msg .. '\n' + if i == #messages and message.role == constants.ROLE.USER then + -- Highlight tools in the last user message + local assistant_msg = self:get_message(constants.ROLE.ASSISTANT) + if assistant_msg and assistant_msg.tool_calls and #assistant_msg.tool_calls > 0 then + for j, line in ipairs(utils.split_lines(message.content)) do + for _, tool_call in ipairs(assistant_msg.tool_calls) do + if line:match(string.format('#%s:%s', tool_call.name, vim.pesc(tool_call.id))) then + local l = message.section.start_line + j + vim.api.nvim_buf_add_highlight(self.bufnr, highlight_ns, 'CopilotChatAnnotationHeader', l, 0, #line) + if not utils.empty(tool_call.arguments) then + vim.api.nvim_buf_set_extmark(self.bufnr, highlight_ns, l, 0, { + virt_lines = vim.tbl_map(function(json_line) + return { { json_line, 'CopilotChatAnnotation' } } + end, vim.split(vim.inspect(utils.json_decode(tool_call.arguments)), '\n')), + priority = 100, + strict = false, + }) + end + end + end + end + end + + -- Show help message and token usage below the last user message + local msg = self.config.show_help and self.help or '' + if self.token_count and self.token_max_count then + if msg ~= '' then + msg = msg .. '\n' + end + msg = msg .. self.token_count .. '/' .. self.token_max_count .. ' tokens used' + end + self:show_help(msg, message.section.start_line) + end + + -- Auto fold non-assistant messages if enabled + if self.config.auto_fold and self:visible() then + if message.role ~= constants.ROLE.ASSISTANT and message.section and i < #messages then + vim.api.nvim_win_call(self.winnr, function() + local fold_level = vim.fn.foldlevel(message.section.start_line) + if fold_level > 0 and vim.fn.foldclosed(message.section.start_line) == -1 then + vim.api.nvim_cmd({ cmd = 'foldclose', range = { message.section.start_line } }, {}) + end + end) end - msg = msg .. self.token_count .. '/' .. self.token_max_count .. ' tokens used' end - self:show_help(msg, last_message.section.start_line - last_message.section.end_line - 1) - else - self:show_help() end end ---- Get the last line and column of the chat window. ----@return number, number, number ----@protected -function Chat:last() - self:validate() - local line_count = vim.api.nvim_buf_line_count(self.bufnr) - local last_line = line_count - 1 - if last_line < 0 then - return 0, 0, line_count - end - local last_line_content = vim.api.nvim_buf_get_lines(self.bufnr, -2, -1, false) - if not last_line_content or #last_line_content == 0 then - return last_line, 0, line_count +--- Get the current source buffer and window. +function Chat:get_source() + return self.source +end + +--- Sets the source to the given window. +---@param source_winnr number +---@return boolean if the source was set +function Chat:set_source(source_winnr) + local source_bufnr = vim.api.nvim_win_get_buf(source_winnr) + + -- Check if the window is valid to use as a source + if source_winnr ~= self.winnr and source_bufnr ~= self.bufnr and vim.fn.win_gettype(source_winnr) == '' then + self.source = { + bufnr = source_bufnr, + winnr = source_winnr, + cwd = function() + local ok, dir = pcall(function() + return vim.w[source_winnr].cchat_cwd + end) + if not ok or not dir or dir == '' then + return '.' + end + return dir + end, + } + + return true end - local last_column = #last_line_content[1] - return last_line, last_column, line_count + + return false end return Chat diff --git a/lua/CopilotChat/ui/overlay.lua b/lua/CopilotChat/ui/overlay.lua index a23c022e..ace646c4 100644 --- a/lua/CopilotChat/ui/overlay.lua +++ b/lua/CopilotChat/ui/overlay.lua @@ -1,8 +1,8 @@ local utils = require('CopilotChat.utils') -local class = utils.class +local class = require('CopilotChat.utils.class') ---@class CopilotChat.ui.overlay.Overlay : Class ----@field bufnr number? +---@field bufnr integer? ---@field protected name string ---@field protected help string ---@field private cursor integer[]? @@ -23,11 +23,11 @@ end) --- Show the overlay buffer ---@param text string ----@param winnr number +---@param winnr integer ---@param filetype? string ---@param syntax string? ----@param on_show? fun(bufnr: number) ----@param on_hide? fun(bufnr: number) +---@param on_show? fun(bufnr: integer) +---@param on_hide? fun(bufnr: integer) function Overlay:show(text, winnr, filetype, syntax, on_show, on_hide) if not text or text == '' then return @@ -41,7 +41,7 @@ function Overlay:show(text, winnr, filetype, syntax, on_show, on_hide) vim.bo[self.bufnr].modifiable = true vim.api.nvim_buf_set_lines(self.bufnr, 0, -1, false, vim.split(text, '\n')) vim.bo[self.bufnr].modifiable = false - self:show_help(self.help, -1) + self:show_help(self.help, vim.api.nvim_buf_line_count(self.bufnr)) vim.api.nvim_win_set_cursor(winnr, { 1, 0 }) filetype = filetype or 'markdown' @@ -75,7 +75,7 @@ function Overlay:delete() end --- Create the overlay buffer ----@return number +---@return integer ---@protected function Overlay:create() local bufnr = vim.api.nvim_create_buf(false, true) @@ -116,6 +116,10 @@ function Overlay:restore(winnr, bufnr) self.on_hide(self.bufnr) end + if not vim.api.nvim_win_is_valid(winnr) then + return + end + vim.api.nvim_win_set_buf(winnr, bufnr) if self.cursor then @@ -124,26 +128,28 @@ function Overlay:restore(winnr, bufnr) -- Manually trigger BufEnter event as nvim_win_set_buf does not trigger it vim.schedule(function() - vim.cmd(string.format('doautocmd BufEnter %s', bufnr)) + if vim.api.nvim_buf_is_valid(bufnr) then + vim.api.nvim_exec_autocmds('BufEnter', { buffer = bufnr }) + end end) end --- Show help message in the overlay ---@param msg string? ----@param offset number? +---@param pos number? ---@protected -function Overlay:show_help(msg, offset) +function Overlay:show_help(msg, pos) if not msg or msg == '' then vim.api.nvim_buf_del_extmark(self.bufnr, self.help_ns, 1) return end self:validate() - local line = vim.api.nvim_buf_line_count(self.bufnr) + (offset or 0) - vim.api.nvim_buf_set_extmark(self.bufnr, self.help_ns, math.max(0, line - 1), 0, { + vim.api.nvim_buf_set_extmark(self.bufnr, self.help_ns, math.max(0, pos - 1), 0, { id = 1, hl_mode = 'combine', priority = 100, + virt_lines_above = true, virt_lines = vim.tbl_map(function(t) return { { t, 'CopilotChatHelp' } } end, vim.split(msg, '\n')), diff --git a/lua/CopilotChat/ui/spinner.lua b/lua/CopilotChat/ui/spinner.lua index 0f582032..06091a16 100644 --- a/lua/CopilotChat/ui/spinner.lua +++ b/lua/CopilotChat/ui/spinner.lua @@ -1,6 +1,7 @@ -local notify = require('CopilotChat.notify') +local notify = require('CopilotChat.utils.notify') local utils = require('CopilotChat.utils') -local class = utils.class +local class = require('CopilotChat.utils.class') + local spinner_frames = { '⠋', '⠙', diff --git a/lua/CopilotChat/utils.lua b/lua/CopilotChat/utils.lua index 251238ca..cbdced39 100644 --- a/lua/CopilotChat/utils.lua +++ b/lua/CopilotChat/utils.lua @@ -1,106 +1,21 @@ local async = require('plenary.async') -local curl = require('plenary.curl') -local scandir = require('plenary.scandir') local log = require('plenary.log') local M = {} M.timers = {} -M.scan_args = { - max_count = 2500, - max_depth = 50, - no_ignore = false, -} - -M.curl_args = { - timeout = 30000, - raw = { - '--retry', - '2', - '--retry-delay', - '1', - '--keepalive-time', - '60', - '--no-compressed', - '--connect-timeout', - '10', - '--tcp-nodelay', - '--no-buffer', - }, -} - ----@class Class ----@field new fun(...):table ----@field init fun(self, ...) - ---- Create class ----@param fn function The class constructor ----@param parent table? The parent class ----@return Class -function M.class(fn, parent) - local out = {} - out.__index = out - - local mt = { - __call = function(cls, ...) - return cls.new(...) - end, - } - - if parent then - mt.__index = parent - end - - setmetatable(out, mt) - - function out.new(...) - local self = setmetatable({}, out) - fn(self, ...) - return self - end - - function out.init(self, ...) - fn(self, ...) - end - - return out +--- Use CopilotChat.utils.curl.get instead +---@deprecated +function M.curl_get(url, opts) + log.warn('M.curl_get is deprecated, use CopilotChat.utils.curl.get instead') + return require('CopilotChat.utils.curl').get(url, opts) end ----@class OrderedMap ----@field set fun(self:OrderedMap, key:any, value:any) ----@field get fun(self:OrderedMap, key:any):any ----@field keys fun(self:OrderedMap):table ----@field values fun(self:OrderedMap):table - ---- Create an ordered map ----@return OrderedMap -function M.ordered_map() - return { - _keys = {}, - _data = {}, - set = function(self, key, value) - if not self._data[key] then - table.insert(self._keys, key) - end - self._data[key] = value - end, - - get = function(self, key) - return self._data[key] - end, - - keys = function(self) - return self._keys - end, - - values = function(self) - local result = {} - for _, key in ipairs(self._keys) do - table.insert(result, self._data[key]) - end - return result - end, - } +--- Use CopilotChat.utils.curl.post instead +---@deprecated +function M.curl_post(url, opts) + log.warn('M.curl_post is deprecated, use CopilotChat.utils.curl.post instead') + return require('CopilotChat.utils.curl').post(url, opts) end --- Convert arguments to a table @@ -121,61 +36,13 @@ function M.to_table(...) return result end ----@class StringBuffer ----@field add fun(self:StringBuffer, s:string) ----@field set fun(self:StringBuffer, s:string) ----@field tostring fun(self:StringBuffer):string - ---- Create a string buffer for efficient string concatenation ----@return StringBuffer -function M.string_buffer() - return { - _buf = { '' }, - - add = function(self, s) - table.insert(self._buf, s) - -- Keep track of lengths to know when to merge - for i = #self._buf - 1, 1, -1 do - if #self._buf[i] > #self._buf[i + 1] then - break - end - self._buf[i] = self._buf[i] .. table.remove(self._buf) - end - end, - - set = function(self, s) - self._buf = { s } - end, - - -- Get final string - tostring = function(self) - return table.concat(self._buf) - end, - } -end - ---- Writes text to a temporary file and returns path ----@param text string The text to write ----@return string? -function M.temp_file(text) - local temp_file = os.tmpname() - local f = io.open(temp_file, 'w+') - if f == nil then - error('Could not open file: ' .. temp_file) - end - f:write(text) - f:close() - return temp_file -end - --- Return to normal mode function M.return_to_normal_mode() local mode = vim.fn.mode():lower() if mode:find('v') then vim.cmd([[execute "normal! \"]]) - elseif mode ~= 'n' then - vim.cmd('stopinsert') end + vim.cmd('stopinsert') end --- Debounce a function @@ -199,88 +66,6 @@ function M.buf_valid(bufnr) or false end ---- Check if file paths are the same ----@param file1 string? The first file path ----@param file2 string? The second file path ----@return boolean -function M.filename_same(file1, file2) - if not file1 or not file2 then - return false - end - return vim.fn.fnamemodify(file1, ':p') == vim.fn.fnamemodify(file2, ':p') -end - ---- Get the filetype of a file ----@param filename string The file name ----@return string|nil -function M.filetype(filename) - local filetype = require('plenary.filetype') - - local ft = filetype.detect(filename, { - fs_access = false, - }) - - if ft == '' or not ft then - return vim.filetype.match({ filename = filename }) - end - - return ft -end - ---- Get the mimetype from filetype ----@param filetype string? ----@return string -function M.filetype_to_mimetype(filetype) - if not filetype or filetype == '' then - return 'text/plain' - end - if filetype == 'json' or filetype == 'yaml' then - return 'application/' .. filetype - end - if filetype == 'html' or filetype == 'css' then - return 'text/' .. filetype - end - return 'text/x-' .. filetype -end - ---- Get the filetype from mimetype ----@param mimetype string? ----@return string -function M.mimetype_to_filetype(mimetype) - if not mimetype or mimetype == '' then - return 'text' - end - - local out = mimetype:gsub('^text/x%-', '') - out = out:gsub('^text/', '') - out = out:gsub('^application/', '') - out = out:gsub('^image/', '') - out = out:gsub('^video/', '') - out = out:gsub('^audio/', '') - return out -end - ---- Convert a URI to a file name ----@param uri string The URI ----@return string -function M.uri_to_filename(uri) - if not uri or uri == '' then - return uri - end - local ok, fname = pcall(vim.uri_to_fname, uri) - if not ok or M.empty(fname) then - return uri - end - return fname -end - ---- Get the file name ----@param filepath string The file path ----@return string -function M.filename(filepath) - return vim.fs.basename(filepath) -end - --- Generate a UUID ---@return string function M.uuid() @@ -293,13 +78,6 @@ function M.uuid() ) end ---- Generate a quick hash ----@param str string The string to hash ----@return string -function M.quick_hash(str) - return #str .. str:sub(1, 64) .. str:sub(-64) -end - --- Make a string from arguments ---@vararg any The arguments ---@return string @@ -344,331 +122,12 @@ function M.json_decode(body) return {}, data end ---- Store curl global arguments ----@param args table The arguments ----@return table -function M.curl_store_args(args) - M.curl_args = vim.tbl_deep_extend('force', M.curl_args, args) - return M.curl_args -end - ---- Send curl get request ----@param url string The url ----@param opts table? The options ----@async -M.curl_get = async.wrap(function(url, opts, callback) - log.debug('GET request:', url, opts) - local args = { - on_error = function(err) - log.debug('GET error:', err) - callback(nil, err and err.stderr or err) - end, - } - - args = vim.tbl_deep_extend('force', M.curl_args, args) - args = vim.tbl_deep_extend('force', args, opts or {}) - - args.callback = function(response) - log.debug('GET response:', response) - if response and not vim.startswith(tostring(response.status), '20') then - callback(response, response.body) - return - end - - if not args.json_response then - callback(response) - return - end - - local body, err = M.json_decode(tostring(response.body)) - if err then - callback(response, err) - else - response.body = body - callback(response) - end - end - - curl.get(url, args) -end, 3) - ---- Send curl post request ----@param url string The url ----@param opts table? The options ----@async -M.curl_post = async.wrap(function(url, opts, callback) - log.debug('POST request:', url, opts) - local args = { - on_error = function(err) - log.debug('POST error:', err) - callback(nil, err and err.stderr or err) - end, - } - - args = vim.tbl_deep_extend('force', M.curl_args, args) - args = vim.tbl_deep_extend('force', args, opts or {}) - - local temp_file_path = nil - - args.callback = function(response) - log.debug('POST response:', url, response) - if temp_file_path then - local ok, err = pcall(os.remove, temp_file_path) - if not ok then - log.debug('Failed to remove temp file:', temp_file_path, err) - end - end - if response and not vim.startswith(tostring(response.status), '20') then - callback(response, response.body) - return - end - - if not args.json_response then - callback(response) - return - end - - local body, err = M.json_decode(tostring(response.body)) - if err then - callback(response, err) - else - response.body = body - callback(response) - end - end - - if args.json_response then - args.headers = vim.tbl_deep_extend('force', args.headers or {}, { - Accept = 'application/json', - }) - end - - if args.json_request then - args.headers = vim.tbl_deep_extend('force', args.headers or {}, { - ['Content-Type'] = 'application/json', - }) - - temp_file_path = M.temp_file(vim.json.encode(args.body)) - args.body = temp_file_path - end - - curl.post(url, args) -end, 3) - -local function filter_files(files, max_count) - local filetype = require('plenary.filetype') - - files = vim.tbl_filter(function(file) - if file == nil or file == '' then - return false - end - - local ft = filetype.detect(file, { - fs_access = false, - }) - - if ft == '' or not ft then - return false - end - - return true - end, files) - if max_count and max_count > 0 then - files = vim.list_slice(files, 1, max_count) - end - - return files -end - ----@class CopilotChat.utils.ScanOpts ----@field max_count number? The maximum number of files to scan ----@field max_depth number? The maximum depth to scan ----@field glob? string The glob pattern to match files ----@field hidden? boolean Whether to include hidden files ----@field no_ignore? boolean Whether to respect or ignore .gitignore - ---- Scan a directory ----@param path string ----@param opts CopilotChat.utils.ScanOpts? ----@async -M.glob = async.wrap(function(path, opts, callback) - opts = vim.tbl_deep_extend('force', M.scan_args, opts or {}) - - -- Use ripgrep if available - if vim.fn.executable('rg') == 1 then - local cmd = { 'rg' } - - if opts.pattern then - table.insert(cmd, '-g') - table.insert(cmd, opts.pattern) - end - - if opts.max_depth then - table.insert(cmd, '--max-depth') - table.insert(cmd, tostring(opts.max_depth)) - end - - if opts.no_ignore then - table.insert(cmd, '--no-ignore') - end - - if opts.hidden then - table.insert(cmd, '--hidden') - end - - table.insert(cmd, '--files') - table.insert(cmd, path) - - vim.system(cmd, { text = true }, function(result) - local files = {} - if result and result.code == 0 and result.stdout ~= '' then - files = filter_files(vim.split(result.stdout, '\n'), opts.max_count) - end - - callback(files) - end) - - return - end - - -- Fall back to scandir if rg is not available or fails - scandir.scan_dir_async( - path, - vim.tbl_deep_extend('force', opts, { - depth = opts.max_depth, - add_dirs = false, - search_pattern = opts.glob and M.glob_to_pattern(opts.glob) or nil, - respect_gitignore = not opts.no_ignore, - on_exit = function(files) - callback(filter_files(files, opts.max_count)) - end, - }) - ) -end, 3) - ---- Grep a directory ----@param path string The path to search ----@param opts CopilotChat.utils.ScanOpts? -M.grep = async.wrap(function(path, opts, callback) - opts = vim.tbl_deep_extend('force', M.scan_args, opts or {}) - local cmd = {} - - if vim.fn.executable('rg') == 1 then - table.insert(cmd, 'rg') - - if opts.max_depth then - table.insert(cmd, '--max-depth') - table.insert(cmd, tostring(opts.max_depth)) - end - - if opts.no_ignore then - table.insert(cmd, '--no-ignore') - end - - if opts.hidden then - table.insert(cmd, '--hidden') - end - - table.insert(cmd, '--files-with-matches') - table.insert(cmd, '--ignore-case') - - if opts.pattern then - table.insert(cmd, '-e') - table.insert(cmd, "'" .. opts.pattern .. "'") - end - - table.insert(cmd, path) - elseif vim.fn.executable('grep') == 1 then - table.insert(cmd, 'grep') - table.insert(cmd, '-rli') - - if opts.pattern then - table.insert(cmd, '-e') - table.insert(cmd, "'" .. opts.pattern .. "'") - end - - table.insert(cmd, path) - end - - if M.empty(cmd) then - error('No executable found for grep') - return - end - - vim.system(cmd, { text = true }, function(result) - local files = {} - if result and result.code == 0 and result.stdout ~= '' then - files = filter_files(vim.split(result.stdout, '\n'), opts.max_count) - end - - callback(files) - end) -end, 3) - ---- Get last modified time of a file ----@param path string The file path ----@return number? ----@async -function M.file_mtime(path) - local err, stat = async.uv.fs_stat(path) - if err or not stat then - return nil - end - return stat.mtime.sec -end - ---- Read a file ----@param path string The file path ----@async -function M.read_file(path) - local err, fd = async.uv.fs_open(path, 'r', 438) - if err or not fd then - return nil - end - - local err, stat = async.uv.fs_fstat(fd) - if err or not stat then - async.uv.fs_close(fd) - return nil - end - - local err, data = async.uv.fs_read(fd, stat.size, 0) - async.uv.fs_close(fd) - if err or not data then - return nil - end - return data -end - ---- Write data to a file ----@param path string The file path ----@param data string The data to write ----@return boolean -function M.write_file(path, data) - M.schedule_main() - vim.fn.mkdir(vim.fn.fnamemodify(path, ':p:h'), 'p') - - local err, fd = async.uv.fs_open(path, 'w', 438) - if err or not fd then - return false - end - - local err = async.uv.fs_write(fd, data, 0) - if err then - async.uv.fs_close(fd) - return false - end - - async.uv.fs_close(fd) - return true -end - --- Call a system command ---@param cmd table The command ---@async -M.system = async.wrap(function(cmd, callback) - vim.system(cmd, { text = true }, callback) -end, 2) +M.system = async.wrap(function(cmd, cwd, callback) + vim.system(cmd, { cwd = cwd, text = true }, callback) +end, 3) --- Schedule a function only when needed (not on main thread) ---@param callback function The callback @@ -685,46 +144,6 @@ M.schedule_main = async.wrap(function(callback) end end, 1) ---- Run parse on a treesitter parser asynchronously if possible ----@param parser vim.treesitter.LanguageTree The parser -M.ts_parse = async.wrap(function(parser, callback) - ---@diagnostic disable-next-line: invisible - if not parser._async_parse then - local fn = function() - local trees = parser:parse(false) - if not trees or #trees == 0 then - callback(nil) - return - end - callback(trees[1]:root()) - end - - if vim.in_fast_event() then - vim.schedule(fn) - else - fn() - end - - return - end - - local fn = function() - parser:parse(false, function(err, trees) - if err or not trees or #trees == 0 then - callback(nil) - return - end - callback(trees[1]:root()) - end) - end - - if vim.in_fast_event() then - vim.schedule(fn) - else - fn() - end -end, 2) - --- Wait for a user input M.input = async.wrap(function(opts, callback) local fn = function() @@ -816,136 +235,15 @@ function M.empty(v) return false end ---- Convert glob pattern to regex pattern ---- https://github.com/davidm/lua-glob-pattern/blob/master/lua/globtopattern.lua ----@param g string The glob pattern ----@return string -function M.glob_to_pattern(g) - local p = '^' -- pattern being built - local i = 0 -- index in g - local c -- char at index i in g. - - -- unescape glob char - local function unescape() - if c == '\\' then - i = i + 1 - c = g:sub(i, i) - if c == '' then - p = '[^]' - return false - end - end - return true - end - - -- escape pattern char - local function escape(c) - return c:match('^%w$') and c or '%' .. c - end - - -- Convert tokens at end of charset. - local function charset_end() - while 1 do - if c == '' then - p = '[^]' - return false - elseif c == ']' then - p = p .. ']' - break - else - if not unescape() then - break - end - local c1 = c - i = i + 1 - c = g:sub(i, i) - if c == '' then - p = '[^]' - return false - elseif c == '-' then - i = i + 1 - c = g:sub(i, i) - if c == '' then - p = '[^]' - return false - elseif c == ']' then - p = p .. escape(c1) .. '%-]' - break - else - if not unescape() then - break - end - p = p .. escape(c1) .. '-' .. escape(c) - end - elseif c == ']' then - p = p .. escape(c1) .. ']' - break - else - p = p .. escape(c1) - i = i - 1 -- put back - end - end - i = i + 1 - c = g:sub(i, i) - end - return true +--- Split text into lines +---@param text string The text to split +---@return string[] A table of lines +function M.split_lines(text) + if not text or text == '' then + return {} end - -- Convert tokens in charset. - local function charset() - i = i + 1 - c = g:sub(i, i) - if c == '' or c == ']' then - p = '[^]' - return false - elseif c == '^' or c == '!' then - i = i + 1 - c = g:sub(i, i) - if c == ']' then - -- ignored - else - p = p .. '[^' - if not charset_end() then - return false - end - end - else - p = p .. '[' - if not charset_end() then - return false - end - end - return true - end - - -- Convert tokens. - while 1 do - i = i + 1 - c = g:sub(i, i) - if c == '' then - p = p .. '$' - break - elseif c == '?' then - p = p .. '.' - elseif c == '*' then - p = p .. '.*' - elseif c == '[' then - if not charset() then - break - end - elseif c == '\\' then - i = i + 1 - c = g:sub(i, i) - if c == '' then - p = p .. '\\$' - break - end - p = p .. escape(c) - else - p = p .. escape(c) - end - end - return p + return vim.split(text, '\r?\n', { trimempty = false }) end return M diff --git a/lua/CopilotChat/utils/class.lua b/lua/CopilotChat/utils/class.lua new file mode 100644 index 00000000..b8dfce83 --- /dev/null +++ b/lua/CopilotChat/utils/class.lua @@ -0,0 +1,38 @@ +---@class Class +---@field new fun(...):table +---@field init fun(self, ...) + +--- Create class +---@param fn function The class constructor +---@param parent table? The parent class +---@return Class +local function class(fn, parent) + local out = {} + out.__index = out + + local mt = { + __call = function(cls, ...) + return cls.new(...) + end, + } + + if parent then + mt.__index = parent + end + + setmetatable(out, mt) + + function out.new(...) + local self = setmetatable({}, out) + fn(self, ...) + return self + end + + function out.init(self, ...) + fn(self, ...) + end + + return out +end + +return class diff --git a/lua/CopilotChat/utils/curl.lua b/lua/CopilotChat/utils/curl.lua new file mode 100644 index 00000000..2c2cf60e --- /dev/null +++ b/lua/CopilotChat/utils/curl.lua @@ -0,0 +1,148 @@ +local async = require('plenary.async') +local curl = require('plenary.curl') +local log = require('plenary.log') +local utils = require('CopilotChat.utils') + +local M = {} + +M.args = { + timeout = 30000, + raw = { + '--retry', + '2', + '--retry-delay', + '1', + '--keepalive-time', + '60', + '--no-compressed', + '--connect-timeout', + '10', + '--tcp-nodelay', + '--no-buffer', + }, +} + +--- Store curl global arguments +---@param args table The arguments +---@return table +function M.store_args(args) + M.args = vim.tbl_deep_extend('force', M.args, args) + return M.args +end + +--- Send curl get request +---@param url string The url +---@param opts table? The options +---@async +M.get = async.wrap(function(url, opts, callback) + log.debug('GET request:', url, opts) + local args = { + on_error = function(err) + log.debug('GET error:', err) + callback(nil, err and err.stderr or err) + end, + } + + args = vim.tbl_deep_extend('force', M.args, args) + args = vim.tbl_deep_extend('force', args, opts or {}) + + args.callback = function(response) + log.debug('GET response:', response) + -- HTTP status codes: 1xx (informational), 2xx (success) + -- Status 100 (Continue) is common with streaming responses + local status_str = tostring(response.status) + if response and not vim.startswith(status_str, '1') and not vim.startswith(status_str, '20') then + callback(response, response.body) + return + end + + if not args.json_response then + callback(response) + return + end + + local body, err = utils.json_decode(tostring(response.body)) + if err then + callback(response, err) + else + response.body = body + callback(response) + end + end + + curl.get(url, args) +end, 3) + +--- Send curl post request +---@param url string The url +---@param opts table? The options +---@async +M.post = async.wrap(function(url, opts, callback) + log.debug('POST request:', url, opts) + local args = { + on_error = function(err) + log.debug('POST error:', err) + callback(nil, err and err.stderr or err) + end, + } + + args = vim.tbl_deep_extend('force', M.args, args) + args = vim.tbl_deep_extend('force', args, opts or {}) + + local temp_file_path = nil + + args.callback = function(response) + log.debug('POST response:', url, response) + if temp_file_path then + local ok, err = pcall(os.remove, temp_file_path) + if not ok then + log.debug('Failed to remove temp file:', temp_file_path, err) + end + end + -- HTTP status codes: 1xx (informational), 2xx (success) + -- Status 100 (Continue) is common with streaming responses + local status_str = tostring(response.status) + if response and not vim.startswith(status_str, '1') and not vim.startswith(status_str, '20') then + callback(response, response.body) + return + end + + if not args.json_response then + callback(response) + return + end + + local body, err = utils.json_decode(tostring(response.body)) + if err then + callback(response, err) + else + response.body = body + callback(response) + end + end + + if args.json_response then + args.headers = vim.tbl_deep_extend('force', args.headers or {}, { + Accept = 'application/json', + }) + end + + if args.json_request then + args.headers = vim.tbl_deep_extend('force', args.headers or {}, { + ['Content-Type'] = 'application/json', + }) + + temp_file_path = os.tmpname() + local f = io.open(temp_file_path, 'w+') + if f == nil then + error('Could not open file: ' .. temp_file_path) + end + f:write(vim.json.encode(args.body)) + f:close() + args.body = temp_file_path + end + + curl.post(url, args) +end, 3) + +return M diff --git a/lua/CopilotChat/utils/diff.lua b/lua/CopilotChat/utils/diff.lua new file mode 100644 index 00000000..6a2384a6 --- /dev/null +++ b/lua/CopilotChat/utils/diff.lua @@ -0,0 +1,241 @@ +local log = require('plenary.log') + +local M = {} + +--- Parse unified diff hunks from diff text +---@param diff_text string +---@return table hunks +local function parse_hunks(diff_text) + local hunks = {} + local current_hunk = nil + for _, line in ipairs(vim.split(diff_text, '\n')) do + if line:match('^@@') then + if current_hunk then + table.insert(hunks, current_hunk) + end + local start_old, len_old, start_new, len_new = line:match('@@%s%-(%d+),?(%d*)%s%+(%d+),?(%d*)%s@@') + current_hunk = { + start_old = tonumber(start_old), + len_old = len_old == '' and 1 or tonumber(len_old), + start_new = tonumber(start_new), + len_new = len_new == '' and 1 or tonumber(len_new), + old_snippet = {}, + new_snippet = {}, + } + elseif current_hunk then + local prefix, rest = line:sub(1, 1), tostring(line:sub(2)) + if prefix == '-' then + table.insert(current_hunk.old_snippet, rest) + elseif prefix == '+' then + table.insert(current_hunk.new_snippet, rest) + elseif prefix == ' ' then + table.insert(current_hunk.old_snippet, rest) + table.insert(current_hunk.new_snippet, rest) + end + end + end + if current_hunk then + table.insert(hunks, current_hunk) + end + return hunks +end + +--- Try to match old_snippet in lines starting at approximate start_line +---@param lines table +---@param old_snippet table +---@param approx_start number +---@param search_range number +---@return number? matched_start +local function find_best_match(lines, old_snippet, approx_start, search_range) + local best_idx, best_score = nil, -1 + local old_len = #old_snippet + + if old_len == 0 then + return approx_start + end + + local min_start = math.max(1, approx_start - search_range) + local max_start = math.min(#lines - old_len + 1, approx_start + search_range) + + for start_idx = min_start, max_start do + local score = 0 + for i = 1, old_len do + if vim.trim(lines[start_idx + i - 1] or '') == vim.trim(old_snippet[i] or '') then + score = score + 1 + end + end + + if score > best_score then + best_score = score + best_idx = start_idx + end + + if score == old_len then + return best_idx + end + end + + if best_score >= math.ceil(old_len * 0.8) then + return best_idx + end + + return nil +end + +--- Apply a single hunk to content +---@param hunk table +---@param content string +---@return string patched_content, boolean applied_cleanly +local function apply_hunk(hunk, content) + local lines = vim.split(content, '\n') + local start_idx = hunk.start_old + + -- Handle insertions (len_old == 0) + if hunk.len_old == 0 then + -- For insertions, start_old indicates where to insert + -- start_old = 0 means insert at beginning + -- start_old = n means insert after line n + if start_idx == 0 then + start_idx = 1 + else + start_idx = start_idx + 1 + end + local new_lines = vim.list_slice(lines, 1, start_idx - 1) + vim.list_extend(new_lines, hunk.new_snippet) + vim.list_extend(new_lines, lines, start_idx, #lines) + -- Insertions are always applied cleanly if we reach this point + return table.concat(new_lines, '\n'), true + end + + -- Handle replacements and deletions (len_old > 0) + -- If we have a start line hint, try to find best match within +/- 2 lines + if start_idx and start_idx > 0 and start_idx <= #lines then + local match_idx = find_best_match(lines, hunk.old_snippet, start_idx, 2) + if match_idx then + start_idx = match_idx + end + else + -- No valid start line, search for best match in whole content + local match_idx = find_best_match(lines, hunk.old_snippet, 1, #lines) + if match_idx then + start_idx = match_idx + else + start_idx = 1 + end + end + + -- Replace old lines with new lines + local end_idx = start_idx + #hunk.old_snippet - 1 + local new_lines = vim.list_slice(lines, 1, start_idx - 1) + vim.list_extend(new_lines, hunk.new_snippet) + vim.list_extend(new_lines, lines, end_idx + 1, #lines) + + -- Check if we matched exactly at the hinted position + local applied_cleanly = find_best_match(lines, hunk.old_snippet, hunk.start_old or start_idx, 0) == start_idx + return table.concat(new_lines, '\n'), applied_cleanly +end + +--- Apply unified diff to a table of lines and return new lines +---@param diff_text string +---@param original_content string +---@return string[], boolean, integer?, integer? +function M.apply_unified_diff(diff_text, original_content) + local hunks = parse_hunks(diff_text) + local new_content = original_content + local applied = false + local offset = 0 -- Track cumulative line offset from previous hunks + + for _, hunk in ipairs(hunks) do + -- Adjust hunk start position based on accumulated offset + local adjusted_hunk = vim.deepcopy(hunk) + if adjusted_hunk.start_old then + adjusted_hunk.start_old = hunk.start_old + offset + end + + local patched, ok = apply_hunk(adjusted_hunk, new_content) + new_content = patched + applied = applied or ok + + -- Update offset: (new lines added) - (old lines removed) + offset = offset + (#hunk.new_snippet - #hunk.old_snippet) + end + + local new_lines = vim.split(new_content, '\n', { trimempty = true }) + local diff_hunks = vim.diff( + original_content, + new_content, + { algorithm = 'myers', ctxlen = 10, interhunkctxlen = 10, ignore_whitespace_change = true, result_type = 'indices' } + ) + if not diff_hunks or #diff_hunks == 0 then + return new_lines, applied, nil, nil + end + local first, last + for _, hunk in ipairs(diff_hunks) do + local hunk_start = hunk[1] + local hunk_end = hunk[1] + hunk[2] - 1 + if not first or hunk_start < first then + first = hunk_start + end + if not last or hunk_end > last then + last = hunk_end + end + end + return new_lines, applied, first, last +end + +--- Get diff from block content and buffer lines +---@param block CopilotChat.ui.chat.Block Block containing diff info +---@param lines table table of lines +---@return string diff, string content +function M.get_diff(block, lines) + local content = table.concat(lines, '\n') + if block.header.filetype == 'diff' then + return block.content, content + end + + local patched_lines = vim.split(block.content, '\n', { trimempty = true }) + local start_idx = block.header.start_line + local end_idx = block.header.end_line + local original_lines = lines + if start_idx and end_idx then + local new_lines = vim.list_slice(original_lines, 1, start_idx - 1) + vim.list_extend(new_lines, patched_lines) + vim.list_extend(new_lines, original_lines, end_idx + 1, #original_lines) + patched_lines = new_lines + end + + return tostring( + vim.diff( + table.concat(original_lines, '\n'), + table.concat(patched_lines, '\n'), + { algorithm = 'myers', ctxlen = 10, interhunkctxlen = 10, ignore_whitespace_change = true } + ) + ), + content +end + +--- Apply a diff (unified or indices) to buffer lines +---@param block CopilotChat.ui.chat.Block Block containing diff info +---@param lines table table of lines +---@return table new_lines +function M.apply_diff(block, lines) + local diff, content = M.get_diff(block, lines) + local new_lines, applied, _, _ = M.apply_unified_diff(diff, content) + if not applied then + log.debug('Diff for ' .. block.header.filename .. ' failed to apply cleanly for:\n' .. diff) + end + + return new_lines +end + +--- Get changed region for diff (unified or indices) +---@param block CopilotChat.ui.chat.Block Block containing diff info +---@param lines table table of lines +---@return number? first, number? last +function M.get_diff_region(block, lines) + local diff, content = M.get_diff(block, lines) + local _, _, first, last = M.apply_unified_diff(diff, content) + return first, last +end + +return M diff --git a/lua/CopilotChat/utils/files.lua b/lua/CopilotChat/utils/files.lua new file mode 100644 index 00000000..24b7b003 --- /dev/null +++ b/lua/CopilotChat/utils/files.lua @@ -0,0 +1,321 @@ +local async = require('plenary.async') + +local M = {} + +M.scan_args = { + max_count = 2500, + max_depth = 50, + no_ignore = false, +} + +local function filter_files(files, max_count) + -- Filter out empty entries + files = vim.tbl_filter(function(file) + return file ~= nil and file ~= '' + end, files) + + if max_count and max_count > 0 then + files = vim.list_slice(files, 1, max_count) + end + + return files +end + +---@class CopilotChat.utils.ScanOpts +---@field max_count number? The maximum number of files to scan +---@field max_depth number? The maximum depth to scan +---@field pattern? string The glob pattern to match files +---@field hidden? boolean Whether to include hidden files +---@field no_ignore? boolean Whether to respect or ignore .gitignore + +--- Scan a directory +---@param path string +---@param opts CopilotChat.utils.ScanOpts? +---@async +M.glob = async.wrap(function(path, opts, callback) + opts = vim.tbl_deep_extend('force', M.scan_args, opts or {}) + + -- Use ripgrep if available + if vim.fn.executable('rg') == 1 then + local cmd = { 'rg' } + + if opts.pattern then + table.insert(cmd, '-g') + table.insert(cmd, opts.pattern) + end + + if opts.max_depth then + table.insert(cmd, '--max-depth') + table.insert(cmd, tostring(opts.max_depth)) + end + + if opts.no_ignore then + table.insert(cmd, '--no-ignore') + end + + if opts.hidden then + table.insert(cmd, '--hidden') + end + + table.insert(cmd, '--files') + + vim.system(cmd, { cwd = path, text = true }, function(result) + local files = {} + if result and result.code == 0 and result.stdout ~= '' then + files = filter_files(vim.split(result.stdout, '\n'), opts.max_count) + end + + callback(files) + end) + + return + end + + -- Fallback to vim.uv.fs_scandir + local matchers = {} + if opts.pattern then + local file_pattern = vim.glob.to_lpeg(opts.pattern) + local path_pattern = vim.lpeg.P(path .. '/') * file_pattern + + table.insert(matchers, function(name, dir) + return file_pattern:match(name) or path_pattern:match(dir .. '/' .. name) + end) + end + + if not opts.hidden then + table.insert(matchers, function(name) + return not name:match('^%.') + end) + end + + local data = {} + local next_dir = { path } + local current_depths = { [path] = 1 } + + local function read_dir(err, fd) + local current_dir = table.remove(next_dir, 1) + local depth = current_depths[current_dir] or 1 + + if not err and fd then + while true do + local name, typ = vim.uv.fs_scandir_next(fd) + if name == nil then + break + end + + local full_path = current_dir .. '/' .. name + + if typ == 'directory' and not name:match('^%.git') then + if not opts.max_depth or depth < opts.max_depth then + table.insert(next_dir, full_path) + current_depths[full_path] = depth + 1 + end + else + local match = true + for _, matcher in ipairs(matchers) do + if not matcher(name, current_dir) then + match = false + break + end + end + + if match then + table.insert(data, full_path) + end + end + end + end + + if #next_dir == 0 then + callback(data) + else + vim.uv.fs_scandir(next_dir[1], read_dir) + end + end + + vim.uv.fs_scandir(path, read_dir) +end, 3) + +--- Grep a directory +---@param path string The path to search +---@param opts CopilotChat.utils.ScanOpts? +M.grep = async.wrap(function(path, opts, callback) + opts = vim.tbl_deep_extend('force', M.scan_args, opts or {}) + local cmd = {} + + if vim.fn.executable('rg') == 1 then + table.insert(cmd, 'rg') + + if opts.max_depth then + table.insert(cmd, '--max-depth') + table.insert(cmd, tostring(opts.max_depth)) + end + + if opts.no_ignore then + table.insert(cmd, '--no-ignore') + end + + if opts.hidden then + table.insert(cmd, '--hidden') + end + + table.insert(cmd, '--files-with-matches') + table.insert(cmd, '--ignore-case') + + if opts.pattern then + table.insert(cmd, '-e') + table.insert(cmd, opts.pattern) + end + elseif vim.fn.executable('grep') == 1 then + table.insert(cmd, 'grep') + table.insert(cmd, '-rli') + + if opts.pattern then + table.insert(cmd, '-e') + table.insert(cmd, opts.pattern) + end + end + + if vim.tbl_isempty(cmd) then + error('No executable found for grep') + return + end + + vim.system(cmd, { cwd = path, text = true }, function(result) + local files = {} + if result and result.code == 0 and result.stdout ~= '' then + files = filter_files(vim.split(result.stdout, '\n'), opts.max_count) + end + + callback(files) + end) +end, 3) + +--- Read a file +---@param path string The file path +---@async +function M.read_file(path) + local err, fd = async.uv.fs_open(path, 'r', 438) + if err or not fd then + return nil + end + + local err, stat = async.uv.fs_fstat(fd) + if err or not stat then + async.uv.fs_close(fd) + return nil + end + + local err, data = async.uv.fs_read(fd, stat.size, 0) + async.uv.fs_close(fd) + if err or not data then + return nil + end + return data +end + +--- Write data to a file +---@param path string The file path +---@param data string The data to write +---@return boolean +function M.write_file(path, data) + local err, fd = async.uv.fs_open(path, 'w', 438) + if err or not fd then + return false + end + + local err = async.uv.fs_write(fd, data, 0) + if err then + async.uv.fs_close(fd) + return false + end + + async.uv.fs_close(fd) + return true +end + +--- Check if file paths are the same +---@param file1 string? The first file path +---@param file2 string? The second file path +---@return boolean +function M.filename_same(file1, file2) + if not file1 or not file2 then + return false + end + return vim.fs.normalize(file1) == vim.fs.normalize(file2) +end + +--- Get the filetype of a file +---@param filename string The file name +---@return string|nil +function M.filetype(filename) + local filetype = require('plenary.filetype') + + local ft = filetype.detect(filename, { + fs_access = false, + }) + + if ft == '' or not ft and not vim.in_fast_event() then + ft = vim.filetype.match({ filename = filename }) + end + + -- If filetype still not detected, default to 'text' + -- Let content validation handle whether it's actually readable + if not ft or ft == '' then + return 'text' + end + + return ft +end + +--- Get the mimetype from filetype +---@param filetype string? +---@return string +function M.filetype_to_mimetype(filetype) + if not filetype or filetype == '' then + return 'text/plain' + end + if filetype == 'json' or filetype == 'yaml' then + return 'application/' .. filetype + end + if filetype == 'html' or filetype == 'css' then + return 'text/' .. filetype + end + if filetype:find('/') then + return filetype + end + return 'text/x-' .. filetype +end + +--- Get the filetype from mimetype +---@param mimetype string? +---@return string +function M.mimetype_to_filetype(mimetype) + if not mimetype or mimetype == '' then + return 'text' + end + + local out = mimetype:gsub('^text/x%-', '') + out = out:gsub('^text/', '') + out = out:gsub('^application/', '') + out = out:gsub('^image/', '') + out = out:gsub('^video/', '') + out = out:gsub('^audio/', '') + return out +end + +--- Convert a URI to a file name +---@param uri string The URI +---@return string +function M.uri_to_filename(uri) + if not uri or uri == '' then + return uri + end + local ok, fname = pcall(vim.uri_to_fname, uri) + if not ok or not fname or fname == '' then + return uri + end + return fname +end + +return M diff --git a/lua/CopilotChat/notify.lua b/lua/CopilotChat/utils/notify.lua similarity index 91% rename from lua/CopilotChat/notify.lua rename to lua/CopilotChat/utils/notify.lua index 99aa499a..b15b209a 100644 --- a/lua/CopilotChat/notify.lua +++ b/lua/CopilotChat/utils/notify.lua @@ -32,4 +32,9 @@ function M.listen(event_name, callback) table.insert(M.listeners[event_name], callback) end +--- Clear all listeners +function M.clear() + M.listeners = {} +end + return M diff --git a/lua/CopilotChat/utils/orderedmap.lua b/lua/CopilotChat/utils/orderedmap.lua new file mode 100644 index 00000000..1907c161 --- /dev/null +++ b/lua/CopilotChat/utils/orderedmap.lua @@ -0,0 +1,52 @@ +---@class OrderedMap +---@field set fun(self:OrderedMap, key:any, value:any) +---@field get fun(self:OrderedMap, key:any):any +---@field remove fun(self:OrderedMap, key:any) +---@field keys fun(self:OrderedMap):table +---@field values fun(self:OrderedMap):table + +--- Create ordered map +---@generic K, V +---@return OrderedMap +local function orderedmap() + return { + _keys = {}, + _data = {}, + set = function(self, key, value) + if not self._data[key] then + table.insert(self._keys, key) + end + self._data[key] = value + end, + + get = function(self, key) + return self._data[key] + end, + + remove = function(self, key) + if self._data[key] then + self._data[key] = nil + for i, k in ipairs(self._keys) do + if k == key then + table.remove(self._keys, i) + break + end + end + end + end, + + keys = function(self) + return self._keys + end, + + values = function(self) + local result = {} + for _, key in ipairs(self._keys) do + table.insert(result, self._data[key]) + end + return result + end, + } +end + +return orderedmap diff --git a/lua/CopilotChat/utils/stringbuffer.lua b/lua/CopilotChat/utils/stringbuffer.lua new file mode 100644 index 00000000..de89f2db --- /dev/null +++ b/lua/CopilotChat/utils/stringbuffer.lua @@ -0,0 +1,46 @@ +local ok, jit_buffer = pcall(require, 'string.buffer') + +---@class StringBuffer +---@field put fun(self:StringBuffer, s:string) +---@field set fun(self:StringBuffer, s:string) +---@field tostring fun(self:StringBuffer):string + +--- Create a string buffer for efficient string concatenation +---@return StringBuffer +local function stringbuffer() + if ok and jit_buffer then + return { + _buf = jit_buffer.new(), + put = function(self, s) + self._buf:put(s) + end, + set = function(self, s) + self._buf:set(s) + end, + tostring = function(self) + return self._buf:tostring() + end, + } + end + + return { + _buf = { '' }, + put = function(self, s) + table.insert(self._buf, s) + for i = #self._buf - 1, 1, -1 do + if #self._buf[i] > #self._buf[i + 1] then + break + end + self._buf[i] = self._buf[i] .. table.remove(self._buf) + end + end, + set = function(self, s) + self._buf = { s } + end, + tostring = function(self) + return table.concat(self._buf) + end, + } +end + +return stringbuffer diff --git a/plugin/CopilotChat.lua b/plugin/CopilotChat.lua index e59ee49e..db83d5b2 100644 --- a/plugin/CopilotChat.lua +++ b/plugin/CopilotChat.lua @@ -14,6 +14,7 @@ local group = vim.api.nvim_create_augroup('CopilotChat', {}) local function setup_highlights() vim.api.nvim_set_hl(0, 'CopilotChatHeader', { link = '@markup.heading.2.markdown', default = true }) vim.api.nvim_set_hl(0, 'CopilotChatSeparator', { link = '@punctuation.special.markdown', default = true }) + vim.api.nvim_set_hl(0, 'CopilotChatSelection', { link = 'Visual', default = true }) vim.api.nvim_set_hl(0, 'CopilotChatStatus', { link = 'DiagnosticHint', default = true }) vim.api.nvim_set_hl(0, 'CopilotChatHelp', { link = 'DiagnosticInfo', default = true }) vim.api.nvim_set_hl(0, 'CopilotChatResource', { link = 'Constant', default = true }) @@ -21,7 +22,6 @@ local function setup_highlights() vim.api.nvim_set_hl(0, 'CopilotChatPrompt', { link = 'Statement', default = true }) vim.api.nvim_set_hl(0, 'CopilotChatModel', { link = 'Type', default = true }) vim.api.nvim_set_hl(0, 'CopilotChatUri', { link = 'Underlined', default = true }) - vim.api.nvim_set_hl(0, 'CopilotChatSelection', { link = 'Visual', default = true }) vim.api.nvim_set_hl(0, 'CopilotChatAnnotation', { link = 'ColorColumn', default = true }) local fg = vim.api.nvim_get_hl(0, { name = 'CopilotChatStatus', link = false }).fg @@ -36,6 +36,18 @@ vim.api.nvim_create_autocmd('ColorScheme', { }) setup_highlights() +vim.api.nvim_create_autocmd('FileType', { + pattern = 'copilot-chat', + group = group, + callback = vim.schedule_wrap(function() + vim.cmd.syntax('match CopilotChatResource "#\\S\\+"') + vim.cmd.syntax('match CopilotChatTool "@\\S\\+"') + vim.cmd.syntax('match CopilotChatPrompt "/\\S\\+"') + vim.cmd.syntax('match CopilotChatModel "\\$\\S\\+"') + vim.cmd.syntax('match CopilotChatUri "##\\S\\+"') + end), +}) + -- Setup commands vim.api.nvim_create_user_command('CopilotChat', function(args) local chat = require('CopilotChat') @@ -79,18 +91,6 @@ vim.api.nvim_create_user_command('CopilotChatReset', function() chat.reset() end, { force = true }) -vim.api.nvim_create_autocmd('FileType', { - pattern = 'copilot-chat', - group = group, - callback = vim.schedule_wrap(function() - vim.cmd.syntax('match CopilotChatResource "#\\S\\+"') - vim.cmd.syntax('match CopilotChatTool "@\\S\\+"') - vim.cmd.syntax('match CopilotChatPrompt "/\\S\\+"') - vim.cmd.syntax('match CopilotChatModel "\\$\\S\\+"') - vim.cmd.syntax('match CopilotChatUri "##\\S\\+"') - end), -}) - local function complete_load() local chat = require('CopilotChat') local options = vim.tbl_map(function(file) diff --git a/queries/markdown/copilotchat.scm b/queries/markdown/copilotchat.scm new file mode 100644 index 00000000..f4ec8546 --- /dev/null +++ b/queries/markdown/copilotchat.scm @@ -0,0 +1,13 @@ +(section + (atx_heading + (atx_h1_marker) + heading_content: (_) @section_header + ) + (_)? @section_content +) +(section + (fenced_code_block + (info_string) @block_header + (code_fence_content) @block_content + ) +) diff --git a/scripts/minimal.lua b/scripts/minimal.lua new file mode 100644 index 00000000..69c5cefb --- /dev/null +++ b/scripts/minimal.lua @@ -0,0 +1,16 @@ +-- https://github.com/neovim/neovim/blob/master/contrib/minimal.lua +vim.opt.runtimepath:append(vim.fn.getcwd()) + +for name, url in pairs({ + 'https://github.com/nvim-lua/plenary.nvim', +}) do + local install_path = vim.fn.fnamemodify('.dependencies/' .. name, ':p') + if vim.fn.isdirectory(install_path) == 0 then + vim.fn.system({ 'git', 'clone', '--depth=1', url, install_path }) + end + vim.opt.runtimepath:append(install_path) +end + +require('CopilotChat').setup({ + -- Add your configuration here +}) diff --git a/scripts/test.lua b/scripts/test.lua new file mode 100644 index 00000000..5da43da3 --- /dev/null +++ b/scripts/test.lua @@ -0,0 +1,13 @@ +vim.opt.runtimepath:append(vim.fn.getcwd()) + +for name, url in pairs({ + 'https://github.com/nvim-lua/plenary.nvim', +}) do + local install_path = vim.fn.fnamemodify('.dependencies/' .. name, ':p') + if vim.fn.isdirectory(install_path) == 0 then + vim.fn.system({ 'git', 'clone', '--depth=1', url, install_path }) + end + vim.opt.runtimepath:append(install_path) +end + +require('plenary.test_harness').test_directory('tests') diff --git a/test/plugin_spec.lua b/test/plugin_spec.lua deleted file mode 100644 index 9497f016..00000000 --- a/test/plugin_spec.lua +++ /dev/null @@ -1,18 +0,0 @@ --- Mock packages -package.loaded['plenary.async'] = { - wrap = function(fn) - return function(...) - return fn(...) - end - end, -} -package.loaded['plenary.curl'] = {} -package.loaded['plenary.log'] = {} -package.loaded['plenary.scandir'] = {} -package.loaded['plenary.filetype'] = {} - -describe('CopilotChat plugin', function() - it('should be able to load', function() - assert.truthy(require('CopilotChat')) - end) -end) diff --git a/tests/class_spec.lua b/tests/class_spec.lua new file mode 100644 index 00000000..ef2f1657 --- /dev/null +++ b/tests/class_spec.lua @@ -0,0 +1,33 @@ +local class = require('CopilotChat.utils.class') + +describe('CopilotChat.utils.class', function() + it('creates a simple class', function() + local Foo = class(function(self, x) + self.x = x + end) + local obj = Foo(42) + assert.equals(42, obj.x) + end) + + it('supports init method', function() + local Bar = class(function(self, y) + self.y = y + end) + local obj = Bar.new(7) + assert.equals(7, obj.y) + obj:init(8) + assert.equals(8, obj.y) + end) + + it('supports inheritance', function() + local Parent = class(function(self) + self.val = 1 + end) + local Child = class(function(self) + self.val = 2 + end, Parent) + local obj = Child() + assert.equals(2, obj.val) + assert.equals(Parent, getmetatable(Child).__index) + end) +end) diff --git a/tests/diff_spec.lua b/tests/diff_spec.lua new file mode 100644 index 00000000..bfaa19a4 --- /dev/null +++ b/tests/diff_spec.lua @@ -0,0 +1,1043 @@ +local diff = require('CopilotChat.utils.diff') + +describe('CopilotChat.utils.diff', function() + it('applies unified diff', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ + context +-old ++new +]] + local original = { 'context', 'old', 'other' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'context', 'new', 'other' }, result) + end) + + it('applies unified diff with no context', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ +-old ++new +]] + local original = { 'old', 'other' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'new', 'other' }, result) + end) + + it('applies unified diff with multiline edits', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ + context1 + context2 +-old1 +-old2 ++new1 ++new2 +]] + local original = { + 'context1', + 'context2', + 'old1', + 'old2', + 'context3', + 'other', + } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ + 'context1', + 'context2', + 'new1', + 'new2', + 'context3', + 'other', + }, result) + end) + + it('gets unified diff region', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ + context +-old ++new +]] + local original = { 'context', 'old', 'other' } + local original_content = table.concat(original, '\n') + local _, _, first, last = diff.apply_unified_diff(diff_text, original_content) + assert.equals(2, first) + assert.equals(2, last) + end) + + it('applies unified diff with only additions', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ + context ++added1 ++added2 +]] + local original = { 'context', 'other' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'context', 'added1', 'added2', 'other' }, result) + end) + + it('applies unified diff with only deletions', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ + context +-old1 +-old2 +]] + local original = { 'context', 'old1', 'old2', 'other' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'context', 'other' }, result) + end) + + it('applies unified diff with changes at start and end', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ +-oldstart ++newstart + context +-oldend ++newend +]] + local original = { 'oldstart', 'context', 'oldend' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'newstart', 'context', 'newend' }, result) + end) + + it('applies unified diff with multiple hunks', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ + context1 +-old1 ++new1 +@@ ... @@ + context2 +-old2 ++new2 +]] + local original = { 'context1', 'old1', 'context2', 'old2', 'other' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'context1', 'new1', 'context2', 'new2', 'other' }, result) + end) + + it('applies unified diff with no changes', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ + context + unchanged +]] + local original = { 'context', 'unchanged' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same(original, result) + end) + + it('applies unified diff with all lines deleted', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ +-old1 +-old2 +-old3 +]] + local original = { 'old1', 'old2', 'old3' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({}, result) + end) + + it('applies unified diff with all lines added to empty file', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ ++new1 ++new2 ++new3 +]] + local original = {} + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'new1', 'new2', 'new3' }, result) + end) + + it('applies unified diff with changes at end of file', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ + context +-oldend ++newend +]] + local original = { 'context', 'oldend' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'context', 'newend' }, result) + end) + + it('applies unified diff with changes at start of file', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ ... @@ +-oldstart ++newstart + context +]] + local original = { 'oldstart', 'context' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'newstart', 'context' }, result) + end) + + it('may confuse similar variable names', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,2 +1,2 @@ +-local x = 1 ++local x = 10 +]] + local original = { + 'local x = 1', + 'local y = 2', + 'local x = 3', + 'local z = 4', + } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + assert.are.same({ + 'local x = 10', + 'local y = 2', + 'local x = 3', + 'local z = 4', + }, result) + end) + + it('may match wrong substring with partial matches', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,2 +1,2 @@ +-old_value ++new_value +]] + local original = { + 'value', + 'old_value', + 'very_old_value', + } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_false(applied) -- not applied cleanly, but adjusted + assert.are.same({ + 'value', + 'new_value', + 'very_old_value', + }, result) + end) + + it('may apply to wrong instance of identical boilerplate code', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,3 +1,3 @@ + return { +- status = "old" ++ status = "new" +]] + local original = { + 'return {', + ' status = "old"', + '}', + 'return {', + ' status = "old"', + '}', + } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + assert.are.same({ + 'return {', + ' status = "new"', + '}', + 'return {', + ' status = "old"', + '}', + }, result) + end) + + it('allows adding at very start with zero original lines', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -0,0 +1,2 @@ ++first ++second +]] + local original = { 'x', 'y' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'first', 'second', 'x', 'y' }, result) + end) + + it('handles insertion at end without context', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -3,0 +4,2 @@ ++new1 ++new2 +]] + local original = { 'a', 'b', 'c' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'a', 'b', 'c', 'new1', 'new2' }, result) + end) + + it('supports multiple adjacent hunks modifying contiguous lines', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,1 +1,1 @@ +-a ++x +@@ -2,1 +2,1 @@ +-b ++y +]] + local original = { 'a', 'b', 'c' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'x', 'y', 'c' }, result) + end) + + it('handles diff with trailing newline missing in original', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,1 +1,1 @@ +-old ++new +]] + local original_content = 'old' -- no trailing newline + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'new' }, result) + end) + + it('handles diff ending without newline on addition lines', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,1 +1,2 @@ + old ++new]] + local original = { 'old' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'old', 'new' }, result) + end) + + it('handles hunks with zero-context lines around changes', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -2,0 +3,1 @@ ++added +]] + local original = { 'a', 'b', 'c' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'a', 'b', 'added', 'c' }, result) + end) + + it('handles insertion of identical-to-context line', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,1 +1,2 @@ + context ++context +]] + local original = { 'context', 'other' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + assert.are.same({ 'context', 'context', 'other' }, result) + end) + + it('rejects hunk with wrong header lengths', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,3 +1,3 @@ + context +-old ++new +]] + local original = { 'context' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Fuzzy matching may still apply despite wrong header lengths + assert.is_not_nil(result) + end) + + it('handles CRLF original with unix diff', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,1 +1,1 @@ +-old ++new +]] + local original_content = 'old\r\n' + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.is_not_nil(result) + assert.is_true(#result >= 1) + end) + + it('handles large insertion with no context', function() + local lines = {} + for i = 1, 10 do + table.insert(lines, '+line' .. i) + end + local diff_text = '--- a/foo.txt\n+++ b/foo.txt\n@@ -4,0 +5,10 @@\n' .. table.concat(lines, '\n') .. '\n' + local original = { 'a', 'b', 'c', 'd', 'e' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + local expected = { 'a', 'b', 'c', 'd' } + for i = 1, 10 do + table.insert(expected, 'line' .. i) + end + table.insert(expected, 'e') + assert.are.same(expected, result) + end) + + it('rejects mismatched deletion ranges', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,3 +0,0 @@ +-old1 +-old2 +-old3 +]] + local original = { 'single' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Fuzzy matching may apply the deletion despite mismatch + assert.is_not_nil(result) + end) + + it('handles mixed operations in one hunk', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,5 +1,4 @@ + context1 +-old + unchanged +-old2 ++new2 + context2 +]] + local original = { 'context1', 'old', 'unchanged', 'old2', 'context2' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + assert.are.same({ 'context1', 'unchanged', 'new2', 'context2' }, result) + end) + + it('handles leading tabs/spaces inside context lines', function() + local diff_text = [[ +--- a/x ++++ b/x +@@ -1,2 +1,2 @@ + indented +-old ++new +]] + local original = { '\tindented', 'old' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + assert.are.same({ '\tindented', 'new' }, result) + end) + + it('respects diff markers even if content begins with + or -', function() + local diff_text = [[ +--- a/x ++++ b/x +@@ -1,2 +1,2 @@ +-+literalplus +--literalminus +++literalplus +++literalminus +]] + local original = { '+literalplus', '-literalminus' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + assert.are.same({ '+literalplus', '+literalminus' }, result) + end) + + it('applies diff despite slight context mismatch with fuzzy matching', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,3 +1,3 @@ + slightly different context +-old ++new +]] + local original = { 'context', 'old', 'other' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Fuzzy matching will replace context lines that don't match + assert.are.same({ 'slightly different context', 'new', 'other' }, result) + end) + + it('applies even when context is completely wrong due to fuzzy matching', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,3 +1,3 @@ + totally wrong line + another wrong line +-old ++new +]] + local original = { 'context1', 'context2', 'old' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Fuzzy matching will replace all old_snippet lines (including wrong context) with new_snippet + assert.are.same({ 'totally wrong line', 'another wrong line', 'new' }, result) + end) + + it('applies with partial context match', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -2,3 +2,3 @@ + matching +-old ++new +]] + local original = { 'first', 'matching', 'old', 'last' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + assert.are.same({ 'first', 'matching', 'new', 'last' }, result) + end) + + it('handles context with extra lines not in original', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,5 +1,5 @@ + context1 + context2 + context3 +-old ++new +]] + local original = { 'context1', 'old' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Should fail or apply with fuzzy matching + assert.is_not_nil(result) + end) + + it('fails when deletion target does not exist', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,2 +1,1 @@ + context +-nonexistent +]] + local original = { 'context', 'actual' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Fuzzy matching might still apply or fail + assert.is_not_nil(result) + end) + + it('applies when context lines are in different order', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,3 +1,3 @@ + line2 + line1 +-old ++new +]] + local original = { 'line1', 'line2', 'old' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Fuzzy matching should handle reordered context + assert.is_not_nil(result) + end) + + it('adds max_retry_time and cumulative retry logic', function() + local diff_text = [[ +--- original.py ++++ modified.py +@@ -24,6 +24,7 @@ + import time + + retry_statuses = {HTTPStatus.TOO_MANY_REQUESTS, 502, 503, 504} ++ max_retry_time = 120 # Maximum cumulative retry time in seconds + retry_exceptions = ( + httpx.ReadTimeout, + httpx.ConnectTimeout, +@@ -34,6 +35,7 @@ + def deco(fn): + def wrapped(*args, **kwargs): + last_exc = None ++ total_retry_time = 0 # Track cumulative retry time + for attempt in range(retries): + try: + resp = fn(*args, **kwargs) +@@ -43,6 +45,9 @@ + delay = min(max_backoff, backoff * (2**attempt)) * ( + 1 + random.random() * 0.25 + ) ++ if total_retry_time + delay > max_retry_time: ++ raise TimeoutError("Exceeded maximum retry time of 120 seconds") ++ total_retry_time += delay + time.sleep(delay) + continue + +@@ -59,6 +64,9 @@ + delay = min(max_backoff, backoff * (2**attempt)) * ( + 1 + random.random() * 0.25 + ) ++ if total_retry_time + delay > max_retry_time: ++ raise TimeoutError("Exceeded maximum retry time of 120 seconds") ++ total_retry_time += delay + time.sleep(delay) + continue +]] + local original = [[ +import base64 +import json +import logging +import os +import random +from datetime import datetime, time +from http import HTTPStatus + +import geojson +import httpx +from cachetools import TTLCache, cached +from geopy.distance import geodesic +from shapely.geometry import MultiPolygon, Polygon, shape + +logger = logging.getLogger(__name__) + +httpx_client = httpx.Client( + timeout=10.0, + limits=httpx.Limits(max_keepalive_connections=20, max_connections=100), +) + + +def retry_request(retries=10, backoff=1, max_backoff=40.0): + import time + + retry_statuses = {HTTPStatus.TOO_MANY_REQUESTS, 502, 503, 504} + retry_exceptions = ( + httpx.ReadTimeout, + httpx.ConnectTimeout, + httpx.NetworkError, # includes transient connection errors + httpx.RemoteProtocolError, + ) + + def deco(fn): + def wrapped(*args, **kwargs): + last_exc = None + for attempt in range(retries): + try: + resp = fn(*args, **kwargs) + except retry_exceptions as exc: + last_exc = exc + # backoff and retry + delay = min(max_backoff, backoff * (2**attempt)) * ( + 1 + random.random() * 0.25 + ) + time.sleep(delay) + continue + + # Retry on selected HTTP status + if resp.status_code in retry_statuses: + # honor Retry-After if present + ra = resp.headers.get("Retry-After") + if ra: + try: + delay = min(max_backoff, float(ra)) + except ValueError: + delay = min(max_backoff, backoff * (2**attempt)) + else: + delay = min(max_backoff, backoff * (2**attempt)) * ( + 1 + random.random() * 0.25 + ) + time.sleep(delay) + continue + + return resp + + if last_exc: + raise last_exc + return resp + + return wrapped + + return deco +]] + local expected = [[ +import base64 +import json +import logging +import os +import random +from datetime import datetime, time +from http import HTTPStatus + +import geojson +import httpx +from cachetools import TTLCache, cached +from geopy.distance import geodesic +from shapely.geometry import MultiPolygon, Polygon, shape + +logger = logging.getLogger(__name__) + +httpx_client = httpx.Client( + timeout=10.0, + limits=httpx.Limits(max_keepalive_connections=20, max_connections=100), +) + + +def retry_request(retries=10, backoff=1, max_backoff=40.0): + import time + + retry_statuses = {HTTPStatus.TOO_MANY_REQUESTS, 502, 503, 504} + max_retry_time = 120 # Maximum cumulative retry time in seconds + retry_exceptions = ( + httpx.ReadTimeout, + httpx.ConnectTimeout, + httpx.NetworkError, # includes transient connection errors + httpx.RemoteProtocolError, + ) + + def deco(fn): + def wrapped(*args, **kwargs): + last_exc = None + total_retry_time = 0 # Track cumulative retry time + for attempt in range(retries): + try: + resp = fn(*args, **kwargs) + except retry_exceptions as exc: + last_exc = exc + # backoff and retry + delay = min(max_backoff, backoff * (2**attempt)) * ( + 1 + random.random() * 0.25 + ) + if total_retry_time + delay > max_retry_time: + raise TimeoutError("Exceeded maximum retry time of 120 seconds") + total_retry_time += delay + time.sleep(delay) + continue + + # Retry on selected HTTP status + if resp.status_code in retry_statuses: + # honor Retry-After if present + ra = resp.headers.get("Retry-After") + if ra: + try: + delay = min(max_backoff, float(ra)) + except ValueError: + delay = min(max_backoff, backoff * (2**attempt)) + else: + delay = min(max_backoff, backoff * (2**attempt)) * ( + 1 + random.random() * 0.25 + ) + if total_retry_time + delay > max_retry_time: + raise TimeoutError("Exceeded maximum retry time of 120 seconds") + total_retry_time += delay + time.sleep(delay) + continue + + return resp + + if last_exc: + raise last_exc + return resp + + return wrapped + + return deco +]] + local result, applied = diff.apply_unified_diff(diff_text, original) + local expected_lines = vim.split(expected, '\n', { trimempty = true }) + assert.are.same(expected_lines, result) + end) + + -- Tests for offset tracking in sequential hunk application + it('correctly applies offset when first hunk adds lines', function() + local diff_text = [[ +--- a/test.txt ++++ b/test.txt +@@ -1,2 +1,4 @@ + line1 ++added1 ++added2 + line2 +@@ -3,1 +5,1 @@ + line3 +]] + local original = { 'line1', 'line2', 'line3' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'line1', 'added1', 'added2', 'line2', 'line3' }, result) + end) + + it('correctly applies offset when first hunk removes lines', function() + local diff_text = [[ +--- a/test.txt ++++ b/test.txt +@@ -1,3 +1,1 @@ + line1 +-line2 +-line3 +@@ -4,1 +2,1 @@ + line4 +]] + local original = { 'line1', 'line2', 'line3', 'line4' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'line1', 'line4' }, result) + end) + + it('correctly tracks offset through multiple hunks with mixed add/remove', function() + local diff_text = [[ +--- a/test.txt ++++ b/test.txt +@@ -1,1 +1,2 @@ + a ++b +@@ -2,1 +3,1 @@ +-c ++C +@@ -3,1 +4,3 @@ + d ++e ++f +]] + local original = { 'a', 'c', 'd' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'a', 'b', 'C', 'd', 'e', 'f' }, result) + end) + + it('handles offset when hunks are far apart', function() + local diff_text = [[ +--- a/test.txt ++++ b/test.txt +@@ -2,1 +2,2 @@ + line2 ++inserted +@@ -10,1 +11,1 @@ +-line10 ++LINE10 +]] + local original = { + 'line1', + 'line2', + 'line3', + 'line4', + 'line5', + 'line6', + 'line7', + 'line8', + 'line9', + 'line10', + } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + local expected = { + 'line1', + 'line2', + 'inserted', + 'line3', + 'line4', + 'line5', + 'line6', + 'line7', + 'line8', + 'line9', + 'LINE10', + } + assert.are.same(expected, result) + end) + + it('applies three consecutive hunks with positive offset accumulation', function() + local diff_text = [[ +--- a/test.txt ++++ b/test.txt +@@ -1,1 +1,2 @@ + a ++b +@@ -2,1 +3,2 @@ + c ++d +@@ -3,1 +5,2 @@ + e ++f +]] + local original = { 'a', 'c', 'e' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'a', 'b', 'c', 'd', 'e', 'f' }, result) + end) + + it('applies three consecutive hunks with negative offset accumulation', function() + local diff_text = [[ +--- a/test.txt ++++ b/test.txt +@@ -1,2 +1,1 @@ +-x + a +@@ -3,2 +2,1 @@ +-y + b +@@ -5,2 +3,1 @@ +-z + c +]] + local original = { 'x', 'a', 'y', 'b', 'z', 'c' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'a', 'b', 'c' }, result) + end) + + it('handles zero-offset hunks (replacements without size change)', function() + local diff_text = [[ +--- a/test.txt ++++ b/test.txt +@@ -1,1 +1,1 @@ +-old1 ++new1 +@@ -2,1 +2,1 @@ +-old2 ++new2 +@@ -3,1 +3,1 @@ +-old3 ++new3 +]] + local original = { 'old1', 'old2', 'old3' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'new1', 'new2', 'new3' }, result) + end) + + it('applies offset correctly when first hunk is pure insertion (len_old=0)', function() + local diff_text = [[ +--- a/test.txt ++++ b/test.txt +@@ -0,0 +1,2 @@ ++inserted1 ++inserted2 +@@ -1,1 +3,1 @@ + original +]] + local original = { 'original' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'inserted1', 'inserted2', 'original' }, result) + end) + + it('handles complex offset scenario with interleaved additions and deletions', function() + local diff_text = [[ +--- a/test.txt ++++ b/test.txt +@@ -1,2 +1,1 @@ +-delete1 + keep1 +@@ -3,1 +2,3 @@ + keep2 ++add1 ++add2 +@@ -4,2 +5,1 @@ +-delete2 + keep3 +]] + local original = { 'delete1', 'keep1', 'keep2', 'delete2', 'keep3' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'keep1', 'keep2', 'add1', 'add2', 'keep3' }, result) + end) + + it('offset tracking works with hunks that have context lines', function() + local diff_text = [[ +--- a/test.txt ++++ b/test.txt +@@ -1,3 +1,4 @@ + ctx1 + line1 ++inserted + ctx2 +@@ -5,2 +6,2 @@ + ctx3 +-line2 ++LINE2 +]] + local original = { 'ctx1', 'line1', 'ctx2', 'ctx3', 'line2' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'ctx1', 'line1', 'inserted', 'ctx2', 'ctx3', 'LINE2' }, result) + end) +end) diff --git a/tests/functions_spec.lua b/tests/functions_spec.lua new file mode 100644 index 00000000..93939cdb --- /dev/null +++ b/tests/functions_spec.lua @@ -0,0 +1,62 @@ +local functions = require('CopilotChat.functions') + +describe('CopilotChat.functions', function() + describe('uri_to_url', function() + it('replaces parameters in uri template', function() + local uri = 'file://{path}' + local input = { path = '/tmp/test.txt' } + assert.equals('file:///tmp/test.txt', functions.uri_to_url(uri, input)) + end) + it('leaves missing params empty', function() + local uri = 'file://{path}/{id}' + local input = { path = '/tmp' } + assert.equals('file:///tmp/', functions.uri_to_url(uri, input)) + end) + end) + + describe('match_uri', function() + it('matches uri and extracts parameters', function() + local uri = 'file:///tmp/test.txt' + local pattern = 'file://{path}' + local result = functions.match_uri(uri, pattern) + assert.are.same({ path = '/tmp/test.txt' }, result) + end) + it('returns nil for non-matching uri', function() + assert.is_nil(functions.match_uri('abc', 'file://{path}')) + end) + it('returns empty table for exact match with no params', function() + assert.are.same({}, functions.match_uri('abc', 'abc')) + end) + end) + + describe('parse_schema', function() + it('returns schema if present', function() + local fn = { schema = { type = 'object', properties = { foo = { type = 'string' } } } } + assert.equals(fn.schema, functions.parse_schema(fn)) + end) + it('generates schema from uri if missing', function() + local fn = { uri = 'file://{path}/{id}' } + local schema = functions.parse_schema(fn) + assert.are.same({ + type = 'object', + properties = { path = { type = 'string' }, id = { type = 'string' } }, + required = { 'path', 'id' }, + }, schema) + end) + end) + + describe('parse_input', function() + it('parses input string into table', function() + local schema = { properties = { a = {}, b = {} }, required = { 'a', 'b' } } + local input = 'foo;;bar' + assert.are.same({ a = 'foo', b = 'bar' }, functions.parse_input(input, schema)) + end) + it('returns input if already table', function() + local input = { a = 1 } + assert.equals(input, functions.parse_input(input)) + end) + it('returns empty table if no schema', function() + assert.are.same({}, functions.parse_input('foo')) + end) + end) +end) diff --git a/tests/init_spec.lua b/tests/init_spec.lua new file mode 100644 index 00000000..995a84c3 --- /dev/null +++ b/tests/init_spec.lua @@ -0,0 +1,14 @@ +describe('CopilotChat module', function() + it('should be able to load', function() + assert.has_no.errors(function() + require('CopilotChat') + end) + end) + + it('should be able to set up', function() + assert.has_no.errors(function() + require('CopilotChat').setup({}) + end) + assert.is_not_nil(require('CopilotChat').chat) + end) +end) diff --git a/tests/notify_spec.lua b/tests/notify_spec.lua new file mode 100644 index 00000000..020c9391 --- /dev/null +++ b/tests/notify_spec.lua @@ -0,0 +1,126 @@ +local notify = require('CopilotChat.utils.notify') + +describe('CopilotChat.notify', function() + before_each(function() + -- Clear all listeners before each test + notify.clear() + end) + + describe('publish and listen', function() + it('calls listener when event is published', function() + local called = false + local received_data = nil + + notify.listen('test_event', function(data) + called = true + received_data = data + end) + + notify.publish('test_event', 'test_data') + + assert.is_true(called) + assert.equals('test_data', received_data) + end) + + it('supports multiple listeners for same event', function() + local count = 0 + + notify.listen('test_event', function(data) + count = count + 1 + end) + notify.listen('test_event', function(data) + count = count + 10 + end) + + notify.publish('test_event', 'data') + + assert.equals(11, count) + end) + + it('does not call listeners for different events', function() + local called = false + + notify.listen('event_a', function(data) + called = true + end) + + notify.publish('event_b', 'data') + + assert.is_false(called) + end) + + it('passes correct data to listeners', function() + local received = nil + + notify.listen('test_event', function(data) + received = data + end) + + notify.publish('test_event', { foo = 'bar', num = 123 }) + + assert.are.same({ foo = 'bar', num = 123 }, received) + end) + + it('handles nil and empty data', function() + local received = 'not_called' + + notify.listen('test_event', function(data) + received = data + end) + + notify.publish('test_event', nil) + assert.is_nil(received) + + notify.publish('test_event', '') + assert.equals('', received) + end) + + it('handles publishing to events with no listeners', function() + -- Should not error + assert.has_no.errors(function() + notify.publish('nonexistent_event', 'data') + end) + end) + end) + + describe('clear', function() + it('removes all listeners', function() + local called = false + + notify.listen('test_event', function(data) + called = true + end) + + notify.clear() + notify.publish('test_event', 'data') + + assert.is_false(called) + end) + + it('allows adding new listeners after clear', function() + local called = false + + notify.listen('test_event', function(data) + called = true + end) + notify.clear() + + notify.listen('test_event', function(data) + called = true + end) + notify.publish('test_event', 'data') + + assert.is_true(called) + end) + end) + + describe('constants', function() + it('defines STATUS constant', function() + assert.equals('status', notify.STATUS) + end) + + it('defines MESSAGE constant', function() + assert.equals('message', notify.MESSAGE) + end) + end) +end) diff --git a/tests/orderedmap_spec.lua b/tests/orderedmap_spec.lua new file mode 100644 index 00000000..b5fa5a37 --- /dev/null +++ b/tests/orderedmap_spec.lua @@ -0,0 +1,37 @@ +local orderedmap = require('CopilotChat.utils.orderedmap') + +describe('CopilotChat.utils.orderedmap', function() + it('sets and gets values', function() + local map = orderedmap() + map:set('a', 1) + map:set('b', 2) + assert.equals(1, map:get('a')) + assert.equals(2, map:get('b')) + end) + + it('preserves insertion order', function() + local map = orderedmap() + map:set('x', 10) + map:set('y', 20) + map:set('z', 30) + assert.are.same({ 'x', 'y', 'z' }, map:keys()) + assert.are.same({ 10, 20, 30 }, map:values()) + end) + + it('overwrites value but not order', function() + local map = orderedmap() + map:set('a', 1) + map:set('a', 2) + assert.are.same({ 'a' }, map:keys()) + assert.are.same({ 2 }, map:values()) + end) + + it('removes values and updates order', function() + local map = orderedmap() + map:set('a', 1) + map:set('b', 2) + map:remove('a') + assert.are.same({ 'b' }, map:keys()) + assert.are.same({ 2 }, map:values()) + end) +end) diff --git a/tests/stringbuffer_spec.lua b/tests/stringbuffer_spec.lua new file mode 100644 index 00000000..d491fd43 --- /dev/null +++ b/tests/stringbuffer_spec.lua @@ -0,0 +1,23 @@ +local stringbuffer = require('CopilotChat.utils.stringbuffer') + +describe('CopilotChat.utils.stringbuffer', function() + it('concatenates strings with put', function() + local buf = stringbuffer() + buf:put('hello') + buf:put(' ') + buf:put('world') + assert.equals('hello world', buf:tostring()) + end) + + it('sets buffer with set', function() + local buf = stringbuffer() + buf:put('foo') + buf:set('bar') + assert.equals('bar', buf:tostring()) + end) + + it('handles empty buffer', function() + local buf = stringbuffer() + assert.equals('', buf:tostring()) + end) +end) diff --git a/tests/utils_spec.lua b/tests/utils_spec.lua new file mode 100644 index 00000000..5352395d --- /dev/null +++ b/tests/utils_spec.lua @@ -0,0 +1,40 @@ +local utils = require('CopilotChat.utils') + +describe('CopilotChat.utils', function() + it('empty', function() + assert.is_true(utils.empty(nil)) + assert.is_true(utils.empty('')) + assert.is_true(utils.empty(' ')) + assert.is_true(utils.empty({})) + assert.is_false(utils.empty({ 1 })) + assert.is_false(utils.empty('abc')) + assert.is_false(utils.empty(0)) + end) + + it('split_lines', function() + assert.are.same(utils.split_lines(''), {}) + assert.are.same(utils.split_lines('a\nb'), { 'a', 'b' }) + assert.are.same(utils.split_lines('a\r\nb'), { 'a', 'b' }) + assert.are.same(utils.split_lines('a\nb\n'), { 'a', 'b', '' }) + end) + + it('make_string', function() + assert.equals('a b 1', utils.make_string('a', 'b', 1)) + assert.equals(vim.inspect({ x = 1 }), utils.make_string({ x = 1 })) + assert.equals('msg', utils.make_string('error:1: msg')) + end) + + it('uuid', function() + local uuid1 = utils.uuid() + local uuid2 = utils.uuid() + assert.equals('string', type(uuid1)) + assert.not_equals(uuid1, uuid2) + assert.equals(36, #uuid1) + end) + + it('to_table', function() + assert.are.same({ 1, 2, 3 }, utils.to_table(1, 2, 3)) + assert.are.same({ 1, 2, 3 }, utils.to_table({ 1, 2 }, 3)) + assert.are.same({ 1 }, utils.to_table(nil, 1)) + end) +end) diff --git a/version.txt b/version.txt index fcdb2e10..b48b2de9 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -4.0.0 +4.7.4