THRIFT-3616 Improve TSSLSocketTest robustness.
Client: Test (Python)
Patch: John Sirois

Previously a combination of fixed ports, fixed paths and delays was used
in all TSSLSocketTest tests that involved making a client-server
connection; now ephemeral ports, unique tmp files and no delays for
successful connection tests are all implemented. A delay still remains
for the failed connection tests to allow for SSL handshake initiation
but not wait too long.

This closes #850
diff --git a/lib/py/test/test_sslsocket.py b/lib/py/test/test_sslsocket.py
index c19dbcd..c76d6d2 100644
--- a/lib/py/test/test_sslsocket.py
+++ b/lib/py/test/test_sslsocket.py
@@ -20,16 +20,19 @@
 import logging
 import os
 import platform
-import select
 import ssl
 import sys
+import tempfile
 import threading
-import time
 import unittest
 import warnings
+from contextlib import contextmanager
+
+import six
 
 import _import_local_thrift
 from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket
+from thrift.transport.TTransport import TTransportException
 
 SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
 ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
@@ -42,22 +45,50 @@
 CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key')
 CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
 
-TEST_PORT = 23458
-TEST_ADDR = '/tmp/.thrift.domain.sock.%d' % TEST_PORT
-CONNECT_DELAY = 0.5
-CONNECT_TIMEOUT = 20.0
 TEST_CIPHERS = 'DES-CBC3-SHA'
 
 
 class ServerAcceptor(threading.Thread):
     def __init__(self, server):
         super(ServerAcceptor, self).__init__()
+        self.daemon = True
         self._server = server
-        self.client = None
+        self._listening = threading.Event()
+        self._port = None
+        self._port_bound = threading.Event()
+        self._client = None
+        self._client_accepted = threading.Event()
 
     def run(self):
         self._server.listen()
-        self.client = self._server.accept()
+        self._listening.set()
+
+        try:
+            address = self._server.handle.getsockname()
+            if len(address) > 1:
+                # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are
+                # 4-tuples (host, port, ...), but in each case port is in the second slot.
+                self._port = address[1]
+        finally:
+            self._port_bound.set()
+
+        try:
+            self._client = self._server.accept()
+        finally:
+            self._client_accepted.set()
+
+    def await_listening(self):
+        self._listening.wait()
+
+    @property
+    def port(self):
+        self._port_bound.wait()
+        return self._port
+
+    @property
+    def client(self):
+        self._client_accepted.wait()
+        return self._client
 
 
 # Python 2.6 compat
@@ -75,185 +106,182 @@
 
 
 class TSSLSocketTest(unittest.TestCase):
-    def _assert_connection_failure(self, server, client):
+    def _server_socket(self, **kwargs):
+        return TSSLServerSocket(port=0, **kwargs)
+
+    @contextmanager
+    def _connectable_client(self, server, path=None, **client_kwargs):
         acc = ServerAcceptor(server)
         try:
             acc.start()
-            time.sleep(CONNECT_DELAY / 2)
-            client.setTimeout(CONNECT_TIMEOUT / 2)
-            with self._assert_raises(Exception):
-                logging.disable(logging.CRITICAL)
-                client.open()
-                select.select([], [client.handle], [], CONNECT_TIMEOUT / 2)
-            # self.assertIsNone(acc.client)
-            self.assertTrue(acc.client is None)
+            acc.await_listening()
+
+            host, port = ('localhost', acc.port) if path is None else (None, None)
+            client = TSSLSocket(host, port, path, **client_kwargs)
+            yield acc, client
         finally:
-            logging.disable(logging.NOTSET)
-            client.close()
             if acc.client:
                 acc.client.close()
             server.close()
 
