THRIFT-2113 Erlang SSL Socket Support
Client: Erlang
Patch: David Robakowski
diff --git a/lib/erl/src/thrift_client_util.erl b/lib/erl/src/thrift_client_util.erl
index 7a11f3a..b51a0b4 100644
--- a/lib/erl/src/thrift_client_util.erl
+++ b/lib/erl/src/thrift_client_util.erl
@@ -41,7 +41,9 @@
   when OptKey =:= framed;
        OptKey =:= connect_timeout;
        OptKey =:= recv_timeout;
-       OptKey =:= sockopts ->
+       OptKey =:= sockopts;
+       OptKey =:= ssltransport;
+       OptKey =:= ssloptions->
     split_options(Rest, ProtoIn, [Opt | TransIn]).
 
 
@@ -49,10 +51,15 @@
 %% with the binary protocol
 new(Host, Port, Service, Options)
   when is_integer(Port), is_atom(Service), is_list(Options) ->
-    {ProtoOpts, TransOpts} = split_options(Options),
+    {ProtoOpts, TransOpts0} = split_options(Options),
+
+    {TransportModule, TransOpts2} = case lists:keytake(ssltransport, 1, TransOpts0) of
+                                        {value, {_, true}, TransOpts1} -> {thrift_sslsocket_transport, TransOpts1};
+                                        false -> {thrift_socket_transport, TransOpts0}
+                                    end,
 
     {ok, TransportFactory} =
-        thrift_socket_transport:new_transport_factory(Host, Port, TransOpts),
+        TransportModule:new_transport_factory(Host, Port, TransOpts2),
 
     {ok, ProtocolFactory} = thrift_binary_protocol:new_protocol_factory(
                               TransportFactory, ProtoOpts),
diff --git a/lib/erl/src/thrift_socket_server.erl b/lib/erl/src/thrift_socket_server.erl
index f7c7a02..233b992 100644
--- a/lib/erl/src/thrift_socket_server.erl
+++ b/lib/erl/src/thrift_socket_server.erl
@@ -38,7 +38,9 @@
          listen=null,
          acceptor=null,
          socket_opts=[{recv_timeout, 500}],
