From 537395b9cc0df2483b57a973f0dd1c68f172c1e8 Mon Sep 17 00:00:00 2001 From: Aleksandr Karbyshev Date: Mon, 14 Oct 2024 17:14:14 +0200 Subject: [PATCH] WIP HPaxos 2.0 specs --- hpaxos/hpaxos.sml | 557 ++++++++++++++++++++++++++++------------------ 1 file changed, 335 insertions(+), 222 deletions(-) diff --git a/hpaxos/hpaxos.sml b/hpaxos/hpaxos.sml index e966a90..0e976c6 100644 --- a/hpaxos/hpaxos.sml +++ b/hpaxos/hpaxos.sml @@ -1,31 +1,20 @@ -signature HPAXOS_NODE = -sig - type t - type node_id - type learner_graph - type mailbox - val hpaxos_node : node_id -> learner_graph -> mailbox -> t - val run : t -> unit -end - functor HPaxos (structure Msg : HPAXOS_MESSAGE - structure Mailbox : HPAXOS_MAILBOX structure LearnerGraph : LEARNER_GRAPH sharing Msg.Learner = LearnerGraph.Learner - and Msg.Acceptor = LearnerGraph.Acceptor - and Mailbox.Message = Msg) - :> HPAXOS_NODE = + and Msg.Acceptor = LearnerGraph.Acceptor) + :> GEN_SERVER_IMPL = struct - type msg = Msg.t - type mailbox = Mailbox.t + infix |> + fun x |> f = f x - type acceptor_id = word type acceptor = Msg.Acceptor.t type ballot = Msg.Ballot.t type value = Msg.Value.t type learner = Msg.Learner.t type learner_graph = LearnerGraph.t + type msg = Msg.t + structure MsgUtil = MessageUtil (Msg) structure MsgSet : ORD_SET = RedBlackSetFn (MessageOrdKey (Msg)) @@ -33,6 +22,7 @@ struct structure AcceptorSet : ORD_SET = RedBlackSetFn (AcceptorOrdKey (Msg.Acceptor)) structure AcceptorMap : ORD_MAP = RedBlackMapFn (AcceptorOrdKey (Msg.Acceptor)) structure LearnerSet : ORD_SET = RedBlackSetFn (LearnerOrdKey (Msg.Learner)) + structure LearnerMap : ORD_MAP = RedBlackMapFn (LearnerOrdKey (Msg.Learner)) structure LearnerAcceptorMap : ORD_MAP = RedBlackMapFn ( @@ -90,10 +80,10 @@ struct fun add_non_wellformed (AlgoState (k, r, p, NonWellformedMsgs nw, maxb)) m = AlgoState (k, r, p, NonWellformedMsgs (MsgSet.add (nw, m)), maxb) - fun get_max (AlgoState (_, _, _, _, MaxBal maxb)) = maxb + (* fun get_max (AlgoState (_, _, _, _, MaxBal maxb)) = maxb *) - fun set_max (AlgoState (k, r, p, nw, _)) bal = - AlgoState (k, r, p, nw, MaxBal bal) + (* fun set_max (AlgoState (k, r, p, nw, _)) bal = + AlgoState (k, r, p, nw, MaxBal bal) *) end (* AlgoState *) (* message info state *) @@ -103,8 +93,8 @@ struct | Uncaught of msg type t = status - fun is_uncaught (Uncaught _) = true - | is_uncaught _ = false + fun is_caught Caught = true + | is_caught _ = false fun join (bal : msg -> ballot) (Uncaught m1, Uncaught m2) = if MsgUtil.PrevTran.is_prev_reachable' bal (m1, m2) then @@ -121,8 +111,9 @@ struct type info_entry = { info_bal_val : ballot * value, info_W : (msg * msg option) LearnerAcceptorMap.map, info_acc_status : AcceptorStatus.t AcceptorMap.map, - info_unburied_2as : MsgSet.set, - info_q : acceptor list } + info_unburied_2as : MsgSet.set LearnerMap.map, + info_q : (acceptor list) LearnerMap.map + } datatype msg_info = MsgInfo of info_entry MsgMap.map @@ -130,6 +121,8 @@ struct fun mk () = MsgInfo MsgMap.empty + fun has ((MsgInfo info), m) = MsgMap.inDomain (info, m) + fun get_bal_val (MsgInfo info) m = #info_bal_val (MsgMap.lookup (info, m)) fun get_W (MsgInfo info) m = #info_W (MsgMap.lookup (info, m)) fun get_acc_status (MsgInfo info) m = #info_acc_status (MsgMap.lookup (info, m)) @@ -171,42 +164,44 @@ struct fun is_non_wellformed (State (s, _, _)) = AlgoState.is_non_wellformed s - fun add_known (State (s, i, c)) m = - State (AlgoState.add_known s m, i, c) + (* fun add_known (State (s, i, c)) m = + State (AlgoState.add_known s m, i, c) *) fun add_recent (State (s, i, c)) m = State (AlgoState.add_recent s m, i, c) - fun add_known_recent (State (s, i, c)) m = - State (AlgoState.add_recent (AlgoState.add_known s m) m, i, c) + (* fun add_known_recent (State (s, i, c)) m = + State (AlgoState.add_recent (AlgoState.add_known s m) m, i, c) *) fun add_non_wellformed (State (s, i, c)) m = State (AlgoState.add_non_wellformed s m, i, c) - fun get_max (State (s, _, _)) = AlgoState.get_max s + (* fun get_max (State (s, _, _)) = AlgoState.get_max s *) - fun get_bal_val (State (s, i, c)) m = + fun get_bal_val (State (_, i, _)) m = if Msg.is_one_a m then valOf (Msg.get_bal_val m) else MessageInfo.get_bal_val i m - fun update_max (state as State (s, i, c)) m = + (* fun update_max (state as State (s, i, c)) m = let fun max (a, b) = - end *) + case Msg.Ballot.compare (a, b) of LESS => b | _ => a val m_bal = fst (get_bal_val state m) val cur_max = AlgoState.get_max s val new_max = max (m_bal, cur_max) in State (AlgoState.set_max s new_max, i, c) - end + end *) fun get_W (State (_, i, _)) = MessageInfo.get_W i fun get_acc_status (State (_, i, _)) = MessageInfo.get_acc_status i fun get_unburied_2as (State (_, i, _)) = MessageInfo.get_unburied_2as i fun get_q (State (_, i, _)) = MessageInfo.get_q i + fun has_info (State (_, i, _)) m = MessageInfo.has (i, m) + fun store_info_entry (State (a, i, c)) mi = State (a, MessageInfo.store_entry i mi, c) @@ -214,15 +209,26 @@ struct fun put_is_fresh (State (_, _, c)) = Cache.put_is_fresh c end (* State *) - type state = State.t + structure ServerState = + struct + datatype state = State of learner_graph * State.t + type t = state - (* learner graph *) - datatype graph = Graph of learner_graph + fun mk (g, s) = State (g, s) + end (* ServerState *) - datatype acceptor_node = Acc of acceptor_id * graph * state * mailbox + type param = learner_graph + type state = ServerState.t - type t = acceptor_node - type node_id = acceptor_id + fun senders (ms : msg list) : acceptor list = + let + val empty = AcceptorSet.empty + val add = AcceptorSet.add' + in + ms + |> List.foldl (fn (x, accu) => add (Msg.sender x, accu)) empty + |> AcceptorSet.toList + end (* [msg_to_bal_val] returns a pair (ballot, value) for each known message, including 1a messages *) (* REQUIRES: m is not 1a *) @@ -234,9 +240,10 @@ struct LESS => (max_bal, max_val) | _ => (b, v) end - val refs = Msg.get_refs m (* refs is non-empty since m is not 1a *) in - foldl helper (Msg.Ballot.zero, Msg.Value.default) refs + (* refs is non-empty since m is not 1a *) + Msg.get_refs m + |> List.foldl helper (Msg.Ballot.zero, Msg.Value.default) end (* [msg_to_bal_val] returns a pair (ballot, value) for each known message and the message m *) @@ -245,6 +252,10 @@ struct fun compute_W (m : msg) msg_to_bal_val msg_to_w : (msg * msg option) LearnerAcceptorMap.map = let + val empty = LearnerAcceptorMap.empty + val insert = LearnerAcceptorMap.insert + val unionWith = LearnerAcceptorMap.unionWith + fun pick_best_two_from_list (ms : msg list) : (msg * msg option) option = let val ballot = fst o msg_to_bal_val @@ -258,10 +269,10 @@ struct | SOME cur_best => case cmp (cur_best, x) of LESS => SOME x - | _ => cur_best_o + | _ => cur_best_o else cur_best_o in - foldl choose NONE lst + List.foldl choose NONE lst end fun cmp_by_ballot (x, y) = Msg.Ballot.compare (ballot x, ballot y) fun pick_first_best ms = pick_best (Fn.const true) cmp_by_ballot ms @@ -272,26 +283,34 @@ struct pick_best pred cmp_by_ballot ms end in - Option.map (fn x => (x, pick_second_best ms x)) (pick_first_best ms) + ms + |> pick_first_best + |> Option.map (fn x => (x, pick_second_best ms x)) end fun pick_best_two (a : msg * msg option, b : msg * msg option) = - let fun to_list (best1, NONE) = [best1] + let + fun to_list (best1, NONE) = [best1] | to_list (best1, SOME best2) = [best1, best2] in - valOf (pick_best_two_from_list (to_list a @ to_list b)) + to_list a @ to_list b |> pick_best_two_from_list |> valOf end - fun join (r, w) = LearnerAcceptorMap.unionWith pick_best_two (msg_to_w r, w) val w0 = if Msg.is_two_a m then - LearnerAcceptorMap.insert - (LearnerAcceptorMap.empty, - (valOf (Msg.learner m), Msg.sender m), - (m, NONE)) + let + val m_acc = Msg.sender m + in + Msg.learners m + |> List.foldl + (fn (alpha, u) => insert (u, (alpha, m_acc), (m, NONE))) + empty + end else - LearnerAcceptorMap.empty - val refs = List.filter (not o Msg.is_one_a) (Msg.get_refs m) + empty in - foldl join w0 refs + m + |> Msg.get_refs + |> List.filter (not o Msg.is_one_a) + |> List.foldl (fn (r, w) => unionWith pick_best_two (msg_to_w r, w)) w0 end (* [msg_to_bal] returns a ballot for each known message, excluding 1a, and the message m *) @@ -300,26 +319,31 @@ struct fun compute_acceptor_status (m : msg) msg_to_bal msg_to_acc_status : AcceptorStatus.t AcceptorMap.map = let - fun helper (r, s) = - AcceptorMap.unionWith (AcceptorStatus.join msg_to_bal) (msg_to_acc_status r, s) + val unionWith = AcceptorMap.unionWith + val join = AcceptorStatus.join msg_to_bal val s0 = AcceptorMap.singleton (Msg.sender m, AcceptorStatus.Uncaught m) - val refs = List.filter (not o Msg.is_one_a) (Msg.get_refs m) in - foldl helper s0 refs + Msg.get_refs m + |> List.filter (not o Msg.is_one_a) + |> List.foldl (fn (r, s) => unionWith join (msg_to_acc_status r, s)) s0 end (* [msg_to_bal_val] returns a pair (ballot, value) for each known message and the message m *) (* [msg_to_w] returns a (msg * msg option) LearnerAcceptorMap.map for each known message, excluding 1a *) (* [msg_to_unburied] returns a set MsgSet.set for each known message, excluding 1a *) (* REQUIRES: m is not 1a *) - fun compute_unburied_2as (m : msg) (g : learner_graph) msg_to_bal_val msg_to_w msg_to_unburied - : MsgSet.set = + fun compute_unburied_2as (m : msg) (g : learner_graph) + (msg_to_bal_val : msg -> ballot * value) msg_to_w msg_to_unburied + : MsgSet.set LearnerMap.map = let - val m_lrn = valOf (Msg.learner m) + fun is_learner_of (alpha : learner, x : msg) = + Msg.learners x + |> List.find (Fn.curry Msg.Learner.eq alpha) + |> Option.isSome + (* z is burying x *) - fun burying (x, z) = - Msg.Learner.eq - (valOf (Msg.learner x), valOf (Msg.learner z)) andalso + fun burying (alpha : learner) (x, z) = + is_learner_of (alpha, z) andalso let val (x_bal, x_val) = msg_to_bal_val x val (z_bal, z_val) = msg_to_bal_val z @@ -327,87 +351,173 @@ struct Msg.Ballot.compare (x_bal, z_bal) = LESS andalso not (Msg.Value.eq (x_val, z_val)) end - val u0 = if Msg.is_two_a m then MsgSet.singleton m else MsgSet.empty - val refs = List.filter (not o Msg.is_one_a) (Msg.get_refs m) - val u = foldl (fn (r, u) => MsgSet.union (msg_to_unburied r, u)) u0 refs + val m_w = msg_to_w m val all_acceptors = LearnerGraph.acceptors g - fun buried x = + + fun compute_unburied_2as_for_learner (beta : learner) = let - val x_lrn = valOf (Msg.learner x) - fun check acc = - case LearnerAcceptorMap.lookup (m_w, (x_lrn, acc)) of - (best1, o_best2) => - burying (x, best1) orelse (isSome o_best2 andalso burying (x, (valOf o_best2))) - val acceptors = List.filter check all_acceptors + fun buried x = + let + fun check acc = + case LearnerAcceptorMap.lookup (m_w, (beta, acc)) of + (best1, o_best2) => + burying beta (x, best1) orelse + (isSome o_best2 andalso burying beta (x, (valOf o_best2))) + val acceptors = List.filter check all_acceptors + in + LearnerGraph.is_quorum g (beta, acceptors) + end + val u0 = if Msg.is_two_a m then MsgSet.singleton m else MsgSet.empty in - LearnerGraph.is_quorum g (m_lrn, acceptors) + m + |> Msg.get_refs + |> List.filter (not o Msg.is_one_a) + |> List.foldl + (fn (r, u) => + MsgSet.union (LearnerMap.lookup (msg_to_unburied r, beta), u)) + u0 + |> MsgSet.filter (not o buried) end in - MsgSet.filter (not o buried) u + LearnerGraph.learners g + |> List.foldl + (fn (beta, u) => + LearnerMap.insert (u, beta, compute_unburied_2as_for_learner beta)) + LearnerMap.empty end (* [msg_to_bal] returns a ballot for each known message, excluding 1a, and the message m *) (* [msg_to_acc_status] returns a map (AcceptorStatus.t AcceptorMap.map) for each known message, excluding 1a *) (* [msg_to_unburied] returns a set MsgSet.set for each known message, excluding 1a *) (* REQUIRES: m is 2a *) - fun compute_q (s : state) (m : msg) (g : learner_graph) msg_to_bal msg_to_acc_status msg_to_unburied - : acceptor list = + fun compute_q (s : State.t) (m : msg) (g : learner_graph) + msg_to_bal_val + msg_to_acc_status + msg_to_unburied + : (acceptor list) LearnerMap.map = let - fun compute_connected (l : learner, m : msg) = + val empty = LearnerMap.empty + val insert = LearnerMap.insert + + val msg_to_bal = fst o msg_to_bal_val + val msg_to_val = snd o msg_to_bal_val + + fun compute_connected (l : learner, m : msg) : learner list = (* REQUIRES: m is 1b *) - let val caught = - AcceptorMap.listKeys ( - AcceptorMap.filter AcceptorStatus.is_uncaught (msg_to_acc_status m) - ) + let + val caught = + msg_to_acc_status m + |> AcceptorMap.filter AcceptorStatus.is_caught + |> AcceptorMap.listKeys in + (* TODO explore optimization options, e.g., caching *) LearnerGraph.get_connected g (l, caught) end - fun compute_connected_2as (l : learner, m : msg) = + + fun compute_connected_2as (alpha : learner, m : msg) = (* REQUIRES: m is 1b *) let - val connected = LearnerSet.fromList (compute_connected (l, m)) - val m_sender = Msg.sender m - fun pred x = - Msg.Acceptor.eq ((Msg.sender x), m_sender) andalso - LearnerSet.member (connected, valOf (Msg.learner x)) + val m_acc = Msg.sender m + val m_lrn = Msg.learners m + val connected = compute_connected (alpha, m) |> LearnerSet.fromList + fun from_this_sender x = Msg.Acceptor.eq ((Msg.sender x), m_acc) in - MsgSet.filter pred (msg_to_unburied m) + Msg.learners m + |> List.filter (fn beta => LearnerSet.member (connected, beta)) + |> List.foldl + (fn (beta, accu) => + MsgSet.union (accu, LearnerMap.lookup (msg_to_unburied m, beta))) + MsgSet.empty + |> MsgSet.filter from_this_sender end + fun is_fresh (l : learner, m : msg) = (* REQUIRES: m is 1b *) - let - val connected_2as = compute_connected_2as (l, m) - val m_bal = msg_to_bal m - fun from_this_sender x = Msg.Ballot.eq (msg_to_bal x, m_bal) - in - MsgSet.all from_this_sender connected_2as - end + let + val m_val = msg_to_val m + fun same_value x = Msg.Value.eq (msg_to_val x, m_val) + in + compute_connected_2as (l, m) |> MsgSet.all same_value + end + (* cached `is_fresh` predicate *) fun is_fresh' s (l : learner, m : msg) = (* REQUIRES: m is 1b *) case State.get_is_fresh s (l, m) of SOME b => b | NONE => - let val res = is_fresh (l, m) + let + val res = is_fresh (l, m) val _ = State.put_is_fresh s ((l, m), res) in res end - val m_tran = + + fun compute_q_for_learner (alpha : learner) = let - val m_lrn = valOf (Msg.learner m) val m_bal = msg_to_bal m - fun pred x = Msg.is_one_b x andalso is_fresh' s (m_lrn, x) + fun pred x = Msg.is_one_b x andalso is_fresh' s (alpha, x) fun cont x = Msg.Ballot.eq (msg_to_bal x, m_bal) in - MsgUtil.tran pred cont m - end - fun senders ms = - let fun helper (x, accu) = AcceptorSet.add' (Msg.sender x, accu) in - foldl helper AcceptorSet.empty ms + m + |> MsgUtil.tran pred cont + |> senders end in - AcceptorSet.toList (senders m_tran) + LearnerGraph.learners g + |> List.map (fn alpha => (alpha, compute_q_for_learner alpha)) + |> List.filter (fn (alpha, xs) => not (null xs)) + |> List.foldl + (fn ((alpha, xs), accu) => insert (accu, alpha, xs)) + empty + end + + (* REQUIRES: m is not 1a *) + fun compute_msg_info_entry (s : State.t) (g : learner_graph) (m : msg) + get_bal_val + get_W + get_acc_status + get_unburied_2as + : MessageInfo.info_entry = + let + val m_bal_val = compute_bal_val m get_bal_val + fun get_bal_val_with_m x = + if Msg.eq (x, m) then m_bal_val else get_bal_val x + (* TODO rename W to something more meaningful *) + (* compute W values as per message m *) + val m_W = compute_W m get_bal_val_with_m get_W + (* compute acceptor status as per message m *) + val m_acc_status = + compute_acceptor_status m (fst o get_bal_val_with_m) get_acc_status + (* compute a set of unburied 2a-messages as per message m *) + val m_unburied_2as = + compute_unburied_2as m g get_bal_val_with_m get_W get_unburied_2as + (* list of quorums per each learner *) + val m_q = + if Msg.is_one_b m then LearnerMap.empty else + compute_q s m g get_bal_val_with_m get_acc_status get_unburied_2as + in + { + info_bal_val = m_bal_val, + info_W = m_W, + info_acc_status = m_acc_status, + info_unburied_2as = m_unburied_2as, + info_q = m_q + } + end + + fun compute_learners_for_2a (s : State.t) (g : learner_graph) (prev : msg option) (recent : msg list) = + let + val get_bal_val = State.get_bal_val s + val get_W = State.get_W s + val get_acc_status = State.get_acc_status s + val get_unburied_2as = State.get_unburied_2as s + + val pre_msg = Msg.mk_two_a (prev, recent, []) + val info_entry = compute_msg_info_entry s g pre_msg + get_bal_val get_W get_acc_status get_unburied_2as + in + #info_q info_entry |> LearnerMap.listKeys end fun prev_wellformed (m : msg) : bool = @@ -427,7 +537,7 @@ struct end end - fun all_refs_known (s : state) (m : msg) = + fun all_refs_known (s : State.t) (m : msg) = List.all (State.is_known s) (Msg.get_refs m) fun has_non_wellformed_ref s m = @@ -438,34 +548,13 @@ struct (* ...further actions possible *) (* REQUIRES: every direct reference is known *) - fun is_wellformed (s : state) (g : learner_graph) (m : msg) : bool * MessageInfo.info_entry option = + fun is_wellformed (s : State.t) (g : learner_graph) (m : msg) : bool * MessageInfo.info_entry option = let - fun compute_msg_info_entry s m : MessageInfo.info_entry = - (* REQUIRES: m is not 1a *) - let - val get_bal_val = State.get_bal_val s - val get_W = State.get_W s - val get_acc_status = State.get_acc_status s - val get_unburied_2as = State.get_unburied_2as s - val m_bal_val = compute_bal_val m get_bal_val - fun get_bal_val_with_m x = - if Msg.eq (x, m) then m_bal_val else get_bal_val x - val m_W = compute_W m get_bal_val_with_m get_W - val m_acc_status = - compute_acceptor_status m (fst o get_bal_val_with_m) get_acc_status - val m_unburied_2as = - compute_unburied_2as m g get_bal_val_with_m get_W get_unburied_2as - val m_q = - if Msg.is_one_b m then [] else - (* case 2a *) - compute_q s m g (fst o get_bal_val_with_m) get_acc_status get_unburied_2as - in - { info_bal_val = m_bal_val, - info_W = m_W, - info_acc_status = m_acc_status, - info_unburied_2as = m_unburied_2as, - info_q = m_q } - end + val get_bal_val = State.get_bal_val s + val get_W = State.get_W s + val get_acc_status = State.get_acc_status s + val get_unburied_2as = State.get_unburied_2as s + fun is_wellformed_1a m = (* TODO this check might be redundant depending on how `get_prev` is defined *) not (isSome (Msg.get_prev m)) andalso @@ -474,7 +563,7 @@ struct fun is_wellformed_1b m (m_info_entry : MessageInfo.info_entry) = MsgUtil.references_exactly_one_1a m andalso let - val ballot = fst o (State.get_bal_val s) + val ballot = fst o get_bal_val val (m_bal, _) = #info_bal_val m_info_entry fun check_ref x = Msg.is_one_a x orelse @@ -484,112 +573,136 @@ struct end fun is_wellformed_2a m (m_info_entry : MessageInfo.info_entry) = not (null (Msg.get_refs m)) andalso - LearnerGraph.is_quorum g (valOf (Msg.learner m), #info_q m_info_entry) + not (null (Msg.learners m)) andalso + let + val m_lrns = Msg.learners m + fun q (alpha : learner) = LearnerMap.lookup (#info_q m_info_entry, alpha) + val q_lrns = LearnerMap.listKeys (#info_q m_info_entry) + in + list_equal (Msg.Learner.eq) m_lrns q_lrns andalso + List.all (fn alpha => LearnerGraph.is_quorum g (alpha, q alpha)) q_lrns + end in (* optionally, we might want to check that every reference occurs at most once (call to refs_nondup) *) - if prev_wellformed m andalso MsgUtil.refs_nondup m then - case Msg.typ m of - Msg.OneA => (is_wellformed_1a m, NONE) - | Msg.OneB => - let val m_info_entry = compute_msg_info_entry s m in - if is_wellformed_1b m m_info_entry then - (true, SOME m_info_entry) - else - (false, NONE) - end - | Msg.TwoA => - let val m_info_entry = compute_msg_info_entry s m in - if is_wellformed_2a m m_info_entry then - (true, SOME m_info_entry) - else - (false, NONE) - end + (* TODO actually, do it in the mailbox implementation *) + if prev_wellformed m andalso MsgUtil.refs_nondup m then ( + if Msg.is_one_a m then + (is_wellformed_1a m, NONE) + else ( + if Msg.is_one_b m then + let + val m_info_entry = + compute_msg_info_entry s g m + get_bal_val get_W get_acc_status get_unburied_2as + in + if is_wellformed_1b m m_info_entry then + (true, SOME m_info_entry) + else + (false, NONE) + end + else + let + val m_info_entry = + compute_msg_info_entry s g m + get_bal_val get_W get_acc_status get_unburied_2as + in + if is_wellformed_2a m m_info_entry then + (true, SOME m_info_entry) + else + (false, NONE) + end + ) + ) else (false, NONE) end - fun check_wellformed_and_update_info (s : state) (g : learner_graph) m : bool * state = - case is_wellformed s g m of - (false, _) => (false, s) - | (true, info_entry_o) => - (true, Option.fold (fn (e, s) => State.store_info_entry s (m, e)) s info_entry_o) - - fun hpaxos_node (id : node_id) (g : learner_graph) (mbox : mailbox) : t = - Acc (id, Graph g, State.mk (), mbox) - - fun run (Acc (id, Graph g, s, mbox)) = + fun check_wellformed_and_update_info (s : State.t, g : learner_graph, m : msg) + : bool * State.t = + (* first, check if the message info is already stored, meaning that the message was sent by us *) + (* TODO currently, broken; to make it really work, we need to use hashes as keys *) + if State.has_info s m then + (true, s) + else + case is_wellformed s g m of + (false, _) => (false, s) + | (true, info_entry_o) => + (true, Option.fold + (fn (e, s) => State.store_info_entry s (m, e)) + s + info_entry_o) + + fun init (g : learner_graph) = ServerState.mk (g, State.mk()) + + fun handle_msg (ServerState.State (g, s), m) = let - fun get_next_wellformed_msg s : msg * state = - case Mailbox.recv mbox of - NONE => get_next_wellformed_msg s - | SOME m => - let val _ = assert (all_refs_known s m) "wrong message ordering" in - if has_non_wellformed_ref s m then - get_next_wellformed_msg (process_non_wellformed s m) - else - let val (res, s) = check_wellformed_and_update_info s g m in - if res then (m, s) else - get_next_wellformed_msg (process_non_wellformed s m) - end - end - fun process_1a s m : state = + fun process_1a s m : State.t * msg option = let val prev = State.get_prev s - val recent = MsgSet.toList (State.get_recent s) + val recent = MsgSet.add (State.get_recent s, m) |> MsgSet.toList val new_1b = Msg.mk_one_b (prev, recent) - val (is_wf, s) = check_wellformed_and_update_info s g new_1b + val (is_wf, s) = check_wellformed_and_update_info (s, g, new_1b) in if is_wf then - let val s = State.clear_recent s - val s = State.set_prev s new_1b + let + fun update_state s = + s + |> State.clear_recent + |> flip State.add_recent m + |> flip State.set_prev new_1b in - process_message s new_1b + (* broadcast new_1b *) + (update_state s, SOME new_1b) end - else s + else (s, NONE) end - and process_1b s m : state = + + fun process_1b s m : State.t * msg option = let val m_bal = fst (State.get_bal_val s m) - val cur_max = State.get_max s - fun process_learner (lrn, s) = + val prev = State.get_prev s + val recent = MsgSet.toList (State.get_recent s) + (* TODO this actually computes an info entry for the new message, which is currently not stored *) + val lrns = compute_learners_for_2a s g prev recent + val new_2a = Msg.mk_two_a (prev, recent, lrns) + val (is_wf, s) = check_wellformed_and_update_info (s, g, new_2a) + in + if is_wf then let - val prev = State.get_prev s - val recent = MsgSet.toList (State.get_recent s) - val new_2a = Msg.mk_two_a (prev, recent, lrn) - val (is_wf, s) = check_wellformed_and_update_info s g new_2a + fun update_state s = + s + |> State.clear_recent + |> flip State.add_recent m + |> flip State.set_prev new_2a in - if is_wf then - let val s = State.clear_recent s - val s = State.set_prev s new_2a - in - process_message s new_2a - end - else s + (* broadcast new_2a *) + (update_state s, SOME new_2a) end - in - if Msg.Ballot.eq (m_bal, cur_max) then - foldl process_learner s (LearnerGraph.learners g) - else s - end - and process_message s m = - let - val _ = Mailbox.broadcast mbox m - val s = State.add_known s m - val s = State.add_recent s m - val s = State.update_max s m - in - case Msg.typ m of - Msg.OneA => process_1a s m - | Msg.OneB => process_1b s m - | Msg.TwoA => s - end - fun loop s = - let val (m, s) = get_next_wellformed_msg s in - if State.is_known s m then - loop s - else - loop (process_message s m) + else (s, NONE) end + + fun process_2a s m : State.t * msg option = + (State.add_recent s m, NONE) in - let val _ = loop s in () end + let + val (res, s) = + if has_non_wellformed_ref s m then + (false, s) + else + check_wellformed_and_update_info (s, g, m) + val (s, new_msg) = + if res then + if Msg.is_one_a m then + process_1a s m + else ( + if Msg.is_one_b m then + process_1b s m + else + process_2a s m + ) + else + (process_non_wellformed s m, NONE) + in + (ServerState.mk (g, s), new_msg) + end end end (* HPaxos *)