Skip to content

Commit

Permalink
feat: add subscriber option for hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
pouriya authored and badlop committed Jul 1, 2024
1 parent 576a3f5 commit b8e3eb0
Showing 1 changed file with 191 additions and 26 deletions.
217 changes: 191 additions & 26 deletions src/ejabberd_hooks.erl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
delete/3,
delete/4,
delete/5,
subscribe/4,
subscribe/5,
unsubscribe/4,
unsubscribe/5,
run/2,
run/3,
run_fold/3,
Expand All @@ -59,6 +63,8 @@
-include("ejabberd_stacktrace.hrl").

-record(state, {}).
-type subscriber() :: {Module :: atom(), Function :: atom(), InitArg :: any()}.
-type subscriber_event() :: before | 'after' | before_callback | after_callback.
-type hook() :: {Seq :: integer(), Module :: atom(), Function :: atom() | fun()}.

-define(TRACE_HOOK_KEY, '$trace_hook').
Expand Down Expand Up @@ -87,6 +93,34 @@ add(Hook, Module, Function, Seq) ->
add(Hook, Host, Module, Function, Seq) ->
gen_server:call(?MODULE, {add, Hook, Host, Module, Function, Seq}).

-spec subscribe(atom(), atom(), atom(), any()) -> ok.
%% @doc Add a subscriber to this hook.
%%
%% Before running any hook callback, the subscriber will be called in form of
%% Module:Function(InitArg, 'before', Host :: binary() | global, Hook, HookArgs)
%% Above function should return new state.
%%
%% Before running each callback, the subscriber will be called in form of
%% Module:Function(State, 'before_callback', Host :: binary() | global, Hook, {CallbackMod, CallbackArg, Seq, HookArgs})
%% Above function should return new state.
%%
%% After running each callback, the subscriber will be called in form of
%% Module:Function(State, 'after_callback', Host :: binary() | global, Hook, {CallbackMod, CallbackArg, Seq, HookArgs})
%% Above function should return new state.
%%
%% After running any hook callback, the subscriber will be called in form of
%% Module:Function(State, 'after', Host :: binary() | global, Hook, HookArgs)
%% Return value of this function call will be dropped.
%%
%% For every ejabberd_hooks:[run|run_fold] above functions for each subscriber will be called and the hook runner
%% maintains State in above four calls.
subscribe(Hook, Module, Function, InitArg) ->
subscribe(Hook, global, Module, Function, InitArg).

-spec subscribe(atom(), binary() | global, atom(), atom(), any()) -> ok.
subscribe(Hook, Host, Module, Function, InitArg) ->
gen_server:call(?MODULE, {subscribe, Hook, Host, Module, Function, InitArg}).

-spec delete(atom(), fun(), integer()) -> ok.
%% @doc See del/4.
delete(Hook, Function, Seq) when is_function(Function) ->
Expand All @@ -105,6 +139,18 @@ delete(Hook, Module, Function, Seq) ->
delete(Hook, Host, Module, Function, Seq) ->
gen_server:call(?MODULE, {delete, Hook, Host, Module, Function, Seq}).



-spec unsubscribe(atom(), atom(), atom(), any()) -> ok.
%% @doc Add a subscriber to this hook.
unsubscribe(Hook, Module, Function, InitArg) ->
unsubscribe(Hook, global, Module, Function, InitArg).

-spec unsubscribe(atom(), binary() | global, atom(), atom(), any()) -> ok.
unsubscribe(Hook, Host, Module, Function, InitArg) ->
gen_server:call(?MODULE, {unsubscribe, Hook, Host, Module, Function, InitArg}).