-         framed=false
+         framed=false,
+         ssltransport=false,
+         ssloptions=[]
         }).
 
 start(State=#thrift_socket_server{}) ->
@@ -103,8 +105,14 @@
                      Max
              end,
     parse_options(Rest, State#thrift_socket_server{max=MaxInt});
+
 parse_options([{framed, Framed} | Rest], State) when is_boolean(Framed) ->
-    parse_options(Rest, State#thrift_socket_server{framed=Framed}).
+    parse_options(Rest, State#thrift_socket_server{framed=Framed});
+
+parse_options([{ssltransport, SSLTransport} | Rest], State) when is_boolean(SSLTransport) ->
+    parse_options(Rest, State#thrift_socket_server{ssltransport=SSLTransport});
+parse_options([{ssloptions, SSLOptions} | Rest], State) when is_list(SSLOptions) ->
+    parse_options(Rest, State#thrift_socket_server{ssloptions=SSLOptions}).
 
 start_server(State=#thrift_socket_server{name=Name}) ->
     case Name of
@@ -168,25 +176,28 @@
     State#thrift_socket_server{acceptor=null};
 new_acceptor(State=#thrift_socket_server{listen=Listen,
                                          service=Service, handler=Handler,
-                                         socket_opts=Opts, framed=Framed
+                                         socket_opts=Opts, framed=Framed,
+                                         ssltransport=SslTransport, ssloptions=SslOptions
                                         }) ->
     Pid = proc_lib:spawn_link(?MODULE, acceptor_loop,
-                              [{self(), Listen, Service, Handler, Opts, Framed}]),
+                              [{self(), Listen, Service, Handler, Opts, Framed, SslTransport, SslOptions}]),
     State#thrift_socket_server{acceptor=Pid}.
 
-acceptor_loop({Server, Listen, Service, Handler, SocketOpts, Framed})
+acceptor_loop({Server, Listen, Service, Handler, SocketOpts, Framed, SslTransport, SslOptions})
   when is_pid(Server), is_list(SocketOpts) ->
     case catch gen_tcp:accept(Listen) of % infinite timeout
         {ok, Socket} ->
             gen_server:cast(Server, {accepted, self()}),
             ProtoGen = fun() ->
-                               {ok, SocketTransport}   = thrift_socket_transport:new(Socket, SocketOpts),
-                               {ok, Transport}         =
-                                   case Framed of
-                                       true  -> thrift_framed_transport:new(SocketTransport);
-                                       false -> thrift_buffered_transport:new(SocketTransport)
-                                   end,
-                               {ok, Protocol}          = thrift_binary_protocol:new(Transport),
+                               {ok, SocketTransport} = case SslTransport of
+                                                           true  -> thrift_sslsocket_transport:new(Socket, SocketOpts, SslOptions);
+                                                           false -> thrift_socket_transport:new(Socket, SocketOpts)
+                                                       end,
+                               {ok, Transport}       = case Framed of
+                                                           true  -> thrift_framed_transport:new(SocketTransport);
+                                                           false -> thrift_buffered_transport:new(SocketTransport)
+                                                       end,
+                               {ok, Protocol}        = thrift_binary_protocol:new(Transport),
                                {ok, Protocol}
                        end,
             thrift_processor:init({Server, ProtoGen, Service, Handler});
diff --git a/lib/erl/src/thrift_sslsocket_transport.erl b/lib/erl/src/thrift_sslsocket_transport.erl
new file mode 100644
index 0000000..211153f
--- /dev/null
+++ b/lib/erl/src/thrift_sslsocket_transport.erl
@@ -0,0 +1,147 @@
+%%
+%% Licensed to the Apache Software Foundation (ASF) under one
+%% or more contributor license agreements. See the NOTICE file
+%% distributed with this work for additional information
+%% regarding copyright ownership. The ASF licenses this file
+%% to you under the Apache License, Version 2.0 (the
+%% "License"); you may not use this file except in compliance
+%% with the License. You may obtain a copy of the License at
+%%
+%%   http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing,
+%% software distributed under the License is distributed on an
+%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+%% KIND, either express or implied. See the License for the
+%% specific language governing permissions and limitations
+%% under the License.
+%%
+-module(thrift_sslsocket_transport).
+
+-include("thrift_transport_behaviour.hrl").
+
+-behaviour(thrift_transport).
+
+-export([new/3,
+         write/2, read/2, flush/1, close/1,
+
+         new_transport_factory/3]).
+
+%% Export only for the transport factory
+-export([new/2]).
+
+-record(data, {socket,
+               recv_timeout=infinity}).
+-type state() :: #data{}.
+
+%% The following "local" record is filled in by parse_factory_options/2
+%% below. These options can be passed to new_protocol_factory/3 in a
+%% proplists-style option list. They're parsed like this so it is an O(n)
+%% operation instead of O(n^2)
+-record(factory_opts, {connect_timeout = infinity,
+                       sockopts = [],
+                       framed = false,
+                       ssloptions = []}).
+
+parse_factory_options([], Opts) ->
+    Opts;
+parse_factory_options([{framed, Bool} | Rest], Opts) when is_boolean(Bool) ->
+    parse_factory_options(Rest, Opts#factory_opts{framed=Bool});
+parse_factory_options([{sockopts, OptList} | Rest], Opts) when is_list(OptList) ->
+    parse_factory_options(Rest, Opts#factory_opts{sockopts=OptList});
+parse_factory_options([{connect_timeout, TO} | Rest], Opts) when TO =:= infinity; is_integer(TO) ->
+    parse_factory_options(Rest, Opts#factory_opts{connect_timeout=TO});
+parse_factory_options([{ssloptions, SslOptions} | Rest], Opts) when is_list(SslOptions) ->
+    parse_factory_options(Rest, Opts#factory_opts{ssloptions=SslOptions}).
+
+new(Socket, SockOpts, SslOptions) when is_list(SockOpts), is_list(SslOptions) ->
+    inet:setopts(Socket, [{active, false}]), %% => prevent the ssl handshake messages get lost
+
+    %% upgrade to an ssl socket
+    case catch ssl:ssl_accept(Socket, SslOptions) of % infinite timeout
+        {ok, SslSocket} ->
+            new(SslSocket, SockOpts);
+        {error, Reason} ->
+            exit({error, Reason});
+        Other ->
+            error_logger:error_report(
+              [{application, thrift},
+               "SSL accept failed error",
+               lists:flatten(io_lib:format("~p", [Other]))]),
+            exit({error, ssl_accept_failed})
+    end.
+
+new(SslSocket, SockOpts) ->
+    State =
+        case lists:keysearch(recv_timeout, 1, SockOpts) of
+            {value, {recv_timeout, Timeout}}
+              when is_integer(Timeout), Timeout > 0 ->
+                #data{socket=SslSocket, recv_timeout=Timeout};
+            _ ->
+                #data{socket=SslSocket}
+        end,
+    thrift_transport:new(?MODULE, State).
+
+%% Data :: iolist()
+write(This = #data{socket = Socket}, Data) ->
+    {This, ssl:send(Socket, Data)}.
+
+read(This = #data{socket=Socket, recv_timeout=Timeout}, Len)
+  when is_integer(Len), Len >= 0 ->
+    case ssl:recv(Socket, Len, Timeout) of
+        Err = {error, timeout} ->
+            error_logger:info_msg("read timeout: peer conn ~p", [inet:peername(Socket)]),
+            ssl:close(Socket),
+            {This, Err};
+        Data ->
+            {This, Data}
+    end.
+
+%% We can't really flush - everything is flushed when we write
+flush(This) ->
+    {This, ok}.
+
+close(This = #data{socket = Socket}) ->
+    {This, ssl:close(Socket)}.
+
+%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+
+%%
+%% Generates a "transport factory" function - a fun which returns a thrift_transport()
+%% instance.
+%% This can be passed into a protocol factory to generate a connection to a
+%% thrift server over a socket.
+%%
+new_transport_factory(Host, Port, Options) ->
+    ParsedOpts = parse_factory_options(Options, #factory_opts{}),
+
+    F = fun() ->
+                SockOpts = [binary,
+                            {packet, 0},
+                            {active, false},
+                            {nodelay, true} |
+                            ParsedOpts#factory_opts.sockopts],
+                case catch gen_tcp:connect(Host, Port, SockOpts,
+                                           ParsedOpts#factory_opts.connect_timeout) of
+                    {ok, Sock} ->
+                        SslSock = case catch ssl:connect(Sock, ParsedOpts#factory_opts.ssloptions,
+                                                         ParsedOpts#factory_opts.connect_timeout) of
+                                      {ok, SslSocket} ->
+                                          SslSocket;
+                                      Other ->
+                                          error_logger:info_msg("error while connecting over ssl - reason: ~p~n", [Other]),
+                                          catch gen_tcp:close(Sock),
+                                          exit(error)
+                                  end,
+                        {ok, Transport} = thrift_sslsocket_transport:new(SslSock, SockOpts),
+                        {ok, BufTransport} =
+                            case ParsedOpts#factory_opts.framed of
+                                true  -> thrift_framed_transport:new(Transport);
+                                false -> thrift_buffered_transport:new(Transport)
+                            end,
+                        {ok, BufTransport};
+                    Error  ->
+                        Error
+                end
+        end,
+    {ok, F}.
\ No newline at end of file
diff --git a/test/erl/src/test_client.erl b/test/erl/src/test_client.erl
index 8cfeb8b..7b9efd6 100644
--- a/test/erl/src/test_client.erl
+++ b/test/erl/src/test_client.erl
@@ -47,6 +47,14 @@
                     _Else ->
                         Opts
                 end;
+            "--ssl" ->
+                ssl:start(),
+                SslOptions =
+                    {ssloptions, [
+                        {certfile, "../keys/client.crt"}
+                        ,{keyfile, "../keys/server.key"}
+                    ]},
+                Opts#options{client_opts = [{ssltransport, true} | [SslOptions | Opts#options.client_opts]]};
             "--protocol=binary" ->
                 % TODO: Enable JSON protocol
                 Opts;
diff --git a/test/erl/src/test_thrift_server.erl b/test/erl/src/test_thrift_server.erl
index 51457f5..884eb9e 100644
--- a/test/erl/src/test_thrift_server.erl
+++ b/test/erl/src/test_thrift_server.erl
@@ -47,6 +47,14 @@
                     _Else ->
                         Opts
                 end;
+            "--ssl" ->
+                ssl:start(),
+                SslOptions =
+                    {ssloptions, [
+                         {certfile, "../keys/server.crt"}
+                        ,{keyfile,  "../keys/server.key"}
+                    ]},
+                Opts#options{server_opts = [{ssltransport, true} | [SslOptions | Opts#options.server_opts]]};
             "--protocol=" ++ _ -> Opts;
             _Else ->
                 erlang:error({bad_arg, Head})