Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions lua/CopilotChat/functions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ local utils = require('CopilotChat.utils')
local M = {}

local INPUT_SEPARATOR = ';;'
local URI_PARAM_PATTERN = '{([^}:*]+)[^}]*}'

local function sorted_propnames(schema)
local prop_names = vim.tbl_keys(schema.properties)
Expand Down Expand Up @@ -63,6 +64,17 @@ local function filter_schema(tbl)
return result
end

--- Convert a URI template to a URL by replacing parameters with values from input
---@param uri_template string The URI template containing parameters in the form {param}
---@param input table A table containing parameter values, e.g., { path = '/my/file.txt' }
---@return string The resulting URL with parameters replaced
function M.uri_to_url(uri_template, input)
-- Replace {param} in the template with input[param] or empty string
return (uri_template:gsub(URI_PARAM_PATTERN, function(param)
return input[param] or ''
end))
end

---@param uri string The URI to parse
---@param pattern string The pattern to match against (e.g., 'file://{path}')
---@return table|nil inputs Extracted parameters or nil if no match
Expand All @@ -73,7 +85,7 @@ function M.match_uri(uri, pattern)

-- Extract parameter names from the pattern
local param_names = {}
for param in pattern:gmatch('{([^}:*]+)[^}]*}') do
for param in pattern:gmatch(URI_PARAM_PATTERN) do
table.insert(param_names, param)
-- Replace {param} with a capture group in our Lua pattern
-- Use non-greedy capture to handle multiple params properly
Expand Down Expand Up @@ -102,6 +114,37 @@ function M.match_uri(uri, pattern)
return result
end

---@param tool CopilotChat.config.functions.Function
function M.parse_schema(tool)
local schema = tool.schema

-- If schema is missing but uri is present, generate a default schema from uri
if not schema and tool.uri then
-- Extract parameter names from the uri pattern, e.g. file://{path}
local param_names = {}
for param in tool.uri:gmatch(URI_PARAM_PATTERN) do
table.insert(param_names, param)
end
if #param_names > 0 then
schema = {
type = 'object',
properties = {},
required = {},
}
for _, param in ipairs(param_names) do
schema.properties[param] = { type = 'string' }
table.insert(schema.required, param)
end
end
end

if schema then
schema = filter_schema(schema)
end

return schema
end

--- Prepare the schema for use
---@param tools table<string, CopilotChat.config.functions.Function>
---@return table<CopilotChat.client.Tool>
Expand All @@ -110,16 +153,11 @@ function M.parse_tools(tools)
table.sort(tool_names)
return vim.tbl_map(function(name)
local tool = tools[name]
local schema = tool.schema

if schema then
schema = filter_schema(schema)
end

return {
name = name,
description = tool.description,
schema = schema,
schema = M.parse_schema(tool),
}
end, tool_names)
end
Expand Down
34 changes: 21 additions & 13 deletions lua/CopilotChat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,12 @@ end
---@async
function M.resolve_functions(prompt, config)
config, prompt = M.resolve_prompt(prompt, config)

local tools = {}
for _, tool in ipairs(functions.parse_tools(M.config.functions)) do
tools[tool.name] = tool
end

local enabled_tools = {}
local resolved_resources = {}
local resolved_tools = {}
Expand All @@ -271,7 +277,7 @@ function M.resolve_functions(prompt, config)
for _, match in ipairs(matches) do
for name, tool in pairs(M.config.functions) do
if name == match or tool.group == match then
enabled_tools[name] = tool
enabled_tools[name] = true
end
end
end
Expand Down Expand Up @@ -311,15 +317,15 @@ function M.resolve_functions(prompt, config)
local tool_id = nil
if not utils.empty(tool_calls) then
for _, tool_call in ipairs(tool_calls) do
if tool_call.name == name and vim.trim(tool_call.id) == vim.trim(input) and enabled_tools[name] then
if tool_call.name == name and vim.trim(tool_call.id) == vim.trim(input) then
input = utils.empty(tool_call.arguments) and {} or utils.json_decode(tool_call.arguments)
tool_id = tool_call.id
break
end
end
end

local tool = enabled_tools[name]
local tool = M.config.functions[name]
if not tool then
-- Check if input matches uri
for tool_name, tool_spec in pairs(M.config.functions) do
Expand All @@ -334,20 +340,16 @@ function M.resolve_functions(prompt, config)
end
end
end
if not tool and not tool_id then
tool = M.config.functions[name]
end
if not tool then
-- If tool is not found, return the original pattern
return nil
end
if not tool_id and not tool.uri then
-- If this is a tool that is not resource and was not called by LLM, reject it
if tool_id and not enabled_tools[name] and not tool.uri then
return nil
end

local schema = tools[name] and tools[name].schema or nil
local result = ''
local ok, output = pcall(tool.resolve, functions.parse_input(input, tool.schema), state.source or {}, prompt)
local ok, output = pcall(tool.resolve, functions.parse_input(input, schema), state.source or {}, input)
if not ok then
result = string.format(BLOCK_OUTPUT_FORMAT, 'error', utils.make_string(output))
else
Expand Down Expand Up @@ -394,7 +396,12 @@ function M.resolve_functions(prompt, config)
end
end

return functions.parse_tools(enabled_tools), resolved_resources, resolved_tools, prompt
return vim.tbl_map(function(name)
return tools[name]
end, vim.tbl_keys(enabled_tools)),
resolved_resources,
resolved_tools,
prompt
end

--- Resolve the final prompt and config from prompt template.
Expand Down Expand Up @@ -574,9 +581,10 @@ function M.trigger_complete(without_input)

if not without_input and vim.startswith(prefix, '#') and vim.endswith(prefix, ':') then
local found_tool = M.config.functions[prefix:sub(2, -2)]
if found_tool and found_tool.schema then
local found_schema = found_tool and functions.parse_schema(found_tool)
if found_tool and found_schema then
async.run(function()
local value = functions.enter_input(found_tool.schema, state.source)
local value = functions.enter_input(found_schema, state.source)
if not value then
return
end
Expand Down
10 changes: 8 additions & 2 deletions lua/CopilotChat/tiktoken.lua
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,13 @@ function M.encode(prompt)
if type(prompt) ~= 'string' then
error('Prompt must be a string')
end
return tiktoken_core.encode(prompt)

local ok, result = pcall(tiktoken_core.encode, prompt)
if not ok then
return nil
end

return result
end

--- Count the tokens in a prompt
Expand All @@ -105,7 +111,7 @@ function M.count(prompt)

local tokens = M.encode(prompt)
if not tokens then
return 0
return math.ceil(#prompt * 0.5) -- Fallback to 1/2 character count
end
return #tokens
end
Expand Down
Loading