Skip to content

Commit

Permalink
Merge pull request #234 from talex5/bootstrap-crash
Browse files Browse the repository at this point in the history
Don't crash if the peer disconnects before the bootstrap reply is ready
  • Loading branch information
talex5 authored Jun 10, 2021
2 parents fe55298 + db9dd09 commit 9f69fe7
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 13 deletions.
14 changes: 9 additions & 5 deletions capnp-rpc/capTP.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1477,18 +1477,22 @@ module Make (EP : Message_types.ENDPOINT) = struct
let answer = Answer.create id ~answer:promise in
Answers.set t.answers id answer;
object_id |> t.restore @@ fun service ->
let results =
match service with
if Answer.needs_return answer && t.disconnected = None then (
let results =
match service with
| Error ex -> Error (`Exception ex)
| Ok service ->
let msg =
Wire.Response.bootstrap ()
|> Core_types.Response_payload.with_caps (RO_array.of_list [service])
in
Ok msg
in
Core_types.resolve_payload answer_resolver results;
Send.return t answer results
in
Core_types.resolve_payload answer_resolver results;
Send.return t answer results
) else (
Result.iter dec_ref service
)

let return_results t question msg descrs =
let caps_used = Question.paths_used question |> caps_used ~msg in
Expand Down
43 changes: 35 additions & 8 deletions test-lwt/test_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,30 @@ let () = Logs.(set_level (Some Logs.Info))

let server_pem = `PEM (Auth.Secret_key.to_pem_data server_key)

let make_vats ?(serve_tls=false) ~switch ~service () =
let id = Restorer.Id.public "" in
let restore = Restorer.single id service in
let make_vats_full ?(serve_tls=false) ~client_switch ~server_switch ~restore () =
let server_config =
let socket_path = Filename.(concat (Filename.get_temp_dir_name ())) "capnp-rpc-test-server" in
Lwt_switch.add_hook (Some switch) (fun () -> Lwt.return @@ ensure_removed socket_path);
Lwt_switch.add_hook (Some server_switch) (fun () -> Lwt.return @@ ensure_removed socket_path);
Capnp_rpc_unix.Vat_config.create ~secret_key:server_pem ~serve_tls (`Unix socket_path)
in
let server_switch = Lwt_switch.create () in
Capnp_rpc_unix.serve ~switch:server_switch ~tags:Test_utils.server_tags ~restore server_config >>= fun server ->
Lwt_switch.add_hook (Some switch) (fun () -> Lwt_switch.turn_off server_switch);
Lwt_switch.add_hook (Some switch) (fun () -> Capability.dec_ref service; Lwt.return_unit);
Lwt.return {
client = Vat.create ~switch ~tags:Test_utils.client_tags ~secret_key:(lazy client_key) ();
client = Vat.create ~switch:client_switch ~tags:Test_utils.client_tags ~secret_key:(lazy client_key) ();
server;
client_key;
server_key;
serve_tls;
server_switch;
}

let make_vats ?serve_tls ~switch ~service () =
let server_switch = Lwt_switch.create () in
Lwt_switch.add_hook (Some switch) (fun () -> Lwt_switch.turn_off server_switch);
let id = Restorer.Id.public "" in
let restore = Restorer.single id service in
Lwt_switch.add_hook (Some switch) (fun () -> Capability.dec_ref service; Lwt.return_unit);
make_vats_full ?serve_tls ~client_switch:switch ~server_switch ~restore ()

(* Generic Lwt running for Alcotest. *)
let run_lwt name ?(expected_warnings=0) fn =
Alcotest_lwt.test_case name `Quick @@ fun sw () ->
Expand Down Expand Up @@ -665,6 +668,29 @@ let test_await_settled _switch =
Alcotest.(check (result unit capnp_error)) "Check await failure" (Error err) check;
Lwt.return_unit

(* The client disconnects before the server has finished loading the bootstrap object. *)
let test_late_bootstrap switch =
let connected, set_connected = Lwt.wait () in
let service, set_service = Lwt.wait () in
let module Loader = struct
type t = unit
let hash () = `SHA256
let make_sturdy () _id = assert false
let load () _sr _name =
Lwt.wakeup_later set_connected ();
service
end in
let table = Capnp_rpc_net.Restorer.Table.of_loader (module Loader) () in
let restore = Restorer.of_table table in
let client_switch = Lwt_switch.create () in
make_vats_full ~client_switch ~server_switch:switch ~restore () >>= fun cs ->
let service = get_bootstrap cs in
connected >>= fun () ->
Lwt_switch.turn_off client_switch >>= fun () ->
Lwt.wakeup set_service @@ Capnp_rpc_net.Restorer.grant @@ Echo.local ();
service >>= fun _ ->
Lwt.return ()

let run name fn = Alcotest_lwt.test_case_sync name `Quick fn

let rpc_tests = [
Expand Down Expand Up @@ -694,6 +720,7 @@ let rpc_tests = [
run_lwt "Store" test_store;
run_lwt "File store" test_file_store;
run_lwt "Await settled" test_await_settled;
run_lwt "Late bootstrap" test_late_bootstrap;
]

let () =
Expand Down

0 comments on commit 9f69fe7

Please sign in to comment.