1
0
Fork 0
mirror of https://github.com/processone/ejabberd synced 2025-10-05 10:39:29 +02:00

Add support for websockets to mqtt bridge

This commit is contained in:
Paweł Chmielowski 2023-01-13 19:40:53 +01:00
parent c103182bc7
commit 4311a5646f
4 changed files with 392 additions and 227 deletions

View file

@ -21,7 +21,7 @@
-vsn(?VSN).
%% API
-export([start/8, start_link/8]).
-export([start/9, start_link/9]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
@ -66,6 +66,7 @@
stop_reason :: undefined | error_reason(),
subscriptions = #{},
publish = #{},
ws_codec = none,
id = 0 :: non_neg_integer(),
codec :: mqtt_codec:state(),
authentication :: #{username => binary(), password => binary(), certfile => binary()}}).
@ -75,23 +76,25 @@
%%%===================================================================
%%% API
%%%===================================================================
start(Proc, Transport, Host, Port, Publish, Subscribe, Authentication, ReplicationUser) ->
p1_server:start({local, Proc}, ?MODULE, [Proc, Transport, Host, Port, Publish, Subscribe, Authentication,
start(Proc, Transport, Host, Port, Path, Publish, Subscribe, Authentication, ReplicationUser) ->
p1_server:start({local, Proc}, ?MODULE, [Proc, Transport, Host, Port, Path, Publish, Subscribe, Authentication,
ReplicationUser], []).
start_link(Proc, Transport, Host, Port, Publish, Subscribe, Authentication, ReplicationUser) ->
p1_server:start_link({local, Proc}, ?MODULE, [Proc, Transport, Host, Port, Publish, Subscribe,
start_link(Proc, Transport, Host, Port, Path, Publish, Subscribe, Authentication, ReplicationUser) ->
p1_server:start_link({local, Proc}, ?MODULE, [Proc, Transport, Host, Port, Path, Publish, Subscribe,
Authentication, ReplicationUser], []).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
init([_Proc, Proto, Host, Port, Publish, Subscribe, Authentication, ReplicationUser]) ->
init([_Proc, Proto, Host, Port, Path, Publish, Subscribe, Authentication, ReplicationUser]) ->
{Version, Transport} = case Proto of
mqtt -> {4, gen_tcp};
mqtts -> {4, ssl};
mqtt5 -> {5, gen_tcp};
mqtt5s -> {5, ssl}
mqtt5s -> {5, ssl};
ws -> {4, gen_tcp};
wss -> {4, ssl}
end,
State = #state{version = Version,
id = p1_rand:uniform(65535),
@ -101,12 +104,20 @@ init([_Proc, Proto, Host, Port, Publish, Subscribe, Authentication, ReplicationU
usr = jid:tolower(ReplicationUser),
publish = Publish},
case Authentication of
#{certfile := Cert} when Proto == mqtts; Proto == mqtt5s ->
connect(ssl:connect(Host, Port, [binary, {certfile, Cert}]), State, ssl, none);
#{certfile := Cert} when Proto == mqtts; Proto == mqtt5s; Proto == wss ->
Sock = ssl:connect(Host, Port, [binary, {active, true}, {certfile, Cert}]),
if Proto == ws orelse Proto == wss ->
connect_ws(Host, Port, Path, Sock, State, ssl, none);
true -> connect(Sock, State, ssl, none)
end;
#{username := User, password := Pass} ->
connect(Transport:connect(Host, Port, [binary]), State, Transport, {User, Pass});
Sock = Transport:connect(Host, Port, [binary, {active, true}]),
if Proto == ws orelse Proto == wss ->
connect_ws(Host, Port, Path, Sock, State, Transport, {User, Pass});
true -> connect(Sock, State, Transport, {User, Pass})
end;
_ ->
{stop, {error, <<"Certificate can be only used for encrypted connections">>}}
{stop, {error, <<"Certificate can be only used for encrypted connections">> }}
end.
handle_call(Request, From, State) ->
@ -118,20 +129,108 @@ handle_cast(Msg, State) ->
{noreply, State}.
handle_info({Tag, TCPSock, TCPData},
#state{codec = Codec, socket = Socket} = State) when Tag == tcp; Tag == ssl ->
#state{ws_codec = {init, Hash, Auth, Last}} = State)
when (Tag == tcp orelse Tag == ssl) ->
Data = <<Last/binary, TCPData/binary>>,
case erlang:decode_packet(http_bin, Data, []) of
{ok, {http_response, _, 101, _}, Rest} ->
handle_info({tcp, TCPSock, Rest}, State#state{ws_codec = {inith, Hash, none, Auth, <<>>}});
{ok, {http_response, _, _, _}, _Rest} ->
stop(State, {socket, closed});
{ok, {http_error, _}, _} ->
stop(State, {socket, closed});
{error, _} ->
stop(State, {socket, closed});
{more, _} ->
{noreply, State#state{ws_codec = {init, Hash, Auth, Data}}}
end;
handle_info({Tag, TCPSock, TCPData},
#state{ws_codec = {inith, Hash, Upgrade, Auth, Last},
socket = {Transport, _}} = State)
when (Tag == tcp orelse Tag == ssl) ->
Data = <<Last/binary, TCPData/binary>>,
case erlang:decode_packet(httph_bin, Data, []) of
{ok, {http_header, _, <<"Sec-Websocket-Accept">>, _, Val}, Rest} ->
case str:to_lower(Val) of
Hash ->
handle_info({tcp, TCPSock, Rest},
State#state{ws_codec = {inith, ok, Upgrade, Auth, <<>>}});
_ ->
stop(State, {socket, closed})
end;
{ok, {http_header, _, 'Connection', _, Val}, Rest} ->
case str:to_lower(Val) of
<<"upgrade">> ->
handle_info({tcp, TCPSock, Rest},
State#state{ws_codec = {inith, Hash, ok, Auth, <<>>}});
_ ->
stop(State, {socket, closed})
end;
{ok, {http_header, _, _, _, _}, Rest} ->
handle_info({tcp, TCPSock, Rest}, State);
{ok, {http_error, _}, _} ->
stop(State, {socket, closed});
{ok, http_eoh, Rest} ->
case {Hash, Upgrade} of
{ok, ok} ->
{ok, State2} = connect({ok, TCPSock},
State#state{ws_codec = ejabberd_websocket_codec:new_client()},
Transport, Auth),
handle_info({tcp, TCPSock, Rest}, State2);
_ ->
stop(State, {socket, closed})
end;
{error, _} ->
stop(State, {socket, closed});
{more, _} ->
{noreply, State#state{ws_codec = {inith, Hash, Upgrade, Data}}}
end;
handle_info({Tag, TCPSock, TCPData},
#state{ws_codec = WSCodec} = State)
when (Tag == tcp orelse Tag == ssl) andalso WSCodec /= none ->
{Packets, Acc0} =
case ejabberd_websocket_codec:decode(WSCodec, TCPData) of
{ok, NewWSCodec, Packets0} ->
{Packets0, {State#state{ws_codec = NewWSCodec}, ok}};
{error, _Error, Packets0} ->
{Packets0, {State, stop}}
end,
Res2 =
lists:foldl(
fun(_, {stop, _, _} = Res) -> Res;
({_Op, Data}, {S, Res}) ->
case handle_info({tcp_decoded, TCPSock, Data}, S) of
{stop, _, _} = Stop ->
Stop;
{_, NewState, _} ->
{NewState, Res};
{_, NewState} ->
{NewState, Res}
end
end, Acc0, Packets),
case Res2 of
{stop, _, _} ->
Res2;
{NewState2, ok} ->
{noreply, NewState2};
{NewState2, stop} ->
stop(NewState2, {socket, closed})
end;
handle_info({Tag, TCPSock, TCPData},
#state{codec = Codec} = State)
when Tag == tcp; Tag == ssl; Tag == tcp_decoded ->
case mqtt_codec:decode(Codec, TCPData) of
{ok, Pkt, Codec1} ->
?DEBUG("Got MQTT packet:~n~ts", [pp(Pkt)]),
State1 = State#state{codec = Codec1},
case handle_packet(Pkt, State1) of
{ok, State2} ->
handle_info({tcp, TCPSock, <<>>}, State2);
handle_info({tcp_decoded, TCPSock, <<>>}, State2);
{error, State2, Reason} ->
stop(State2, Reason)
end;
{more, Codec1} ->
State1 = State#state{codec = Codec1},
activate(Socket),
{noreply, State1};
{error, Why} ->
stop(State, {codec, Why})
@ -156,7 +255,7 @@ handle_info({publish, #publish{topic = Topic} = Pkt}, #state{publish = Publish}
{noreply, State2}
end;
_ ->
State
{noreply, State}
end;
handle_info({timeout, _TRef, ping_timeout}, State) ->
case send(State, #pingreq{}) of
@ -230,6 +329,22 @@ connect({ok, Sock}, State0, Transport, Auth) ->
{ok, _, Codec2} = mqtt_codec:decode(State#state.codec, Pkt),
{ok, State#state{codec = Codec2}}.
connect_ws(_Host, _Port, _Path, {error, Reason}, _State, _Transport, _Auth) ->
{stop, {error, Reason}};
connect_ws(Host, Port, Path, {ok, Sock}, State0, Transport, Auth) ->
Key = base64:encode(p1_rand:get_string()),
Hash = str:to_lower(base64:encode(crypto:hash(sha, <<Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11">>))),
Data = <<"GET ", (list_to_binary(Path))/binary, " HTTP/1.1\r\n",
"Host: ", (list_to_binary(Host))/binary, ":", (integer_to_binary(Port))/binary,"\r\n",
"Upgrade: websocket\r\n",
"Connection: Upgrade\r\n",
"Sec-WebSocket-Protocol: mqtt\r\n",
"Sec-WebSocket-Key: ", Key/binary, "\r\n",
"Sec-WebSocket-Version: 13\r\n\r\n">>,
Res = Transport:send(Sock, Data),
check_sock_result({Transport, Sock}, Res),
{ok, State0#state{ws_codec = {init, Hash, Auth, <<>>}, socket = {Transport, Sock}}}.
-spec stop(state(), error_reason()) ->
{noreply, state(), infinity} |
{stop, normal, state()}.
@ -286,6 +401,14 @@ send(State, Pkt) ->
{ok, do_send(State, Pkt)}.
-spec do_send(state(), mqtt_packet()) -> state().
do_send(#state{ws_codec = WSCodec, socket = {SockMod, Sock} = Socket} = State, Pkt)
when WSCodec /= none ->
?DEBUG("Send MQTT packet:~n~ts", [pp(Pkt)]),
Data = mqtt_codec:encode(State#state.version, Pkt),
WSData = ejabberd_websocket_codec:encode(WSCodec, 2, Data),
Res = SockMod:send(Sock, WSData),
check_sock_result(Socket, Res),
reset_ping_timer(State);
do_send(#state{socket = {SockMod, Sock} = Socket} = State, Pkt) ->
?DEBUG("Send MQTT packet:~n~ts", [pp(Pkt)]),
Data = mqtt_codec:encode(State#state.version, Pkt),
@ -295,14 +418,6 @@ do_send(#state{socket = {SockMod, Sock} = Socket} = State, Pkt) ->
do_send(State, _Pkt) ->
State.
-spec activate(socket()) -> ok.
activate(Socket) ->
Res = case Socket of
{gen_tcp, Sock} -> inet:setopts(Sock, [{active, once}]);
{SockMod, Sock} -> SockMod:setopts(Sock, [{active, once}])
end,
check_sock_result(Socket, Res).
-spec disconnect(state(), error_reason()) -> state().
disconnect(#state{socket = {SockMod, Sock}} = State, Err) ->
State1 = case Err of