+    def _assert_connection_failure(self, server, path=None, **client_args):
+        with self._connectable_client(server, path=path, **client_args) as (acc, client):
+            try:
+                logging.disable(logging.CRITICAL)
+                # We need to wait for a connection failure, but not too long.  20ms is a tunable
+                # compromise between test speed and stability
+                client.setTimeout(20)
+                with self._assert_raises(TTransportException):
+                    client.open()
+                self.assertTrue(acc.client is None)
+            finally:
+                logging.disable(logging.NOTSET)
+
     def _assert_raises(self, exc):
         if sys.hexversion >= 0x020700F0:
             return self.assertRaises(exc)
         else:
             return AssertRaises(exc)
 
-    def _assert_connection_success(self, server, client):
-        acc = ServerAcceptor(server)
-        try:
-            acc.start()
-            time.sleep(CONNECT_DELAY)
-            client.setTimeout(CONNECT_TIMEOUT)
+    def _assert_connection_success(self, server, path=None, **client_args):
+        with self._connectable_client(server, path=path, **client_args) as (acc, client):
             client.open()
-            select.select([], [client.handle], [], CONNECT_TIMEOUT)
-            # self.assertIsNotNone(acc.client)
-            self.assertTrue(acc.client is not None)
-        finally:
-            client.close()
-            if acc.client:
-                acc.client.close()
-            server.close()
+            try:
+                self.assertTrue(acc.client is not None)
+            finally:
+                client.close()
 
     # deprecated feature
     def test_deprecation(self):
+        if not six.PY3:
+            # The checks below currently only work for python3.
+            # See: https://issues.apache.org/jira/browse/THRIFT-3618
+            print('skiping test_deprecation')
+            return
+
         with warnings.catch_warnings(record=True) as w:
             warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
-            TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
+            TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT)
             self.assertEqual(len(w), 1)
 
         with warnings.catch_warnings(record=True) as w:
             warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
             # Deprecated signature
             # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
-            client = TSSLSocket('localhost', TEST_PORT, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
+            client = TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
             self.assertEqual(len(w), 7)
 
         with warnings.catch_warnings(record=True) as w:
             warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
             # Deprecated signature
             # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
-            server = TSSLServerSocket(None, TEST_PORT, SERVER_PEM, None, TEST_CIPHERS)
+            server = TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS)
             self.assertEqual(len(w), 3)
 
-        self._assert_connection_success(server, client)
-
     # deprecated feature
     def test_set_cert_reqs_by_validate(self):
-        c1 = TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
+        c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT)
         self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
 
-        c1 = TSSLSocket('localhost', TEST_PORT, validate=False)
+        c1 = TSSLSocket('localhost', 0, validate=False)
         self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
 
     # deprecated feature
     def test_set_validate_by_cert_reqs(self):
-        c1 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
+        c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE)
         self.assertFalse(c1.validate)
 
-        c2 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
+        c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
         self.assertTrue(c2.validate)
 
-        c3 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
+        c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
         self.assertTrue(c3.validate)
 
     def test_unix_domain_socket(self):
         if platform.system() == 'Windows':
             print('skipping test_unix_domain_socket')
             return
-        server = TSSLServerSocket(unix_socket=TEST_ADDR, keyfile=SERVER_KEY, certfile=SERVER_CERT)
-        client = TSSLSocket(None, None, TEST_ADDR, cert_reqs=ssl.CERT_NONE)
-        self._assert_connection_success(server, client)
+        fd, path = tempfile.mkstemp()
+        os.close(fd)
+        try:
+            server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT)
+            self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE)
+        finally:
+            os.unlink(path)
 
     def test_server_cert(self):
-        server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
-        client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
-        self._assert_connection_success(server, client)
+        server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
+        self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
 
-        server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
-        # server cert on in ca_certs
-        client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
-        self._assert_connection_failure(server, client)
+        server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
+        # server cert not in ca_certs
+        self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
 
-        server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
-        client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
-        self._assert_connection_success(server, client)
+        server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
+        self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE)
 
     def test_set_server_cert(self):
