Skip to content

Commit 2e00dba

Browse files
authored
feat: add copilot.api module (zbirenbaum#48)
* feat: cleanup and fix some util * feat: add copilot.api module
1 parent 2e3cd13 commit 2e00dba

File tree

4 files changed

+282
-63
lines changed

4 files changed

+282
-63
lines changed

lua/copilot/api.lua

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
local mod = {}
2+
3+
---@param callback? fun(err: any|nil, data: table, ctx: table): nil
4+
---@return any|nil err
5+
---@return table data
6+
---@return table ctx
7+
function mod.request(client, method, params, callback)
8+
-- hack to convert empty table to json object,
9+
-- empty table is convert to json array by default.
10+
params._ = true
11+
12+
local bufnr = params.bufnr
13+
params.bufnr = nil
14+
15+
if callback then
16+
return client.request(method, params, callback, bufnr)
17+
end
18+
19+
local co = coroutine.running()
20+
client.request(method, params, function(err, data, ctx)
21+
coroutine.resume(co, err, data, ctx)
22+
end, bufnr)
23+
return coroutine.yield()
24+
end
25+
26+
---@return boolean sent
27+
function mod.notify(client, method, params)
28+
return client.notify(method, params)
29+
end
30+
31+
---@alias copilot_check_status_data { user?: string }
32+
33+
---@return any|nil err
34+
---@return copilot_check_status_data data
35+
---@return table ctx
36+
function mod.check_status(client, callback)
37+
return mod.request(client, "checkStatus", {}, callback)
38+
end
39+
40+
---@alias copilot_sign_in_initiate_data { verificationUri?: string, userCode?: string }
41+
42+
---@return any|nil err
43+
---@return copilot_sign_in_initiate_data data
44+
---@return table ctx
45+
function mod.sign_in_initiate(client, callback)
46+
return mod.request(client, "signInInitiate", {}, callback)
47+
end
48+
49+
---@alias copilot_sign_in_confirm_params { userId: string }
50+
---@alias copilot_sign_in_confirm_data { status: string, error: { message: string }, user: string }
51+
52+
---@param params copilot_sign_in_confirm_params
53+
---@return any|nil err
54+
---@return copilot_sign_in_confirm_data data
55+
---@return table ctx
56+
function mod.sign_in_confirm(client, params, callback)
57+
return mod.request(client, "signInConfirm", params, callback)
58+
end
59+
60+
---@alias copilot_notify_accepted_params { uuid: string }
61+
62+
---@param params copilot_notify_accepted_params
63+
function mod.notify_accepted(client, params, callback)
64+
return mod.request(client, "notifyAccepted", params, callback)
65+
end
66+
67+
---@alias copilot_notify_rejected_params { uuids: string[] }
68+
69+
---@param params copilot_notify_rejected_params
70+
function mod.notify_rejected(client, params, callback)
71+
return mod.request(client, "notifyRejected", params, callback)
72+
end
73+
74+
---@alias copilot_notify_shown_params { uuid: string }
75+
76+
---@param params copilot_notify_shown_params
77+
function mod.notify_shown(client, params, callback)
78+
return mod.request(client, "notifyShown", params, callback)
79+
end
80+
81+
---@alias copilot_get_completions_data { completions: { displayText: string, position: { character: integer, line: integer }, range: { ['end']: { character: integer, line: integer }, start: { character: integer, line: integer } }, text: string, uuid: string }[] }
82+
83+
---@return any|nil err
84+
---@return copilot_get_completions_data data
85+
---@return table ctx
86+
function mod.get_completions(client, params, callback)
87+
return mod.request(client, "getCompletions", params, callback)
88+
end
89+
90+
function mod.get_completions_cycling(client, params, callback)
91+
return mod.request(client, "getCompletionsCycling", params, callback)
92+
end
93+
94+
---@alias copilot_get_panel_completions_data { solutionCountTarget: integer }
95+
---@alias copilot_panel_solution_data { panelId: string, completionText: string, displayText: string, range: { ['end']: { character: integer, line: integer }, start: { character: integer, line: integer } }, score: number, solutionId: string }
96+
---@alias copilot_panel_on_solution_handler fun(result: copilot_panel_solution_data): nil
97+
---@alias copilot_panel_solutions_done_data { panelId: string, status: 'OK'|'Error', message?: string }
98+
---@alias copilot_panel_on_solutions_done_handler fun(result: copilot_panel_solutions_done_data): nil
99+
100+
---@return any|nil err
101+
---@return copilot_get_panel_completions_data data
102+
---@return table ctx
103+
function mod.get_panel_completions(client, params, callback)
104+
return mod.request(client, "getPanelCompletions", params, callback)
105+
end
106+
107+
local panel = {
108+
callback = {
109+
PanelSolution = {},
110+
PanelSolutionsDone = {},
111+
},
112+
}
113+
114+
panel.handlers = {
115+
---@param result copilot_panel_solution_data
116+
PanelSolution = function(_, result)
117+
if panel.callback.PanelSolution[result.panelId] then
118+
panel.callback.PanelSolution[result.panelId](result)
119+
end
120+
end,
121+
122+
---@param result copilot_panel_solutions_done_data
123+
PanelSolutionsDone = function(_, result)
124+
if panel.callback.PanelSolutionsDone[result.panelId] then
125+
panel.callback.PanelSolutionsDone[result.panelId](result)
126+
end
127+
end,
128+
}
129+
130+
---@param panelId string
131+
---@param handlers { on_solution: copilot_panel_on_solution_handler, on_solutions_done: copilot_panel_on_solutions_done_handler }
132+
function mod.register_panel_handlers(panelId, handlers)
133+
assert(type(panelId) == "string", "missing panelId")
134+
panel.callback.PanelSolution[panelId] = handlers.on_solution
135+
panel.callback.PanelSolutionsDone[panelId] = handlers.on_solutions_done
136+
end
137+
138+
---@param panelId string
139+
function mod.unregister_panel_handlers(panelId)
140+
assert(type(panelId) == "string", "missing panelId")
141+
panel.callback.PanelSolution[panelId] = nil
142+
panel.callback.PanelSolutionsDone[panelId] = nil
143+
end
144+
145+
---@alias copilot_status_notification_data { status: string, message: string }
146+
147+
local status = {
148+
client_id = nil,
149+
data = {
150+
status = "",
151+
message = "",
152+
},
153+
callback = {},
154+
}
155+
156+
status.handlers = {
157+
---@param result copilot_status_notification_data
158+
---@param ctx { client_id: integer, method: string }
159+
statusNotification = function(_, result, ctx)
160+
status.client_id = ctx.client_id
161+
status.data = result
162+
163+
for callback in pairs(status.callback) do
164+
callback(status.data)
165+
end
166+
end,
167+
}
168+
169+
---@param handler fun(data: copilot_status_notification_data): nil
170+
function mod.register_status_notification_handler(handler)
171+
status.callback[handler] = true
172+
end
173+
174+
---@param handler fun(data: copilot_status_notification_data): nil
175+
function mod.unregister_status_notification_handler(handler)
176+
status.callback[handler] = nil
177+
end
178+
179+
mod.handlers = {
180+
PanelSolution = panel.handlers.PanelSolution,
181+
PanelSolutionsDone = panel.handlers.PanelSolutionsDone,
182+
statusNotification = status.handlers.statusNotification,
183+
}
184+
mod.panel = panel
185+
mod.status = status
186+
187+
return mod

lua/copilot/auth.lua

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
local util = require("copilot.util")
1+
local api = require("copilot.api")
22

33
local M = {}
44

55
function M.setup(client)
66
local function echo(message)
7-
vim.cmd('echom "[Copilot] ' .. message .. '"')
7+
vim.cmd('echom "[Copilot] ' .. tostring(message):gsub('"', '\\"') .. '"')
88
end
99

1010
local function copy_to_clipboard(str)
@@ -18,20 +18,6 @@ function M.setup(client)
1818
))
1919
end
2020

