THRIFT-2113 Erlang SSL Socket Support
Client: Erlang
Patch: David Robakowski
diff --git a/lib/erl/src/thrift_client_util.erl b/lib/erl/src/thrift_client_util.erl
index 7a11f3a..b51a0b4 100644
--- a/lib/erl/src/thrift_client_util.erl
+++ b/lib/erl/src/thrift_client_util.erl
@@ -41,7 +41,9 @@
when OptKey =:= framed;
OptKey =:= connect_timeout;
OptKey =:= recv_timeout;
- OptKey =:= sockopts ->
+ OptKey =:= sockopts;
+ OptKey =:= ssltransport;
+ OptKey =:= ssloptions->
split_options(Rest, ProtoIn, [Opt | TransIn]).
@@ -49,10 +51,15 @@
%% with the binary protocol
new(Host, Port, Service, Options)
when is_integer(Port), is_atom(Service), is_list(Options) ->
- {ProtoOpts, TransOpts} = split_options(Options),
+ {ProtoOpts, TransOpts0} = split_options(Options),
+
+ {TransportModule, TransOpts2} = case lists:keytake(ssltransport, 1, TransOpts0) of
+ {value, {_, true}, TransOpts1} -> {thrift_sslsocket_transport, TransOpts1};
+ false -> {thrift_socket_transport, TransOpts0}
+ end,
{ok, TransportFactory} =
- thrift_socket_transport:new_transport_factory(Host, Port, TransOpts),
+ TransportModule:new_transport_factory(Host, Port, TransOpts2),
{ok, ProtocolFactory} = thrift_binary_protocol:new_protocol_factory(
TransportFactory, ProtoOpts),
diff --git a/lib/erl/src/thrift_socket_server.erl b/lib/erl/src/thrift_socket_server.erl
index f7c7a02..233b992 100644
--- a/lib/erl/src/thrift_socket_server.erl
+++ b/lib/erl/src/thrift_socket_server.erl
@@ -38,7 +38,9 @@
listen=null,
acceptor=null,
socket_opts=[{recv_timeout, 500}],
- framed=false
+ framed=false,
+ ssltransport=false,
+ ssloptions=[]
}).
start(State=#thrift_socket_server{}) ->
@@ -103,8 +105,14 @@
Max
end,
parse_options(Rest, State#thrift_socket_server{max=MaxInt});
+
parse_options([{framed, Framed} | Rest], State) when is_boolean(Framed) ->
- parse_options(Rest, State#thrift_socket_server{framed=Framed}).
+ parse_options(Rest, State#thrift_socket_server{framed=Framed});
+
+parse_options([{ssltransport, SSLTransport} | Rest], State) when is_boolean(SSLTransport) ->
+ parse_options(Rest, State#thrift_socket_server{ssltransport=SSLTransport});
+parse_options([{ssloptions, SSLOptions} | Rest], State) when is_list(SSLOptions) ->
+ parse_options(Rest, State#thrift_socket_server{ssloptions=SSLOptions}).
start_server(State=#thrift_socket_server{name=Name}) ->
case Name of
@@ -168,25 +176,28 @@
State#thrift_socket_server{acceptor=null};
new_acceptor(State=#thrift_socket_server{listen=Listen,
service=Service, handler=Handler,
- socket_opts=Opts, framed=Framed
+ socket_opts=Opts, framed=Framed,
+ ssltransport=SslTransport, ssloptions=SslOptions
}) ->
Pid = proc_lib:spawn_link(?MODULE, acceptor_loop,
- [{self(), Listen, Service, Handler, Opts, Framed}]),
+ [{self(), Listen, Service, Handler, Opts, Framed, SslTransport, SslOptions}]),
State#thrift_socket_server{acceptor=Pid}.
-acceptor_loop({Server, Listen, Service, Handler, SocketOpts, Framed})
+acceptor_loop({Server, Listen, Service, Handler, SocketOpts, Framed, SslTransport, SslOptions})
when is_pid(Server), is_list(SocketOpts) ->
case catch gen_tcp:accept(Listen) of % infinite timeout
{ok, Socket} ->
gen_server:cast(Server, {accepted, self()}),
ProtoGen = fun() ->
- {ok, SocketTransport} = thrift_socket_transport:new(Socket, SocketOpts),
- {ok, Transport} =
- case Framed of
- true -> thrift_framed_transport:new(SocketTransport);
- false -> thrift_buffered_transport:new(SocketTransport)
- end,
- {ok, Protocol} = thrift_binary_protocol:new(Transport),
+ {ok, SocketTransport} = case SslTransport of
+ true -> thrift_sslsocket_transport:new(Socket, SocketOpts, SslOptions);
+ false -> thrift_socket_transport:new(Socket, SocketOpts)
+ end,
+ {ok, Transport} = case Framed of
+ true -> thrift_framed_transport:new(SocketTransport);
+ false -> thrift_buffered_transport:new(SocketTransport)
+ end,
+ {ok, Protocol} = thrift_binary_protocol:new(Transport),
{ok, Protocol}
end,
thrift_processor:init({Server, ProtoGen, Service, Handler});
diff --git a/lib/erl/src/thrift_sslsocket_transport.erl b/lib/erl/src/thrift_sslsocket_transport.erl
new file mode 100644
index 0000000..211153f
--- /dev/null
+++ b/lib/erl/src/thrift_sslsocket_transport.erl
@@ -0,0 +1,147 @@
+%%
+%% Licensed to the Apache Software Foundation (ASF) under one
+%% or more contributor license agreements. See the NOTICE file
+%% distributed with this work for additional information
+%% regarding copyright ownership. The ASF licenses this file
+%% to you under the Apache License, Version 2.0 (the
+%% "License"); you may not use this file except in compliance
+%% with the License. You may obtain a copy of the License at
+%%
+%% http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing,
+%% software distributed under the License is distributed on an
+%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+%% KIND, either express or implied. See the License for the
+%% specific language governing permissions and limitations
+%% under the License.
+%%
+-module(thrift_sslsocket_transport).
+
+-include("thrift_transport_behaviour.hrl").
+
+-behaviour(thrift_transport).
+
+-export([new/3,
+ write/2, read/2, flush/1, close/1,
+
+ new_transport_factory/3]).
+
+%% Export only for the transport factory
+-export([new/2]).
+
+-record(data, {socket,
+ recv_timeout=infinity}).
+-type state() :: #data{}.
+
+%% The following "local" record is filled in by parse_factory_options/2
+%% below. These options can be passed to new_protocol_factory/3 in a
+%% proplists-style option list. They're parsed like this so it is an O(n)
+%% operation instead of O(n^2)
+-record(factory_opts, {connect_timeout = infinity,
+ sockopts = [],
+ framed = false,
+ ssloptions = []}).
+
+parse_factory_options([], Opts) ->
+ Opts;
+parse_factory_options([{framed, Bool} | Rest], Opts) when is_boolean(Bool) ->
+ parse_factory_options(Rest, Opts#factory_opts{framed=Bool});
+parse_factory_options([{sockopts, OptList} | Rest], Opts) when is_list(OptList) ->
+ parse_factory_options(Rest, Opts#factory_opts{sockopts=OptList});
+parse_factory_options([{connect_timeout, TO} | Rest], Opts) when TO =:= infinity; is_integer(TO) ->
+ parse_factory_options(Rest, Opts#factory_opts{connect_timeout=TO});
+parse_factory_options([{ssloptions, SslOptions} | Rest], Opts) when is_list(SslOptions) ->
+ parse_factory_options(Rest, Opts#factory_opts{ssloptions=SslOptions}).
+
+new(Socket, SockOpts, SslOptions) when is_list(SockOpts), is_list(SslOptions) ->
+ inet:setopts(Socket, [{active, false}]), %% => prevent the ssl handshake messages get lost
+
+ %% upgrade to an ssl socket
+ case catch ssl:ssl_accept(Socket, SslOptions) of % infinite timeout
+ {ok, SslSocket} ->
+ new(SslSocket, SockOpts);
+ {error, Reason} ->
+ exit({error, Reason});
+ Other ->
+ error_logger:error_report(
+ [{application, thrift},
+ "SSL accept failed error",
+ lists:flatten(io_lib:format("~p", [Other]))]),
+ exit({error, ssl_accept_failed})
+ end.
+
+new(SslSocket, SockOpts) ->
+ State =
+ case lists:keysearch(recv_timeout, 1, SockOpts) of
+ {value, {recv_timeout, Timeout}}
+ when is_integer(Timeout), Timeout > 0 ->
+ #data{socket=SslSocket, recv_timeout=Timeout};
+ _ ->
+ #data{socket=SslSocket}
+ end,
+ thrift_transport:new(?MODULE, State).
+
+%% Data :: iolist()
+write(This = #data{socket = Socket}, Data) ->
+ {This, ssl:send(Socket, Data)}.
+
+read(This = #data{socket=Socket, recv_timeout=Timeout}, Len)
+ when is_integer(Len), Len >= 0 ->
+ case ssl:recv(Socket, Len, Timeout) of
+ Err = {error, timeout} ->
+ error_logger:info_msg("read timeout: peer conn ~p", [inet:peername(Socket)]),
+ ssl:close(Socket),
+ {This, Err};
+ Data ->
+ {This, Data}
+ end.
+
+%% We can't really flush - everything is flushed when we write
+flush(This) ->
+ {This, ok}.
+
+close(This = #data{socket = Socket}) ->
+ {This, ssl:close(Socket)}.
+
+%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+%%
+%% Generates a "transport factory" function - a fun which returns a thrift_transport()
+%% instance.
+%% This can be passed into a protocol factory to generate a connection to a
+%% thrift server over a socket.
+%%
+new_transport_factory(Host, Port, Options) ->
+ ParsedOpts = parse_factory_options(Options, #factory_opts{}),
+
+ F = fun() ->
+ SockOpts = [binary,
+ {packet, 0},
+ {active, false},
+ {nodelay, true} |
+ ParsedOpts#factory_opts.sockopts],
+ case catch gen_tcp:connect(Host, Port, SockOpts,
+ ParsedOpts#factory_opts.connect_timeout) of
+ {ok, Sock} ->
+ SslSock = case catch ssl:connect(Sock, ParsedOpts#factory_opts.ssloptions,
+ ParsedOpts#factory_opts.connect_timeout) of
+ {ok, SslSocket} ->
+ SslSocket;
+ Other ->
+ error_logger:info_msg("error while connecting over ssl - reason: ~p~n", [Other]),
+ catch gen_tcp:close(Sock),
+ exit(error)
+ end,
+ {ok, Transport} = thrift_sslsocket_transport:new(SslSock, SockOpts),
+ {ok, BufTransport} =
+ case ParsedOpts#factory_opts.framed of
+ true -> thrift_framed_transport:new(Transport);
+ false -> thrift_buffered_transport:new(Transport)
+ end,
+ {ok, BufTransport};
+ Error ->
+ Error
+ end
+ end,
+ {ok, F}.
\ No newline at end of file
diff --git a/test/erl/src/test_client.erl b/test/erl/src/test_client.erl
index 8cfeb8b..7b9efd6 100644
--- a/test/erl/src/test_client.erl
+++ b/test/erl/src/test_client.erl
@@ -47,6 +47,14 @@
_Else ->
Opts
end;
+ "--ssl" ->
+ ssl:start(),
+ SslOptions =
+ {ssloptions, [
+ {certfile, "../keys/client.crt"}
+ ,{keyfile, "../keys/server.key"}
+ ]},
+ Opts#options{client_opts = [{ssltransport, true} | [SslOptions | Opts#options.client_opts]]};
"--protocol=binary" ->
% TODO: Enable JSON protocol
Opts;
diff --git a/test/erl/src/test_thrift_server.erl b/test/erl/src/test_thrift_server.erl
index 51457f5..884eb9e 100644
--- a/test/erl/src/test_thrift_server.erl
+++ b/test/erl/src/test_thrift_server.erl
@@ -47,6 +47,14 @@
_Else ->
Opts
end;
+ "--ssl" ->
+ ssl:start(),
+ SslOptions =
+ {ssloptions, [
+ {certfile, "../keys/server.crt"}
+ ,{keyfile, "../keys/server.key"}
+ ]},
+ Opts#options{server_opts = [{ssltransport, true} | [SslOptions | Opts#options.server_opts]]};
"--protocol=" ++ _ -> Opts;
_Else ->
erlang:error({bad_arg, Head})