diff --git a/build-with.sh b/build-with.sh index c54d999..5252f23 100755 --- a/build-with.sh +++ b/build-with.sh @@ -20,5 +20,5 @@ $builder build -t qubes-mirage-firewall . echo Building Firewall... $builder run --rm -i -v `pwd`:/tmp/orb-build:Z qubes-mirage-firewall echo "SHA2 of build: $(sha256sum ./dist/qubes-firewall.xen)" -echo "SHA2 last known: 4b1f743bf4540bc8a9366cf8f23a78316e4f2d477af77962e50618753c4adf10" +echo "SHA2 last known: 2392386d9056b17a648f26b0c5d1c72b93f8a197964c670b2b45e71707727317" echo "(hashes should match for released versions)" diff --git a/client_eth.ml b/client_eth.ml index de41f70..fc0b01a 100644 --- a/client_eth.ml +++ b/client_eth.ml @@ -8,7 +8,7 @@ let src = Logs.Src.create "client_eth" ~doc:"Ethernet networks for NetVM clients module Log = (val Logs.src_log src : Logs.LOG) type t = { - mutable iface_of_ip : client_link IpMap.t; + mutable iface_of_ip : client_link Ipaddr.V4.Map.t; changed : unit Lwt_condition.t; (* Fires when [iface_of_ip] changes. *) my_ip : Ipaddr.V4.t; (* The IP that clients are given as their default gateway. *) } @@ -21,21 +21,21 @@ type host = let create config = let changed = Lwt_condition.create () in let my_ip = config.Dao.our_ip in - Lwt.return { iface_of_ip = IpMap.empty; my_ip; changed } + Lwt.return { iface_of_ip = Ipaddr.V4.Map.empty; my_ip; changed } let client_gw t = t.my_ip let add_client t iface = let ip = iface#other_ip in let rec aux () = - match IpMap.find ip t.iface_of_ip with + match Ipaddr.V4.Map.find_opt ip t.iface_of_ip with | Some old -> (* Wait for old client to disappear before adding one with the same IP address. Otherwise, its [remove_client] call will remove the new client instead. *) Log.info (fun f -> f ~header:iface#log_header "Waiting for old client %s to go away before accepting new one" old#log_header); Lwt_condition.wait t.changed >>= aux | None -> - t.iface_of_ip <- t.iface_of_ip |> IpMap.add ip iface; + t.iface_of_ip <- t.iface_of_ip |> Ipaddr.V4.Map.add ip iface; Lwt_condition.broadcast t.changed (); Lwt.return_unit in @@ -43,11 +43,11 @@ let add_client t iface = let remove_client t iface = let ip = iface#other_ip in - assert (IpMap.mem ip t.iface_of_ip); - t.iface_of_ip <- t.iface_of_ip |> IpMap.remove ip; + assert (Ipaddr.V4.Map.mem ip t.iface_of_ip); + t.iface_of_ip <- t.iface_of_ip |> Ipaddr.V4.Map.remove ip; Lwt_condition.broadcast t.changed () -let lookup t ip = IpMap.find ip t.iface_of_ip +let lookup t ip = Ipaddr.V4.Map.find_opt ip t.iface_of_ip let classify t ip = match ip with @@ -79,7 +79,7 @@ module ARP = struct (* We're now treating client networks as point-to-point links, so we no longer respond on behalf of other clients. *) (* - else match IpMap.find ip t.net.iface_of_ip with + else match Ipaddr.V4.Map.find_opt ip t.net.iface_of_ip with | Some client_iface -> Some client_iface#other_mac | None -> None *) diff --git a/dao.ml b/dao.ml index 2361630..9344c1f 100644 --- a/dao.ml +++ b/dao.ml @@ -65,43 +65,40 @@ let read_rules rules client_ip = number = 0;})] let vifs client domid = + let open Lwt.Syntax in match int_of_string_opt domid with | None -> Log.err (fun f -> f "Invalid domid %S" domid); Lwt.return [] | Some domid -> - let path = Printf.sprintf "backend/vif/%d" domid in - Xen_os.Xs.immediate client (fun handle -> - directory ~handle path >>= - Lwt_list.filter_map_p (fun device_id -> - match int_of_string_opt device_id with - | None -> Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid); Lwt.return_none - | Some device_id -> - let vif = { ClientVif.domid; device_id } in - Lwt.try_bind - (fun () -> Xen_os.Xs.read handle (Printf.sprintf "%s/%d/ip" path device_id)) - (fun client_ip -> - let client_ip' = match String.split_on_char ' ' client_ip with - | [] -> Log.err (fun m -> m "unexpected empty list"); "" - | [ ip ] -> ip - | ip::rest -> - Log.warn (fun m -> m "ignoring IPs %s from %a, we support one IP per client" - (String.concat " " rest) ClientVif.pp vif); - ip - in - match Ipaddr.V4.of_string client_ip' with - | Ok ip -> Lwt.return (Some (vif, ip)) - | Error `Msg msg -> - Log.err (fun f -> f "Error parsing IP address of %a from %s: %s" - ClientVif.pp vif client_ip msg); - Lwt.return None - ) - (function - | Xs_protocol.Enoent _ -> Lwt.return None - | ex -> - Log.err (fun f -> f "Error getting IP address of %a: %s" - ClientVif.pp vif (Printexc.to_string ex)); - Lwt.return None - ) - )) + let path = Fmt.str "backend/vif/%d" domid in + let vifs_of_domain handle = + let* devices = directory ~handle path in + let ip_of_vif device_id = match int_of_string_opt device_id with + | None -> + Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid); + Lwt.return_none + | Some device_id -> + let vif = { ClientVif.domid; device_id } in + let get_client_ip () = + let* str = Xen_os.Xs.read handle (Fmt.str "%s/%d/ip" path device_id) in + let client_ip = List.hd (String.split_on_char ' ' str) in + (* NOTE(dinosaure): it's safe to use [List.hd] here, + [String.split_on_char] can not return an empty list. *) + Lwt.return_some (vif, Ipaddr.V4.of_string_exn client_ip) + in + Lwt.catch get_client_ip @@ function + | Xs_protocol.Enoent _ -> Lwt.return_none + | Ipaddr.Parse_error (msg, client_ip) -> + Log.err (fun f -> f "Error parsing IP address of %a from %s: %s" + ClientVif.pp vif client_ip msg); + Lwt.return_none + | exn -> + Log.err (fun f -> f "Error getting IP address of %a: %s" + ClientVif.pp vif (Printexc.to_string exn)); + Lwt.return_none + in + Lwt_list.filter_map_p ip_of_vif devices + in + Xen_os.Xs.immediate client vifs_of_domain let watch_clients fn = Xen_os.Xs.make () >>= fun xs -> @@ -116,7 +113,7 @@ let watch_clients fn = end >>= fun items -> Xen_os.Xs.make () >>= fun xs -> Lwt_list.map_p (vifs xs) items >>= fun items -> - fn (List.concat items |> VifMap.of_list); + fn (List.concat items |> VifMap.of_list) >>= fun () -> (* Wait for further updates *) Lwt.fail Xs_protocol.Eagain ) diff --git a/dao.mli b/dao.mli index bff4cbf..c278d16 100644 --- a/dao.mli +++ b/dao.mli @@ -15,7 +15,7 @@ module VifMap : sig val find : key -> 'a t -> 'a option end -val watch_clients : (Ipaddr.V4.t VifMap.t -> unit) -> 'a Lwt.t +val watch_clients : (Ipaddr.V4.t VifMap.t -> unit Lwt.t) -> 'a Lwt.t (** [watch_clients fn] calls [fn clients] with the list of backend clients in XenStore, and again each time XenStore updates. *) diff --git a/dispatcher.ml b/dispatcher.ml index 3768863..60927f6 100644 --- a/dispatcher.ml +++ b/dispatcher.ml @@ -17,8 +17,6 @@ struct module I = Static_ipv4.Make (R) (Clock) (UplinkEth) (Arp) module U = Udp.Make (I) (R) - let clients : Cleanup.t Dao.VifMap.t ref = ref Dao.VifMap.empty - class client_iface eth ~domid ~gateway_ip ~client_ip client_mac : client_link = let log_header = Fmt.str "dom%d:%a" domid Ipaddr.V4.pp client_ip in @@ -344,11 +342,12 @@ struct (** Connect to a new client's interface and listen for incoming frames and firewall rule changes. *) let add_vif get_ts { Dao.ClientVif.domid; device_id } dns_client dns_servers - ~client_ip ~router ~cleanup_tasks qubesDB = - Netback.make ~domid ~device_id >>= fun backend -> + ~client_ip ~router ~cleanup_tasks qubesDB () = + let open Lwt.Syntax in + let* backend = Netback.make ~domid ~device_id in Log.info (fun f -> f "Client %d (IP: %s) ready" domid (Ipaddr.V4.to_string client_ip)); - ClientEth.connect backend >>= fun eth -> + let* eth = ClientEth.connect backend in let client_mac = Netback.frontend_mac backend in let client_eth = router.clients in let gateway_ip = Client_eth.client_gw client_eth in @@ -404,46 +403,54 @@ struct (function Lwt.Canceled -> Lwt.return_unit | e -> Lwt.fail e) in Cleanup.on_cleanup cleanup_tasks (fun () -> Lwt.cancel listener); - Lwt.pick [ qubesdb_updater; listener ] + (* NOTE(dinosaure): [qubes_updater] and [listener] can be forgotten, our [cleanup_task] + will cancel them if the client is disconnected. *) + Lwt.async (fun () -> Lwt.pick [ qubesdb_updater; listener ]); + Lwt.return_unit (** A new client VM has been found in XenStore. Find its interface and connect to it. *) let add_client get_ts dns_client dns_servers ~router vif client_ip qubesDB = + let open Lwt.Syntax in let cleanup_tasks = Cleanup.create () in Log.info (fun f -> f "add client vif %a with IP %a" Dao.ClientVif.pp vif Ipaddr.V4.pp client_ip); - Lwt.async (fun () -> - Lwt.catch - (fun () -> - add_vif get_ts vif dns_client dns_servers ~client_ip ~router - ~cleanup_tasks qubesDB) - (fun ex -> - Log.warn (fun f -> - f "Error with client %a: %s" Dao.ClientVif.pp vif - (Printexc.to_string ex)); - Lwt.return_unit)); - cleanup_tasks + let* () = + Lwt.catch (add_vif get_ts vif dns_client dns_servers ~client_ip ~router + ~cleanup_tasks qubesDB) + @@ fun exn -> + Log.warn (fun f -> + f "Error with client %a: %s" Dao.ClientVif.pp vif + (Printexc.to_string exn)); + Lwt.return_unit + in + Lwt.return cleanup_tasks (** Watch XenStore for notifications of new clients. *) let wait_clients get_ts dns_client dns_servers qubesDB router = - Dao.watch_clients (fun new_set -> - (* Check for removed clients *) - !clients - |> Dao.VifMap.iter (fun key cleanup -> - if not (Dao.VifMap.mem key new_set) then ( - clients := !clients |> Dao.VifMap.remove key; - Log.info (fun f -> f "client %a has gone" Dao.ClientVif.pp key); - Cleanup.cleanup cleanup)); - (* Check for added clients *) - new_set - |> Dao.VifMap.iter (fun key ip_addr -> - if not (Dao.VifMap.mem key !clients) then ( - let cleanup = - add_client get_ts dns_client dns_servers ~router key ip_addr - qubesDB - in - Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key); - clients := !clients |> Dao.VifMap.add key cleanup))) + let open Lwt.Syntax in + let clients : Cleanup.t Dao.VifMap.t ref = ref Dao.VifMap.empty in + Dao.watch_clients @@ fun new_set -> + (* Check for removed clients *) + let clean_up_clients key cleanup = + if not (Dao.VifMap.mem key new_set) then begin + clients := !clients |> Dao.VifMap.remove key; + Log.info (fun f -> f "client %a has gone" Dao.ClientVif.pp key); + Cleanup.cleanup cleanup + end + in + Dao.VifMap.iter clean_up_clients !clients; + (* Check for added clients *) + let rec go seq = match Seq.uncons seq with + | None -> Lwt.return_unit + | Some ((key, ipaddr), seq) when not (Dao.VifMap.mem key !clients) -> + let* cleanup = add_client get_ts dns_client dns_servers ~router key ipaddr qubesDB in + Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key); + clients := Dao.VifMap.add key cleanup !clients; + go seq + | Some (_, seq) -> go seq + in + go (Dao.VifMap.to_seq new_set) let send_dns_client_query t ~src_port ~dst ~dst_port buf = match t.uplink with diff --git a/fw_utils.ml b/fw_utils.ml index 0307810..f20c63a 100644 --- a/fw_utils.ml +++ b/fw_utils.ml @@ -3,14 +3,6 @@ (** General utility functions. *) -module IpMap = struct - include Map.Make(Ipaddr.V4) - let find x map = - try Some (find x map) - with Not_found -> None - | _ -> Logs.err( fun f -> f "uncaught exception in find...%!"); None -end - (** An Ethernet interface. *) class type interface = object method my_mac : Macaddr.t diff --git a/unikernel.ml b/unikernel.ml index b64fd4e..f0e12df 100644 --- a/unikernel.ml +++ b/unikernel.ml @@ -46,15 +46,12 @@ module Main (R : Mirage_crypto_rng_mirage.S)(Clock : Mirage_clock.MCLOCK)(Time : (* Main unikernel entry point (called from auto-generated main.ml). *) let start _random _clock _time = + let open Lwt.Syntax in let start_time = Clock.elapsed_ns () in (* Start qrexec agent and QubesDB agent in parallel *) - let qrexec = RExec.connect ~domid:0 () in - let qubesDB = DB.connect ~domid:0 () in - - (* Wait for clients to connect *) - qrexec >>= fun qrexec -> + let* qrexec = RExec.connect ~domid:0 () in let agent_listener = RExec.listen qrexec Command.handler in - qubesDB >>= fun qubesDB -> + let* qubesDB = DB.connect ~domid:0 () in let startup_time = let (-) = Int64.sub in let time_in_ns = Clock.elapsed_ns () - start_time in @@ -93,7 +90,7 @@ module Main (R : Mirage_crypto_rng_mirage.S)(Clock : Mirage_clock.MCLOCK)(Time : Dao.print_network_config config ; (* Set up client-side networking *) - Client_eth.create config >>= fun clients -> + let* clients = Client_eth.create config in (* Set up routing between networks and hosts *) let router = Dispatcher.create