THRIFT-3917 Check backports.ssl_match_hostname module version
This closes #1076
diff --git a/lib/py/src/transport/sslcompat.py b/lib/py/src/transport/sslcompat.py
index 19cfaca..7bf5e06 100644
--- a/lib/py/src/transport/sslcompat.py
+++ b/lib/py/src/transport/sslcompat.py
@@ -17,10 +17,13 @@
# under the License.
#
+import logging
import sys
from thrift.transport.TTransport import TTransportException
+logger = logging.getLogger(__name__)
+
def legacy_validate_callback(self, cert, hostname):
"""legacy method to validate the peer's SSL certificate, and to check
@@ -61,20 +64,36 @@
% (hostname, cert))
-try:
- import ipaddress # noqa
- _match_has_ipaddress = True
-except ImportError:
- _match_has_ipaddress = False
+def _optional_dependencies():
+ try:
+ import ipaddress # noqa
+ logger.debug('ipaddress module is available')
+ ipaddr = True
+ except ImportError:
+ logger.warn('ipaddress module is unavailable')
+ ipaddr = False
-try:
- from backports.ssl_match_hostname import match_hostname
- _match_hostname = match_hostname
-except ImportError:
if sys.hexversion < 0x030500F0:
- _match_has_ipaddress = False
+ try:
+ from backports.ssl_match_hostname import match_hostname, __version__ as ver
+ ver = list(map(int, ver.split('.')))
+ logger.debug('backports.ssl_match_hostname module is available')
+ match = match_hostname
+ if ver[0] * 10 + ver[1] >= 35:
+ return ipaddr, match
+ else:
+ logger.warn('backports.ssl_match_hostname module is too old')
+ ipaddr = False
+ except ImportError:
+ logger.warn('backports.ssl_match_hostname is unavailable')
+ ipaddr = False
try:
from ssl import match_hostname
- _match_hostname = match_hostname
+ logger.debug('ssl.match_hostname is available')
+ match = match_hostname
except ImportError:
- _match_hostname = legacy_validate_callback
+ logger.warn('using legacy validation callback')
+ match = legacy_validate_callback
+ return ipaddr, match
+
+_match_has_ipaddress, _match_hostname = _optional_dependencies()
diff --git a/lib/py/test/test_sslsocket.py b/lib/py/test/test_sslsocket.py
index 93d34d0..3e4b266 100644
--- a/lib/py/test/test_sslsocket.py
+++ b/lib/py/test/test_sslsocket.py
@@ -30,8 +30,6 @@
from contextlib import contextmanager
import _import_local_thrift # noqa
-from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket
-from thrift.transport.TTransport import TTransportException
SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
@@ -133,16 +131,16 @@
def _assert_connection_failure(self, server, path=None, **client_args):
logging.disable(logging.CRITICAL)
- with self._connectable_client(server, True, path=path, **client_args) as (acc, client):
- try:
+ try:
+ with self._connectable_client(server, True, path=path, **client_args) as (acc, client):
# We need to wait for a connection failure, but not too long. 20ms is a tunable
# compromise between test speed and stability
client.setTimeout(20)
with self._assert_raises(TTransportException):
client.open()
self.assertTrue(acc.client is None)
- finally:
- logging.disable(logging.NOTSET)
+ finally:
+ logging.disable(logging.NOTSET)
def _assert_raises(self, exc):
if sys.hexversion >= 0x020700F0:
@@ -334,6 +332,8 @@
self._assert_connection_success(server, ssl_context=client_context)
if __name__ == '__main__':
- # import logging
- # logging.basicConfig(level=logging.DEBUG)
+ logging.basicConfig(level=logging.WARN)
+ from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket
+ from thrift.transport.TTransport import TTransportException
+
unittest.main()