diff --git a/gen-server-impl.sig b/gen-server-impl.sig new file mode 100644 index 0000000..1d0e5fe --- /dev/null +++ b/gen-server-impl.sig @@ -0,0 +1,9 @@ +signature GEN_SERVER_IMPL = +sig + type param + type state + type msg + + val init : param -> state + val handle_msg : state * msg -> state * msg option +end diff --git a/hashing.sml b/hashing.sml new file mode 100644 index 0000000..b48929e --- /dev/null +++ b/hashing.sml @@ -0,0 +1,10 @@ +structure Hashing = struct + infix |> + fun x |> f = f x + + fun hash words = + words + |> List.map Word.toString + |> String.concat + |> FNVHash.hashString +end diff --git a/hpaxos.sml b/hpaxos.sml deleted file mode 100644 index 0635798..0000000 --- a/hpaxos.sml +++ /dev/null @@ -1,595 +0,0 @@ -signature HPAXOS_NODE = -sig - type t - type node_id - type learner_graph - type mailbox - val hpaxos_node : node_id -> learner_graph -> mailbox -> t - val run : t -> unit -end - -functor HPaxos (structure Msg : HPAXOS_MESSAGE - structure Mailbox : HPAXOS_MAILBOX - structure LearnerGraph : LEARNER_GRAPH - sharing Msg.Learner = LearnerGraph.Learner - and Msg.Acceptor = LearnerGraph.Acceptor - and Mailbox.Message = Msg) - :> HPAXOS_NODE = -struct - type msg = Msg.t - type mailbox = Mailbox.t - - type acceptor_id = word - type acceptor = Msg.Acceptor.t - type ballot = Msg.Ballot.t - type value = Msg.Value.t - type learner = Msg.Learner.t - type learner_graph = LearnerGraph.t - - structure MsgUtil = MessageUtil (Msg) - - structure MsgSet : ORD_SET = RedBlackSetFn (MessageOrdKey (Msg)) - structure MsgMap : ORD_MAP = RedBlackMapFn (MessageOrdKey (Msg)) - structure AcceptorSet : ORD_SET = RedBlackSetFn (AcceptorOrdKey (Msg.Acceptor)) - structure AcceptorMap : ORD_MAP = RedBlackMapFn (AcceptorOrdKey (Msg.Acceptor)) - structure LearnerSet : ORD_SET = RedBlackSetFn (LearnerOrdKey (Msg.Learner)) - - structure LearnerAcceptorMap : ORD_MAP = - RedBlackMapFn ( - ProdLexOrdKey - (LearnerOrdKey (Msg.Learner)) - (AcceptorOrdKey (Msg.Acceptor))) - - structure LearnerMsgMap : ORD_MAP = - RedBlackMapFn ( - ProdLexOrdKey - (LearnerOrdKey (Msg.Learner)) - (MessageOrdKey (Msg))) - - (* algorithm state *) - structure AlgoState = - struct - datatype known_msgs = KnownMsgs of MsgSet.set - datatype recent_msgs = RecentMsgs of MsgSet.set - datatype prev_msg = PrevMsg of msg option - datatype non_wellformed_msgs = NonWellformedMsgs of MsgSet.set - datatype max_ballot = MaxBal of ballot - - datatype state = AlgoState of known_msgs * recent_msgs * prev_msg * - non_wellformed_msgs * max_ballot - type t = state - - fun mk () = AlgoState (KnownMsgs MsgSet.empty, - RecentMsgs MsgSet.empty, - PrevMsg NONE, - NonWellformedMsgs MsgSet.empty, - MaxBal Msg.Ballot.zero) - - fun is_known (AlgoState (KnownMsgs k, _, _, _, _)) = - Fn.curry MsgSet.member k - - fun add_known (AlgoState (KnownMsgs k, r, p, nw, maxb)) m = - AlgoState (KnownMsgs (MsgSet.add (k, m)), r, p, nw, maxb) - - fun get_recent (AlgoState (_, RecentMsgs r, _, _, _)) = r - - fun add_recent (AlgoState (k, RecentMsgs r, p, nw, maxb)) m = - AlgoState (k, RecentMsgs (MsgSet.add (r, m)), p, nw, maxb) - - fun clear_recent (AlgoState (k, _, p, nw, maxb)) = - AlgoState (k, RecentMsgs MsgSet.empty, p, nw, maxb) - - fun get_prev (AlgoState (_, _, PrevMsg p, _, _)) = p - - fun set_prev (AlgoState (k, r, _, nw, maxb)) m = - AlgoState (k, r, PrevMsg (SOME m), nw, maxb) - - fun is_non_wellformed (AlgoState (_, _, _, NonWellformedMsgs nw, _)) = - Fn.curry MsgSet.member nw - - fun add_non_wellformed (AlgoState (k, r, p, NonWellformedMsgs nw, maxb)) m = - AlgoState (k, r, p, NonWellformedMsgs (MsgSet.add (nw, m)), maxb) - - fun get_max (AlgoState (_, _, _, _, MaxBal maxb)) = maxb - - fun set_max (AlgoState (k, r, p, nw, _)) bal = - AlgoState (k, r, p, nw, MaxBal bal) - end (* AlgoState *) - - (* message info state *) - structure AcceptorStatus = - struct - datatype status = Caught - | Uncaught of msg - type t = status - - fun is_uncaught (Uncaught _) = true - | is_uncaught _ = false - - fun join (bal : msg -> ballot) (Uncaught m1, Uncaught m2) = - if MsgUtil.PrevTran.is_prev_reachable' bal (m1, m2) then - Uncaught m1 - else if MsgUtil.PrevTran.is_prev_reachable' bal (m2, m1) then - Uncaught m2 - else - Caught - | join bal (_, _) = Caught - end (* AcceptorStatus *) - - structure MessageInfo = - struct - type info_entry = { info_bal_val : ballot * value, - info_W : (msg * msg option) LearnerAcceptorMap.map, - info_acc_status : AcceptorStatus.t AcceptorMap.map, - info_unburied_2as : MsgSet.set, - info_q : acceptor list } - - datatype msg_info = MsgInfo of info_entry MsgMap.map - - type t = msg_info - - fun mk () = MsgInfo MsgMap.empty - - fun get_bal_val (MsgInfo info) m = #info_bal_val (MsgMap.lookup (info, m)) - fun get_W (MsgInfo info) m = #info_W (MsgMap.lookup (info, m)) - fun get_acc_status (MsgInfo info) m = #info_acc_status (MsgMap.lookup (info, m)) - fun get_unburied_2as (MsgInfo info) m = #info_unburied_2as (MsgMap.lookup (info, m)) - fun get_q (MsgInfo info) m = #info_q (MsgMap.lookup (info, m)) - - fun store_entry (MsgInfo info) (m, info_entry : info_entry) = - MsgInfo (MsgMap.insert (info, m, info_entry)) - end (* MessageInfo *) - - (* memo state *) - structure Cache = - struct - datatype is_fresh_cache = IsFresh of (bool LearnerMsgMap.map) ref - datatype cache = Cache of is_fresh_cache - type t = cache - - fun mk () = Cache (IsFresh (ref LearnerMsgMap.empty)) - - fun get_is_fresh (Cache (IsFresh f)) = Fn.curry LearnerMsgMap.find (!f) - fun put_is_fresh (Cache (IsFresh f)) (lm, v) = - Ref.modify (fn map => LearnerMsgMap.insert (map, lm, v)) f - end - - structure State = - struct - datatype state = State of AlgoState.t * MessageInfo.t * Cache.t - type t = state - - fun mk () = State (AlgoState.mk (), MessageInfo.mk (), (Cache.mk ())) - - fun is_known (State (s, _, _)) = AlgoState.is_known s - - fun get_recent (State (s, _, _)) = AlgoState.get_recent s - fun clear_recent (State (s, i, c)) = State (AlgoState.clear_recent s, i, c) - - fun get_prev (State (s, _, _)) = AlgoState.get_prev s - fun set_prev (State (s, i, c)) m = State (AlgoState.set_prev s m, i, c) - - fun is_non_wellformed (State (s, _, _)) = AlgoState.is_non_wellformed s - - fun add_known (State (s, i, c)) m = - State (AlgoState.add_known s m, i, c) - - fun add_recent (State (s, i, c)) m = - State (AlgoState.add_recent s m, i, c) - - fun add_known_recent (State (s, i, c)) m = - State (AlgoState.add_recent (AlgoState.add_known s m) m, i, c) - - fun add_non_wellformed (State (s, i, c)) m = - State (AlgoState.add_non_wellformed s m, i, c) - - fun get_max (State (s, _, _)) = AlgoState.get_max s - - fun get_bal_val (State (s, i, c)) m = - if Msg.is_one_a m then - valOf (Msg.get_bal_val m) - else - MessageInfo.get_bal_val i m - - fun update_max (state as State (s, i, c)) m = - let - fun max (a, b) = - case Msg.Ballot.compare (a, b) of LESS => b | _ => a - val m_bal = fst (get_bal_val state m) - val cur_max = AlgoState.get_max s - val new_max = max (m_bal, cur_max) - in - State (AlgoState.set_max s new_max, i, c) - end - - fun get_W (State (_, i, _)) = MessageInfo.get_W i - fun get_acc_status (State (_, i, _)) = MessageInfo.get_acc_status i - fun get_unburied_2as (State (_, i, _)) = MessageInfo.get_unburied_2as i - fun get_q (State (_, i, _)) = MessageInfo.get_q i - - fun store_info_entry (State (a, i, c)) mi = - State (a, MessageInfo.store_entry i mi, c) - - fun get_is_fresh (State (_, _, c)) = Cache.get_is_fresh c - fun put_is_fresh (State (_, _, c)) = Cache.put_is_fresh c - end (* State *) - - type state = State.t - - (* learner graph *) - datatype graph = Graph of learner_graph - - datatype acceptor_node = Acc of acceptor_id * graph * state * mailbox - - type t = acceptor_node - type node_id = acceptor_id - - (* [msg_to_bal_val] returns a pair (ballot, value) for each known message, including 1a messages *) - (* REQUIRES: m is not 1a *) - fun compute_bal_val (m : msg) msg_to_bal_val : ballot * value = - let - fun helper (x, (max_bal, max_val)) = - let val (b, v) = msg_to_bal_val x in - case Msg.Ballot.compare (b, max_bal) of - LESS => (max_bal, max_val) - | _ => (b, v) - end - val refs = Msg.get_refs m (* refs is non-empty since m is not 1a *) - in - foldl helper (Msg.Ballot.zero, Msg.Value.default) refs - end - - (* [msg_to_bal_val] returns a pair (ballot, value) for each known message and the message m *) - (* [msg_to_w] returns a (msg * msg option) LearnerAcceptorMap.map for each known message, excluding 1a *) - (* REQUIRES: m is not 1a *) - fun compute_W (m : msg) msg_to_bal_val msg_to_w - : (msg * msg option) LearnerAcceptorMap.map = - let - fun pick_best_two_from_list (ms : msg list) : (msg * msg option) option = - let - val ballot = fst o msg_to_bal_val - val value = snd o msg_to_bal_val - fun pick_best pred cmp lst = - let - fun choose (x, cur_best_o) = - if pred x then - case cur_best_o of - NONE => SOME x - | SOME cur_best => - case cmp (cur_best, x) of - LESS => SOME x - | _ => cur_best_o - else cur_best_o - in - foldl choose NONE lst - end - fun cmp_by_ballot (x, y) = Msg.Ballot.compare (ballot x, ballot y) - fun pick_first_best ms = pick_best (Fn.const true) cmp_by_ballot ms - fun pick_second_best ms fst_best = - let val fst_best_val = value fst_best - fun pred x = not (Msg.Value.eq (value x, fst_best_val)) - in - pick_best pred cmp_by_ballot ms - end - in - Option.map (fn x => (x, pick_second_best ms x)) (pick_first_best ms) - end - fun pick_best_two (a : msg * msg option, b : msg * msg option) = - let fun to_list (best1, NONE) = [best1] - | to_list (best1, SOME best2) = [best1, best2] - in - valOf (pick_best_two_from_list (to_list a @ to_list b)) - end - fun join (r, w) = LearnerAcceptorMap.unionWith pick_best_two (msg_to_w r, w) - val w0 = - if Msg.is_two_a m then - LearnerAcceptorMap.insert - (LearnerAcceptorMap.empty, - (valOf (Msg.learner m), Msg.sender m), - (m, NONE)) - else - LearnerAcceptorMap.empty - val refs = List.filter (not o Msg.is_one_a) (Msg.get_refs m) - in - foldl join w0 refs - end - - (* [msg_to_bal] returns a ballot for each known message, excluding 1a, and the message m *) - (* [msg_to_acc_status] returns a map (AcceptorStatus.t AcceptorMap.map) for each known message, excluding 1a *) - (* REQUIRES: m is not 1a *) - fun compute_acceptor_status (m : msg) msg_to_bal msg_to_acc_status - : AcceptorStatus.t AcceptorMap.map = - let - fun helper (r, s) = - AcceptorMap.unionWith (AcceptorStatus.join msg_to_bal) (msg_to_acc_status r, s) - val s0 = AcceptorMap.singleton (Msg.sender m, AcceptorStatus.Uncaught m) - val refs = List.filter (not o Msg.is_one_a) (Msg.get_refs m) - in - foldl helper s0 refs - end - - (* [msg_to_bal_val] returns a pair (ballot, value) for each known message and the message m *) - (* [msg_to_w] returns a (msg * msg option) LearnerAcceptorMap.map for each known message, excluding 1a *) - (* [msg_to_unburied] returns a set MsgSet.set for each known message, excluding 1a *) - (* REQUIRES: m is not 1a *) - fun compute_unburied_2as (m : msg) (g : learner_graph) msg_to_bal_val msg_to_w msg_to_unburied - : MsgSet.set = - let - val m_lrn = valOf (Msg.learner m) - (* z is burying x *) - fun burying (x, z) = - Msg.Learner.eq - (valOf (Msg.learner x), valOf (Msg.learner z)) andalso - let - val (x_bal, x_val) = msg_to_bal_val x - val (z_bal, z_val) = msg_to_bal_val z - in - Msg.Ballot.compare (x_bal, z_bal) = LESS andalso - not (Msg.Value.eq (x_val, z_val)) - end - val u0 = if Msg.is_two_a m then MsgSet.singleton m else MsgSet.empty - val refs = List.filter (not o Msg.is_one_a) (Msg.get_refs m) - val u = foldl (fn (r, u) => MsgSet.union (msg_to_unburied r, u)) u0 refs - val m_w = msg_to_w m - val all_acceptors = LearnerGraph.acceptors g - fun buried x = - let - val x_lrn = valOf (Msg.learner x) - fun check acc = - case LearnerAcceptorMap.lookup (m_w, (x_lrn, acc)) of - (best1, o_best2) => - burying (x, best1) orelse (isSome o_best2 andalso burying (x, (valOf o_best2))) - val acceptors = List.filter check all_acceptors - in - LearnerGraph.is_quorum g (m_lrn, acceptors) - end - in - MsgSet.filter (not o buried) u - end - - (* [msg_to_bal] returns a ballot for each known message, excluding 1a, and the message m *) - (* [msg_to_acc_status] returns a map (AcceptorStatus.t AcceptorMap.map) for each known message, excluding 1a *) - (* [msg_to_unburied] returns a set MsgSet.set for each known message, excluding 1a *) - (* REQUIRES: m is 2a *) - fun compute_q (s : state) (m : msg) (g : learner_graph) msg_to_bal msg_to_acc_status msg_to_unburied - : acceptor list = - let - fun compute_connected (l : learner, m : msg) = - (* REQUIRES: m is 1b *) - let val caught = - AcceptorMap.listKeys ( - AcceptorMap.filter AcceptorStatus.is_uncaught (msg_to_acc_status m) - ) - in - LearnerGraph.get_connected g (l, caught) - end - fun compute_connected_2as (l : learner, m : msg) = - (* REQUIRES: m is 1b *) - let - val connected = LearnerSet.fromList (compute_connected (l, m)) - val m_sender = Msg.sender m - fun pred x = - Msg.Acceptor.eq ((Msg.sender x), m_sender) andalso - LearnerSet.member (connected, valOf (Msg.learner x)) - in - MsgSet.filter pred (msg_to_unburied m) - end - fun is_fresh (l : learner, m : msg) = - (* REQUIRES: m is 1b *) - let - val connected_2as = compute_connected_2as (l, m) - val m_bal = msg_to_bal m - fun from_this_sender x = Msg.Ballot.eq (msg_to_bal x, m_bal) - in - MsgSet.all from_this_sender connected_2as - end - fun is_fresh' s (l : learner, m : msg) = - (* REQUIRES: m is 1b *) - case State.get_is_fresh s (l, m) of - SOME b => b - | NONE => - let val res = is_fresh (l, m) - val _ = State.put_is_fresh s ((l, m), res) - in - res - end - val m_tran = - let - val m_lrn = valOf (Msg.learner m) - val m_bal = msg_to_bal m - fun pred x = Msg.is_one_b x andalso is_fresh' s (m_lrn, x) - fun cont x = Msg.Ballot.eq (msg_to_bal x, m_bal) - in - MsgUtil.tran pred cont m - end - fun senders ms = - let fun helper (x, accu) = AcceptorSet.add' (Msg.sender x, accu) in - foldl helper AcceptorSet.empty ms - end - in - AcceptorSet.toList (senders m_tran) - end - - fun prev_wellformed (m : msg) : bool = - let - val m_refs = List.filter (not o Msg.is_one_a) (Msg.get_refs m) - val m_acc = Msg.sender m - fun from_this_sender x = Msg.Acceptor.eq ((Msg.sender x), m_acc) - in - case Msg.get_prev m of - NONE => List.all (not o from_this_sender) m_refs - | SOME prev => - isSome (List.find (Fn.curry Msg.eq prev) m_refs) andalso - let fun check_ref x = - not (from_this_sender x) orelse Msg.eq (x, prev) - in - List.all check_ref m_refs - end - end - - fun all_refs_known (s : state) (m : msg) = - List.all (State.is_known s) (Msg.get_refs m) - - fun has_non_wellformed_ref s m = - List.exists (State.is_non_wellformed s) (Msg.get_refs m) - - fun process_non_wellformed s m = - State.add_non_wellformed s m - (* ...further actions possible *) - - (* REQUIRES: every direct reference is known *) - fun is_wellformed (s : state) (g : learner_graph) (m : msg) : bool * MessageInfo.info_entry option = - let - fun compute_msg_info_entry s m : MessageInfo.info_entry = - (* REQUIRES: m is not 1a *) - let - val get_bal_val = State.get_bal_val s - val get_W = State.get_W s - val get_acc_status = State.get_acc_status s - val get_unburied_2as = State.get_unburied_2as s - val m_bal_val = compute_bal_val m get_bal_val - fun get_bal_val_with_m x = - if Msg.eq (x, m) then m_bal_val else get_bal_val x - val m_W = compute_W m get_bal_val_with_m get_W - val m_acc_status = - compute_acceptor_status m (fst o get_bal_val_with_m) get_acc_status - val m_unburied_2as = - compute_unburied_2as m g get_bal_val_with_m get_W get_unburied_2as - val m_q = - if Msg.is_one_b m then [] else - (* case 2a *) - compute_q s m g (fst o get_bal_val_with_m) get_acc_status get_unburied_2as - in - { info_bal_val = m_bal_val, - info_W = m_W, - info_acc_status = m_acc_status, - info_unburied_2as = m_unburied_2as, - info_q = m_q } - end - fun is_wellformed_1a m = - (* TODO this check might be redundant depending on how `get_prev` is defined *) - not (isSome (Msg.get_prev m)) andalso - (* TODO this check might be redundant depending on how `get_refs` is defined *) - null (Msg.get_refs m) - fun is_wellformed_1b m (m_info_entry : MessageInfo.info_entry) = - MsgUtil.references_exactly_one_1a m andalso - let - val ballot = fst o (State.get_bal_val s) - val (m_bal, _) = #info_bal_val m_info_entry - fun check_ref x = - Msg.is_one_a x orelse - Msg.Ballot.compare (ballot x, m_bal) = LESS - in - List.all check_ref (Msg.get_refs m) - end - fun is_wellformed_2a m (m_info_entry : MessageInfo.info_entry) = - not (null (Msg.get_refs m)) andalso - LearnerGraph.is_quorum g (valOf (Msg.learner m), #info_q m_info_entry) - in - (* optionally, we might want to check that every reference occurs at most once (call to refs_nondup) *) - if prev_wellformed m andalso MsgUtil.refs_nondup m then - case Msg.typ m of - Msg.OneA => (is_wellformed_1a m, NONE) - | Msg.OneB => - let val m_info_entry = compute_msg_info_entry s m in - if is_wellformed_1b m m_info_entry then - (true, SOME m_info_entry) - else - (false, NONE) - end - | Msg.TwoA => - let val m_info_entry = compute_msg_info_entry s m in - if is_wellformed_2a m m_info_entry then - (true, SOME m_info_entry) - else - (false, NONE) - end - else (false, NONE) - end - - fun check_wellformed_and_update_info (s : state) (g : learner_graph) m : bool * state = - case is_wellformed s g m of - (false, _) => (false, s) - | (true, info_entry_o) => - (true, Option.fold (fn (e, s) => State.store_info_entry s (m, e)) s info_entry_o) - - fun hpaxos_node (id : node_id) (g : learner_graph) (mbox : mailbox) : t = - Acc (id, Graph g, State.mk (), mbox) - - fun run (Acc (id, Graph g, s, mbox)) = - let - fun get_next_wellformed_msg s : msg * state = - case Mailbox.recv mbox of - NONE => get_next_wellformed_msg s - | SOME m => - let val _ = assert (all_refs_known s m) "wrong message ordering" in - if has_non_wellformed_ref s m then - get_next_wellformed_msg (process_non_wellformed s m) - else - let val (res, s) = check_wellformed_and_update_info s g m in - if res then (m, s) else - get_next_wellformed_msg (process_non_wellformed s m) - end - end - fun process_1a s m : state = - let - val prev = State.get_prev s - val recent = MsgSet.toList (State.get_recent s) - val new_1b = Msg.mk_one_b (prev, recent) - val (is_wf, s) = check_wellformed_and_update_info s g new_1b - in - if is_wf then - let val s = State.clear_recent s - val s = State.set_prev s new_1b - in - process_message s new_1b - end - else s - end - and process_1b s m : state = - let - val m_bal = fst (State.get_bal_val s m) - val cur_max = State.get_max s - fun process_learner (lrn, s) = - let - val prev = State.get_prev s - val recent = MsgSet.toList (State.get_recent s) - val new_2a = Msg.mk_two_a (prev, recent, lrn) - val (is_wf, s) = check_wellformed_and_update_info s g new_2a - in - if is_wf then - let val s = State.clear_recent s - val s = State.set_prev s new_2a - in - process_message s new_2a - end - else s - end - in - if Msg.Ballot.eq (m_bal, cur_max) then - foldl process_learner s (LearnerGraph.learners g) - else s - end - and process_message s m = - let - val _ = Mailbox.broadcast mbox m - val s = State.add_known s m - val s = State.add_recent s m - val s = State.update_max s m - in - case Msg.typ m of - Msg.OneA => process_1a s m - | Msg.OneB => process_1b s m - | Msg.TwoA => s - end - fun loop s = - let val (m, s) = get_next_wellformed_msg s in - if State.is_known s m then - loop s - else - loop (process_message s m) - end - in - let val _ = loop s in () end - end -end (* HPaxos *) diff --git a/hpaxos/hpaxos-acceptor.sig b/hpaxos/hpaxos-acceptor.sig new file mode 100644 index 0000000..fca056d --- /dev/null +++ b/hpaxos/hpaxos-acceptor.sig @@ -0,0 +1,7 @@ +signature ACCEPTOR = +sig + type t + + val pubkey : t -> word + val eq : t * t -> bool +end diff --git a/hpaxos-acceptor.sml b/hpaxos/hpaxos-acceptor.sml similarity index 58% rename from hpaxos-acceptor.sml rename to hpaxos/hpaxos-acceptor.sml index 3974b9d..7d3e71a 100644 --- a/hpaxos-acceptor.sml +++ b/hpaxos/hpaxos-acceptor.sml @@ -1,8 +1,9 @@ -signature ACCEPTOR = -sig - type t - val pubkey : t -> word - val eq : t * t -> bool +structure HPaxosAcceptor = +struct + type t = word + + fun pubkey a = a + val eq : t * t -> bool = (op =) end functor AcceptorOrdKey (A : ACCEPTOR) : ORD_KEY = diff --git a/hpaxos/hpaxos-ballot.sig b/hpaxos/hpaxos-ballot.sig new file mode 100644 index 0000000..3d9a2eb --- /dev/null +++ b/hpaxos/hpaxos-ballot.sig @@ -0,0 +1,13 @@ +signature HPAXOS_BALLOT = +sig + type t + type value + + val default : t + + val eq : t * t -> bool + val compare : t * t -> order + val hash : t -> word + + val value : t -> value +end diff --git a/hpaxos/hpaxos-ballot.sml b/hpaxos/hpaxos-ballot.sml new file mode 100644 index 0000000..b453113 --- /dev/null +++ b/hpaxos/hpaxos-ballot.sml @@ -0,0 +1,18 @@ +structure HPaxosBallot : HPAXOS_BALLOT = +struct + type t = word + + structure Value = HPaxosValue + type value = Value.t + + val default = Word.fromInt 0 + + val eq : t * t -> bool = (op =) + val compare = Word.compare + + (* TODO *) + fun hash b = b + + (* TODO *) + fun value b = Value.default +end diff --git a/hpaxos-mailbox.sml b/hpaxos/hpaxos-mailbox.sml similarity index 100% rename from hpaxos-mailbox.sml rename to hpaxos/hpaxos-mailbox.sml diff --git a/hpaxos/hpaxos-message.sig b/hpaxos/hpaxos-message.sig new file mode 100644 index 0000000..4085af9 --- /dev/null +++ b/hpaxos/hpaxos-message.sig @@ -0,0 +1,35 @@ +signature HPAXOS_MESSAGE = +sig + type t + + structure Value : HPAXOS_VALUE + type value = Value.t + + structure Ballot : HPAXOS_BALLOT + type ballot = Ballot.t + + structure Learner : LEARNER + type learner = Learner.t + + structure Acceptor : ACCEPTOR + type acceptor = Acceptor.t + + val hash : t -> word + val eq : t * t -> bool + + val is_proposal : t -> bool + + val mk_non_proposal : t option * t list -> t + + (* returns message sender *) + val sender : t -> acceptor + + (* if the message is a proposal, return its ballot and value; otherwise, return NONE *) + val get_bal_val : t -> (ballot * value) option + + (* returns a previous message of the sender *) + val get_prev : t -> t option + + (* returns a list of direct references *) + val get_refs : t -> t list +end diff --git a/hpaxos-message.sml b/hpaxos/hpaxos-message.sml similarity index 54% rename from hpaxos-message.sml rename to hpaxos/hpaxos-message.sml index 1e10e0a..f545b45 100644 --- a/hpaxos-message.sml +++ b/hpaxos/hpaxos-message.sml @@ -1,65 +1,76 @@ -(* HPaxos Message *) - -signature HPAXOS_VALUE = -sig - type t - val default : t (* default value *) - val eq : t * t -> bool (* equality *) -end +structure HPaxosMessage (*: HPAXOS_MESSAGE*) = +struct + infix |> + fun x |> f = f x -signature HPAXOS_BALLOT = -sig - type t - val zero : t (* the smallest ballot *) - val eq : t * t -> bool - val compare : t * t -> order -end + type hash = word -signature HPAXOS_MESSAGE = -sig - type t - datatype typ = OneA - | OneB - | TwoA + structure Learner = Learner + type learner = Learner.t - structure Value : HPAXOS_VALUE + structure Value = HPaxosValue type value = Value.t - structure Ballot : HPAXOS_BALLOT + structure Ballot = HPaxosBallot type ballot = Ballot.t - structure Learner : LEARNER - type learner = Learner.t - - structure Acceptor : ACCEPTOR + structure Acceptor = HPaxosAcceptor type acceptor = Acceptor.t - val hash : t -> word - val eq : t * t -> bool - - val typ : t -> typ - - val is_one_a : t -> bool - val is_one_b : t -> bool - val is_two_a : t -> bool - - val mk_one_b : t option * t list -> t - val mk_two_a : t option * t list * learner -> t - - (* if the message is 2a, return its learner instance; otherwise, return NONE *) - val learner : t -> learner option + datatype msg = + Proposal of ( + ballot * (* ballot *) + hash (* message hash *) + ) + | NonProposal of ( + acceptor * (* sender *) + msg list * (* references *) + msg option * (* previous message *) + hash (* message hash *) + ) + + type t = msg + + fun hash (Proposal (_, h) | NonProposal (_, _, _, h)) = h + + fun compute_hash (Proposal (bal, _)) = + Ballot.hash bal + | compute_hash (NonProposal (_, refs, prev, _)) = + let + val refs_hash = List.map hash refs + val prev_hash = map_or prev [] (fn p => [hash p]) + in + prev_hash @ refs_hash |> Hashing.hash + end - (* returns message sender *) - val sender : t -> acceptor + (* TODO check if equality is used *) + fun eq (Proposal (_, h1), Proposal (_, h2)) = h1 = h2 + | eq (NonProposal (_, _, _, h1), NonProposal (_, _, _, h2)) = h1 = h2 + | eq (_, _) = false - (* if the message is 1a, return its ballot and value; otherwise, return NONE *) - val get_bal_val : t -> (ballot * value) option + (* fun typ (Msg (t, _, _, _, _)) = t *) - (* returns a previous message of the sender *) - val get_prev : t -> t option + fun is_proposal (Proposal _) = true + | is_proposal _ = false - (* returns a list of direct references *) - val get_refs : t -> t list + (* TODO this should be raw? *) + fun mk_non_proposal (sender, prev_msg, recent_msgs) = + let val hash = Word.fromInt 42 in + NonProposal (sender, recent_msgs, prev_msg, hash) + end + + fun sender (NonProposal (sender, _, _, _)) = sender + | sender _ = raise Fail "sender not defined" + + fun get_bal_val (Proposal (b, _)) = + let val v = Ballot.value b in SOME (b, v) end + | get_bal_val _ = NONE + + fun get_prev (Proposal _) = NONE + | get_prev (NonProposal (_, _ , prev, _)) = prev + + fun get_refs (Proposal _) = [] + | get_refs (NonProposal (_, refs , _, _)) = refs end functor MessageOrdKey (Msg : HPAXOS_MESSAGE) : ORD_KEY = @@ -72,15 +83,18 @@ functor MessageUtil (Msg : HPAXOS_MESSAGE) = struct structure MsgSet : ORD_SET = RedBlackSetFn (MessageOrdKey (Msg)) + type msg = Msg.t + type ballot = Msg.Ballot.t + (* fun does_reference_1a m : bool = *) (* isSome (List.find Msg.is_one_a (Msg.get_refs m)) *) fun references_exactly_one_1a m : bool = let fun check (x, (found, false)) = (found, false) | check (x, (false, true)) = - if Msg.is_one_a x then (true, true) else (false, true) + if Msg.is_proposal x then (true, true) else (false, true) | check (x, (true, true)) = - if Msg.is_one_a x then (true, false) else (true, true) + if Msg.is_proposal x then (true, false) else (true, true) in case foldl check (false, true) (Msg.get_refs m) of (found, no_second) => found andalso no_second @@ -101,8 +115,8 @@ struct (* checks if m2 is in transitive closure of prev for m1 *) structure PrevTran :> sig - val is_prev_reachable : Msg.t * Msg.t -> bool - val is_prev_reachable' : (Msg.t -> Msg.Ballot.t) -> Msg.t * Msg.t -> bool + val is_prev_reachable : msg * msg -> bool + val is_prev_reachable' : (msg -> ballot) -> msg * msg -> bool end = struct fun is_prev_reachable_aux cont (m1, m2) = @@ -129,7 +143,7 @@ struct end (* PrevTran *) (* compute transitive references of the message *) - fun tran pred cont m = + fun tran (pred : msg -> bool) (cont : msg -> bool) (m : msg) = let fun loop accu visited [] = accu | loop accu visited (x :: tl) = diff --git a/hpaxos/hpaxos-value.sig b/hpaxos/hpaxos-value.sig new file mode 100644 index 0000000..8f63947 --- /dev/null +++ b/hpaxos/hpaxos-value.sig @@ -0,0 +1,7 @@ +signature HPAXOS_VALUE = +sig + type t + + val default : t (* default value *) + val eq : t * t -> bool (* equality *) +end diff --git a/hpaxos/hpaxos-value.sml b/hpaxos/hpaxos-value.sml new file mode 100644 index 0000000..ed51266 --- /dev/null +++ b/hpaxos/hpaxos-value.sml @@ -0,0 +1,7 @@ +structure HPaxosValue : HPAXOS_VALUE = +struct + type t = word + + val default = Word.fromInt 0 + val eq : t * t -> bool = (op =) +end diff --git a/hpaxos.cm b/hpaxos/hpaxos.cm similarity index 100% rename from hpaxos.cm rename to hpaxos/hpaxos.cm diff --git a/hpaxos/hpaxos.sml b/hpaxos/hpaxos.sml new file mode 100644 index 0000000..7d3e9b1 --- /dev/null +++ b/hpaxos/hpaxos.sml @@ -0,0 +1,740 @@ +functor HPaxos (structure Msg : HPAXOS_MESSAGE + structure LearnerGraph : LEARNER_GRAPH + sharing Msg.Learner = LearnerGraph.Learner + and Msg.Acceptor = LearnerGraph.Acceptor) + :> GEN_SERVER_IMPL = +struct + infix |> + fun x |> f = f x + + type acceptor = Msg.Acceptor.t + type ballot = Msg.Ballot.t + type value = Msg.Value.t + type learner = Msg.Learner.t + type learner_graph = LearnerGraph.t + + type msg = Msg.t + + structure MsgUtil = MessageUtil (Msg) + + structure MsgSet : ORD_SET = RedBlackSetFn (MessageOrdKey (Msg)) + structure MsgMap : ORD_MAP = RedBlackMapFn (MessageOrdKey (Msg)) + structure AcceptorSet : ORD_SET = RedBlackSetFn (AcceptorOrdKey (Msg.Acceptor)) + structure AcceptorMap : ORD_MAP = RedBlackMapFn (AcceptorOrdKey (Msg.Acceptor)) + structure LearnerSet : ORD_SET = RedBlackSetFn (LearnerOrdKey (Msg.Learner)) + structure LearnerMap : ORD_MAP = RedBlackMapFn (LearnerOrdKey (Msg.Learner)) + + structure LearnerAcceptorMap : ORD_MAP = + RedBlackMapFn ( + ProdLexOrdKey + (LearnerOrdKey (Msg.Learner)) + (AcceptorOrdKey (Msg.Acceptor))) + + structure LearnerMsgMap : ORD_MAP = + RedBlackMapFn ( + ProdLexOrdKey + (LearnerOrdKey (Msg.Learner)) + (MessageOrdKey (Msg))) + + (* algorithm state *) + structure AlgoState = + struct + datatype known_msgs = KnownMsgs of MsgSet.set + datatype recent_msgs = RecentMsgs of MsgSet.set + datatype prev_msg = PrevMsg of msg option + datatype non_wellformed_msgs = NonWellformedMsgs of MsgSet.set + datatype max_ballot = MaxBal of ballot + + datatype state = AlgoState of known_msgs * recent_msgs * prev_msg * + non_wellformed_msgs * max_ballot + type t = state + + fun mk () = AlgoState (KnownMsgs MsgSet.empty, + RecentMsgs MsgSet.empty, + PrevMsg NONE, + NonWellformedMsgs MsgSet.empty, + MaxBal Msg.Ballot.zero) + + fun is_known (AlgoState (KnownMsgs k, _, _, _, _)) = + Fn.curry MsgSet.member k + + fun add_known (AlgoState (KnownMsgs k, r, p, nw, maxb)) m = + AlgoState (KnownMsgs (MsgSet.add (k, m)), r, p, nw, maxb) + + fun get_recent (AlgoState (_, RecentMsgs r, _, _, _)) = r + + fun add_recent (AlgoState (k, RecentMsgs r, p, nw, maxb)) m = + AlgoState (k, RecentMsgs (MsgSet.add (r, m)), p, nw, maxb) + + fun clear_recent (AlgoState (k, _, p, nw, maxb)) = + AlgoState (k, RecentMsgs MsgSet.empty, p, nw, maxb) + + fun get_prev (AlgoState (_, _, PrevMsg p, _, _)) = p + + fun set_prev (AlgoState (k, r, _, nw, maxb)) m = + AlgoState (k, r, PrevMsg (SOME m), nw, maxb) + + fun is_non_wellformed (AlgoState (_, _, _, NonWellformedMsgs nw, _)) = + Fn.curry MsgSet.member nw + + fun add_non_wellformed (AlgoState (k, r, p, NonWellformedMsgs nw, maxb)) m = + AlgoState (k, r, p, NonWellformedMsgs (MsgSet.add (nw, m)), maxb) + + (* fun get_max (AlgoState (_, _, _, _, MaxBal maxb)) = maxb *) + + (* fun set_max (AlgoState (k, r, p, nw, _)) bal = + AlgoState (k, r, p, nw, MaxBal bal) *) + end (* AlgoState *) + + (* message info state *) + structure AcceptorStatus = + struct + datatype status = Caught + | Uncaught of msg + type t = status + + fun is_caught Caught = true + | is_caught _ = false + + fun join (bal : msg -> ballot) (Uncaught m1, Uncaught m2) = + if MsgUtil.PrevTran.is_prev_reachable' bal (m1, m2) then + Uncaught m1 + else if MsgUtil.PrevTran.is_prev_reachable' bal (m2, m1) then + Uncaught m2 + else + Caught + | join bal (_, _) = Caught + end (* AcceptorStatus *) + + structure MessageType = + struct + datatype msg_type = OneA | OneB | TwoA + type t = msg_type + + (* fun is_one_a OneA = true + | is_one_a _ = false *) + + fun is_one_b OneB = true + | is_one_b _ = false + + fun is_two_a TwoA = true + | is_two_a _ = false + end + + structure MessageInfo = + struct + type info_entry = { info_type : MessageType.t, + info_bal_val : ballot * value, + info_W : (msg * msg option) LearnerAcceptorMap.map, + info_acc_status : AcceptorStatus.t AcceptorMap.map, + info_unburied_2as : MsgSet.set LearnerMap.map, + info_q : (acceptor list) LearnerMap.map, + info_learners : learner list } + + datatype msg_info = MsgInfo of info_entry MsgMap.map + + type t = msg_info + + fun mk () = MsgInfo MsgMap.empty + + fun has ((MsgInfo info), m) = MsgMap.inDomain (info, m) + + fun get_type (MsgInfo info) m = #info_type (MsgMap.lookup (info, m)) + + (* fun is_one_a info m = get_type info m |> MessageType.is_one_a *) + fun is_one_b info m = get_type info m |> MessageType.is_one_b + fun is_two_a info m = get_type info m |> MessageType.is_two_a + + fun get_bal_val (MsgInfo info) m = #info_bal_val (MsgMap.lookup (info, m)) + fun get_learners (MsgInfo info) m = #info_learners (MsgMap.lookup (info, m)) + fun get_W (MsgInfo info) m = #info_W (MsgMap.lookup (info, m)) + fun get_acc_status (MsgInfo info) m = #info_acc_status (MsgMap.lookup (info, m)) + fun get_unburied_2as (MsgInfo info) m = #info_unburied_2as (MsgMap.lookup (info, m)) + fun get_q (MsgInfo info) m = #info_q (MsgMap.lookup (info, m)) + + fun store_entry (MsgInfo info) (m, info_entry : info_entry) = + MsgInfo (MsgMap.insert (info, m, info_entry)) + end (* MessageInfo *) + + (* memo state *) + structure Cache = + struct + datatype is_fresh_cache = IsFresh of (bool LearnerMsgMap.map) ref + datatype cache = Cache of is_fresh_cache + type t = cache + + fun mk () = Cache (IsFresh (ref LearnerMsgMap.empty)) + + fun get_is_fresh (Cache (IsFresh f)) = Fn.curry LearnerMsgMap.find (!f) + fun put_is_fresh (Cache (IsFresh f)) (lm, v) = + Ref.modify (fn map => LearnerMsgMap.insert (map, lm, v)) f + end + + structure State = + struct + datatype state = State of AlgoState.t * MessageInfo.t * Cache.t + type t = state + + fun mk () = State (AlgoState.mk (), MessageInfo.mk (), (Cache.mk ())) + + fun is_known (State (s, _, _)) = AlgoState.is_known s + + fun get_recent (State (s, _, _)) = AlgoState.get_recent s + fun clear_recent (State (s, i, c)) = State (AlgoState.clear_recent s, i, c) + + fun get_prev (State (s, _, _)) = AlgoState.get_prev s + fun set_prev (State (s, i, c)) m = State (AlgoState.set_prev s m, i, c) + + fun is_non_wellformed (State (s, _, _)) = AlgoState.is_non_wellformed s + + (* fun add_known (State (s, i, c)) m = + State (AlgoState.add_known s m, i, c) *) + + fun add_recent (State (s, i, c)) m = + State (AlgoState.add_recent s m, i, c) + + (* fun add_known_recent (State (s, i, c)) m = + State (AlgoState.add_recent (AlgoState.add_known s m) m, i, c) *) + + fun add_non_wellformed (State (s, i, c)) m = + State (AlgoState.add_non_wellformed s m, i, c) + + (* fun get_max (State (s, _, _)) = AlgoState.get_max s *) + + (* TODO clean *) + (* fun is_one_a (State (_, i, _)) m = MessageInfo.is_one_a i m *) + fun is_one_b (State (_, i, _)) m = MessageInfo.is_one_b i m + (* TODO unused *) + fun is_two_a (State (_, i, _)) m = MessageInfo.is_two_a i m + + fun get_bal_val (State (_, i, _)) m = + if Msg.is_proposal m then + valOf (Msg.get_bal_val m) + else + MessageInfo.get_bal_val i m + + (* fun update_max (state as State (s, i, c)) m = + let + fun max (a, b) = + case Msg.Ballot.compare (a, b) of LESS => b | _ => a + val m_bal = fst (get_bal_val state m) + val cur_max = AlgoState.get_max s + val new_max = max (m_bal, cur_max) + in + State (AlgoState.set_max s new_max, i, c) + end *) + + fun get_learners (State (_, i, _)) m = MessageInfo.get_learners i m + fun get_W (State (_, i, _)) = MessageInfo.get_W i + fun get_acc_status (State (_, i, _)) = MessageInfo.get_acc_status i + fun get_unburied_2as (State (_, i, _)) = MessageInfo.get_unburied_2as i + fun get_q (State (_, i, _)) = MessageInfo.get_q i + + fun has_info (State (_, i, _)) m = MessageInfo.has (i, m) + + fun store_info_entry (State (a, i, c)) mi = + State (a, MessageInfo.store_entry i mi, c) + + fun get_is_fresh (State (_, _, c)) = Cache.get_is_fresh c + fun put_is_fresh (State (_, _, c)) = Cache.put_is_fresh c + end (* State *) + + structure ServerState = + struct + datatype state = State of learner_graph * State.t + type t = state + + fun mk (g, s) = State (g, s) + end (* ServerState *) + + type param = learner_graph + type state = ServerState.t + + fun senders (ms : msg list) : acceptor list = + let + val empty = AcceptorSet.empty + val add = AcceptorSet.add' + in + ms + |> List.foldl (fn (x, accu) => add (Msg.sender x, accu)) empty + |> AcceptorSet.toList + end + + fun compute_type (m : msg) : MessageType.t = + if Msg.is_proposal m then + MessageType.OneA + else + let + fun has_proposal_ref m = + Msg.get_refs m |> List.exists Msg.is_proposal + in + if has_proposal_ref m then + MessageType.OneB + else + MessageType.TwoA + end + + (* [msg_to_bal_val] returns a pair (ballot, value) for each known message, including 1a messages *) + (* REQUIRES: m is not 1a *) + fun compute_bal_val (m : msg) msg_to_bal_val : ballot * value = + let + fun helper (x, (max_bal, max_val)) = + let val (b, v) = msg_to_bal_val x in + case Msg.Ballot.compare (b, max_bal) of + LESS => (max_bal, max_val) + | _ => (b, v) + end + in + (* refs is non-empty since m is not 1a *) + Msg.get_refs m + |> List.foldl helper (Msg.Ballot.zero, Msg.Value.default) + end + + (* [msg_to_bal_val] returns a pair (ballot, value) for each known message and the message m *) + (* [msg_to_w] returns a (msg * msg option) LearnerAcceptorMap.map for each known message, excluding 1a *) + (* REQUIRES: m is not 1a *) + fun compute_W (m : msg) (m_type : MessageType.t) (g : learner_graph) msg_to_bal_val msg_to_w + : (msg * msg option) LearnerAcceptorMap.map = + let + val empty = LearnerAcceptorMap.empty + val insert = LearnerAcceptorMap.insert + val unionWith = LearnerAcceptorMap.unionWith + + fun pick_best_two_from_list (ms : msg list) : (msg * msg option) option = + let + val ballot = fst o msg_to_bal_val + val value = snd o msg_to_bal_val + fun pick_best pred cmp lst = + let + fun choose (x, cur_best_o) = + if pred x then + case cur_best_o of + NONE => SOME x + | SOME cur_best => + case cmp (cur_best, x) of + LESS => SOME x + | _ => cur_best_o + else cur_best_o + in + List.foldl choose NONE lst + end + fun cmp_by_ballot (x, y) = Msg.Ballot.compare (ballot x, ballot y) + fun pick_first_best ms = pick_best (Fn.const true) cmp_by_ballot ms + fun pick_second_best ms fst_best = + let val fst_best_val = value fst_best + fun pred x = not (Msg.Value.eq (value x, fst_best_val)) + in + pick_best pred cmp_by_ballot ms + end + in + ms + |> pick_first_best + |> Option.map (fn x => (x, pick_second_best ms x)) + end + fun pick_best_two (a : msg * msg option, b : msg * msg option) = + let + fun to_list (best1, NONE) = [best1] + | to_list (best1, SOME best2) = [best1, best2] + in + to_list a @ to_list b |> pick_best_two_from_list |> valOf + end + val w0 = + if MessageType.is_two_a m_type then + let + val m_acc = Msg.sender m + in + LearnerGraph.learners g + |> List.foldl + (fn (alpha, u) => insert (u, (alpha, m_acc), (m, NONE))) + empty + end + else + empty + in + m + |> Msg.get_refs + |> List.filter (not o Msg.is_proposal) + |> List.foldl (fn (r, w) => unionWith pick_best_two (msg_to_w r, w)) w0 + end + + (* [msg_to_bal] returns a ballot for each known message, excluding 1a, and the message m *) + (* [msg_to_acc_status] returns a map (AcceptorStatus.t AcceptorMap.map) for each known message, excluding 1a *) + (* REQUIRES: m is not 1a *) + fun compute_acceptor_status (m : msg) msg_to_bal msg_to_acc_status + : AcceptorStatus.t AcceptorMap.map = + let + val unionWith = AcceptorMap.unionWith + val join = AcceptorStatus.join msg_to_bal + val s0 = AcceptorMap.singleton (Msg.sender m, AcceptorStatus.Uncaught m) + in + Msg.get_refs m + |> List.filter (not o Msg.is_proposal) + |> List.foldl (fn (r, s) => unionWith join (msg_to_acc_status r, s)) s0 + end + + (* [msg_to_bal_val] returns a pair (ballot, value) for each known message and the message m *) + (* [msg_to_w] returns a (msg * msg option) LearnerAcceptorMap.map for each known message, excluding 1a *) + (* [msg_to_unburied] returns a set MsgSet.set for each known message, excluding 1a *) + (* REQUIRES: m is not 1a *) + fun compute_unburied_2as + (m : msg) (g : learner_graph) (m_type : MessageType.t) m_W + (msg_to_bal_val : msg -> ballot * value) + (* TODO check if msg_to_learners has to be defined for m *) + msg_to_learners + msg_to_unburied + : MsgSet.set LearnerMap.map = + let + fun is_learner_of (alpha : learner, x : msg) = + msg_to_learners x + |> List.find (Fn.curry Msg.Learner.eq alpha) + |> Option.isSome + + (* z is burying x *) + fun burying (alpha : learner) (x, z) = + is_learner_of (alpha, z) andalso + let + val (x_bal, x_val) = msg_to_bal_val x + val (z_bal, z_val) = msg_to_bal_val z + in + Msg.Ballot.compare (x_bal, z_bal) = LESS andalso + not (Msg.Value.eq (x_val, z_val)) + end + + val get_w = Fn.curry LearnerAcceptorMap.lookup m_W + val all_acceptors = LearnerGraph.acceptors g + + fun compute_unburied_2as_for_learner (beta : learner) = + let + fun buried x = + let + fun check acc = + let + val (best1, o_best2) = get_w (beta, acc) + in + burying beta (x, best1) orelse + (isSome o_best2 andalso burying beta (x, (valOf o_best2))) + end + val acceptors = List.filter check all_acceptors + in + LearnerGraph.is_quorum g (beta, acceptors) + end + val u0 = if MessageType.is_two_a m_type then MsgSet.singleton m else MsgSet.empty + in + m + |> Msg.get_refs + |> List.filter (not o Msg.is_proposal) + |> List.foldl + (fn (r, u) => + MsgSet.union (LearnerMap.lookup (msg_to_unburied r, beta), u)) + u0 + |> MsgSet.filter (not o buried) + end + in + LearnerGraph.learners g + |> List.foldl + (fn (beta, u) => + LearnerMap.insert (u, beta, compute_unburied_2as_for_learner beta)) + LearnerMap.empty + end + + (* [msg_to_bal] returns a ballot for each known message, excluding 1a, and the message m *) + (* [msg_to_acc_status] returns a map (AcceptorStatus.t AcceptorMap.map) for each known message, excluding 1a *) + (* [msg_to_unburied] returns a set MsgSet.set for each known message, excluding 1a *) + (* REQUIRES: m is 2a *) + fun compute_q (s : State.t) (m : msg) (g : learner_graph) + msg_to_bal_val + msg_to_learners + msg_to_acc_status + msg_to_unburied + : (acceptor list) LearnerMap.map = + let + val empty = LearnerMap.empty + val insert = LearnerMap.insert + + val msg_to_bal = fst o msg_to_bal_val + val msg_to_val = snd o msg_to_bal_val + + fun compute_connected (l : learner, m : msg) : learner list = + (* REQUIRES: m is 1b *) + let + val caught = + msg_to_acc_status m + |> AcceptorMap.filter AcceptorStatus.is_caught + |> AcceptorMap.listKeys + in + (* TODO explore optimization options, e.g., caching *) + LearnerGraph.get_connected g (l, caught) + end + + fun compute_connected_2as (alpha : learner, m : msg) = + (* REQUIRES: m is 1b *) + let + val m_acc = Msg.sender m + val connected = compute_connected (alpha, m) |> LearnerSet.fromList + fun from_this_sender x = Msg.Acceptor.eq ((Msg.sender x), m_acc) + in + m + |> msg_to_learners + |> List.filter (fn beta => LearnerSet.member (connected, beta)) + |> List.foldl + (fn (beta, accu) => + MsgSet.union (accu, LearnerMap.lookup (msg_to_unburied m, beta))) + MsgSet.empty + |> MsgSet.filter from_this_sender + end + + fun is_fresh (l : learner, m : msg) = + (* REQUIRES: m is 1b *) + let + val m_val = msg_to_val m + fun same_value x = Msg.Value.eq (msg_to_val x, m_val) + in + compute_connected_2as (l, m) |> MsgSet.all same_value + end + (* cached `is_fresh` predicate, impure *) + fun is_fresh' s (l : learner, m : msg) = + (* REQUIRES: m is 1b *) + case State.get_is_fresh s (l, m) of + SOME b => b + | NONE => + let + val res = is_fresh (l, m) + val _ = State.put_is_fresh s ((l, m), res) + in + res + end + + fun compute_q_for_learner (alpha : learner) = + let + val m_bal = msg_to_bal m + fun pred x = State.is_one_b s x andalso is_fresh' s (alpha, x) + fun cont x = Msg.Ballot.eq (msg_to_bal x, m_bal) + in + m + |> MsgUtil.tran pred cont + |> senders + end + in + LearnerGraph.learners g + |> List.map (fn alpha => (alpha, compute_q_for_learner alpha)) + |> List.filter (fn (alpha, xs) => not (null xs)) + |> List.foldl + (fn ((alpha, xs), accu) => insert (accu, alpha, xs)) + empty + end + + (* REQUIRES: m is not 1a *) + fun compute_msg_info_entry (s : State.t) (g : learner_graph) (m : msg) + get_bal_val + get_learners + get_W + get_acc_status + get_unburied_2as + : MessageInfo.info_entry = + let + val m_type = compute_type m + val m_bal_val = compute_bal_val m get_bal_val + fun get_bal_val_with_m x = + if Msg.eq (x, m) then m_bal_val else get_bal_val x + (* TODO rename W to something more meaningful *) + (* compute W values as per message m *) + val m_W = compute_W m m_type g get_bal_val_with_m get_W + (* compute acceptor status as per message m *) + val m_acc_status = + compute_acceptor_status m (fst o get_bal_val_with_m) get_acc_status + (* compute a set of unburied 2a-messages as per message m *) + val m_unburied_2as = + compute_unburied_2as m g m_type m_W get_bal_val_with_m get_learners get_unburied_2as + (* list of quorums per each learner *) + val m_q = + if MessageType.is_one_b m_type then + LearnerMap.empty + else + compute_q s m g get_bal_val_with_m get_learners get_acc_status get_unburied_2as + (* list of learner values *) + val m_learners = m_q |> LearnerMap.listKeys + in + { + info_type = m_type, + info_bal_val = m_bal_val, + info_W = m_W, + info_acc_status = m_acc_status, + info_unburied_2as = m_unburied_2as, + info_q = m_q, + info_learners = m_learners + } + end + + fun prev_wellformed (m : msg) : bool = + let + val m_refs = List.filter (not o Msg.is_proposal) (Msg.get_refs m) + val m_acc = Msg.sender m + fun from_this_sender x = Msg.Acceptor.eq ((Msg.sender x), m_acc) + in + case Msg.get_prev m of + NONE => List.all (not o from_this_sender) m_refs + | SOME prev => + isSome (List.find (Fn.curry Msg.eq prev) m_refs) andalso + let fun check_ref x = + not (from_this_sender x) orelse Msg.eq (x, prev) + in + List.all check_ref m_refs + end + end + + fun all_refs_known (s : State.t) (m : msg) = + List.all (State.is_known s) (Msg.get_refs m) + + fun has_non_wellformed_ref s m = + List.exists (State.is_non_wellformed s) (Msg.get_refs m) + + fun process_non_wellformed s m = + State.add_non_wellformed s m + (* ...further actions possible *) + + (* REQUIRES: every direct reference is known *) + fun is_wellformed (s : State.t) (g : learner_graph) (m : msg) : bool * MessageInfo.info_entry option = + let + val get_bal_val = State.get_bal_val s + val get_learners = State.get_learners s + val get_W = State.get_W s + val get_acc_status = State.get_acc_status s + val get_unburied_2as = State.get_unburied_2as s + + fun is_wellformed_1b m (m_info_entry : MessageInfo.info_entry) = + MsgUtil.references_exactly_one_1a m andalso + let + val ballot = fst o get_bal_val + val (m_bal, _) = #info_bal_val m_info_entry + fun check_ref x = + Msg.is_proposal x orelse + Msg.Ballot.compare (ballot x, m_bal) = LESS + in + m + |> Msg.get_refs + |> List.all check_ref + end + + fun is_wellformed_2a m (m_info_entry : MessageInfo.info_entry) = + not (null (Msg.get_refs m)) andalso + not (null (#info_learners m_info_entry)) + + fun is_wellformed_non_proposal m (m_info_entry : MessageInfo.info_entry) = + if MessageType.is_one_b (#info_type m_info_entry) then + is_wellformed_1b m m_info_entry + else + is_wellformed_2a m m_info_entry + in + (* optionally, we might want to check that every reference occurs at most once (call to refs_nondup) *) + (* TODO actually, do it in the mailbox implementation *) + if prev_wellformed m andalso MsgUtil.refs_nondup m then + ( + if Msg.is_proposal m then + (true, NONE) + else + ( + let + val m_info_entry = + compute_msg_info_entry s g m + get_bal_val get_learners get_W get_acc_status get_unburied_2as + in + if is_wellformed_non_proposal m m_info_entry then + (true, SOME m_info_entry) + else + (false, NONE) + end + ) + ) + else + (false, NONE) + end + + fun check_wellformed_and_update_info (s : State.t, g : learner_graph, m : msg) + : bool * State.t = + (* first, check if the message info is already stored, meaning that the message was sent by us *) + (* TODO currently, broken; to make it really work, we need to use hashes as keys *) + if State.has_info s m then + (true, s) + else + case is_wellformed s g m of + (false, _) => (false, s) + | (true, info_entry_o) => + (true, Option.fold + (fn (e, s) => State.store_info_entry s (m, e)) + s + info_entry_o) + + fun init (g : learner_graph) = ServerState.mk (g, State.mk()) + + fun handle_msg (ServerState.State (g, s), m) = + let + fun process_1a s m : State.t * msg option = + let + val prev = State.get_prev s + val recent = MsgSet.add (State.get_recent s, m) |> MsgSet.toList + val new_1b = Msg.mk_non_proposal (prev, recent) + val (is_wf, s) = check_wellformed_and_update_info (s, g, new_1b) + in + if is_wf then + let + fun update_state s = + s + |> State.clear_recent + |> flip State.add_recent m + |> flip State.set_prev new_1b + in + (* broadcast new_1b *) + (update_state s, SOME new_1b) + end + else (s, NONE) + end + + fun process_1b s m : State.t * msg option = + let + val prev = State.get_prev s + val recent = MsgSet.toList (State.get_recent s) + val new_2a = Msg.mk_non_proposal (prev, recent) + val (is_wf, s) = check_wellformed_and_update_info (s, g, new_2a) + in + if is_wf then + let + fun update_state s = + s + |> State.clear_recent + |> flip State.add_recent m + |> flip State.set_prev new_2a + in + (* broadcast new_2a *) + (update_state s, SOME new_2a) + end + else + (* XXX add to recent? *) + (s, NONE) + end + + fun process_2a s m : State.t * msg option = + (State.add_recent s m, NONE) + in + let + val (res, s) = + if has_non_wellformed_ref s m then + (false, s) + else + check_wellformed_and_update_info (s, g, m) + val (s, new_msg) = + if res then + if Msg.is_proposal m then + process_1a s m + else ( + if State.is_one_b s m then + process_1b s m + else + process_2a s m + ) + else + (process_non_wellformed s m, NONE) + in + (ServerState.mk (g, s), new_msg) + end + end +end (* HPaxos *) diff --git a/learner-graph.sml b/hpaxos/learner-graph.sig similarity index 80% rename from learner-graph.sml rename to hpaxos/learner-graph.sig index 0cb6327..708cdc1 100644 --- a/learner-graph.sml +++ b/hpaxos/learner-graph.sig @@ -2,8 +2,8 @@ signature LEARNER_GRAPH = sig type t - structure Epoch : EPOCH - type epoch = Epoch.t + (* structure Epoch : EPOCH + type epoch = Epoch.t *) structure Learner : LEARNER type learner = Learner.t @@ -11,7 +11,7 @@ sig structure Acceptor : ACCEPTOR type acceptor = Acceptor.t - val epoch : t -> epoch + (* val epoch : t -> epoch *) val learners : t -> learner list val acceptors : t -> acceptor list diff --git a/hpaxos/learner.sig b/hpaxos/learner.sig new file mode 100644 index 0000000..c4b0b11 --- /dev/null +++ b/hpaxos/learner.sig @@ -0,0 +1,10 @@ +signature LEARNER = +sig + type t + val id : t -> word + val hash : t -> word + + val eq : t * t -> bool + val gt : t * t -> bool + val compare : t * t -> order +end diff --git a/hpaxos/learner.sml b/hpaxos/learner.sml new file mode 100644 index 0000000..7737e74 --- /dev/null +++ b/hpaxos/learner.sml @@ -0,0 +1,16 @@ +structure Learner : LEARNER = +struct + type t = word + fun id a = a + val eq : t * t -> bool = (op =) + fun hash a = a + + fun compare (a, b) = Word.compare (id a, id b) + fun gt (a, b) = case compare (a, b) of GREATER => true | _ => false +end + +functor LearnerOrdKey (L : LEARNER) : ORD_KEY = +struct + type ord_key = L.t + fun compare (a, b) = L.compare (a, b) +end diff --git a/learner.sml b/learner.sml deleted file mode 100644 index b438e38..0000000 --- a/learner.sml +++ /dev/null @@ -1,12 +0,0 @@ -signature LEARNER = -sig - type t - val id : t -> word - val eq : t * t -> bool -end - -functor LearnerOrdKey (L : LEARNER) : ORD_KEY = -struct - type ord_key = L.t - fun compare (l1, l2) = Word.compare (L.id l1, L.id l2) -end diff --git a/util.sml b/util.sml index 8f8783f..c687508 100644 --- a/util.sml +++ b/util.sml @@ -2,9 +2,24 @@ fun fst (x, _) = x fun snd (_, x) = x +fun map_or (x : 'a option) (default : 'b) (f : 'a -> 'b) = + case x of SOME v => f v | NONE => default + +fun flip f x y = f y x + fun assert cond str = if cond then () else raise Fail ("assert " ^ str) +fun list_equal eq xs ys = + let + fun doit ([], []) = true + | doit (x :: tx, y :: ty) = + if eq (x, y) then doit (tx, ty) else false + | doit (_, _) = false + in + doit (xs, ys) + end + functor ProdLexOrdKey (A : ORD_KEY) (B : ORD_KEY) : ORD_KEY =