-        server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=CLIENT_CERT)
+        server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT)
         with self._assert_raises(Exception):
             server.certfile = 'foo'
         with self._assert_raises(Exception):
             server.certfile = None
         server.certfile = SERVER_CERT
-        client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
-        self._assert_connection_success(server, client)
+        self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
 
     def test_client_cert(self):
-        server = TSSLServerSocket(
-            port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
+        server = self._server_socket(
+            cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
             certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
-        client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
-        self._assert_connection_failure(server, client)
+        self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
 
-        server = TSSLServerSocket(
-            port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
+        server = self._server_socket(
+            cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
             certfile=SERVER_CERT, ca_certs=CLIENT_CA)
-        client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
-        self._assert_connection_failure(server, client)
+        self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
 
-        server = TSSLServerSocket(
-            port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
+        server = self._server_socket(
+            cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
             certfile=SERVER_CERT, ca_certs=CLIENT_CA)
-        client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
-        self._assert_connection_success(server, client)
+        self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
 
-        server = TSSLServerSocket(
-            port=TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY,
+        server = self._server_socket(
+            cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY,
             certfile=SERVER_CERT, ca_certs=CLIENT_CA)
-        client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
-        self._assert_connection_success(server, client)
+        self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
 
     def test_ciphers(self):
-        server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
-        client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
-        self._assert_connection_success(server, client)
+        server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
+        self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
 
         if not TSSLSocket._has_ciphers:
             # unittest.skip is not available for Python 2.6
             print('skipping test_ciphers')
             return
-        server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
-        client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
-        self._assert_connection_failure(server, client)
+        server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
+        self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
 
-        server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
-        client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
-        self._assert_connection_failure(server, client)
+        server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
+        self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
 
     def test_ssl2_and_ssl3_disabled(self):
         if not hasattr(ssl, 'PROTOCOL_SSLv3'):
             print('PROTOCOL_SSLv3 is not available')
         else:
-            server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
-            client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
-            self._assert_connection_failure(server, client)
+            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
+            self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
 
-            server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
-            client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
-            self._assert_connection_failure(server, client)
+            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
+            self._assert_connection_failure(server, ca_certs=SERVER_CERT)
 
         if not hasattr(ssl, 'PROTOCOL_SSLv2'):
             print('PROTOCOL_SSLv2 is not available')
         else:
-            server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
-            client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
-            self._assert_connection_failure(server, client)
+            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
+            self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
 
-            server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
-            client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
-            self._assert_connection_failure(server, client)
+            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
+            self._assert_connection_failure(server, ca_certs=SERVER_CERT)
 
     def test_newer_tls(self):
         if not TSSLSocket._has_ssl_context:
@@ -263,23 +291,20 @@
         if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
             print('PROTOCOL_TLSv1_2 is not available')
         else:
-            server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
-            client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
-            self._assert_connection_success(server, client)
+            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
+            self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
 
         if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
             print('PROTOCOL_TLSv1_1 is not available')
         else:
-            server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
-            client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
-            self._assert_connection_success(server, client)
+            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
+            self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
 
         if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
             print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
         else:
-            server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
-            client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
-            self._assert_connection_failure(server, client)
+            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
+            self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
 
     def test_ssl_context(self):
         if not TSSLSocket._has_ssl_context:
@@ -290,15 +315,14 @@
         server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
         server_context.load_verify_locations(CLIENT_CA)
         server_context.verify_mode = ssl.CERT_REQUIRED
-        server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context)
+        server = self._server_socket(ssl_context=server_context)
 
         client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
         client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
         client_context.load_verify_locations(SERVER_CERT)
         client_context.verify_mode = ssl.CERT_REQUIRED
-        client = TSSLSocket('localhost', TEST_PORT, ssl_context=client_context)
 
-        self._assert_connection_success(server, client)
+        self._assert_connection_success(server, ssl_context=client_context)
 
 if __name__ == '__main__':
     # import logging