THRIFT-3599 Validate client IP address against cert's SubjectAltName
diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/transport/TSSLSocket.py
index dfaa5db..763e533 100644
--- a/lib/py/src/transport/TSSLSocket.py
+++ b/lib/py/src/transport/TSSLSocket.py
@@ -372,6 +372,11 @@
         Alternative keyword arguments: (Python 2.7.9 or later)
           ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
           ``server_hostname``: Passed to SSLContext.wrap_socket
+
+        Common keyword argument:
+          ``validate_callback`` (cert, hostname) -> None:
+              Called after SSL handshake. Can raise when hostname does not
+              match the cert.
         """
         if args:
             if len(args) > 3:
@@ -389,6 +394,8 @@
                 kwargs['certfile'] = 'cert.pem'
 
         unix_socket = kwargs.pop('unix_socket', None)
+        self._validate_callback = \
+            kwargs.pop('validate_callback', match_hostname)
         TSSLBase.__init__(self, True, None, kwargs)
         TSocket.TServerSocket.__init__(self, host, port, unix_socket)
 
@@ -419,6 +426,19 @@
             # Instead, return None, and let the TServer instance deal with it in
             # other exception handling.  (but TSimpleServer dies anyway)
             return None
+
+        if self._should_verify:
+            client.peercert = client.getpeercert()
+            try:
+                self._validate_callback(client.peercert, addr[0])
+                client.is_valid = True
+            except Exception:
+                logger.warn('Failed to validate client certificate address',
+                            exc_info=True)
+                client.close()
+                plain_client.close()
+                return None
+
         result = TSocket.TSocket()
         result.setHandle(client)
         return result
diff --git a/lib/py/test/test_sslsocket.py b/lib/py/test/test_sslsocket.py
index fe03961..98d47ae 100644
--- a/lib/py/test/test_sslsocket.py
+++ b/lib/py/test/test_sslsocket.py
@@ -35,8 +35,11 @@
 SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
 SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
 SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
-CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
-CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
+CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
+CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
+CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt')
+CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key')
+CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
 
 TEST_PORT = 23458
 TEST_ADDR = '/tmp/.thrift.domain.sock.%d' % TEST_PORT
@@ -188,6 +191,24 @@
         server = TSSLServerSocket(
             port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
             certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
+        client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
+        self._assert_connection_failure(server, client)
+
+        server = TSSLServerSocket(
+            port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
+            certfile=SERVER_CERT, ca_certs=CLIENT_CA)
+        client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
+        self._assert_connection_failure(server, client)
+
+        server = TSSLServerSocket(
+            port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
+            certfile=SERVER_CERT, ca_certs=CLIENT_CA)
+        client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
+        self._assert_connection_success(server, client)
+
+        server = TSSLServerSocket(
+            port=TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY,
+            certfile=SERVER_CERT, ca_certs=CLIENT_CA)
         client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
         self._assert_connection_success(server, client)
 
@@ -264,14 +285,16 @@
             return
         server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
         server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
-        server_context.load_verify_locations(CLIENT_CERT)
+        server_context.load_verify_locations(CLIENT_CA)
+        server_context.verify_mode = ssl.CERT_REQUIRED
+        server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context)
 
         client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
         client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
         client_context.load_verify_locations(SERVER_CERT)
-
-        server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context)
+        client_context.verify_mode = ssl.CERT_REQUIRED
         client = TSSLSocket('localhost', TEST_PORT, ssl_context=client_context)
+
         self._assert_connection_success(server, client)
 
 if __name__ == '__main__':