From 8b2a6a8857be4e50a4affb82b3d6521bed1c2053 Mon Sep 17 00:00:00 2001 From: Tomas Slusny Date: Wed, 30 Jul 2025 11:10:54 +0200 Subject: [PATCH] feat(functions): automatically parse schema from url templates Signed-off-by: Tomas Slusny --- lua/CopilotChat/functions.lua | 52 ++++++++++++++++++++++++++++++----- lua/CopilotChat/init.lua | 34 ++++++++++++++--------- lua/CopilotChat/tiktoken.lua | 10 +++++-- 3 files changed, 74 insertions(+), 22 deletions(-) diff --git a/lua/CopilotChat/functions.lua b/lua/CopilotChat/functions.lua index dbcf5bcd..04c5b103 100644 --- a/lua/CopilotChat/functions.lua +++ b/lua/CopilotChat/functions.lua @@ -3,6 +3,7 @@ local utils = require('CopilotChat.utils') local M = {} local INPUT_SEPARATOR = ';;' +local URI_PARAM_PATTERN = '{([^}:*]+)[^}]*}' local function sorted_propnames(schema) local prop_names = vim.tbl_keys(schema.properties) @@ -63,6 +64,17 @@ local function filter_schema(tbl) return result end +--- Convert a URI template to a URL by replacing parameters with values from input +---@param uri_template string The URI template containing parameters in the form {param} +---@param input table A table containing parameter values, e.g., { path = '/my/file.txt' } +---@return string The resulting URL with parameters replaced +function M.uri_to_url(uri_template, input) + -- Replace {param} in the template with input[param] or empty string + return (uri_template:gsub(URI_PARAM_PATTERN, function(param) + return input[param] or '' + end)) +end + ---@param uri string The URI to parse ---@param pattern string The pattern to match against (e.g., 'file://{path}') ---@return table|nil inputs Extracted parameters or nil if no match @@ -73,7 +85,7 @@ function M.match_uri(uri, pattern) -- Extract parameter names from the pattern local param_names = {} - for param in pattern:gmatch('{([^}:*]+)[^}]*}') do + for param in pattern:gmatch(URI_PARAM_PATTERN) do table.insert(param_names, param) -- Replace {param} with a capture group in our Lua pattern -- Use non-greedy capture to handle multiple params properly @@ -102,6 +114,37 @@ function M.match_uri(uri, pattern) return result end +---@param tool CopilotChat.config.functions.Function +function M.parse_schema(tool) + local schema = tool.schema + + -- If schema is missing but uri is present, generate a default schema from uri + if not schema and tool.uri then + -- Extract parameter names from the uri pattern, e.g. file://{path} + local param_names = {} + for param in tool.uri:gmatch(URI_PARAM_PATTERN) do + table.insert(param_names, param) + end + if #param_names > 0 then + schema = { + type = 'object', + properties = {}, + required = {}, + } + for _, param in ipairs(param_names) do + schema.properties[param] = { type = 'string' } + table.insert(schema.required, param) + end + end + end + + if schema then + schema = filter_schema(schema) + end + + return schema +end + --- Prepare the schema for use ---@param tools table ---@return table @@ -110,16 +153,11 @@ function M.parse_tools(tools) table.sort(tool_names) return vim.tbl_map(function(name) local tool = tools[name] - local schema = tool.schema - - if schema then - schema = filter_schema(schema) - end return { name = name, description = tool.description, - schema = schema, + schema = M.parse_schema(tool), } end, tool_names) end diff --git a/lua/CopilotChat/init.lua b/lua/CopilotChat/init.lua index 71e80f3f..91aaf143 100644 --- a/lua/CopilotChat/init.lua +++ b/lua/CopilotChat/init.lua @@ -245,6 +245,12 @@ end ---@async function M.resolve_functions(prompt, config) config, prompt = M.resolve_prompt(prompt, config) + + local tools = {} + for _, tool in ipairs(functions.parse_tools(M.config.functions)) do + tools[tool.name] = tool + end + local enabled_tools = {} local resolved_resources = {} local resolved_tools = {} @@ -271,7 +277,7 @@ function M.resolve_functions(prompt, config) for _, match in ipairs(matches) do for name, tool in pairs(M.config.functions) do if name == match or tool.group == match then - enabled_tools[name] = tool + enabled_tools[name] = true end end end @@ -311,7 +317,7 @@ function M.resolve_functions(prompt, config) local tool_id = nil if not utils.empty(tool_calls) then for _, tool_call in ipairs(tool_calls) do - if tool_call.name == name and vim.trim(tool_call.id) == vim.trim(input) and enabled_tools[name] then + if tool_call.name == name and vim.trim(tool_call.id) == vim.trim(input) then input = utils.empty(tool_call.arguments) and {} or utils.json_decode(tool_call.arguments) tool_id = tool_call.id break @@ -319,7 +325,7 @@ function M.resolve_functions(prompt, config) end end - local tool = enabled_tools[name] + local tool = M.config.functions[name] if not tool then -- Check if input matches uri for tool_name, tool_spec in pairs(M.config.functions) do @@ -334,20 +340,16 @@ function M.resolve_functions(prompt, config) end end end - if not tool and not tool_id then - tool = M.config.functions[name] - end if not tool then - -- If tool is not found, return the original pattern return nil end - if not tool_id and not tool.uri then - -- If this is a tool that is not resource and was not called by LLM, reject it + if tool_id and not enabled_tools[name] and not tool.uri then return nil end + local schema = tools[name] and tools[name].schema or nil local result = '' - local ok, output = pcall(tool.resolve, functions.parse_input(input, tool.schema), state.source or {}, prompt) + local ok, output = pcall(tool.resolve, functions.parse_input(input, schema), state.source or {}, input) if not ok then result = string.format(BLOCK_OUTPUT_FORMAT, 'error', utils.make_string(output)) else @@ -394,7 +396,12 @@ function M.resolve_functions(prompt, config) end end - return functions.parse_tools(enabled_tools), resolved_resources, resolved_tools, prompt + return vim.tbl_map(function(name) + return tools[name] + end, vim.tbl_keys(enabled_tools)), + resolved_resources, + resolved_tools, + prompt end --- Resolve the final prompt and config from prompt template. @@ -574,9 +581,10 @@ function M.trigger_complete(without_input) if not without_input and vim.startswith(prefix, '#') and vim.endswith(prefix, ':') then local found_tool = M.config.functions[prefix:sub(2, -2)] - if found_tool and found_tool.schema then + local found_schema = found_tool and functions.parse_schema(found_tool) + if found_tool and found_schema then async.run(function() - local value = functions.enter_input(found_tool.schema, state.source) + local value = functions.enter_input(found_schema, state.source) if not value then return end diff --git a/lua/CopilotChat/tiktoken.lua b/lua/CopilotChat/tiktoken.lua index a4582cb4..dde3d2b5 100644 --- a/lua/CopilotChat/tiktoken.lua +++ b/lua/CopilotChat/tiktoken.lua @@ -92,7 +92,13 @@ function M.encode(prompt) if type(prompt) ~= 'string' then error('Prompt must be a string') end - return tiktoken_core.encode(prompt) + + local ok, result = pcall(tiktoken_core.encode, prompt) + if not ok then + return nil + end + + return result end --- Count the tokens in a prompt @@ -105,7 +111,7 @@ function M.count(prompt) local tokens = M.encode(prompt) if not tokens then - return 0 + return math.ceil(#prompt * 0.5) -- Fallback to 1/2 character count end return #tokens end