THRIFT-2103 [python] Support for SSL certificates with Subject Alternative Names
diff --git a/lib/py/CMakeLists.txt b/lib/py/CMakeLists.txt
index 7bb91fe..2ec8b56 100755
--- a/lib/py/CMakeLists.txt
+++ b/lib/py/CMakeLists.txt
@@ -20,6 +20,7 @@
 include_directories(${PYTHON_INCLUDE_DIRS})
 
 add_custom_target(python_build ALL
+    COMMAND ${PIP_EXECUTABLE} install -r requirements.txt || ${PIP_EXECUTABLE} install --user -r requirements.txt
     COMMAND ${PYTHON_EXECUTABLE} setup.py build
     WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
     COMMENT "Building Python library"
diff --git a/lib/py/Makefile.am b/lib/py/Makefile.am
index 5183b9c..a075eb1 100755
--- a/lib/py/Makefile.am
+++ b/lib/py/Makefile.am
@@ -21,6 +21,7 @@
 
 if WITH_PY3
 py3-build:
+	$(PIP3) install -r requirements.txt || $(PIP3) install --user -r requirements.txt
 	$(PYTHON3) setup.py build
 py3-test: py3-build
 	$(PYTHON3) test/thrift_json.py
@@ -31,6 +32,7 @@
 endif
 
 all-local: py3-build
+	$(PIP) install -r requirements.txt || $(PIP) install --user -r requirements.txt
 	$(PYTHON) setup.py build
 
 # We're ignoring prefix here because site-packages seems to be
@@ -38,6 +40,7 @@
 # Old version (can't put inline because it's not portable).
 #$(PYTHON) setup.py install --prefix=$(prefix) --root=$(DESTDIR) $(PYTHON_SETUPUTIL_ARGS)
 install-exec-hook:
+	$(PIP) install -r requirements.txt
 	$(PYTHON) setup.py install --root=$(DESTDIR) --prefix=$(PY_PREFIX) $(PYTHON_SETUPUTIL_ARGS)
 
 clean-local:
diff --git a/lib/py/requirements.txt b/lib/py/requirements.txt
new file mode 100644
index 0000000..7cf8b31
--- /dev/null
+++ b/lib/py/requirements.txt
@@ -0,0 +1,3 @@
+six
+backports.ssl_match_hostname
+ipaddress
diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/transport/TSSLSocket.py
index 3f1a909..dfaa5db 100644
--- a/lib/py/src/transport/TSSLSocket.py
+++ b/lib/py/src/transport/TSSLSocket.py
@@ -23,12 +23,14 @@
 import ssl
 import sys
 import warnings
+from backports.ssl_match_hostname import match_hostname
 
 from thrift.transport import TSocket
 from thrift.transport.TTransport import TTransportException
 
 logger = logging.getLogger(__name__)
-warnings.filterwarnings('default', category=DeprecationWarning, module=__name__)
+warnings.filterwarnings(
+    'default', category=DeprecationWarning, module=__name__)
 
 
 class TSSLBase(object):
@@ -38,10 +40,13 @@
     # ciphers argument is not available for Python < 2.7.0
     _has_ciphers = sys.hexversion >= 0x020700F0
 
-    # For pythoon >= 2.7.9, use latest TLS that both client and server supports.
+    # For pythoon >= 2.7.9, use latest TLS that both client and server
+    # supports.
     # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
-    # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nare OP_NO_SSLvX are unavailable.
-    _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else ssl.PROTOCOL_TLSv1
+    # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is
+    # unavailable.
+    _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \
+        ssl.PROTOCOL_TLSv1
 
     def _init_context(self, ssl_version):
         if self._has_ssl_context:
@@ -54,6 +59,13 @@
             self._ssl_version = ssl_version
 
     @property
+    def _should_verify(self):
+        if self._has_ssl_context:
+            return self._context.verify_mode != ssl.CERT_NONE
+        else:
+            return self.cert_reqs != ssl.CERT_NONE
+
+    @property
     def ssl_version(self):
         if self._has_ssl_context:
             return self.ssl_context.protocol
@@ -76,10 +88,14 @@
             return
         real_pos = pos + 3
         warnings.warn(
-            '%dth positional argument is deprecated. Use keyward argument insteand.' % real_pos,
+            '%dth positional argument is deprecated. Use keyward argument insteand.'
+            % real_pos,
             DeprecationWarning)
+
         if key in kwargs:
-            raise TypeError('Duplicate argument: %dth argument and %s keyward argument.', (real_pos, key))
+            raise TypeError(
+                'Duplicate argument: %dth argument and %s keyward argument.'
+                % (real_pos, key))
         kwargs[key] = args[pos]
 
     def _unix_socket_arg(self, host, port, args, kwargs):
@@ -91,13 +107,16 @@
 
     def __getattr__(self, key):
         if key == 'SSL_VERSION':
-            warnings.warn('Use ssl_version attribute instead.', DeprecationWarning)
+            warnings.warn('Use ssl_version attribute instead.',
+                          DeprecationWarning)
             return self.ssl_version
 
     def __init__(self, server_side, host, ssl_opts):
         self._server_side = server_side
         if TSSLBase.SSL_VERSION != self._default_protocol:
-            warnings.warn('SSL_VERSION is deprecated. Use ssl_version keyward argument instead.', DeprecationWarning)
+            warnings.warn(
+                'SSL_VERSION is deprecated. Use ssl_version keyward argument instead.',
+                DeprecationWarning)
         self._context = ssl_opts.pop('ssl_context', None)
         self._server_hostname = None
         if not self._server_side:
@@ -105,9 +124,12 @@
         if self._context:
             self._custom_context = True
             if ssl_opts:
-                raise ValueError('Incompatible arguments: ssl_context and %s' % ' '.join(ssl_opts.keys()))
+                raise ValueError(
+                    'Incompatible arguments: ssl_context and %s'
+                    % ' '.join(ssl_opts.keys()))
             if not self._has_ssl_context:
-                raise ValueError('ssl_context is not available for this version of Python')
+                raise ValueError(
+                    'ssl_context is not available for this version of Python')
         else:
             self._custom_context = False
             ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
@@ -119,11 +141,13 @@
             self.ciphers = ssl_opts.pop('ciphers', None)
 
             if ssl_opts:
-                raise ValueError('Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
+                raise ValueError(
+                    'Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
 
-            if self.cert_reqs != ssl.CERT_NONE:
+            if self._should_verify:
                 if not self.ca_certs:
-                    raise ValueError('ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
+                    raise ValueError(
+                        'ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
                 if not os.access(self.ca_certs, os.R_OK):
                     raise IOError('Certificate Authority ca_certs file "%s" '
                                   'is not readable, cannot validate SSL '
@@ -146,13 +170,15 @@
             if not self._custom_context:
                 self.ssl_context.verify_mode = self.cert_reqs
                 if self.certfile:
-                    self.ssl_context.load_cert_chain(self.certfile, self.keyfile)
+                    self.ssl_context.load_cert_chain(self.certfile,
+                                                     self.keyfile)
                 if self.ciphers:
                     self.ssl_context.set_ciphers(self.ciphers)
                 if self.ca_certs:
                     self.ssl_context.load_verify_locations(self.ca_certs)
-            return self.ssl_context.wrap_socket(sock, server_side=self._server_side,
-                                                server_hostname=self._server_hostname)
+            return self.ssl_context.wrap_socket(
+                sock, server_side=self._server_side,
+                server_hostname=self._server_hostname)
         else:
             ssl_opts = {
                 'ssl_version': self._ssl_version,
@@ -166,7 +192,8 @@
                 if self._has_ciphers:
                     ssl_opts['ciphers'] = self.ciphers
                 else:
-                    logger.warning('ciphers is specified but ignored due to old Python version')
+                    logger.warning(
+                        'ciphers is specified but ignored due to old Python version')
             return ssl.wrap_socket(sock, **ssl_opts)
 
 
@@ -179,20 +206,29 @@
     """
 
     # New signature
-    # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
+    # def __init__(self, host='localhost', port=9090, unix_socket=None,
+    #              **ssl_args):
     # Deprecated signature
-    # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
+    # def __init__(self, host='localhost', port=9090, validate=True,
+    #              ca_certs=None, keyfile=None, certfile=None,
+    #              unix_socket=None, ciphers=None):
     def __init__(self, host='localhost', port=9090, *args, **kwargs):
         """Positional arguments: ``host``, ``port``, ``unix_socket``
 
-        Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
-                           ``ca_certs``, ``ciphers`` (Python 2.7.0 or later),
+        Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``,
+                           ``ssl_version``, ``ca_certs``,
+                           ``ciphers`` (Python 2.7.0 or later),
                            ``server_hostname`` (Python 2.7.9 or later)
         Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
 
-        Alternative keywoard arguments: (Python 2.7.9 or later)
+        Alternative keyword arguments: (Python 2.7.9 or later)
           ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
           ``server_hostname``: Passed to SSLContext.wrap_socket
+
+        Common keyword argument:
+          ``validate_callback`` (cert, hostname) -> None:
+              Called after SSL handshake. Can raise when hostname does not
+              match the cert.
         """
         self.is_valid = False
         self.peercert = None
@@ -212,13 +248,15 @@
         if validate is not None:
             cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
             warnings.warn(
-                'validate is deprecated. Use cert_reqs=ssl.%s instead' % cert_reqs_name,
+                'validate is deprecated. Use cert_reqs=ssl.%s instead'
+                % cert_reqs_name,
                 DeprecationWarning)
             if 'cert_reqs' in kwargs:
                 raise TypeError('Cannot specify both validate and cert_reqs')
             kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
 
         unix_socket = kwargs.pop('unix_socket', None)
+        self._validate_callback = kwargs.pop('validate_callback', match_hostname)
         TSSLBase.__init__(self, False, host, kwargs)
         TSocket.TSocket.__init__(self, host, port, unix_socket)
 
@@ -245,7 +283,9 @@
                     self.handle.connect(ip_port)
                 except socket.error as e:
                     if res is not res0[-1]:
-                        logger.warning('Error while connecting with %s. Trying next one.', ip_port, exc_info=True)
+                        logger.warning(
+                            'Error while connecting with %s. Trying next one.',
+                            ip_port, exc_info=True)
                         continue
                     else:
                         raise
@@ -255,27 +295,35 @@
                 message = 'Could not connect to secure socket %s: %s' \
                           % (self._unix_socket, e)
             else:
-                message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e)
-            logger.error('Error while connecting with %s.', ip_port, exc_info=True)
-            raise TTransportException(type=TTransportException.NOT_OPEN,
-                                      message=message)
-        if self.validate:
-            self._validate_cert()
+                message = 'Could not connect to %s:%d: %s' \
+                          % (self.host, self.port, e)
+            logger.error(
+                'Error while connecting with %s.', ip_port, exc_info=True)
+            raise TTransportException(TTransportException.NOT_OPEN, message)
 
-    def _validate_cert(self):
-        """internal method to validate the peer's SSL certificate, and to check the
-        commonName of the certificate to ensure it matches the hostname we
+        if self._should_verify:
+            self.peercert = self.handle.getpeercert()
+            try:
+                self._validate_callback(self.peercert, self._server_hostname)
+                self.is_valid = True
+            except TTransportException:
+                raise
+            except Exception as ex:
+                raise TTransportException(TTransportException.UNKNOWN, str(ex))
+
+    @staticmethod
+    def legacy_validate_callback(self, cert, hostname):
+        """legacy method to validate the peer's SSL certificate, and to check
+        the commonName of the certificate to ensure it matches the hostname we
         used to make this connection.  Does not support subjectAltName records
         in certificates.
 
         raises TTransportException if the certificate fails validation.
         """
-        cert = self.handle.getpeercert()
-        self.peercert = cert
         if 'subject' not in cert:
             raise TTransportException(
-                type=TTransportException.NOT_OPEN,
-                message='No SSL certificate found from %s:%s' % (self.host, self.port))
+                TTransportException.NOT_OPEN,
+                'No SSL certificate found from %s:%s' % (self.host, self.port))
         fields = cert['subject']
         for field in fields:
             # ensure structure we get back is what we expect
@@ -289,19 +337,18 @@
                 continue
             certhost = cert_value
             # this check should be performed by some sort of Access Manager
-            if certhost == self.host:
+            if certhost == hostname:
                 # success, cert commonName matches desired hostname
-                self.is_valid = True
                 return
             else:
                 raise TTransportException(
-                    type=TTransportException.UNKNOWN,
-                    message='Hostname we connected to "%s" doesn\'t match certificate '
+                    TTransportException.UNKNOWN,
+                    'Hostname we connected to "%s" doesn\'t match certificate '
                     'provided commonName "%s"' % (self.host, certhost))
         raise TTransportException(
-            type=TTransportException.UNKNOWN,
-            message='Could not validate SSL certificate from '
-            'host "%s".  Cert=%s' % (self.host, cert))
+            TTransportException.UNKNOWN,
+            'Could not validate SSL certificate from host "%s".  Cert=%s'
+            % (hostname, cert))
 
 
 class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
@@ -322,7 +369,7 @@
                            ``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
         See ssl.wrap_socket documentation.
 
-        Alternative keywoard arguments: (Python 2.7.9 or later)
+        Alternative keyword arguments: (Python 2.7.9 or later)
           ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
           ``server_hostname``: Passed to SSLContext.wrap_socket
         """
@@ -346,7 +393,8 @@
         TSocket.TServerSocket.__init__(self, host, port, unix_socket)
 
     def setCertfile(self, certfile):
-        """Set or change the server certificate file used to wrap new connections.
+        """Set or change the server certificate file used to wrap new
+        connections.
 
         @param certfile: The filename of the server certificate,
                          i.e. '/etc/certs/server.pem'
diff --git a/lib/py/test/test_sslsocket.py b/lib/py/test/test_sslsocket.py
index fa156a0..fe03961 100644
--- a/lib/py/test/test_sslsocket.py
+++ b/lib/py/test/test_sslsocket.py
@@ -41,7 +41,7 @@
 TEST_PORT = 23458
 TEST_ADDR = '/tmp/.thrift.domain.sock.%d' % TEST_PORT
 CONNECT_DELAY = 0.5
-CONNECT_TIMEOUT = 10.0
+CONNECT_TIMEOUT = 20.0
 TEST_CIPHERS = 'DES-CBC3-SHA'
 
 
@@ -72,19 +72,21 @@
 
 class TSSLSocketTest(unittest.TestCase):
     def _assert_connection_failure(self, server, client):
+        acc = ServerAcceptor(server)
         try:
-            acc = ServerAcceptor(server)
             acc.start()
-            time.sleep(CONNECT_DELAY)
-            client.setTimeout(CONNECT_TIMEOUT)
+            time.sleep(CONNECT_DELAY / 2)
+            client.setTimeout(CONNECT_TIMEOUT / 2)
             with self._assert_raises(Exception):
                 client.open()
-                select.select([], [client.handle], [], CONNECT_TIMEOUT)
+                select.select([], [client.handle], [], CONNECT_TIMEOUT / 2)
             # self.assertIsNone(acc.client)
             self.assertTrue(acc.client is None)
         finally:
-            server.close()
             client.close()
+            if acc.client:
+                acc.client.close()
+            server.close()
 
     def _assert_raises(self, exc):
         if sys.hexversion >= 0x020700F0:
@@ -93,18 +95,20 @@
             return AssertRaises(exc)
 
     def _assert_connection_success(self, server, client):
+        acc = ServerAcceptor(server)
         try:
-            acc = ServerAcceptor(server)
             acc.start()
-            time.sleep(0.15)
+            time.sleep(CONNECT_DELAY)
             client.setTimeout(CONNECT_TIMEOUT)
             client.open()
             select.select([], [client.handle], [], CONNECT_TIMEOUT)
             # self.assertIsNotNone(acc.client)
             self.assertTrue(acc.client is not None)
         finally:
-            server.close()
             client.close()
+            if acc.client:
+                acc.client.close()
+            server.close()
 
     # deprecated feature
     def test_deprecation(self):