diff --git a/hpaxos/hpaxos.sml b/hpaxos/hpaxos.sml index 0e976c6..7d3e9b1 100644 --- a/hpaxos/hpaxos.sml +++ b/hpaxos/hpaxos.sml @@ -106,14 +106,30 @@ struct | join bal (_, _) = Caught end (* AcceptorStatus *) + structure MessageType = + struct + datatype msg_type = OneA | OneB | TwoA + type t = msg_type + + (* fun is_one_a OneA = true + | is_one_a _ = false *) + + fun is_one_b OneB = true + | is_one_b _ = false + + fun is_two_a TwoA = true + | is_two_a _ = false + end + structure MessageInfo = struct - type info_entry = { info_bal_val : ballot * value, + type info_entry = { info_type : MessageType.t, + info_bal_val : ballot * value, info_W : (msg * msg option) LearnerAcceptorMap.map, info_acc_status : AcceptorStatus.t AcceptorMap.map, info_unburied_2as : MsgSet.set LearnerMap.map, - info_q : (acceptor list) LearnerMap.map - } + info_q : (acceptor list) LearnerMap.map, + info_learners : learner list } datatype msg_info = MsgInfo of info_entry MsgMap.map @@ -123,7 +139,14 @@ struct fun has ((MsgInfo info), m) = MsgMap.inDomain (info, m) + fun get_type (MsgInfo info) m = #info_type (MsgMap.lookup (info, m)) + + (* fun is_one_a info m = get_type info m |> MessageType.is_one_a *) + fun is_one_b info m = get_type info m |> MessageType.is_one_b + fun is_two_a info m = get_type info m |> MessageType.is_two_a + fun get_bal_val (MsgInfo info) m = #info_bal_val (MsgMap.lookup (info, m)) + fun get_learners (MsgInfo info) m = #info_learners (MsgMap.lookup (info, m)) fun get_W (MsgInfo info) m = #info_W (MsgMap.lookup (info, m)) fun get_acc_status (MsgInfo info) m = #info_acc_status (MsgMap.lookup (info, m)) fun get_unburied_2as (MsgInfo info) m = #info_unburied_2as (MsgMap.lookup (info, m)) @@ -178,8 +201,14 @@ struct (* fun get_max (State (s, _, _)) = AlgoState.get_max s *) + (* TODO clean *) + (* fun is_one_a (State (_, i, _)) m = MessageInfo.is_one_a i m *) + fun is_one_b (State (_, i, _)) m = MessageInfo.is_one_b i m + (* TODO unused *) + fun is_two_a (State (_, i, _)) m = MessageInfo.is_two_a i m + fun get_bal_val (State (_, i, _)) m = - if Msg.is_one_a m then + if Msg.is_proposal m then valOf (Msg.get_bal_val m) else MessageInfo.get_bal_val i m @@ -195,6 +224,7 @@ struct State (AlgoState.set_max s new_max, i, c) end *) + fun get_learners (State (_, i, _)) m = MessageInfo.get_learners i m fun get_W (State (_, i, _)) = MessageInfo.get_W i fun get_acc_status (State (_, i, _)) = MessageInfo.get_acc_status i fun get_unburied_2as (State (_, i, _)) = MessageInfo.get_unburied_2as i @@ -230,6 +260,20 @@ struct |> AcceptorSet.toList end + fun compute_type (m : msg) : MessageType.t = + if Msg.is_proposal m then + MessageType.OneA + else + let + fun has_proposal_ref m = + Msg.get_refs m |> List.exists Msg.is_proposal + in + if has_proposal_ref m then + MessageType.OneB + else + MessageType.TwoA + end + (* [msg_to_bal_val] returns a pair (ballot, value) for each known message, including 1a messages *) (* REQUIRES: m is not 1a *) fun compute_bal_val (m : msg) msg_to_bal_val : ballot * value = @@ -249,7 +293,7 @@ struct (* [msg_to_bal_val] returns a pair (ballot, value) for each known message and the message m *) (* [msg_to_w] returns a (msg * msg option) LearnerAcceptorMap.map for each known message, excluding 1a *) (* REQUIRES: m is not 1a *) - fun compute_W (m : msg) msg_to_bal_val msg_to_w + fun compute_W (m : msg) (m_type : MessageType.t) (g : learner_graph) msg_to_bal_val msg_to_w : (msg * msg option) LearnerAcceptorMap.map = let val empty = LearnerAcceptorMap.empty @@ -295,11 +339,11 @@ struct to_list a @ to_list b |> pick_best_two_from_list |> valOf end val w0 = - if Msg.is_two_a m then + if MessageType.is_two_a m_type then let val m_acc = Msg.sender m in - Msg.learners m + LearnerGraph.learners g |> List.foldl (fn (alpha, u) => insert (u, (alpha, m_acc), (m, NONE))) empty @@ -309,7 +353,7 @@ struct in m |> Msg.get_refs - |> List.filter (not o Msg.is_one_a) + |> List.filter (not o Msg.is_proposal) |> List.foldl (fn (r, w) => unionWith pick_best_two (msg_to_w r, w)) w0 end @@ -324,7 +368,7 @@ struct val s0 = AcceptorMap.singleton (Msg.sender m, AcceptorStatus.Uncaught m) in Msg.get_refs m - |> List.filter (not o Msg.is_one_a) + |> List.filter (not o Msg.is_proposal) |> List.foldl (fn (r, s) => unionWith join (msg_to_acc_status r, s)) s0 end @@ -332,12 +376,16 @@ struct (* [msg_to_w] returns a (msg * msg option) LearnerAcceptorMap.map for each known message, excluding 1a *) (* [msg_to_unburied] returns a set MsgSet.set for each known message, excluding 1a *) (* REQUIRES: m is not 1a *) - fun compute_unburied_2as (m : msg) (g : learner_graph) - (msg_to_bal_val : msg -> ballot * value) msg_to_w msg_to_unburied + fun compute_unburied_2as + (m : msg) (g : learner_graph) (m_type : MessageType.t) m_W + (msg_to_bal_val : msg -> ballot * value) + (* TODO check if msg_to_learners has to be defined for m *) + msg_to_learners + msg_to_unburied : MsgSet.set LearnerMap.map = let fun is_learner_of (alpha : learner, x : msg) = - Msg.learners x + msg_to_learners x |> List.find (Fn.curry Msg.Learner.eq alpha) |> Option.isSome @@ -352,7 +400,7 @@ struct not (Msg.Value.eq (x_val, z_val)) end - val m_w = msg_to_w m + val get_w = Fn.curry LearnerAcceptorMap.lookup m_W val all_acceptors = LearnerGraph.acceptors g fun compute_unburied_2as_for_learner (beta : learner) = @@ -360,19 +408,21 @@ struct fun buried x = let fun check acc = - case LearnerAcceptorMap.lookup (m_w, (beta, acc)) of - (best1, o_best2) => - burying beta (x, best1) orelse - (isSome o_best2 andalso burying beta (x, (valOf o_best2))) + let + val (best1, o_best2) = get_w (beta, acc) + in + burying beta (x, best1) orelse + (isSome o_best2 andalso burying beta (x, (valOf o_best2))) + end val acceptors = List.filter check all_acceptors in LearnerGraph.is_quorum g (beta, acceptors) end - val u0 = if Msg.is_two_a m then MsgSet.singleton m else MsgSet.empty + val u0 = if MessageType.is_two_a m_type then MsgSet.singleton m else MsgSet.empty in m |> Msg.get_refs - |> List.filter (not o Msg.is_one_a) + |> List.filter (not o Msg.is_proposal) |> List.foldl (fn (r, u) => MsgSet.union (LearnerMap.lookup (msg_to_unburied r, beta), u)) @@ -393,6 +443,7 @@ struct (* REQUIRES: m is 2a *) fun compute_q (s : State.t) (m : msg) (g : learner_graph) msg_to_bal_val + msg_to_learners msg_to_acc_status msg_to_unburied : (acceptor list) LearnerMap.map = @@ -419,11 +470,11 @@ struct (* REQUIRES: m is 1b *) let val m_acc = Msg.sender m - val m_lrn = Msg.learners m val connected = compute_connected (alpha, m) |> LearnerSet.fromList fun from_this_sender x = Msg.Acceptor.eq ((Msg.sender x), m_acc) in - Msg.learners m + m + |> msg_to_learners |> List.filter (fn beta => LearnerSet.member (connected, beta)) |> List.foldl (fn (beta, accu) => @@ -440,7 +491,7 @@ struct in compute_connected_2as (l, m) |> MsgSet.all same_value end - (* cached `is_fresh` predicate *) + (* cached `is_fresh` predicate, impure *) fun is_fresh' s (l : learner, m : msg) = (* REQUIRES: m is 1b *) case State.get_is_fresh s (l, m) of @@ -456,7 +507,7 @@ struct fun compute_q_for_learner (alpha : learner) = let val m_bal = msg_to_bal m - fun pred x = Msg.is_one_b x andalso is_fresh' s (alpha, x) + fun pred x = State.is_one_b s x andalso is_fresh' s (alpha, x) fun cont x = Msg.Ballot.eq (msg_to_bal x, m_bal) in m @@ -475,54 +526,48 @@ struct (* REQUIRES: m is not 1a *) fun compute_msg_info_entry (s : State.t) (g : learner_graph) (m : msg) get_bal_val + get_learners get_W get_acc_status get_unburied_2as : MessageInfo.info_entry = let + val m_type = compute_type m val m_bal_val = compute_bal_val m get_bal_val fun get_bal_val_with_m x = if Msg.eq (x, m) then m_bal_val else get_bal_val x (* TODO rename W to something more meaningful *) (* compute W values as per message m *) - val m_W = compute_W m get_bal_val_with_m get_W + val m_W = compute_W m m_type g get_bal_val_with_m get_W (* compute acceptor status as per message m *) val m_acc_status = compute_acceptor_status m (fst o get_bal_val_with_m) get_acc_status (* compute a set of unburied 2a-messages as per message m *) val m_unburied_2as = - compute_unburied_2as m g get_bal_val_with_m get_W get_unburied_2as + compute_unburied_2as m g m_type m_W get_bal_val_with_m get_learners get_unburied_2as (* list of quorums per each learner *) val m_q = - if Msg.is_one_b m then LearnerMap.empty else - compute_q s m g get_bal_val_with_m get_acc_status get_unburied_2as + if MessageType.is_one_b m_type then + LearnerMap.empty + else + compute_q s m g get_bal_val_with_m get_learners get_acc_status get_unburied_2as + (* list of learner values *) + val m_learners = m_q |> LearnerMap.listKeys in { + info_type = m_type, info_bal_val = m_bal_val, info_W = m_W, info_acc_status = m_acc_status, info_unburied_2as = m_unburied_2as, - info_q = m_q + info_q = m_q, + info_learners = m_learners } end - fun compute_learners_for_2a (s : State.t) (g : learner_graph) (prev : msg option) (recent : msg list) = - let - val get_bal_val = State.get_bal_val s - val get_W = State.get_W s - val get_acc_status = State.get_acc_status s - val get_unburied_2as = State.get_unburied_2as s - - val pre_msg = Msg.mk_two_a (prev, recent, []) - val info_entry = compute_msg_info_entry s g pre_msg - get_bal_val get_W get_acc_status get_unburied_2as - in - #info_q info_entry |> LearnerMap.listKeys - end - fun prev_wellformed (m : msg) : bool = let - val m_refs = List.filter (not o Msg.is_one_a) (Msg.get_refs m) + val m_refs = List.filter (not o Msg.is_proposal) (Msg.get_refs m) val m_acc = Msg.sender m fun from_this_sender x = Msg.Acceptor.eq ((Msg.sender x), m_acc) in @@ -551,69 +596,57 @@ struct fun is_wellformed (s : State.t) (g : learner_graph) (m : msg) : bool * MessageInfo.info_entry option = let val get_bal_val = State.get_bal_val s + val get_learners = State.get_learners s val get_W = State.get_W s val get_acc_status = State.get_acc_status s val get_unburied_2as = State.get_unburied_2as s - fun is_wellformed_1a m = - (* TODO this check might be redundant depending on how `get_prev` is defined *) - not (isSome (Msg.get_prev m)) andalso - (* TODO this check might be redundant depending on how `get_refs` is defined *) - null (Msg.get_refs m) fun is_wellformed_1b m (m_info_entry : MessageInfo.info_entry) = MsgUtil.references_exactly_one_1a m andalso let val ballot = fst o get_bal_val val (m_bal, _) = #info_bal_val m_info_entry fun check_ref x = - Msg.is_one_a x orelse + Msg.is_proposal x orelse Msg.Ballot.compare (ballot x, m_bal) = LESS in - List.all check_ref (Msg.get_refs m) + m + |> Msg.get_refs + |> List.all check_ref end + fun is_wellformed_2a m (m_info_entry : MessageInfo.info_entry) = not (null (Msg.get_refs m)) andalso - not (null (Msg.learners m)) andalso - let - val m_lrns = Msg.learners m - fun q (alpha : learner) = LearnerMap.lookup (#info_q m_info_entry, alpha) - val q_lrns = LearnerMap.listKeys (#info_q m_info_entry) - in - list_equal (Msg.Learner.eq) m_lrns q_lrns andalso - List.all (fn alpha => LearnerGraph.is_quorum g (alpha, q alpha)) q_lrns - end + not (null (#info_learners m_info_entry)) + + fun is_wellformed_non_proposal m (m_info_entry : MessageInfo.info_entry) = + if MessageType.is_one_b (#info_type m_info_entry) then + is_wellformed_1b m m_info_entry + else + is_wellformed_2a m m_info_entry in (* optionally, we might want to check that every reference occurs at most once (call to refs_nondup) *) (* TODO actually, do it in the mailbox implementation *) - if prev_wellformed m andalso MsgUtil.refs_nondup m then ( - if Msg.is_one_a m then - (is_wellformed_1a m, NONE) - else ( - if Msg.is_one_b m then - let - val m_info_entry = - compute_msg_info_entry s g m - get_bal_val get_W get_acc_status get_unburied_2as - in - if is_wellformed_1b m m_info_entry then - (true, SOME m_info_entry) - else - (false, NONE) - end - else - let - val m_info_entry = - compute_msg_info_entry s g m - get_bal_val get_W get_acc_status get_unburied_2as - in - if is_wellformed_2a m m_info_entry then - (true, SOME m_info_entry) - else - (false, NONE) - end + if prev_wellformed m andalso MsgUtil.refs_nondup m then + ( + if Msg.is_proposal m then + (true, NONE) + else + ( + let + val m_info_entry = + compute_msg_info_entry s g m + get_bal_val get_learners get_W get_acc_status get_unburied_2as + in + if is_wellformed_non_proposal m m_info_entry then + (true, SOME m_info_entry) + else + (false, NONE) + end ) ) - else (false, NONE) + else + (false, NONE) end fun check_wellformed_and_update_info (s : State.t, g : learner_graph, m : msg) @@ -639,7 +672,7 @@ struct let val prev = State.get_prev s val recent = MsgSet.add (State.get_recent s, m) |> MsgSet.toList - val new_1b = Msg.mk_one_b (prev, recent) + val new_1b = Msg.mk_non_proposal (prev, recent) val (is_wf, s) = check_wellformed_and_update_info (s, g, new_1b) in if is_wf then @@ -658,12 +691,9 @@ struct fun process_1b s m : State.t * msg option = let - val m_bal = fst (State.get_bal_val s m) val prev = State.get_prev s val recent = MsgSet.toList (State.get_recent s) - (* TODO this actually computes an info entry for the new message, which is currently not stored *) - val lrns = compute_learners_for_2a s g prev recent - val new_2a = Msg.mk_two_a (prev, recent, lrns) + val new_2a = Msg.mk_non_proposal (prev, recent) val (is_wf, s) = check_wellformed_and_update_info (s, g, new_2a) in if is_wf then @@ -677,7 +707,9 @@ struct (* broadcast new_2a *) (update_state s, SOME new_2a) end - else (s, NONE) + else + (* XXX add to recent? *) + (s, NONE) end fun process_2a s m : State.t * msg option = @@ -691,10 +723,10 @@ struct check_wellformed_and_update_info (s, g, m) val (s, new_msg) = if res then - if Msg.is_one_a m then + if Msg.is_proposal m then process_1a s m else ( - if Msg.is_one_b m then + if State.is_one_b s m then process_1b s m else process_2a s m