-spec run(atom(), list()) -> ok.
%% @doc Run the calls of this hook in order, don't care about function results.
%% If a call returns stop, no more calls are performed.
Expand All @@ -114,17 +160,28 @@ run(Hook, Args) ->
-spec run(atom(), binary() | global, list()) -> ok.
run(Hook, Host, Args) ->
try ets:lookup(hooks, {Hook, Host}) of
[{_, Ls}] ->
[{_, Ls, Subs}] ->
case erlang:get(?TRACE_HOOK_KEY) of
undefined ->
undefined when Subs == [] ->
run1(Ls, Hook, Args);
undefined ->
Subs2 = call_subscriber_list(Subs, Host, Hook, Args, before, []),
Subs3 = run1(Ls, Hook, Args, Host, Subs2),
_Subs4 = call_subscriber_list(Subs3, Host, Hook, Args, 'after', []),
ok;
TracingHooksOpts ->
case do_get_tracing_options(Hook, Host, TracingHooksOpts) of
undefined ->
run1(Ls, Hook, Args);
Subs2 = call_subscriber_list(Subs, Host, Hook, Args, before, []),
Subs3 = run1(Ls, Hook, Args, Host, Subs2),
_Subs4 = call_subscriber_list(Subs3, Host, Hook, Args, 'after', []),
ok;
TracingOpts ->
foreach_start_hook_tracing(TracingOpts, Hook, Host, Args),
run2(Ls, Hook, Args, Host, TracingOpts)
Subs2 = call_subscriber_list(Subs, Host, Hook, Args, before, []),
Subs3 = run2(Ls, Hook, Args, Host, TracingOpts, Subs2),
_Subs4 = call_subscriber_list(Subs3, Host, Hook, Args, 'after', []),
ok
end
end;
[] ->
Expand All @@ -145,17 +202,28 @@ run_fold(Hook, Val, Args) ->
-spec run_fold(atom(), binary() | global, T, list()) -> T.
run_fold(Hook, Host, Val, Args) ->
try ets:lookup(hooks, {Hook, Host}) of
[{_, Ls}] ->
[{_, Ls, Subs}] ->
case erlang:get(?TRACE_HOOK_KEY) of
undefined ->
undefined when Subs == [] ->
run_fold1(Ls, Hook, Val, Args);
undefined ->
Subs2 = call_subscriber_list(Subs, Host, Hook, [Val | Args], before, []),
{Val2, Subs3} = run_fold1(Ls, Hook, Val, Args, Host, Subs2),
_Subs4 = call_subscriber_list(Subs3, Host, Hook, [Val2 | Args], 'after', []),
Val2;
TracingHooksOpts ->
case do_get_tracing_options(Hook, Host, TracingHooksOpts) of
undefined ->
run_fold1(Ls, Hook, Val, Args);
Subs2 = call_subscriber_list(Subs, Host, Hook, [Val | Args], before, []),
{Val2, Subs3} = run_fold1(Ls, Hook, Val, Args, Host, Subs2),
_Subs4 = call_subscriber_list(Subs3, Host, Hook, [Val2 | Args], 'after', []),
Val2;
TracingOpts ->
fold_start_hook_tracing(TracingOpts, Hook, Host, [Val | Args]),
run_fold2(Ls, Hook, Val, Args, Host, TracingOpts)
Subs2 = call_subscriber_list(Subs, Host, Hook, [Val | Args], before, []),
{Val2, Subs3} = run_fold2(Ls, Hook, Val, Args, Host, TracingOpts, Subs2),
_Subs4 = call_subscriber_list(Subs3, Host, Hook, [Val2 | Args], 'after', []),
Val2
end
end;
[] ->
Expand Down Expand Up @@ -230,34 +298,68 @@ handle_call({delete, Hook, Host, Module, Function, Seq}, _From, State) ->
HookFormat = {Seq, Module, Function},
Reply = handle_delete(Hook, Host, HookFormat),
{reply, Reply, State};
handle_call({subscribe, Hook, Host, Module, Function, InitArg}, _From, State) ->
SubscriberFormat = {Module, Function, InitArg},
Reply = handle_subscribe(Hook, Host, SubscriberFormat),
{reply, Reply, State};
handle_call({unsubscribe, Hook, Host, Module, Function, InitArg}, _From, State) ->
SubscriberFormat = {Module, Function, InitArg},
Reply = handle_unsubscribe(Hook, Host, SubscriberFormat),
{reply, Reply, State};
handle_call(Request, From, State) ->
?WARNING_MSG("Unexpected call from ~p: ~p", [From, Request]),
{noreply, State}.

-spec handle_add(atom(), atom(), hook()) -> ok.
handle_add(Hook, Host, El) ->
case ets:lookup(hooks, {Hook, Host}) of
[{_, Ls}] ->
[{_, Ls, Subs}] ->
case lists:member(El, Ls) of
true ->
ok;
false ->
NewLs = lists:merge(Ls, [El]),
ets:insert(hooks, {{Hook, Host}, NewLs}),
ets:insert(hooks, {{Hook, Host}, NewLs, Subs}),
ok
end;
[] ->
NewLs = [El],
ets:insert(hooks, {{Hook, Host}, NewLs}),
ets:insert(hooks, {{Hook, Host}, NewLs, []}),
ok
end.

-spec handle_delete(atom(), atom(), hook()) -> ok.
handle_delete(Hook, Host, El) ->
case ets:lookup(hooks, {Hook, Host}) of
[{_, Ls}] ->
[{_, Ls, Subs}] ->
NewLs = lists:delete(El, Ls),
ets:insert(hooks, {{Hook, Host}, NewLs}),
ets:insert(hooks, {{Hook, Host}, NewLs, Subs}),
ok;
[] ->
ok
end.

