THRIFT-2103 [python] Support for SSL certificates with Subject Alternative Names
diff --git a/lib/py/CMakeLists.txt b/lib/py/CMakeLists.txt
index 7bb91fe..2ec8b56 100755
--- a/lib/py/CMakeLists.txt
+++ b/lib/py/CMakeLists.txt
@@ -20,6 +20,7 @@
include_directories(${PYTHON_INCLUDE_DIRS})
add_custom_target(python_build ALL
+ COMMAND ${PIP_EXECUTABLE} install -r requirements.txt || ${PIP_EXECUTABLE} install --user -r requirements.txt
COMMAND ${PYTHON_EXECUTABLE} setup.py build
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMENT "Building Python library"
diff --git a/lib/py/Makefile.am b/lib/py/Makefile.am
index 5183b9c..a075eb1 100755
--- a/lib/py/Makefile.am
+++ b/lib/py/Makefile.am
@@ -21,6 +21,7 @@
if WITH_PY3
py3-build:
+ $(PIP3) install -r requirements.txt || $(PIP3) install --user -r requirements.txt
$(PYTHON3) setup.py build
py3-test: py3-build
$(PYTHON3) test/thrift_json.py
@@ -31,6 +32,7 @@
endif
all-local: py3-build
+ $(PIP) install -r requirements.txt || $(PIP) install --user -r requirements.txt
$(PYTHON) setup.py build
# We're ignoring prefix here because site-packages seems to be
@@ -38,6 +40,7 @@
# Old version (can't put inline because it's not portable).
#$(PYTHON) setup.py install --prefix=$(prefix) --root=$(DESTDIR) $(PYTHON_SETUPUTIL_ARGS)
install-exec-hook:
+ $(PIP) install -r requirements.txt
$(PYTHON) setup.py install --root=$(DESTDIR) --prefix=$(PY_PREFIX) $(PYTHON_SETUPUTIL_ARGS)
clean-local:
diff --git a/lib/py/requirements.txt b/lib/py/requirements.txt
new file mode 100644
index 0000000..7cf8b31
--- /dev/null
+++ b/lib/py/requirements.txt
@@ -0,0 +1,3 @@
+six
+backports.ssl_match_hostname
+ipaddress
diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/transport/TSSLSocket.py
index 3f1a909..dfaa5db 100644
--- a/lib/py/src/transport/TSSLSocket.py
+++ b/lib/py/src/transport/TSSLSocket.py
@@ -23,12 +23,14 @@
import ssl
import sys
import warnings
+from backports.ssl_match_hostname import match_hostname
from thrift.transport import TSocket
from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
-warnings.filterwarnings('default', category=DeprecationWarning, module=__name__)
+warnings.filterwarnings(
+ 'default', category=DeprecationWarning, module=__name__)
class TSSLBase(object):
@@ -38,10 +40,13 @@
# ciphers argument is not available for Python < 2.7.0
_has_ciphers = sys.hexversion >= 0x020700F0
- # For pythoon >= 2.7.9, use latest TLS that both client and server supports.
+ # For pythoon >= 2.7.9, use latest TLS that both client and server
+ # supports.
# SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
- # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nare OP_NO_SSLvX are unavailable.
- _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else ssl.PROTOCOL_TLSv1
+ # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is
+ # unavailable.
+ _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \
+ ssl.PROTOCOL_TLSv1
def _init_context(self, ssl_version):
if self._has_ssl_context:
@@ -54,6 +59,13 @@
self._ssl_version = ssl_version
@property
+ def _should_verify(self):
+ if self._has_ssl_context:
+ return self._context.verify_mode != ssl.CERT_NONE
+ else:
+ return self.cert_reqs != ssl.CERT_NONE
+
+ @property
def ssl_version(self):
if self._has_ssl_context:
return self.ssl_context.protocol
@@ -76,10 +88,14 @@
return
real_pos = pos + 3
warnings.warn(
- '%dth positional argument is deprecated. Use keyward argument insteand.' % real_pos,
+ '%dth positional argument is deprecated. Use keyward argument insteand.'
+ % real_pos,
DeprecationWarning)
+
if key in kwargs:
- raise TypeError('Duplicate argument: %dth argument and %s keyward argument.', (real_pos, key))
+ raise TypeError(
+ 'Duplicate argument: %dth argument and %s keyward argument.'
+ % (real_pos, key))
kwargs[key] = args[pos]
def _unix_socket_arg(self, host, port, args, kwargs):
@@ -91,13 +107,16 @@
def __getattr__(self, key):
if key == 'SSL_VERSION':
- warnings.warn('Use ssl_version attribute instead.', DeprecationWarning)
+ warnings.warn('Use ssl_version attribute instead.',
+ DeprecationWarning)
return self.ssl_version
def __init__(self, server_side, host, ssl_opts):
self._server_side = server_side
if TSSLBase.SSL_VERSION != self._default_protocol:
- warnings.warn('SSL_VERSION is deprecated. Use ssl_version keyward argument instead.', DeprecationWarning)
+ warnings.warn(
+ 'SSL_VERSION is deprecated. Use ssl_version keyward argument instead.',
+ DeprecationWarning)
self._context = ssl_opts.pop('ssl_context', None)
self._server_hostname = None
if not self._server_side:
@@ -105,9 +124,12 @@
if self._context:
self._custom_context = True
if ssl_opts:
- raise ValueError('Incompatible arguments: ssl_context and %s' % ' '.join(ssl_opts.keys()))
+ raise ValueError(
+ 'Incompatible arguments: ssl_context and %s'
+ % ' '.join(ssl_opts.keys()))
if not self._has_ssl_context:
- raise ValueError('ssl_context is not available for this version of Python')
+ raise ValueError(
+ 'ssl_context is not available for this version of Python')
else:
self._custom_context = False
ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
@@ -119,11 +141,13 @@
self.ciphers = ssl_opts.pop('ciphers', None)
if ssl_opts:
- raise ValueError('Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
+ raise ValueError(
+ 'Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
- if self.cert_reqs != ssl.CERT_NONE:
+ if self._should_verify:
if not self.ca_certs:
- raise ValueError('ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
+ raise ValueError(
+ 'ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
if not os.access(self.ca_certs, os.R_OK):
raise IOError('Certificate Authority ca_certs file "%s" '
'is not readable, cannot validate SSL '
@@ -146,13 +170,15 @@
if not self._custom_context:
self.ssl_context.verify_mode = self.cert_reqs
if self.certfile:
- self.ssl_context.load_cert_chain(self.certfile, self.keyfile)
+ self.ssl_context.load_cert_chain(self.certfile,
+ self.keyfile)
if self.ciphers:
self.ssl_context.set_ciphers(self.ciphers)
if self.ca_certs:
self.ssl_context.load_verify_locations(self.ca_certs)
- return self.ssl_context.wrap_socket(sock, server_side=self._server_side,
- server_hostname=self._server_hostname)
+ return self.ssl_context.wrap_socket(
+ sock, server_side=self._server_side,
+ server_hostname=self._server_hostname)
else:
ssl_opts = {
'ssl_version': self._ssl_version,
@@ -166,7 +192,8 @@
if self._has_ciphers:
ssl_opts['ciphers'] = self.ciphers
else:
- logger.warning('ciphers is specified but ignored due to old Python version')
+ logger.warning(
+ 'ciphers is specified but ignored due to old Python version')
return ssl.wrap_socket(sock, **ssl_opts)
@@ -179,20 +206,29 @@
"""
# New signature
- # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
+ # def __init__(self, host='localhost', port=9090, unix_socket=None,
+ # **ssl_args):
# Deprecated signature
- # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
+ # def __init__(self, host='localhost', port=9090, validate=True,
+ # ca_certs=None, keyfile=None, certfile=None,
+ # unix_socket=None, ciphers=None):
def __init__(self, host='localhost', port=9090, *args, **kwargs):
"""Positional arguments: ``host``, ``port``, ``unix_socket``
- Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
- ``ca_certs``, ``ciphers`` (Python 2.7.0 or later),
+ Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``,
+ ``ssl_version``, ``ca_certs``,
+ ``ciphers`` (Python 2.7.0 or later),
``server_hostname`` (Python 2.7.9 or later)
Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
- Alternative keywoard arguments: (Python 2.7.9 or later)
+ Alternative keyword arguments: (Python 2.7.9 or later)
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
``server_hostname``: Passed to SSLContext.wrap_socket
+
+ Common keyword argument:
+ ``validate_callback`` (cert, hostname) -> None:
+ Called after SSL handshake. Can raise when hostname does not
+ match the cert.
"""
self.is_valid = False
self.peercert = None
@@ -212,13 +248,15 @@
if validate is not None:
cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
warnings.warn(
- 'validate is deprecated. Use cert_reqs=ssl.%s instead' % cert_reqs_name,
+ 'validate is deprecated. Use cert_reqs=ssl.%s instead'
+ % cert_reqs_name,
DeprecationWarning)
if 'cert_reqs' in kwargs:
raise TypeError('Cannot specify both validate and cert_reqs')
kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
unix_socket = kwargs.pop('unix_socket', None)
+ self._validate_callback = kwargs.pop('validate_callback', match_hostname)
TSSLBase.__init__(self, False, host, kwargs)
TSocket.TSocket.__init__(self, host, port, unix_socket)
@@ -245,7 +283,9 @@
self.handle.connect(ip_port)
except socket.error as e:
if res is not res0[-1]:
- logger.warning('Error while connecting with %s. Trying next one.', ip_port, exc_info=True)
+ logger.warning(
+ 'Error while connecting with %s. Trying next one.',
+ ip_port, exc_info=True)
continue
else:
raise
@@ -255,27 +295,35 @@
message = 'Could not connect to secure socket %s: %s' \
% (self._unix_socket, e)
else:
- message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e)
- logger.error('Error while connecting with %s.', ip_port, exc_info=True)
- raise TTransportException(type=TTransportException.NOT_OPEN,
- message=message)
- if self.validate:
- self._validate_cert()
+ message = 'Could not connect to %s:%d: %s' \
+ % (self.host, self.port, e)
+ logger.error(
+ 'Error while connecting with %s.', ip_port, exc_info=True)
+ raise TTransportException(TTransportException.NOT_OPEN, message)
- def _validate_cert(self):
- """internal method to validate the peer's SSL certificate, and to check the
- commonName of the certificate to ensure it matches the hostname we
+ if self._should_verify:
+ self.peercert = self.handle.getpeercert()
+ try:
+ self._validate_callback(self.peercert, self._server_hostname)
+ self.is_valid = True
+ except TTransportException:
+ raise
+ except Exception as ex:
+ raise TTransportException(TTransportException.UNKNOWN, str(ex))
+
+ @staticmethod
+ def legacy_validate_callback(self, cert, hostname):
+ """legacy method to validate the peer's SSL certificate, and to check
+ the commonName of the certificate to ensure it matches the hostname we
used to make this connection. Does not support subjectAltName records
in certificates.
raises TTransportException if the certificate fails validation.
"""
- cert = self.handle.getpeercert()
- self.peercert = cert
if 'subject' not in cert:
raise TTransportException(
- type=TTransportException.NOT_OPEN,
- message='No SSL certificate found from %s:%s' % (self.host, self.port))
+ TTransportException.NOT_OPEN,
+ 'No SSL certificate found from %s:%s' % (self.host, self.port))
fields = cert['subject']
for field in fields:
# ensure structure we get back is what we expect
@@ -289,19 +337,18 @@
continue
certhost = cert_value
# this check should be performed by some sort of Access Manager
- if certhost == self.host:
+ if certhost == hostname:
# success, cert commonName matches desired hostname
- self.is_valid = True
return
else:
raise TTransportException(
- type=TTransportException.UNKNOWN,
- message='Hostname we connected to "%s" doesn\'t match certificate '
+ TTransportException.UNKNOWN,
+ 'Hostname we connected to "%s" doesn\'t match certificate '
'provided commonName "%s"' % (self.host, certhost))
raise TTransportException(
- type=TTransportException.UNKNOWN,
- message='Could not validate SSL certificate from '
- 'host "%s". Cert=%s' % (self.host, cert))
+ TTransportException.UNKNOWN,
+ 'Could not validate SSL certificate from host "%s". Cert=%s'
+ % (hostname, cert))
class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
@@ -322,7 +369,7 @@
``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
See ssl.wrap_socket documentation.
- Alternative keywoard arguments: (Python 2.7.9 or later)
+ Alternative keyword arguments: (Python 2.7.9 or later)
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
``server_hostname``: Passed to SSLContext.wrap_socket
"""
@@ -346,7 +393,8 @@
TSocket.TServerSocket.__init__(self, host, port, unix_socket)
def setCertfile(self, certfile):
- """Set or change the server certificate file used to wrap new connections.
+ """Set or change the server certificate file used to wrap new
+ connections.
@param certfile: The filename of the server certificate,
i.e. '/etc/certs/server.pem'
diff --git a/lib/py/test/test_sslsocket.py b/lib/py/test/test_sslsocket.py
index fa156a0..fe03961 100644
--- a/lib/py/test/test_sslsocket.py
+++ b/lib/py/test/test_sslsocket.py
@@ -41,7 +41,7 @@
TEST_PORT = 23458
TEST_ADDR = '/tmp/.thrift.domain.sock.%d' % TEST_PORT
CONNECT_DELAY = 0.5
-CONNECT_TIMEOUT = 10.0
+CONNECT_TIMEOUT = 20.0
TEST_CIPHERS = 'DES-CBC3-SHA'
@@ -72,19 +72,21 @@
class TSSLSocketTest(unittest.TestCase):
def _assert_connection_failure(self, server, client):
+ acc = ServerAcceptor(server)
try:
- acc = ServerAcceptor(server)
acc.start()
- time.sleep(CONNECT_DELAY)
- client.setTimeout(CONNECT_TIMEOUT)
+ time.sleep(CONNECT_DELAY / 2)
+ client.setTimeout(CONNECT_TIMEOUT / 2)
with self._assert_raises(Exception):
client.open()
- select.select([], [client.handle], [], CONNECT_TIMEOUT)
+ select.select([], [client.handle], [], CONNECT_TIMEOUT / 2)
# self.assertIsNone(acc.client)
self.assertTrue(acc.client is None)
finally:
- server.close()
client.close()
+ if acc.client:
+ acc.client.close()
+ server.close()
def _assert_raises(self, exc):
if sys.hexversion >= 0x020700F0:
@@ -93,18 +95,20 @@
return AssertRaises(exc)
def _assert_connection_success(self, server, client):
+ acc = ServerAcceptor(server)
try:
- acc = ServerAcceptor(server)
acc.start()
- time.sleep(0.15)
+ time.sleep(CONNECT_DELAY)
client.setTimeout(CONNECT_TIMEOUT)
client.open()
select.select([], [client.handle], [], CONNECT_TIMEOUT)
# self.assertIsNotNone(acc.client)
self.assertTrue(acc.client is not None)
finally:
- server.close()
client.close()
+ if acc.client:
+ acc.client.close()
+ server.close()
# deprecated feature
def test_deprecation(self):