21-
local request = function(method, params)
22-
local co = coroutine.running()
23-
params.id = util.get_next_id()
24-
client.rpc.request(method, params, function(err, data)
25-
coroutine.resume(co, err, data)
26-
end)
27-
local err, data = coroutine.yield()
28-
if err then
29-
echo("Error: " .. err)
30-
error(err)
31-
end
32-
return data
33-
end
34-
3521
local function open_signin_popup(code, url)
3622
local lines = {
3723
" [Copilot] ",
@@ -68,16 +54,24 @@ function M.setup(client)
6854
end
6955

7056
local initiate_setup = coroutine.wrap(function()
71-
local data = request("checkStatus", {})
57+
local cserr, status = api.check_status(client)
58+
if cserr then
59+
echo(cserr)
60+
return
61+
end
7262

73-
if data.user then
74-
echo("Authenticated as GitHub user: " .. data.user)
63+
if status.user then
64+
echo("Authenticated as GitHub user: " .. status.user)
7565
return
7666
end
7767

78-
local signin = request("signInInitiate", {})
68+
local siierr, signin = api.sign_in_initiate(client)
69+
if siierr then
70+
echo(siierr)
71+
return
72+
end
7973

80-
if not signin.verificationUri then
74+
if not signin.verificationUri or not signin.userCode then
8175
echo("Failed to setup")
8276
return
8377
end
@@ -86,10 +80,15 @@ function M.setup(client)
8680

8781
local close_signin_popup = open_signin_popup(signin.userCode, signin.verificationUri)
8882

89-
local confirm = request("signInConfirm", { userCode = signin.userCode })
83+
local sicerr, confirm = api.sign_in_confirm(client, { userCode = signin.userCode })
9084

9185
close_signin_popup()
9286

87+
if sicerr then
88+
echo(sicerr)
89+
return
90+
end
91+
9392
if string.lower(confirm.status) ~= "ok" then
9493
echo("Authentication failure: " .. confirm.error.message)
9594
return

lua/copilot/client.lua

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
local M = { params = {} }
1+
local api = require("copilot.api")
22
local util = require("copilot.util")
33

4+
local M = { params = {} }
5+
46
local register_autocmd = function ()
57
vim.api.nvim_create_autocmd({ "BufEnter" }, {
68
callback = vim.schedule_wrap(M.buf_attach_copilot),
@@ -36,6 +38,11 @@ M.merge_server_opts = function (params)
3638
vim.schedule(M.buf_attach_copilot)
3739
vim.schedule(register_autocmd)
3840
end,
41+
handlers = {
42+
-- PanelSolution = api.handlers.PanelSolution,
43+
-- PanelSolutionsDone = api.handlers.PanelSolutionsDone,
44+
statusNotification = api.handlers.statusNotification,
45+
}
3946
}, params.server_opts_overrides or {})
4047
end
4148

0 commit comments

Comments
 (0)