-spec handle_subscribe(atom(), atom(), subscriber()) -> ok.
handle_subscribe(Hook, Host, El) ->
case ets:lookup(hooks, {Hook, Host}) of
[{_, Ls, Subs}] ->
case lists:member(El, Subs) of
true ->
ok;
false ->
ets:insert(hooks, {{Hook, Host}, Ls, Subs ++ [El]}),
ok
end;
[] ->
ets:insert(hooks, {{Hook, Host}, [], [El]}),
ok
end.

-spec handle_unsubscribe(atom(), atom(), subscriber()) -> ok.
handle_unsubscribe(Hook, Host, El) ->
case ets:lookup(hooks, {Hook, Host}) of
[{_, Ls, Subs}] ->
ets:insert(hooks, {{Hook, Host}, Ls, lists:delete(El, Subs)}),
ok;
[] ->
ok
Expand Down Expand Up @@ -310,6 +412,40 @@ run_fold1([{_Seq, Module, Function} | Ls], Hook, Val, Args) ->
run_fold1(Ls, Hook, NewVal, Args)
end.

-spec run1([hook()], atom(), list(), binary() | global, [subscriber()]) -> [subscriber()].
run1([], _Hook, _Args, _Host, SubscriberList) ->
SubscriberList;
run1([{Seq, Module, Function} | Ls], Hook, Args, Host, SubscriberList) ->
SubscriberList2 = call_subscriber_list(SubscriberList, Host, Hook, {Module, Function, Seq, Args}, before_callback, []),
Res = safe_apply(Hook, Module, Function, Args),
SubscriberList3 = call_subscriber_list(SubscriberList2, Host, Hook, {Module, Function, Seq, Args}, after_callback, []),
case Res of
'EXIT' ->
run1(Ls, Hook, Args, Host, SubscriberList3);
stop ->
SubscriberList3;
_ ->
run1(Ls, Hook, Args, Host, SubscriberList3)
end.

-spec run_fold1([hook()], atom(), T, list(), binary() | global, [subscriber()]) -> {T, [subscriber()]}.
run_fold1([], _Hook, Val, _Args, _Host, SubscriberList) ->
{Val, SubscriberList};
run_fold1([{Seq, Module, Function} | Ls], Hook, Val, Args, Host, SubscriberList) ->
SubscriberList2 = call_subscriber_list(SubscriberList, Host, Hook, {Module, Function, Seq, [Val | Args]}, before_callback, []),
Res = safe_apply(Hook, Module, Function, [Val | Args]),
SubscriberList3 = call_subscriber_list(SubscriberList2, Host, Hook, {Module, Function, Seq, [Val | Args]}, after_callback, []),
case Res of
'EXIT' ->
run_fold1(Ls, Hook, Val, Args, Host, SubscriberList3);
stop ->
{Val, SubscriberList3};
{stop, NewVal} ->
{NewVal, SubscriberList3};
NewVal ->
run_fold1(Ls, Hook, NewVal, Args, Host, SubscriberList3)
end.

