diff --git a/lua/CopilotChat/config.lua b/lua/CopilotChat/config.lua index f3a0d291..1f35f33f 100644 --- a/lua/CopilotChat/config.lua +++ b/lua/CopilotChat/config.lua @@ -14,7 +14,7 @@ ---@field blend number? ---@class CopilotChat.config.Shared ----@field system_prompt string? +---@field system_prompt string|fun(source: CopilotChat.source):string|nil ---@field model string? ---@field tools string|table|nil ---@field sticky string|table|nil diff --git a/lua/CopilotChat/init.lua b/lua/CopilotChat/init.lua index 48703c74..8a660183 100644 --- a/lua/CopilotChat/init.lua +++ b/lua/CopilotChat/init.lua @@ -189,6 +189,26 @@ local function list_prompts() return prompts_to_use end +--- Resolve system prompt - handle both string and function types +---@param system_prompt string|function|nil +---@return string? +local function resolve_system_prompt(system_prompt) + if not system_prompt then + return nil + end + + if type(system_prompt) == 'function' then + local ok, result = pcall(system_prompt) + if not ok then + log.warn('Failed to resolve system prompt function: ' .. result) + return nil + end + return result + end + + return system_prompt +end + --- Finish writing to chat buffer. ---@param start_of_chat boolean? local function finish(start_of_chat) @@ -503,6 +523,9 @@ function M.resolve_prompt(prompt, config) config = vim.tbl_deep_extend('force', M.config, config or {}) config, prompt = resolve(config, prompt or '') + -- Resolve system prompt (handle functions) + config.system_prompt = resolve_system_prompt(config.system_prompt, state.source) + if config.system_prompt then for name, prompt in pairs(prompts_to_use) do if prompt.system_prompt then