refactor(lsp): extract rpc client from rpc.start

Makes the previously inner functions re-usable for a TCP client
This commit is contained in:
Mathias Fussenegger
2022-08-24 19:36:37 +02:00
parent 7d3e4aee6a
commit 46bb34e26b

View File

@ -270,6 +270,292 @@ local function create_read_loop(handle_body, on_no_chunk, on_error)
end
end
---@class RpcClient
---@field message_index number
---@field message_callbacks table
---@field notify_reply_callbacks table
---@field transport table
---@field dispatchers table
---@class RpcClient
local Client = {}
---@private
function Client:encode_and_send(payload)
local _ = log.debug() and log.debug('rpc.send', payload)
if self.transport.is_closing() then
return false
end
local encoded = vim.json.encode(payload)
self.transport.write(format_message_with_content_length(encoded))
return true
end
---@private
--- Sends a notification to the LSP server.
---@param method (string) The invoked LSP method
---@param params (table|nil): Parameters for the invoked LSP method
---@returns (bool) `true` if notification could be sent, `false` if not
function Client:notify(method, params)
return self:encode_and_send({
jsonrpc = '2.0',
method = method,
params = params,
})
end
---@private
--- sends an error object to the remote LSP process.
function Client:send_response(request_id, err, result)
return self:encode_and_send({
id = request_id,
jsonrpc = '2.0',
error = err,
result = result,
})
end
---@private
--- Sends a request to the LSP server and runs {callback} upon response.
---
---@param method (string) The invoked LSP method
---@param params (table|nil) Parameters for the invoked LSP method
---@param callback (function) Callback to invoke
---@param notify_reply_callback (function|nil) Callback to invoke as soon as a request is no longer pending
---@returns (bool, number) `(true, message_id)` if request could be sent, `false` if not
function Client:request(method, params, callback, notify_reply_callback)
validate({
callback = { callback, 'f' },
notify_reply_callback = { notify_reply_callback, 'f', true },
})
self.message_index = self.message_index + 1
local message_id = self.message_index
local result = self:encode_and_send({
id = message_id,
jsonrpc = '2.0',
method = method,
params = params,
})
local message_callbacks = self.message_callbacks
local notify_reply_callbacks = self.notify_reply_callbacks
if result then
if message_callbacks then
message_callbacks[message_id] = schedule_wrap(callback)
else
return false
end
if notify_reply_callback and notify_reply_callbacks then
notify_reply_callbacks[message_id] = schedule_wrap(notify_reply_callback)
end
return result, message_id
else
return false
end
end
---@private
function Client:on_error(errkind, ...)
assert(client_errors[errkind])
-- TODO what to do if this fails?
pcall(self.dispatchers.on_error, errkind, ...)
end
---@private
function Client:pcall_handler(errkind, status, head, ...)
if not status then
self:on_error(errkind, head, ...)
return status, head
end
return status, head, ...
end
---@private
function Client:try_call(errkind, fn, ...)
return self:pcall_handler(errkind, pcall(fn, ...))
end
-- TODO periodically check message_callbacks for old requests past a certain
-- time and log them. This would require storing the timestamp. I could call
-- them with an error then, perhaps.
---@private
function Client:handle_body(body)
local ok, decoded = pcall(vim.json.decode, body, { luanil = { object = true } })
if not ok then
self:on_error(client_errors.INVALID_SERVER_JSON, decoded)
return
end
local _ = log.debug() and log.debug('rpc.receive', decoded)
if type(decoded.method) == 'string' and decoded.id then
local err
-- Schedule here so that the users functions don't trigger an error and
-- we can still use the result.
schedule(function()
local status, result
status, result, err = self:try_call(
client_errors.SERVER_REQUEST_HANDLER_ERROR,
self.dispatchers.server_request,
decoded.method,
decoded.params
)
local _ = log.debug()
and log.debug(
'server_request: callback result',
{ status = status, result = result, err = err }
)
if status then
if not (result or err) then
-- TODO this can be a problem if `null` is sent for result. needs vim.NIL
error(
string.format(
'method %q: either a result or an error must be sent to the server in response',
decoded.method
)
)
end
if err then
assert(
type(err) == 'table',
'err must be a table. Use rpc_response_error to help format errors.'
)
local code_name = assert(
protocol.ErrorCodes[err.code],
'Errors must use protocol.ErrorCodes. Use rpc_response_error to help format errors.'
)
err.message = err.message or code_name
end
else
-- On an exception, result will contain the error message.
err = rpc_response_error(protocol.ErrorCodes.InternalError, result)
result = nil
end
self:send_response(decoded.id, err, result)
end)
-- This works because we are expecting vim.NIL here
elseif decoded.id and (decoded.result ~= vim.NIL or decoded.error ~= vim.NIL) then
-- We sent a number, so we expect a number.
local result_id = assert(tonumber(decoded.id), 'response id must be a number')
-- Notify the user that a response was received for the request
local notify_reply_callbacks = self.notify_reply_callbacks
local notify_reply_callback = notify_reply_callbacks and notify_reply_callbacks[result_id]
if notify_reply_callback then
validate({
notify_reply_callback = { notify_reply_callback, 'f' },
})
notify_reply_callback(result_id)
notify_reply_callbacks[result_id] = nil
end
local message_callbacks = self.message_callbacks
-- Do not surface RequestCancelled to users, it is RPC-internal.
if decoded.error then
local mute_error = false
if decoded.error.code == protocol.ErrorCodes.RequestCancelled then
local _ = log.debug() and log.debug('Received cancellation ack', decoded)
mute_error = true
end
if mute_error then
-- Clear any callback since this is cancelled now.
-- This is safe to do assuming that these conditions hold:
-- - The server will not send a result callback after this cancellation.
-- - If the server sent this cancellation ACK after sending the result, the user of this RPC
-- client will ignore the result themselves.
if result_id and message_callbacks then
message_callbacks[result_id] = nil
end
return
end
end
local callback = message_callbacks and message_callbacks[result_id]
if callback then
message_callbacks[result_id] = nil
validate({
callback = { callback, 'f' },
})
if decoded.error then
decoded.error = setmetatable(decoded.error, {
__tostring = format_rpc_error,
})
end
self:try_call(
client_errors.SERVER_RESULT_CALLBACK_ERROR,
callback,
decoded.error,
decoded.result
)
else
self:on_error(client_errors.NO_RESULT_CALLBACK_FOUND, decoded)
local _ = log.error() and log.error('No callback found for server response id ' .. result_id)
end
elseif type(decoded.method) == 'string' then
-- Notification
self:try_call(
client_errors.NOTIFICATION_HANDLER_ERROR,
self.dispatchers.notification,
decoded.method,
decoded.params
)
else
-- Invalid server message
self:on_error(client_errors.INVALID_SERVER_MESSAGE, decoded)
end
end
---@private
---@return RpcClient
local function new_client(dispatchers, transport)
local state = {
message_index = 0,
message_callbacks = {},
notify_reply_callbacks = {},
transport = transport,
dispatchers = dispatchers,
}
return setmetatable(state, { __index = Client })
end
---@private
---@param client RpcClient
local function public_client(client)
local result = {}
---@private
function result.is_closing()
return client.transport.is_closing()
end
---@private
function result.terminate()
client.transport.terminate()
end
--- Sends a request to the LSP server and runs {callback} upon response.
---
---@param method (string) The invoked LSP method
---@param params (table|nil) Parameters for the invoked LSP method
---@param callback (function) Callback to invoke
---@param notify_reply_callback (function|nil) Callback to invoke as soon as a request is no longer pending
---@returns (bool, number) `(true, message_id)` if request could be sent, `false` if not
function result.request(method, params, callback, notify_reply_callback)
return client:request(method, params, callback, notify_reply_callback)
end
--- Sends a notification to the LSP server.
---@param method (string) The invoked LSP method
---@param params (table|nil): Parameters for the invoked LSP method
---@returns (bool) `true` if notification could be sent, `false` if not
function result.notify(method, params)
return client:notify(method, params)
end
return result
end
--- Starts an LSP server process and create an LSP RPC client object to
--- interact with it. Communication with the server is currently limited to stdio.
---
@ -334,134 +620,59 @@ local function start(cmd, cmd_args, dispatchers, extra_spawn_params)
local stdin = uv.new_pipe(false)
local stdout = uv.new_pipe(false)
local stderr = uv.new_pipe(false)
local message_index = 0
local message_callbacks = {}
local notify_reply_callbacks = {}
local handle, pid
do
---@private
--- Callback for |vim.loop.spawn()| Closes all streams and runs the `on_exit` dispatcher.
---@param code (number) Exit code
---@param signal (number) Signal that was used to terminate (if any)
local function onexit(code, signal)
stdin:close()
stdout:close()
stderr:close()
handle:close()
-- Make sure that message_callbacks/notify_reply_callbacks can be gc'd.
message_callbacks = nil
notify_reply_callbacks = nil
dispatchers.on_exit(code, signal)
end
local spawn_params = {
args = cmd_args,
stdio = { stdin, stdout, stderr },
detached = not is_win,
}
if extra_spawn_params then
spawn_params.cwd = extra_spawn_params.cwd
spawn_params.env = env_merge(extra_spawn_params.env)
if extra_spawn_params.detached ~= nil then
spawn_params.detached = extra_spawn_params.detached
local client = new_client(dispatchers, {
write = function(msg)
stdin:write(msg)
end,
is_closing = function()
return handle == nil or handle:is_closing()
end,
terminate = function()
if handle then
handle:kill(15)
end
end
handle, pid = uv.spawn(cmd, spawn_params, onexit)
if handle == nil then
stdin:close()
stdout:close()
stderr:close()
local msg = string.format('Spawning language server with cmd: `%s` failed', cmd)
if string.match(pid, 'ENOENT') then
msg = msg
.. '. The language server is either not installed, missing from PATH, or not executable.'
else
msg = msg .. string.format(' with error message: %s', pid)
end
vim.notify(msg, vim.log.levels.WARN)
return
end
end
end,
})
---@private
--- Encodes {payload} into a JSON-RPC message and sends it to the remote
--- process.
---
---@param payload table
---@returns true if the payload could be scheduled, false if the main event-loop is in the process of closing.
local function encode_and_send(payload)
local _ = log.debug() and log.debug('rpc.send', payload)
if handle == nil or handle:is_closing() then
return false
--- Callback for |vim.loop.spawn()| Closes all streams and runs the `on_exit` dispatcher.
---@param code (number) Exit code
---@param signal (number) Signal that was used to terminate (if any)
local function onexit(code, signal)
stdin:close()
stdout:close()
stderr:close()
handle:close()
dispatchers.on_exit(code, signal)
end
local spawn_params = {
args = cmd_args,
stdio = { stdin, stdout, stderr },
detached = not is_win,
}
if extra_spawn_params then
spawn_params.cwd = extra_spawn_params.cwd
spawn_params.env = env_merge(extra_spawn_params.env)
if extra_spawn_params.detached ~= nil then
spawn_params.detached = extra_spawn_params.detached
end
local encoded = vim.json.encode(payload)
stdin:write(format_message_with_content_length(encoded))
return true
end
-- FIXME: DOC: Should be placed on the RPC client object returned by
-- `start()`
--
--- Sends a notification to the LSP server.
---@param method (string) The invoked LSP method
---@param params (table|nil): Parameters for the invoked LSP method
---@returns (bool) `true` if notification could be sent, `false` if not
local function notify(method, params)
return encode_and_send({
jsonrpc = '2.0',
method = method,
params = params,
})
end
---@private
--- sends an error object to the remote LSP process.
local function send_response(request_id, err, result)
return encode_and_send({
id = request_id,
jsonrpc = '2.0',
error = err,
result = result,
})
end
-- FIXME: DOC: Should be placed on the RPC client object returned by
-- `start()`
--
--- Sends a request to the LSP server and runs {callback} upon response.
---
---@param method (string) The invoked LSP method
---@param params (table|nil) Parameters for the invoked LSP method
---@param callback (function) Callback to invoke
---@param notify_reply_callback (function|nil) Callback to invoke as soon as a request is no longer pending
---@returns (bool, number) `(true, message_id)` if request could be sent, `false` if not
local function request(method, params, callback, notify_reply_callback)
validate({
callback = { callback, 'f' },
notify_reply_callback = { notify_reply_callback, 'f', true },
})
message_index = message_index + 1
local message_id = message_index
local result = encode_and_send({
id = message_id,
jsonrpc = '2.0',
method = method,
params = params,
})
if result then
if message_callbacks then
message_callbacks[message_id] = schedule_wrap(callback)
else
return false
end
if notify_reply_callback and notify_reply_callbacks then
notify_reply_callbacks[message_id] = schedule_wrap(notify_reply_callback)
end
return result, message_id
handle, pid = uv.spawn(cmd, spawn_params, onexit)
if handle == nil then
stdin:close()
stdout:close()
stderr:close()
local msg = string.format('Spawning language server with cmd: `%s` failed', cmd)
if string.match(pid, 'ENOENT') then
msg = msg
.. '. The language server is either not installed, missing from PATH, or not executable.'
else
return false
msg = msg .. string.format(' with error message: %s', pid)
end
vim.notify(msg, vim.log.levels.WARN)
return
end
stderr:read_start(function(_, chunk)
@ -470,171 +681,14 @@ local function start(cmd, cmd_args, dispatchers, extra_spawn_params)
end
end)
---@private
local function on_error(errkind, ...)
assert(client_errors[errkind])
-- TODO what to do if this fails?
pcall(dispatchers.on_error, errkind, ...)
local handle_body = function(body)
client:handle_body(body)
end
---@private
local function pcall_handler(errkind, status, head, ...)
if not status then
on_error(errkind, head, ...)
return status, head
end
return status, head, ...
end
---@private
local function try_call(errkind, fn, ...)
return pcall_handler(errkind, pcall(fn, ...))
end
-- TODO periodically check message_callbacks for old requests past a certain
-- time and log them. This would require storing the timestamp. I could call
-- them with an error then, perhaps.
---@private
local function handle_body(body)
local ok, decoded = pcall(vim.json.decode, body, { luanil = { object = true } })
if not ok then
on_error(client_errors.INVALID_SERVER_JSON, decoded)
return
end
local _ = log.debug() and log.debug('rpc.receive', decoded)
if type(decoded.method) == 'string' and decoded.id then
local err
-- Schedule here so that the users functions don't trigger an error and
-- we can still use the result.
schedule(function()
local status, result
status, result, err = try_call(
client_errors.SERVER_REQUEST_HANDLER_ERROR,
dispatchers.server_request,
decoded.method,
decoded.params
)
local _ = log.debug()
and log.debug(
'server_request: callback result',
{ status = status, result = result, err = err }
)
if status then
if not (result or err) then
-- TODO this can be a problem if `null` is sent for result. needs vim.NIL
error(
string.format(
'method %q: either a result or an error must be sent to the server in response',
decoded.method
)
)
end
if err then
assert(
type(err) == 'table',
'err must be a table. Use rpc_response_error to help format errors.'
)
local code_name = assert(
protocol.ErrorCodes[err.code],
'Errors must use protocol.ErrorCodes. Use rpc_response_error to help format errors.'
)
err.message = err.message or code_name
end
else
-- On an exception, result will contain the error message.
err = rpc_response_error(protocol.ErrorCodes.InternalError, result)
result = nil
end
send_response(decoded.id, err, result)
end)
-- This works because we are expecting vim.NIL here
elseif decoded.id and (decoded.result ~= vim.NIL or decoded.error ~= vim.NIL) then
-- We sent a number, so we expect a number.
local result_id = assert(tonumber(decoded.id), 'response id must be a number')
-- Notify the user that a response was received for the request
local notify_reply_callback = notify_reply_callbacks and notify_reply_callbacks[result_id]
if notify_reply_callback then
validate({
notify_reply_callback = { notify_reply_callback, 'f' },
})
notify_reply_callback(result_id)
notify_reply_callbacks[result_id] = nil
end
-- Do not surface RequestCancelled to users, it is RPC-internal.
if decoded.error then
local mute_error = false
if decoded.error.code == protocol.ErrorCodes.RequestCancelled then
local _ = log.debug() and log.debug('Received cancellation ack', decoded)
mute_error = true
end
if mute_error then
-- Clear any callback since this is cancelled now.
-- This is safe to do assuming that these conditions hold:
-- - The server will not send a result callback after this cancellation.
-- - If the server sent this cancellation ACK after sending the result, the user of this RPC
-- client will ignore the result themselves.
if result_id and message_callbacks then
message_callbacks[result_id] = nil
end
return
end
end
local callback = message_callbacks and message_callbacks[result_id]
if callback then
message_callbacks[result_id] = nil
validate({
callback = { callback, 'f' },
})
if decoded.error then
decoded.error = setmetatable(decoded.error, {
__tostring = format_rpc_error,
})
end
try_call(
client_errors.SERVER_RESULT_CALLBACK_ERROR,
callback,
decoded.error,
decoded.result
)
else
on_error(client_errors.NO_RESULT_CALLBACK_FOUND, decoded)
local _ = log.error()
and log.error('No callback found for server response id ' .. result_id)
end
elseif type(decoded.method) == 'string' then
-- Notification
try_call(
client_errors.NOTIFICATION_HANDLER_ERROR,
dispatchers.notification,
decoded.method,
decoded.params
)
else
-- Invalid server message
on_error(client_errors.INVALID_SERVER_MESSAGE, decoded)
end
end
local request_parser = coroutine.wrap(request_parser_loop)
request_parser()
stdout:read_start(create_read_loop(handle_body, nil, function(err)
on_error(client_errors.READ_ERROR, err)
client:on_error(client_errors.READ_ERROR, err)
end))
return {
is_closing = function()
return handle:is_closing()
end,
terminate = function()
handle:kill(15)
end,
request = request,
notify = notify,
}
return public_client(client)
end
return {