-spec safe_apply(atom(), atom(), atom() | fun(), list()) -> any().
safe_apply(Hook, Module, Function, Args) ->
?DEBUG("Running hook ~p: ~p:~p/~B",
Expand All @@ -332,6 +468,31 @@ safe_apply(Hook, Module, Function, Args) ->
'EXIT'
end.

-spec call_subscriber_list([subscriber()], binary() | global, atom(), {atom(), atom(), integer(), list()} | list(), subscriber_event(), [subscriber()]) -> any().
call_subscriber_list([], _Host, _Hook, _CallbackOrArgs, _Event, []) ->
[];
call_subscriber_list([], _Host, _Hook, _CallbackOrArgs, _Event, Result) ->
lists:reverse(Result);
call_subscriber_list([{Mod, Func, InitArg} | SubscriberList], Host, Hook, CallbackOrArgs, Event, Result) ->
SubscriberArgs = [InitArg, Event, Host, Hook, CallbackOrArgs],
?DEBUG("Running hook subsciber ~p: ~p:~p/~B with event ~p",
[Hook, Mod, Func, length(SubscriberArgs), Event]),
try apply(Mod, Func, SubscriberArgs) of
State ->
call_subscriber_list(SubscriberList, Host, Hook, CallbackOrArgs, Event, [{Mod, Func, State} | Result])
catch ?EX_RULE(E, R, St) when E /= exit; R /= normal ->
Stack = ?EX_STACK(St),
?ERROR_MSG("Hook subscriber ~p crashed when running ~p:~p/~p:~n" ++
string:join(
["** ~ts"|
["** Arg " ++ integer_to_list(I) ++ " = ~p"
|| I <- lists:seq(1, length(SubscriberArgs))]],
"~n"),
[Hook, Mod, Func, length(SubscriberArgs),
misc:format_exception(2, E, R, Stack)|SubscriberArgs]),
call_subscriber_list(SubscriberList, Host, Hook, CallbackOrArgs, Event, Result)
end.

%%%----------------------------------------------------------------------
%%% Internal tracing functions
%%%----------------------------------------------------------------------
Expand Down Expand Up @@ -453,41 +614,45 @@ do_get_tracing_options(Hook, Host, MaybeMap) ->
end
end.

run2([], Hook, Args, Host, Opts) ->
run2([], Hook, Args, Host, Opts, SubscriberList) ->
foreach_stop_hook_tracing(Opts, Hook, Host, Args, undefined),
ok;
run2([{Seq, Module, Function} | Ls], Hook, Args, Host, TracingOpts) ->
SubscriberList;
run2([{Seq, Module, Function} | Ls], Hook, Args, Host, TracingOpts, SubscriberList) ->
foreach_start_callback_tracing(TracingOpts, Hook, Host, Module, Function, Args, Seq),
SubscriberList2 = call_subscriber_list(SubscriberList, Host, Hook, {Module, Function, Seq, Args}, before_callback, []),
Res = safe_apply(Hook, Module, Function, Args),
SubscriberList3 = call_subscriber_list(SubscriberList2, Host, Hook, {Module, Function, Seq, Args}, after_callback, []),
foreach_stop_callback_tracing(TracingOpts, Hook, Host, Module, Function, Args, Seq, Res),
case Res of
'EXIT' ->
run2(Ls, Hook, Args, Host, TracingOpts);
run2(Ls, Hook, Args, Host, TracingOpts, SubscriberList3);
stop ->
foreach_stop_hook_tracing(TracingOpts, Hook, Host, Args, {Module, Function, Seq, Ls}),
ok;
SubscriberList3;
_ ->
run2(Ls, Hook, Args, Host, TracingOpts)
run2(Ls, Hook, Args, Host, TracingOpts, SubscriberList3)
end.

run_fold2([], Hook, Val, Args, Host, Opts) ->
run_fold2([], Hook, Val, Args, Host, Opts, SubscriberList) ->
fold_stop_hook_tracing(Opts, Hook, Host, [Val | Args], undefined),
Val;
run_fold2([{Seq, Module, Function} | Ls], Hook, Val, Args, Host, TracingOpts) ->
{Val, SubscriberList};
run_fold2([{Seq, Module, Function} | Ls], Hook, Val, Args, Host, TracingOpts, SubscriberList) ->
fold_start_callback_tracing(TracingOpts, Hook, Host, Module, Function, [Val | Args], Seq),
SubscriberList2 = call_subscriber_list(SubscriberList, Host, Hook, {Module, Function, Seq, [Val | Args]}, before_callback, []),
Res = safe_apply(Hook, Module, Function, [Val | Args]),
SubscriberList3 = call_subscriber_list(SubscriberList2, Host, Hook, {Module, Function, Seq, [Val | Args]}, after_callback, []),
fold_stop_callback_tracing(TracingOpts, Hook, Host, Module, Function, [Val | Args], Seq, Res),
case Res of
'EXIT' ->
run_fold2(Ls, Hook, Val, Args, Host, TracingOpts);
run_fold2(Ls, Hook, Val, Args, Host, TracingOpts, SubscriberList3);
stop ->
fold_stop_hook_tracing(TracingOpts, Hook, Host, [Val | Args], {Module, Function, Seq, {old, Val}, Ls}),
Val;
{Val, SubscriberList3};
{stop, NewVal} ->
fold_stop_hook_tracing(TracingOpts, Hook, Host, [Val | Args], {Module, Function, Seq, {new, NewVal}, Ls}),
NewVal;
{NewVal, SubscriberList3};
NewVal ->
run_fold2(Ls, Hook, NewVal, Args, Host, TracingOpts)
run_fold2(Ls, Hook, NewVal, Args, Host, TracingOpts, SubscriberList3)
end.

foreach_start_hook_tracing(TracingOpts, Hook, Host, Args) ->
Expand Down

0 comments on commit b8e3eb0

Please sign in to comment.