THRIFT-4904: Fix python unit test errors and exception escapes
Due to the way SSL layers on top of sockets, it was possible
to complete a connection and then have the server close it.
This would happen if the client is not checking certificates
but the server is. The TSSLSocket unit test was enhanced to
do a read and a write as well as just connecting to ensure a
more complete test.
The TSocket read() and write() calls were leaking OSError,
socker.error, and ssl.Error exceptions. These cases are now
wrapped into a TTransportException of the appropriate type,
and the original exception is added as an argument named inner.
diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/transport/TSSLSocket.py
index 066d8da..5b3ae59 100644
--- a/lib/py/src/transport/TSSLSocket.py
+++ b/lib/py/src/transport/TSSLSocket.py
@@ -291,11 +291,11 @@
plain_sock = socket.socket(family, socktype)
try:
return self._wrap_socket(plain_sock)
- except Exception:
+ except Exception as ex:
plain_sock.close()
msg = 'failed to initialize SSL'
logger.exception(msg)
- raise TTransportException(TTransportException.NOT_OPEN, msg)
+ raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=ex)
def open(self):
super(TSSLSocket, self).open()
@@ -307,7 +307,7 @@
except TTransportException:
raise
except Exception as ex:
- raise TTransportException(TTransportException.UNKNOWN, str(ex))
+ raise TTransportException(message=str(ex), inner=ex)
class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py
index c8be25a..df25d42 100644
--- a/lib/py/src/transport/TSocket.py
+++ b/lib/py/src/transport/TSocket.py
@@ -94,13 +94,13 @@
def open(self):
if self.handle:
- raise TTransportException(TTransportException.ALREADY_OPEN)
+ raise TTransportException(type=TTransportException.ALREADY_OPEN, message="already open")
try:
addrs = self._resolveAddr()
- except socket.gaierror:
+ except socket.gaierror as gai:
msg = 'failed to resolve sockaddr for ' + str(self._address)
logger.exception(msg)
- raise TTransportException(TTransportException.NOT_OPEN, msg)
+ raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=gai)
for family, socktype, _, _, sockaddr in addrs:
handle = self._do_open(family, socktype)
@@ -119,7 +119,7 @@
msg = 'Could not connect to any of %s' % list(map(lambda a: a[4],
addrs))
logger.error(msg)
- raise TTransportException(TTransportException.NOT_OPEN, msg)
+ raise TTransportException(type=TTransportException.NOT_OPEN, message=msg)
def read(self, sz):
try:
@@ -134,8 +134,10 @@
self.close()
# Trigger the check to raise the END_OF_FILE exception below.
buff = ''
+ elif e.args[0] == errno.ETIMEDOUT:
+ raise TTransportException(type=TTransportException.TIMED_OUT, message="read timeout", inner=e)
else:
- raise
+ raise TTransportException(message="unexpected exception", inner=e)
if len(buff) == 0:
raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket read 0 bytes')
@@ -148,12 +150,15 @@
sent = 0
have = len(buff)
while sent < have:
- plus = self.handle.send(buff)
- if plus == 0:
- raise TTransportException(type=TTransportException.END_OF_FILE,
- message='TSocket sent 0 bytes')
- sent += plus
- buff = buff[plus:]
+ try:
+ plus = self.handle.send(buff)
+ if plus == 0:
+ raise TTransportException(type=TTransportException.END_OF_FILE,
+ message='TSocket sent 0 bytes')
+ sent += plus
+ buff = buff[plus:]
+ except socket.error as e:
+ raise TTransportException(message="unexpected exception", inner=e)
def flush(self):
pass
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index 8573ba7..9dbe95d 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -34,9 +34,10 @@
SIZE_LIMIT = 6
INVALID_CLIENT_TYPE = 7
- def __init__(self, type=UNKNOWN, message=None):
+ def __init__(self, type=UNKNOWN, message=None, inner=None):
TException.__init__(self, message)
self.type = type
+ self.inner = inner
class TTransportBase(object):
diff --git a/lib/py/test/test_sslsocket.py b/lib/py/test/test_sslsocket.py
index 598c174..f4c87f1 100644
--- a/lib/py/test/test_sslsocket.py
+++ b/lib/py/test/test_sslsocket.py
@@ -75,6 +75,9 @@
try:
self._client = self._server.accept()
+ if self._client:
+ self._client.read(5) # hello
+ self._client.write(b"there")
except Exception:
logging.exception('error on server side (%s):' % self.name)
if not self._expect_failure:
@@ -141,7 +144,8 @@
client.setTimeout(20)
with self._assert_raises(TTransportException):
client.open()
- self.assertTrue(acc.client is None)
+ client.write(b"hello")
+ client.read(5) # b"there"
finally:
logging.disable(logging.NOTSET)
@@ -153,8 +157,10 @@
def _assert_connection_success(self, server, path=None, **client_args):
with self._connectable_client(server, path=path, **client_args) as (acc, client):
- client.open()
try:
+ client.open()
+ client.write(b"hello")
+ self.assertEqual(client.read(5), b"there")
self.assertTrue(acc.client is not None)
finally:
client.close()