diff --git a/src/ejabberd_auth_jwt.erl b/src/ejabberd_auth_jwt.erl index 911cae963..c8f1b786b 100644 --- a/src/ejabberd_auth_jwt.erl +++ b/src/ejabberd_auth_jwt.erl @@ -33,6 +33,8 @@ store_type/1, plain_password_required/1, user_exists/2, use_cache/1 ]). +%% 'ejabberd_hooks' callback: +-export([check_decoded_jwt/5]). -include_lib("xmpp/include/xmpp.hrl"). -include("logger.hrl"). @@ -41,6 +43,12 @@ %%% API %%%---------------------------------------------------------------------- start(Host) -> + %% We add our default JWT verifier with hook priority 100. + %% So if you need to check or verify your custom JWT before the + %% default verifier, It's better to use this hook with priority + %% little than 100 and return bool() or {stop, bool()} in your own + %% callback function. + ejabberd_hooks:add(check_decoded_jwt, Host, ?MODULE, check_decoded_jwt, 100), case ejabberd_option:jwt_key(Host) of undefined -> ?ERROR_MSG("Option jwt_key is not configured for ~ts: " @@ -49,7 +57,8 @@ start(Host) -> ok end. -stop(_Host) -> ok. +stop(Host) -> + ejabberd_hooks:delete(check_decoded_jwt, Host, ?MODULE, check_decoded_jwt, 100). plain_password_required(_Host) -> true. @@ -81,36 +90,47 @@ user_exists(_User, _Host) -> {nocache, false}. use_cache(_) -> false. +%%%---------------------------------------------------------------------- +%%% 'ejabberd_hooks' callback +%%%---------------------------------------------------------------------- +check_decoded_jwt(true, Fields, _Signature, Server, User) -> + JidField = ejabberd_option:jwt_jid_field(Server), + case maps:find(JidField, Fields) of + {ok, SJid} when is_binary(SJid) -> + try + JID = jid:decode(SJid), + JID#jid.luser == User andalso JID#jid.lserver == Server + catch error:{bad_jid, _} -> + false + end; + _ -> % error | {ok, _UnknownType} + false + end; +check_decoded_jwt(Acc, _, _, _, _) -> + Acc. + %%%---------------------------------------------------------------------- %%% Internal functions %%%---------------------------------------------------------------------- check_jwt_token(User, Server, Token) -> JWK = ejabberd_option:jwt_key(Server), - JidField = ejabberd_option:jwt_jid_field(Server), try jose_jwt:verify(JWK, Token) of {true, {jose_jwt, Fields}, Signature} -> - ?DEBUG("jwt verify: ~p - ~p~n", [Fields, Signature]), + Now = erlang:system_time(second), + ?DEBUG("jwt verify at system timestamp ~p: ~p - ~p~n", [Now, Fields, Signature]), case maps:find(<<"exp">>, Fields) of error -> %% No expiry in token => We consider token invalid: false; {ok, Exp} -> - Now = erlang:system_time(second), if Exp > Now -> - case maps:find(JidField, Fields) of - error -> - false; - {ok, SJID} -> - try jid:decode(SJID) of - JID -> - (JID#jid.luser == User) andalso - (JID#jid.lserver == Server) andalso - ejabberd_hooks:run_fold(check_decoded_jwt, Server, true, [Fields, Signature, User]) - catch error:{bad_jid, _} -> - false - end - end; + ejabberd_hooks:run_fold( + check_decoded_jwt, + Server, + true, + [Fields, Signature, Server, User] + ); true -> %% return false, if token has expired false @@ -122,4 +142,3 @@ check_jwt_token(User, Server, Token) -> error:{badarg, _} -> false end. -