THRIFT-3596 Better conformance to PEP8
This closes #832
diff --git a/lib/py/src/transport/THttpClient.py b/lib/py/src/transport/THttpClient.py
index 5abd41c..95f118c 100644
--- a/lib/py/src/transport/THttpClient.py
+++ b/lib/py/src/transport/THttpClient.py
@@ -26,130 +26,130 @@
from six.moves import urllib
from six.moves import http_client
-from .TTransport import *
+from .TTransport import TTransportBase
import six
class THttpClient(TTransportBase):
- """Http implementation of TTransport base."""
+ """Http implementation of TTransport base."""
- def __init__(self, uri_or_host, port=None, path=None):
- """THttpClient supports two different types constructor parameters.
+ def __init__(self, uri_or_host, port=None, path=None):
+ """THttpClient supports two different types constructor parameters.
- THttpClient(host, port, path) - deprecated
- THttpClient(uri)
+ THttpClient(host, port, path) - deprecated
+ THttpClient(uri)
- Only the second supports https.
- """
- if port is not None:
- warnings.warn(
- "Please use the THttpClient('http://host:port/path') syntax",
- DeprecationWarning,
- stacklevel=2)
- self.host = uri_or_host
- self.port = port
- assert path
- self.path = path
- self.scheme = 'http'
- else:
- parsed = urllib.parse.urlparse(uri_or_host)
- self.scheme = parsed.scheme
- assert self.scheme in ('http', 'https')
- if self.scheme == 'http':
- self.port = parsed.port or http_client.HTTP_PORT
- elif self.scheme == 'https':
- self.port = parsed.port or http_client.HTTPS_PORT
- self.host = parsed.hostname
- self.path = parsed.path
- if parsed.query:
- self.path += '?%s' % parsed.query
- self.__wbuf = BytesIO()
- self.__http = None
- self.__http_response = None
- self.__timeout = None
- self.__custom_headers = None
+ Only the second supports https.
+ """
+ if port is not None:
+ warnings.warn(
+ "Please use the THttpClient('http://host:port/path') syntax",
+ DeprecationWarning,
+ stacklevel=2)
+ self.host = uri_or_host
+ self.port = port
+ assert path
+ self.path = path
+ self.scheme = 'http'
+ else:
+ parsed = urllib.parse.urlparse(uri_or_host)
+ self.scheme = parsed.scheme
+ assert self.scheme in ('http', 'https')
+ if self.scheme == 'http':
+ self.port = parsed.port or http_client.HTTP_PORT
+ elif self.scheme == 'https':
+ self.port = parsed.port or http_client.HTTPS_PORT
+ self.host = parsed.hostname
+ self.path = parsed.path
+ if parsed.query:
+ self.path += '?%s' % parsed.query
+ self.__wbuf = BytesIO()
+ self.__http = None
+ self.__http_response = None
+ self.__timeout = None
+ self.__custom_headers = None
- def open(self):
- if self.scheme == 'http':
- self.__http = http_client.HTTPConnection(self.host, self.port)
- else:
- self.__http = http_client.HTTPSConnection(self.host, self.port)
+ def open(self):
+ if self.scheme == 'http':
+ self.__http = http_client.HTTPConnection(self.host, self.port)
+ else:
+ self.__http = http_client.HTTPSConnection(self.host, self.port)
- def close(self):
- self.__http.close()
- self.__http = None
- self.__http_response = None
+ def close(self):
+ self.__http.close()
+ self.__http = None
+ self.__http_response = None
- def isOpen(self):
- return self.__http is not None
+ def isOpen(self):
+ return self.__http is not None
- def setTimeout(self, ms):
- if not hasattr(socket, 'getdefaulttimeout'):
- raise NotImplementedError
+ def setTimeout(self, ms):
+ if not hasattr(socket, 'getdefaulttimeout'):
+ raise NotImplementedError
- if ms is None:
- self.__timeout = None
- else:
- self.__timeout = ms / 1000.0
+ if ms is None:
+ self.__timeout = None
+ else:
+ self.__timeout = ms / 1000.0
- def setCustomHeaders(self, headers):
- self.__custom_headers = headers
+ def setCustomHeaders(self, headers):
+ self.__custom_headers = headers
- def read(self, sz):
- return self.__http_response.read(sz)
+ def read(self, sz):
+ return self.__http_response.read(sz)
- def write(self, buf):
- self.__wbuf.write(buf)
+ def write(self, buf):
+ self.__wbuf.write(buf)
- def __withTimeout(f):
- def _f(*args, **kwargs):
- orig_timeout = socket.getdefaulttimeout()
- socket.setdefaulttimeout(args[0].__timeout)
- try:
- result = f(*args, **kwargs)
- finally:
- socket.setdefaulttimeout(orig_timeout)
- return result
- return _f
+ def __withTimeout(f):
+ def _f(*args, **kwargs):
+ orig_timeout = socket.getdefaulttimeout()
+ socket.setdefaulttimeout(args[0].__timeout)
+ try:
+ result = f(*args, **kwargs)
+ finally:
+ socket.setdefaulttimeout(orig_timeout)
+ return result
+ return _f
- def flush(self):
- if self.isOpen():
- self.close()
- self.open()
+ def flush(self):
+ if self.isOpen():
+ self.close()
+ self.open()
- # Pull data out of buffer
- data = self.__wbuf.getvalue()
- self.__wbuf = BytesIO()
+ # Pull data out of buffer
+ data = self.__wbuf.getvalue()
+ self.__wbuf = BytesIO()
- # HTTP request
- self.__http.putrequest('POST', self.path)
+ # HTTP request
+ self.__http.putrequest('POST', self.path)
- # Write headers
- self.__http.putheader('Content-Type', 'application/x-thrift')
- self.__http.putheader('Content-Length', str(len(data)))
+ # Write headers
+ self.__http.putheader('Content-Type', 'application/x-thrift')
+ self.__http.putheader('Content-Length', str(len(data)))
- if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
- user_agent = 'Python/THttpClient'
- script = os.path.basename(sys.argv[0])
- if script:
- user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script))
- self.__http.putheader('User-Agent', user_agent)
+ if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
+ user_agent = 'Python/THttpClient'
+ script = os.path.basename(sys.argv[0])
+ if script:
+ user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script))
+ self.__http.putheader('User-Agent', user_agent)
- if self.__custom_headers:
- for key, val in six.iteritems(self.__custom_headers):
- self.__http.putheader(key, val)
+ if self.__custom_headers:
+ for key, val in six.iteritems(self.__custom_headers):
+ self.__http.putheader(key, val)
- self.__http.endheaders()
+ self.__http.endheaders()
- # Write payload
- self.__http.send(data)
+ # Write payload
+ self.__http.send(data)
- # Get reply to flush the request
- self.__http_response = self.__http.getresponse()
- self.code = self.__http_response.status
- self.message = self.__http_response.reason
- self.headers = self.__http_response.msg
+ # Get reply to flush the request
+ self.__http_response = self.__http.getresponse()
+ self.code = self.__http_response.status
+ self.message = self.__http_response.reason
+ self.headers = self.__http_response.msg
- # Decorate if we know how to timeout
- if hasattr(socket, 'getdefaulttimeout'):
- flush = __withTimeout(flush)
+ # Decorate if we know how to timeout
+ if hasattr(socket, 'getdefaulttimeout'):
+ flush = __withTimeout(flush)
diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/transport/TSSLSocket.py
index 9be0912..3f1a909 100644
--- a/lib/py/src/transport/TSSLSocket.py
+++ b/lib/py/src/transport/TSSLSocket.py
@@ -32,345 +32,345 @@
class TSSLBase(object):
- # SSLContext is not available for Python < 2.7.9
- _has_ssl_context = sys.hexversion >= 0x020709F0
+ # SSLContext is not available for Python < 2.7.9
+ _has_ssl_context = sys.hexversion >= 0x020709F0
- # ciphers argument is not available for Python < 2.7.0
- _has_ciphers = sys.hexversion >= 0x020700F0
+ # ciphers argument is not available for Python < 2.7.0
+ _has_ciphers = sys.hexversion >= 0x020700F0
- # For pythoon >= 2.7.9, use latest TLS that both client and server supports.
- # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
- # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nare OP_NO_SSLvX are unavailable.
- _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else ssl.PROTOCOL_TLSv1
+ # For pythoon >= 2.7.9, use latest TLS that both client and server supports.
+ # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
+ # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nare OP_NO_SSLvX are unavailable.
+ _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else ssl.PROTOCOL_TLSv1
- def _init_context(self, ssl_version):
- if self._has_ssl_context:
- self._context = ssl.SSLContext(ssl_version)
- if self._context.protocol == ssl.PROTOCOL_SSLv23:
- self._context.options |= ssl.OP_NO_SSLv2
- self._context.options |= ssl.OP_NO_SSLv3
- else:
- self._context = None
- self._ssl_version = ssl_version
+ def _init_context(self, ssl_version):
+ if self._has_ssl_context:
+ self._context = ssl.SSLContext(ssl_version)
+ if self._context.protocol == ssl.PROTOCOL_SSLv23:
+ self._context.options |= ssl.OP_NO_SSLv2
+ self._context.options |= ssl.OP_NO_SSLv3
+ else:
+ self._context = None
+ self._ssl_version = ssl_version
- @property
- def ssl_version(self):
- if self._has_ssl_context:
- return self.ssl_context.protocol
- else:
- return self._ssl_version
+ @property
+ def ssl_version(self):
+ if self._has_ssl_context:
+ return self.ssl_context.protocol
+ else:
+ return self._ssl_version
- @property
- def ssl_context(self):
- return self._context
+ @property
+ def ssl_context(self):
+ return self._context
- SSL_VERSION = _default_protocol
- """
+ SSL_VERSION = _default_protocol
+ """
Default SSL version.
For backword compatibility, it can be modified.
Use __init__ keywoard argument "ssl_version" instead.
"""
- def _deprecated_arg(self, args, kwargs, pos, key):
- if len(args) <= pos:
- return
- real_pos = pos + 3
- warnings.warn(
- '%dth positional argument is deprecated. Use keyward argument insteand.' % real_pos,
- DeprecationWarning)
- if key in kwargs:
- raise TypeError('Duplicate argument: %dth argument and %s keyward argument.', (real_pos, key))
- kwargs[key] = args[pos]
+ def _deprecated_arg(self, args, kwargs, pos, key):
+ if len(args) <= pos:
+ return
+ real_pos = pos + 3
+ warnings.warn(
+ '%dth positional argument is deprecated. Use keyward argument insteand.' % real_pos,
+ DeprecationWarning)
+ if key in kwargs:
+ raise TypeError('Duplicate argument: %dth argument and %s keyward argument.', (real_pos, key))
+ kwargs[key] = args[pos]
- def _unix_socket_arg(self, host, port, args, kwargs):
- key = 'unix_socket'
- if host is None and port is None and len(args) == 1 and key not in kwargs:
- kwargs[key] = args[0]
- return True
- return False
+ def _unix_socket_arg(self, host, port, args, kwargs):
+ key = 'unix_socket'
+ if host is None and port is None and len(args) == 1 and key not in kwargs:
+ kwargs[key] = args[0]
+ return True
+ return False
- def __getattr__(self, key):
- if key == 'SSL_VERSION':
- warnings.warn('Use ssl_version attribute instead.', DeprecationWarning)
- return self.ssl_version
+ def __getattr__(self, key):
+ if key == 'SSL_VERSION':
+ warnings.warn('Use ssl_version attribute instead.', DeprecationWarning)
+ return self.ssl_version
- def __init__(self, server_side, host, ssl_opts):
- self._server_side = server_side
- if TSSLBase.SSL_VERSION != self._default_protocol:
- warnings.warn('SSL_VERSION is deprecated. Use ssl_version keyward argument instead.', DeprecationWarning)
- self._context = ssl_opts.pop('ssl_context', None)
- self._server_hostname = None
- if not self._server_side:
- self._server_hostname = ssl_opts.pop('server_hostname', host)
- if self._context:
- self._custom_context = True
- if ssl_opts:
- raise ValueError('Incompatible arguments: ssl_context and %s' % ' '.join(ssl_opts.keys()))
- if not self._has_ssl_context:
- raise ValueError('ssl_context is not available for this version of Python')
- else:
- self._custom_context = False
- ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
- self._init_context(ssl_version)
- self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
- self.ca_certs = ssl_opts.pop('ca_certs', None)
- self.keyfile = ssl_opts.pop('keyfile', None)
- self.certfile = ssl_opts.pop('certfile', None)
- self.ciphers = ssl_opts.pop('ciphers', None)
-
- if ssl_opts:
- raise ValueError('Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
-
- if self.cert_reqs != ssl.CERT_NONE:
- if not self.ca_certs:
- raise ValueError('ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
- if not os.access(self.ca_certs, os.R_OK):
- raise IOError('Certificate Authority ca_certs file "%s" '
- 'is not readable, cannot validate SSL '
- 'certificates.' % (self.ca_certs))
-
- @property
- def certfile(self):
- return self._certfile
-
- @certfile.setter
- def certfile(self, certfile):
- if self._server_side and not certfile:
- raise ValueError('certfile is needed for server-side')
- if certfile and not os.access(certfile, os.R_OK):
- raise IOError('No such certfile found: %s' % (certfile))
- self._certfile = certfile
-
- def _wrap_socket(self, sock):
- if self._has_ssl_context:
- if not self._custom_context:
- self.ssl_context.verify_mode = self.cert_reqs
- if self.certfile:
- self.ssl_context.load_cert_chain(self.certfile, self.keyfile)
- if self.ciphers:
- self.ssl_context.set_ciphers(self.ciphers)
- if self.ca_certs:
- self.ssl_context.load_verify_locations(self.ca_certs)
- return self.ssl_context.wrap_socket(sock, server_side=self._server_side,
- server_hostname=self._server_hostname)
- else:
- ssl_opts = {
- 'ssl_version': self._ssl_version,
- 'server_side': self._server_side,
- 'ca_certs': self.ca_certs,
- 'keyfile': self.keyfile,
- 'certfile': self.certfile,
- 'cert_reqs': self.cert_reqs,
- }
- if self.ciphers:
- if self._has_ciphers:
- ssl_opts['ciphers'] = self.ciphers
+ def __init__(self, server_side, host, ssl_opts):
+ self._server_side = server_side
+ if TSSLBase.SSL_VERSION != self._default_protocol:
+ warnings.warn('SSL_VERSION is deprecated. Use ssl_version keyward argument instead.', DeprecationWarning)
+ self._context = ssl_opts.pop('ssl_context', None)
+ self._server_hostname = None
+ if not self._server_side:
+ self._server_hostname = ssl_opts.pop('server_hostname', host)
+ if self._context:
+ self._custom_context = True
+ if ssl_opts:
+ raise ValueError('Incompatible arguments: ssl_context and %s' % ' '.join(ssl_opts.keys()))
+ if not self._has_ssl_context:
+ raise ValueError('ssl_context is not available for this version of Python')
else:
- logger.warning('ciphers is specified but ignored due to old Python version')
- return ssl.wrap_socket(sock, **ssl_opts)
+ self._custom_context = False
+ ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
+ self._init_context(ssl_version)
+ self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
+ self.ca_certs = ssl_opts.pop('ca_certs', None)
+ self.keyfile = ssl_opts.pop('keyfile', None)
+ self.certfile = ssl_opts.pop('certfile', None)
+ self.ciphers = ssl_opts.pop('ciphers', None)
+
+ if ssl_opts:
+ raise ValueError('Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
+
+ if self.cert_reqs != ssl.CERT_NONE:
+ if not self.ca_certs:
+ raise ValueError('ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
+ if not os.access(self.ca_certs, os.R_OK):
+ raise IOError('Certificate Authority ca_certs file "%s" '
+ 'is not readable, cannot validate SSL '
+ 'certificates.' % (self.ca_certs))
+
+ @property
+ def certfile(self):
+ return self._certfile
+
+ @certfile.setter
+ def certfile(self, certfile):
+ if self._server_side and not certfile:
+ raise ValueError('certfile is needed for server-side')
+ if certfile and not os.access(certfile, os.R_OK):
+ raise IOError('No such certfile found: %s' % (certfile))
+ self._certfile = certfile
+
+ def _wrap_socket(self, sock):
+ if self._has_ssl_context:
+ if not self._custom_context:
+ self.ssl_context.verify_mode = self.cert_reqs
+ if self.certfile:
+ self.ssl_context.load_cert_chain(self.certfile, self.keyfile)
+ if self.ciphers:
+ self.ssl_context.set_ciphers(self.ciphers)
+ if self.ca_certs:
+ self.ssl_context.load_verify_locations(self.ca_certs)
+ return self.ssl_context.wrap_socket(sock, server_side=self._server_side,
+ server_hostname=self._server_hostname)
+ else:
+ ssl_opts = {
+ 'ssl_version': self._ssl_version,
+ 'server_side': self._server_side,
+ 'ca_certs': self.ca_certs,
+ 'keyfile': self.keyfile,
+ 'certfile': self.certfile,
+ 'cert_reqs': self.cert_reqs,
+ }
+ if self.ciphers:
+ if self._has_ciphers:
+ ssl_opts['ciphers'] = self.ciphers
+ else:
+ logger.warning('ciphers is specified but ignored due to old Python version')
+ return ssl.wrap_socket(sock, **ssl_opts)
class TSSLSocket(TSocket.TSocket, TSSLBase):
- """
- SSL implementation of TSocket
-
- This class creates outbound sockets wrapped using the
- python standard ssl module for encrypted connections.
- """
-
- # New signature
- # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
- # Deprecated signature
- # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
- def __init__(self, host='localhost', port=9090, *args, **kwargs):
- """Positional arguments: ``host``, ``port``, ``unix_socket``
-
- Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
- ``ca_certs``, ``ciphers`` (Python 2.7.0 or later),
- ``server_hostname`` (Python 2.7.9 or later)
- Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
-
- Alternative keywoard arguments: (Python 2.7.9 or later)
- ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
- ``server_hostname``: Passed to SSLContext.wrap_socket
"""
- self.is_valid = False
- self.peercert = None
+ SSL implementation of TSocket
- if args:
- if len(args) > 6:
- raise TypeError('Too many positional argument')
- if not self._unix_socket_arg(host, port, args, kwargs):
- self._deprecated_arg(args, kwargs, 0, 'validate')
- self._deprecated_arg(args, kwargs, 1, 'ca_certs')
- self._deprecated_arg(args, kwargs, 2, 'keyfile')
- self._deprecated_arg(args, kwargs, 3, 'certfile')
- self._deprecated_arg(args, kwargs, 4, 'unix_socket')
- self._deprecated_arg(args, kwargs, 5, 'ciphers')
+ This class creates outbound sockets wrapped using the
+ python standard ssl module for encrypted connections.
+ """
- validate = kwargs.pop('validate', None)
- if validate is not None:
- cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
- warnings.warn(
- 'validate is deprecated. Use cert_reqs=ssl.%s instead' % cert_reqs_name,
- DeprecationWarning)
- if 'cert_reqs' in kwargs:
- raise TypeError('Cannot specify both validate and cert_reqs')
- kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
+ # New signature
+ # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
+ # Deprecated signature
+ # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
+ def __init__(self, host='localhost', port=9090, *args, **kwargs):
+ """Positional arguments: ``host``, ``port``, ``unix_socket``
- unix_socket = kwargs.pop('unix_socket', None)
- TSSLBase.__init__(self, False, host, kwargs)
- TSocket.TSocket.__init__(self, host, port, unix_socket)
+ Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
+ ``ca_certs``, ``ciphers`` (Python 2.7.0 or later),
+ ``server_hostname`` (Python 2.7.9 or later)
+ Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
- @property
- def validate(self):
- warnings.warn('Use cert_reqs instead', DeprecationWarning)
- return self.cert_reqs != ssl.CERT_NONE
+ Alternative keywoard arguments: (Python 2.7.9 or later)
+ ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
+ ``server_hostname``: Passed to SSLContext.wrap_socket
+ """
+ self.is_valid = False
+ self.peercert = None
- @validate.setter
- def validate(self, value):
- warnings.warn('Use cert_reqs instead', DeprecationWarning)
- self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE
+ if args:
+ if len(args) > 6:
+ raise TypeError('Too many positional argument')
+ if not self._unix_socket_arg(host, port, args, kwargs):
+ self._deprecated_arg(args, kwargs, 0, 'validate')
+ self._deprecated_arg(args, kwargs, 1, 'ca_certs')
+ self._deprecated_arg(args, kwargs, 2, 'keyfile')
+ self._deprecated_arg(args, kwargs, 3, 'certfile')
+ self._deprecated_arg(args, kwargs, 4, 'unix_socket')
+ self._deprecated_arg(args, kwargs, 5, 'ciphers')
- def open(self):
- try:
- res0 = self._resolveAddr()
- for res in res0:
- sock_family, sock_type = res[0:2]
- ip_port = res[4]
- plain_sock = socket.socket(sock_family, sock_type)
- self.handle = self._wrap_socket(plain_sock)
- self.handle.settimeout(self._timeout)
+ validate = kwargs.pop('validate', None)
+ if validate is not None:
+ cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
+ warnings.warn(
+ 'validate is deprecated. Use cert_reqs=ssl.%s instead' % cert_reqs_name,
+ DeprecationWarning)
+ if 'cert_reqs' in kwargs:
+ raise TypeError('Cannot specify both validate and cert_reqs')
+ kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
+
+ unix_socket = kwargs.pop('unix_socket', None)
+ TSSLBase.__init__(self, False, host, kwargs)
+ TSocket.TSocket.__init__(self, host, port, unix_socket)
+
+ @property
+ def validate(self):
+ warnings.warn('Use cert_reqs instead', DeprecationWarning)
+ return self.cert_reqs != ssl.CERT_NONE
+
+ @validate.setter
+ def validate(self, value):
+ warnings.warn('Use cert_reqs instead', DeprecationWarning)
+ self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE
+
+ def open(self):
try:
- self.handle.connect(ip_port)
+ res0 = self._resolveAddr()
+ for res in res0:
+ sock_family, sock_type = res[0:2]
+ ip_port = res[4]
+ plain_sock = socket.socket(sock_family, sock_type)
+ self.handle = self._wrap_socket(plain_sock)
+ self.handle.settimeout(self._timeout)
+ try:
+ self.handle.connect(ip_port)
+ except socket.error as e:
+ if res is not res0[-1]:
+ logger.warning('Error while connecting with %s. Trying next one.', ip_port, exc_info=True)
+ continue
+ else:
+ raise
+ break
except socket.error as e:
- if res is not res0[-1]:
- logger.warning('Error while connecting with %s. Trying next one.', ip_port, exc_info=True)
- continue
- else:
- raise
- break
- except socket.error as e:
- if self._unix_socket:
- message = 'Could not connect to secure socket %s: %s' \
- % (self._unix_socket, e)
- else:
- message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e)
- logger.error('Error while connecting with %s.', ip_port, exc_info=True)
- raise TTransportException(type=TTransportException.NOT_OPEN,
- message=message)
- if self.validate:
- self._validate_cert()
+ if self._unix_socket:
+ message = 'Could not connect to secure socket %s: %s' \
+ % (self._unix_socket, e)
+ else:
+ message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e)
+ logger.error('Error while connecting with %s.', ip_port, exc_info=True)
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message=message)
+ if self.validate:
+ self._validate_cert()
- def _validate_cert(self):
- """internal method to validate the peer's SSL certificate, and to check the
- commonName of the certificate to ensure it matches the hostname we
- used to make this connection. Does not support subjectAltName records
- in certificates.
+ def _validate_cert(self):
+ """internal method to validate the peer's SSL certificate, and to check the
+ commonName of the certificate to ensure it matches the hostname we
+ used to make this connection. Does not support subjectAltName records
+ in certificates.
- raises TTransportException if the certificate fails validation.
- """
- cert = self.handle.getpeercert()
- self.peercert = cert
- if 'subject' not in cert:
- raise TTransportException(
- type=TTransportException.NOT_OPEN,
- message='No SSL certificate found from %s:%s' % (self.host, self.port))
- fields = cert['subject']
- for field in fields:
- # ensure structure we get back is what we expect
- if not isinstance(field, tuple):
- continue
- cert_pair = field[0]
- if len(cert_pair) < 2:
- continue
- cert_key, cert_value = cert_pair[0:2]
- if cert_key != 'commonName':
- continue
- certhost = cert_value
- # this check should be performed by some sort of Access Manager
- if certhost == self.host:
- # success, cert commonName matches desired hostname
- self.is_valid = True
- return
- else:
+ raises TTransportException if the certificate fails validation.
+ """
+ cert = self.handle.getpeercert()
+ self.peercert = cert
+ if 'subject' not in cert:
+ raise TTransportException(
+ type=TTransportException.NOT_OPEN,
+ message='No SSL certificate found from %s:%s' % (self.host, self.port))
+ fields = cert['subject']
+ for field in fields:
+ # ensure structure we get back is what we expect
+ if not isinstance(field, tuple):
+ continue
+ cert_pair = field[0]
+ if len(cert_pair) < 2:
+ continue
+ cert_key, cert_value = cert_pair[0:2]
+ if cert_key != 'commonName':
+ continue
+ certhost = cert_value
+ # this check should be performed by some sort of Access Manager
+ if certhost == self.host:
+ # success, cert commonName matches desired hostname
+ self.is_valid = True
+ return
+ else:
+ raise TTransportException(
+ type=TTransportException.UNKNOWN,
+ message='Hostname we connected to "%s" doesn\'t match certificate '
+ 'provided commonName "%s"' % (self.host, certhost))
raise TTransportException(
- type=TTransportException.UNKNOWN,
- message='Hostname we connected to "%s" doesn\'t match certificate '
- 'provided commonName "%s"' % (self.host, certhost))
- raise TTransportException(
- type=TTransportException.UNKNOWN,
- message='Could not validate SSL certificate from '
- 'host "%s". Cert=%s' % (self.host, cert))
+ type=TTransportException.UNKNOWN,
+ message='Could not validate SSL certificate from '
+ 'host "%s". Cert=%s' % (self.host, cert))
class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
- """SSL implementation of TServerSocket
+ """SSL implementation of TServerSocket
- This uses the ssl module's wrap_socket() method to provide SSL
- negotiated encryption.
- """
-
- # New signature
- # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
- # Deprecated signature
- # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
- def __init__(self, host=None, port=9090, *args, **kwargs):
- """Positional arguments: ``host``, ``port``, ``unix_socket``
-
- Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
- ``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
- See ssl.wrap_socket documentation.
-
- Alternative keywoard arguments: (Python 2.7.9 or later)
- ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
- ``server_hostname``: Passed to SSLContext.wrap_socket
+ This uses the ssl module's wrap_socket() method to provide SSL
+ negotiated encryption.
"""
- if args:
- if len(args) > 3:
- raise TypeError('Too many positional argument')
- if not self._unix_socket_arg(host, port, args, kwargs):
- self._deprecated_arg(args, kwargs, 0, 'certfile')
- self._deprecated_arg(args, kwargs, 1, 'unix_socket')
- self._deprecated_arg(args, kwargs, 2, 'ciphers')
- if 'ssl_context' not in kwargs:
- # Preserve existing behaviors for default values
- if 'cert_reqs' not in kwargs:
- kwargs['cert_reqs'] = ssl.CERT_NONE
- if'certfile' not in kwargs:
- kwargs['certfile'] = 'cert.pem'
+ # New signature
+ # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
+ # Deprecated signature
+ # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
+ def __init__(self, host=None, port=9090, *args, **kwargs):
+ """Positional arguments: ``host``, ``port``, ``unix_socket``
- unix_socket = kwargs.pop('unix_socket', None)
- TSSLBase.__init__(self, True, None, kwargs)
- TSocket.TServerSocket.__init__(self, host, port, unix_socket)
+ Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
+ ``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
+ See ssl.wrap_socket documentation.
- def setCertfile(self, certfile):
- """Set or change the server certificate file used to wrap new connections.
+ Alternative keywoard arguments: (Python 2.7.9 or later)
+ ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
+ ``server_hostname``: Passed to SSLContext.wrap_socket
+ """
+ if args:
+ if len(args) > 3:
+ raise TypeError('Too many positional argument')
+ if not self._unix_socket_arg(host, port, args, kwargs):
+ self._deprecated_arg(args, kwargs, 0, 'certfile')
+ self._deprecated_arg(args, kwargs, 1, 'unix_socket')
+ self._deprecated_arg(args, kwargs, 2, 'ciphers')
- @param certfile: The filename of the server certificate,
- i.e. '/etc/certs/server.pem'
- @type certfile: str
+ if 'ssl_context' not in kwargs:
+ # Preserve existing behaviors for default values
+ if 'cert_reqs' not in kwargs:
+ kwargs['cert_reqs'] = ssl.CERT_NONE
+ if'certfile' not in kwargs:
+ kwargs['certfile'] = 'cert.pem'
- Raises an IOError exception if the certfile is not present or unreadable.
- """
- warnings.warn('Use certfile property instead.', DeprecationWarning)
- self.certfile = certfile
+ unix_socket = kwargs.pop('unix_socket', None)
+ TSSLBase.__init__(self, True, None, kwargs)
+ TSocket.TServerSocket.__init__(self, host, port, unix_socket)
- def accept(self):
- plain_client, addr = self.handle.accept()
- try:
- client = self._wrap_socket(plain_client)
- except ssl.SSLError:
- logger.error('Error while accepting from %s', addr, exc_info=True)
- # failed handshake/ssl wrap, close socket to client
- plain_client.close()
- # raise
- # We can't raise the exception, because it kills most TServer derived
- # serve() methods.
- # Instead, return None, and let the TServer instance deal with it in
- # other exception handling. (but TSimpleServer dies anyway)
- return None
- result = TSocket.TSocket()
- result.setHandle(client)
- return result
+ def setCertfile(self, certfile):
+ """Set or change the server certificate file used to wrap new connections.
+
+ @param certfile: The filename of the server certificate,
+ i.e. '/etc/certs/server.pem'
+ @type certfile: str
+
+ Raises an IOError exception if the certfile is not present or unreadable.
+ """
+ warnings.warn('Use certfile property instead.', DeprecationWarning)
+ self.certfile = certfile
+
+ def accept(self):
+ plain_client, addr = self.handle.accept()
+ try:
+ client = self._wrap_socket(plain_client)
+ except ssl.SSLError:
+ logger.error('Error while accepting from %s', addr, exc_info=True)
+ # failed handshake/ssl wrap, close socket to client
+ plain_client.close()
+ # raise
+ # We can't raise the exception, because it kills most TServer derived
+ # serve() methods.
+ # Instead, return None, and let the TServer instance deal with it in
+ # other exception handling. (but TSimpleServer dies anyway)
+ return None
+ result = TSocket.TSocket()
+ result.setHandle(client)
+ return result
diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py
index cb204a4..a8ed4b7 100644
--- a/lib/py/src/transport/TSocket.py
+++ b/lib/py/src/transport/TSocket.py
@@ -22,159 +22,159 @@
import socket
import sys
-from .TTransport import *
+from .TTransport import TTransportBase, TTransportException, TServerTransportBase
class TSocketBase(TTransportBase):
- def _resolveAddr(self):
- if self._unix_socket is not None:
- return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
- self._unix_socket)]
- else:
- return socket.getaddrinfo(self.host,
- self.port,
- self._socket_family,
- socket.SOCK_STREAM,
- 0,
- socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
+ def _resolveAddr(self):
+ if self._unix_socket is not None:
+ return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
+ self._unix_socket)]
+ else:
+ return socket.getaddrinfo(self.host,
+ self.port,
+ self._socket_family,
+ socket.SOCK_STREAM,
+ 0,
+ socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
- def close(self):
- if self.handle:
- self.handle.close()
- self.handle = None
+ def close(self):
+ if self.handle:
+ self.handle.close()
+ self.handle = None
class TSocket(TSocketBase):
- """Socket implementation of TTransport base."""
+ """Socket implementation of TTransport base."""
- def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
- """Initialize a TSocket
+ def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
+ """Initialize a TSocket
- @param host(str) The host to connect to.
- @param port(int) The (TCP) port to connect to.
- @param unix_socket(str) The filename of a unix socket to connect to.
- (host and port will be ignored.)
- @param socket_family(int) The socket family to use with this socket.
- """
- self.host = host
- self.port = port
- self.handle = None
- self._unix_socket = unix_socket
- self._timeout = None
- self._socket_family = socket_family
+ @param host(str) The host to connect to.
+ @param port(int) The (TCP) port to connect to.
+ @param unix_socket(str) The filename of a unix socket to connect to.
+ (host and port will be ignored.)
+ @param socket_family(int) The socket family to use with this socket.
+ """
+ self.host = host
+ self.port = port
+ self.handle = None
+ self._unix_socket = unix_socket
+ self._timeout = None
+ self._socket_family = socket_family
- def setHandle(self, h):
- self.handle = h
+ def setHandle(self, h):
+ self.handle = h
- def isOpen(self):
- return self.handle is not None
+ def isOpen(self):
+ return self.handle is not None
- def setTimeout(self, ms):
- if ms is None:
- self._timeout = None
- else:
- self._timeout = ms / 1000.0
+ def setTimeout(self, ms):
+ if ms is None:
+ self._timeout = None
+ else:
+ self._timeout = ms / 1000.0
- if self.handle is not None:
- self.handle.settimeout(self._timeout)
+ if self.handle is not None:
+ self.handle.settimeout(self._timeout)
- def open(self):
- try:
- res0 = self._resolveAddr()
- for res in res0:
- self.handle = socket.socket(res[0], res[1])
- self.handle.settimeout(self._timeout)
+ def open(self):
try:
- self.handle.connect(res[4])
+ res0 = self._resolveAddr()
+ for res in res0:
+ self.handle = socket.socket(res[0], res[1])
+ self.handle.settimeout(self._timeout)
+ try:
+ self.handle.connect(res[4])
+ except socket.error as e:
+ if res is not res0[-1]:
+ continue
+ else:
+ raise e
+ break
except socket.error as e:
- if res is not res0[-1]:
- continue
- else:
- raise e
- break
- except socket.error as e:
- if self._unix_socket:
- message = 'Could not connect to socket %s' % self._unix_socket
- else:
- message = 'Could not connect to %s:%d' % (self.host, self.port)
- raise TTransportException(type=TTransportException.NOT_OPEN,
- message=message)
+ if self._unix_socket:
+ message = 'Could not connect to socket %s' % self._unix_socket
+ else:
+ message = 'Could not connect to %s:%d' % (self.host, self.port)
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message=message)
- def read(self, sz):
- try:
- buff = self.handle.recv(sz)
- except socket.error as e:
- if (e.args[0] == errno.ECONNRESET and
- (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
- # freebsd and Mach don't follow POSIX semantic of recv
- # and fail with ECONNRESET if peer performed shutdown.
- # See corresponding comment and code in TSocket::read()
- # in lib/cpp/src/transport/TSocket.cpp.
- self.close()
- # Trigger the check to raise the END_OF_FILE exception below.
- buff = ''
- else:
- raise
- if len(buff) == 0:
- raise TTransportException(type=TTransportException.END_OF_FILE,
- message='TSocket read 0 bytes')
- return buff
+ def read(self, sz):
+ try:
+ buff = self.handle.recv(sz)
+ except socket.error as e:
+ if (e.args[0] == errno.ECONNRESET and
+ (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
+ # freebsd and Mach don't follow POSIX semantic of recv
+ # and fail with ECONNRESET if peer performed shutdown.
+ # See corresponding comment and code in TSocket::read()
+ # in lib/cpp/src/transport/TSocket.cpp.
+ self.close()
+ # Trigger the check to raise the END_OF_FILE exception below.
+ buff = ''
+ else:
+ raise
+ if len(buff) == 0:
+ raise TTransportException(type=TTransportException.END_OF_FILE,
+ message='TSocket read 0 bytes')
+ return buff
- def write(self, buff):
- if not self.handle:
- raise TTransportException(type=TTransportException.NOT_OPEN,
- message='Transport not open')
- sent = 0
- have = len(buff)
- while sent < have:
- plus = self.handle.send(buff)
- if plus == 0:
- raise TTransportException(type=TTransportException.END_OF_FILE,
- message='TSocket sent 0 bytes')
- sent += plus
- buff = buff[plus:]
+ def write(self, buff):
+ if not self.handle:
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message='Transport not open')
+ sent = 0
+ have = len(buff)
+ while sent < have:
+ plus = self.handle.send(buff)
+ if plus == 0:
+ raise TTransportException(type=TTransportException.END_OF_FILE,
+ message='TSocket sent 0 bytes')
+ sent += plus
+ buff = buff[plus:]
- def flush(self):
- pass
+ def flush(self):
+ pass
class TServerSocket(TSocketBase, TServerTransportBase):
- """Socket implementation of TServerTransport base."""
+ """Socket implementation of TServerTransport base."""
- def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
- self.host = host
- self.port = port
- self._unix_socket = unix_socket
- self._socket_family = socket_family
- self.handle = None
+ def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
+ self.host = host
+ self.port = port
+ self._unix_socket = unix_socket
+ self._socket_family = socket_family
+ self.handle = None
- def listen(self):
- res0 = self._resolveAddr()
- socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
- for res in res0:
- if res[0] is socket_family or res is res0[-1]:
- break
+ def listen(self):
+ res0 = self._resolveAddr()
+ socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
+ for res in res0:
+ if res[0] is socket_family or res is res0[-1]:
+ break
- # We need remove the old unix socket if the file exists and
- # nobody is listening on it.
- if self._unix_socket:
- tmp = socket.socket(res[0], res[1])
- try:
- tmp.connect(res[4])
- except socket.error as err:
- eno, message = err.args
- if eno == errno.ECONNREFUSED:
- os.unlink(res[4])
+ # We need remove the old unix socket if the file exists and
+ # nobody is listening on it.
+ if self._unix_socket:
+ tmp = socket.socket(res[0], res[1])
+ try:
+ tmp.connect(res[4])
+ except socket.error as err:
+ eno, message = err.args
+ if eno == errno.ECONNREFUSED:
+ os.unlink(res[4])
- self.handle = socket.socket(res[0], res[1])
- self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- if hasattr(self.handle, 'settimeout'):
- self.handle.settimeout(None)
- self.handle.bind(res[4])
- self.handle.listen(128)
+ self.handle = socket.socket(res[0], res[1])
+ self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ if hasattr(self.handle, 'settimeout'):
+ self.handle.settimeout(None)
+ self.handle.bind(res[4])
+ self.handle.listen(128)
- def accept(self):
- client, addr = self.handle.accept()
- result = TSocket()
- result.setHandle(client)
- return result
+ def accept(self):
+ client, addr = self.handle.accept()
+ result = TSocket()
+ result.setHandle(client)
+ return result
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index f99b3b9..6669891 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -23,427 +23,426 @@
class TTransportException(TException):
- """Custom Transport Exception class"""
+ """Custom Transport Exception class"""
- UNKNOWN = 0
- NOT_OPEN = 1
- ALREADY_OPEN = 2
- TIMED_OUT = 3
- END_OF_FILE = 4
- NEGATIVE_SIZE = 5
- SIZE_LIMIT = 6
+ UNKNOWN = 0
+ NOT_OPEN = 1
+ ALREADY_OPEN = 2
+ TIMED_OUT = 3
+ END_OF_FILE = 4
+ NEGATIVE_SIZE = 5
+ SIZE_LIMIT = 6
- def __init__(self, type=UNKNOWN, message=None):
- TException.__init__(self, message)
- self.type = type
+ def __init__(self, type=UNKNOWN, message=None):
+ TException.__init__(self, message)
+ self.type = type
class TTransportBase(object):
- """Base class for Thrift transport layer."""
+ """Base class for Thrift transport layer."""
- def isOpen(self):
- pass
+ def isOpen(self):
+ pass
- def open(self):
- pass
+ def open(self):
+ pass
- def close(self):
- pass
+ def close(self):
+ pass
- def read(self, sz):
- pass
+ def read(self, sz):
+ pass
- def readAll(self, sz):
- buff = b''
- have = 0
- while (have < sz):
- chunk = self.read(sz - have)
- have += len(chunk)
- buff += chunk
+ def readAll(self, sz):
+ buff = b''
+ have = 0
+ while (have < sz):
+ chunk = self.read(sz - have)
+ have += len(chunk)
+ buff += chunk
- if len(chunk) == 0:
- raise EOFError()
+ if len(chunk) == 0:
+ raise EOFError()
- return buff
+ return buff
- def write(self, buf):
- pass
+ def write(self, buf):
+ pass
- def flush(self):
- pass
+ def flush(self):
+ pass
# This class should be thought of as an interface.
class CReadableTransport(object):
- """base class for transports that are readable from C"""
+ """base class for transports that are readable from C"""
- # TODO(dreiss): Think about changing this interface to allow us to use
- # a (Python, not c) StringIO instead, because it allows
- # you to write after reading.
+ # TODO(dreiss): Think about changing this interface to allow us to use
+ # a (Python, not c) StringIO instead, because it allows
+ # you to write after reading.
- # NOTE: This is a classic class, so properties will NOT work
- # correctly for setting.
- @property
- def cstringio_buf(self):
- """A cStringIO buffer that contains the current chunk we are reading."""
- pass
+ # NOTE: This is a classic class, so properties will NOT work
+ # correctly for setting.
+ @property
+ def cstringio_buf(self):
+ """A cStringIO buffer that contains the current chunk we are reading."""
+ pass
- def cstringio_refill(self, partialread, reqlen):
- """Refills cstringio_buf.
+ def cstringio_refill(self, partialread, reqlen):
+ """Refills cstringio_buf.
- Returns the currently used buffer (which can but need not be the same as
- the old cstringio_buf). partialread is what the C code has read from the
- buffer, and should be inserted into the buffer before any more reads. The
- return value must be a new, not borrowed reference. Something along the
- lines of self._buf should be fine.
+ Returns the currently used buffer (which can but need not be the same as
+ the old cstringio_buf). partialread is what the C code has read from the
+ buffer, and should be inserted into the buffer before any more reads. The
+ return value must be a new, not borrowed reference. Something along the
+ lines of self._buf should be fine.
- If reqlen bytes can't be read, throw EOFError.
- """
- pass
+ If reqlen bytes can't be read, throw EOFError.
+ """
+ pass
class TServerTransportBase(object):
- """Base class for Thrift server transports."""
+ """Base class for Thrift server transports."""
- def listen(self):
- pass
+ def listen(self):
+ pass
- def accept(self):
- pass
+ def accept(self):
+ pass
- def close(self):
- pass
+ def close(self):
+ pass
class TTransportFactoryBase(object):
- """Base class for a Transport Factory"""
+ """Base class for a Transport Factory"""
- def getTransport(self, trans):
- return trans
+ def getTransport(self, trans):
+ return trans
class TBufferedTransportFactory(object):
- """Factory transport that builds buffered transports"""
+ """Factory transport that builds buffered transports"""
- def getTransport(self, trans):
- buffered = TBufferedTransport(trans)
- return buffered
+ def getTransport(self, trans):
+ buffered = TBufferedTransport(trans)
+ return buffered
class TBufferedTransport(TTransportBase, CReadableTransport):
- """Class that wraps another transport and buffers its I/O.
+ """Class that wraps another transport and buffers its I/O.
- The implementation uses a (configurable) fixed-size read buffer
- but buffers all writes until a flush is performed.
- """
- DEFAULT_BUFFER = 4096
+ The implementation uses a (configurable) fixed-size read buffer
+ but buffers all writes until a flush is performed.
+ """
+ DEFAULT_BUFFER = 4096
- def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
- self.__trans = trans
- self.__wbuf = BufferIO()
- # Pass string argument to initialize read buffer as cStringIO.InputType
- self.__rbuf = BufferIO(b'')
- self.__rbuf_size = rbuf_size
+ def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
+ self.__trans = trans
+ self.__wbuf = BufferIO()
+ # Pass string argument to initialize read buffer as cStringIO.InputType
+ self.__rbuf = BufferIO(b'')
+ self.__rbuf_size = rbuf_size
- def isOpen(self):
- return self.__trans.isOpen()
+ def isOpen(self):
+ return self.__trans.isOpen()
- def open(self):
- return self.__trans.open()
+ def open(self):
+ return self.__trans.open()
- def close(self):
- return self.__trans.close()
+ def close(self):
+ return self.__trans.close()
- def read(self, sz):
- ret = self.__rbuf.read(sz)
- if len(ret) != 0:
- return ret
- self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
- return self.__rbuf.read(sz)
+ def read(self, sz):
+ ret = self.__rbuf.read(sz)
+ if len(ret) != 0:
+ return ret
+ self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
+ return self.__rbuf.read(sz)
- def write(self, buf):
- try:
- self.__wbuf.write(buf)
- except Exception as e:
- # on exception reset wbuf so it doesn't contain a partial function call
- self.__wbuf = BufferIO()
- raise e
- self.__wbuf.getvalue()
+ def write(self, buf):
+ try:
+ self.__wbuf.write(buf)
+ except Exception as e:
+ # on exception reset wbuf so it doesn't contain a partial function call
+ self.__wbuf = BufferIO()
+ raise e
+ self.__wbuf.getvalue()
- def flush(self):
- out = self.__wbuf.getvalue()
- # reset wbuf before write/flush to preserve state on underlying failure
- self.__wbuf = BufferIO()
- self.__trans.write(out)
- self.__trans.flush()
+ def flush(self):
+ out = self.__wbuf.getvalue()
+ # reset wbuf before write/flush to preserve state on underlying failure
+ self.__wbuf = BufferIO()
+ self.__trans.write(out)
+ self.__trans.flush()
- # Implement the CReadableTransport interface.
- @property
- def cstringio_buf(self):
- return self.__rbuf
+ # Implement the CReadableTransport interface.
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
- def cstringio_refill(self, partialread, reqlen):
- retstring = partialread
- if reqlen < self.__rbuf_size:
- # try to make a read of as much as we can.
- retstring += self.__trans.read(self.__rbuf_size)
+ def cstringio_refill(self, partialread, reqlen):
+ retstring = partialread
+ if reqlen < self.__rbuf_size:
+ # try to make a read of as much as we can.
+ retstring += self.__trans.read(self.__rbuf_size)
- # but make sure we do read reqlen bytes.
- if len(retstring) < reqlen:
- retstring += self.__trans.readAll(reqlen - len(retstring))
+ # but make sure we do read reqlen bytes.
+ if len(retstring) < reqlen:
+ retstring += self.__trans.readAll(reqlen - len(retstring))
- self.__rbuf = BufferIO(retstring)
- return self.__rbuf
+ self.__rbuf = BufferIO(retstring)
+ return self.__rbuf
class TMemoryBuffer(TTransportBase, CReadableTransport):
- """Wraps a cBytesIO object as a TTransport.
+ """Wraps a cBytesIO object as a TTransport.
- NOTE: Unlike the C++ version of this class, you cannot write to it
- then immediately read from it. If you want to read from a
- TMemoryBuffer, you must either pass a string to the constructor.
- TODO(dreiss): Make this work like the C++ version.
- """
+ NOTE: Unlike the C++ version of this class, you cannot write to it
+ then immediately read from it. If you want to read from a
+ TMemoryBuffer, you must either pass a string to the constructor.
+ TODO(dreiss): Make this work like the C++ version.
+ """
- def __init__(self, value=None):
- """value -- a value to read from for stringio
+ def __init__(self, value=None):
+ """value -- a value to read from for stringio
- If value is set, this will be a transport for reading,
- otherwise, it is for writing"""
- if value is not None:
- self._buffer = BufferIO(value)
- else:
- self._buffer = BufferIO()
+ If value is set, this will be a transport for reading,
+ otherwise, it is for writing"""
+ if value is not None:
+ self._buffer = BufferIO(value)
+ else:
+ self._buffer = BufferIO()
- def isOpen(self):
- return not self._buffer.closed
+ def isOpen(self):
+ return not self._buffer.closed
- def open(self):
- pass
+ def open(self):
+ pass
- def close(self):
- self._buffer.close()
+ def close(self):
+ self._buffer.close()
- def read(self, sz):
- return self._buffer.read(sz)
+ def read(self, sz):
+ return self._buffer.read(sz)
- def write(self, buf):
- self._buffer.write(buf)
+ def write(self, buf):
+ self._buffer.write(buf)
- def flush(self):
- pass
+ def flush(self):
+ pass
- def getvalue(self):
- return self._buffer.getvalue()
+ def getvalue(self):
+ return self._buffer.getvalue()
- # Implement the CReadableTransport interface.
- @property
- def cstringio_buf(self):
- return self._buffer
+ # Implement the CReadableTransport interface.
+ @property
+ def cstringio_buf(self):
+ return self._buffer
- def cstringio_refill(self, partialread, reqlen):
- # only one shot at reading...
- raise EOFError()
+ def cstringio_refill(self, partialread, reqlen):
+ # only one shot at reading...
+ raise EOFError()
class TFramedTransportFactory(object):
- """Factory transport that builds framed transports"""
+ """Factory transport that builds framed transports"""
- def getTransport(self, trans):
- framed = TFramedTransport(trans)
- return framed
+ def getTransport(self, trans):
+ framed = TFramedTransport(trans)
+ return framed
class TFramedTransport(TTransportBase, CReadableTransport):
- """Class that wraps another transport and frames its I/O when writing."""
+ """Class that wraps another transport and frames its I/O when writing."""
- def __init__(self, trans,):
- self.__trans = trans
- self.__rbuf = BufferIO(b'')
- self.__wbuf = BufferIO()
+ def __init__(self, trans,):
+ self.__trans = trans
+ self.__rbuf = BufferIO(b'')
+ self.__wbuf = BufferIO()
- def isOpen(self):
- return self.__trans.isOpen()
+ def isOpen(self):
+ return self.__trans.isOpen()
- def open(self):
- return self.__trans.open()
+ def open(self):
+ return self.__trans.open()
- def close(self):
- return self.__trans.close()
+ def close(self):
+ return self.__trans.close()
- def read(self, sz):
- ret = self.__rbuf.read(sz)
- if len(ret) != 0:
- return ret
+ def read(self, sz):
+ ret = self.__rbuf.read(sz)
+ if len(ret) != 0:
+ return ret
- self.readFrame()
- return self.__rbuf.read(sz)
+ self.readFrame()
+ return self.__rbuf.read(sz)
- def readFrame(self):
- buff = self.__trans.readAll(4)
- sz, = unpack('!i', buff)
- self.__rbuf = BufferIO(self.__trans.readAll(sz))
+ def readFrame(self):
+ buff = self.__trans.readAll(4)
+ sz, = unpack('!i', buff)
+ self.__rbuf = BufferIO(self.__trans.readAll(sz))
- def write(self, buf):
- self.__wbuf.write(buf)
+ def write(self, buf):
+ self.__wbuf.write(buf)
- def flush(self):
- wout = self.__wbuf.getvalue()
- wsz = len(wout)
- # reset wbuf before write/flush to preserve state on underlying failure
- self.__wbuf = BufferIO()
- # N.B.: Doing this string concatenation is WAY cheaper than making
- # two separate calls to the underlying socket object. Socket writes in
- # Python turn out to be REALLY expensive, but it seems to do a pretty
- # good job of managing string buffer operations without excessive copies
- buf = pack("!i", wsz) + wout
- self.__trans.write(buf)
- self.__trans.flush()
+ def flush(self):
+ wout = self.__wbuf.getvalue()
+ wsz = len(wout)
+ # reset wbuf before write/flush to preserve state on underlying failure
+ self.__wbuf = BufferIO()
+ # N.B.: Doing this string concatenation is WAY cheaper than making
+ # two separate calls to the underlying socket object. Socket writes in
+ # Python turn out to be REALLY expensive, but it seems to do a pretty
+ # good job of managing string buffer operations without excessive copies
+ buf = pack("!i", wsz) + wout
+ self.__trans.write(buf)
+ self.__trans.flush()
- # Implement the CReadableTransport interface.
- @property
- def cstringio_buf(self):
- return self.__rbuf
+ # Implement the CReadableTransport interface.
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
- def cstringio_refill(self, prefix, reqlen):
- # self.__rbuf will already be empty here because fastbinary doesn't
- # ask for a refill until the previous buffer is empty. Therefore,
- # we can start reading new frames immediately.
- while len(prefix) < reqlen:
- self.readFrame()
- prefix += self.__rbuf.getvalue()
- self.__rbuf = BufferIO(prefix)
- return self.__rbuf
+ def cstringio_refill(self, prefix, reqlen):
+ # self.__rbuf will already be empty here because fastbinary doesn't
+ # ask for a refill until the previous buffer is empty. Therefore,
+ # we can start reading new frames immediately.
+ while len(prefix) < reqlen:
+ self.readFrame()
+ prefix += self.__rbuf.getvalue()
+ self.__rbuf = BufferIO(prefix)
+ return self.__rbuf
class TFileObjectTransport(TTransportBase):
- """Wraps a file-like object to make it work as a Thrift transport."""
+ """Wraps a file-like object to make it work as a Thrift transport."""
- def __init__(self, fileobj):
- self.fileobj = fileobj
+ def __init__(self, fileobj):
+ self.fileobj = fileobj
- def isOpen(self):
- return True
+ def isOpen(self):
+ return True
- def close(self):
- self.fileobj.close()
+ def close(self):
+ self.fileobj.close()
- def read(self, sz):
- return self.fileobj.read(sz)
+ def read(self, sz):
+ return self.fileobj.read(sz)
- def write(self, buf):
- self.fileobj.write(buf)
+ def write(self, buf):
+ self.fileobj.write(buf)
- def flush(self):
- self.fileobj.flush()
+ def flush(self):
+ self.fileobj.flush()
class TSaslClientTransport(TTransportBase, CReadableTransport):
- """
- SASL transport
- """
-
- START = 1
- OK = 2
- BAD = 3
- ERROR = 4
- COMPLETE = 5
-
- def __init__(self, transport, host, service, mechanism='GSSAPI',
- **sasl_kwargs):
"""
- transport: an underlying transport to use, typically just a TSocket
- host: the name of the server, from a SASL perspective
- service: the name of the server's service, from a SASL perspective
- mechanism: the name of the preferred mechanism to use
-
- All other kwargs will be passed to the puresasl.client.SASLClient
- constructor.
+ SASL transport
"""
- from puresasl.client import SASLClient
+ START = 1
+ OK = 2
+ BAD = 3
+ ERROR = 4
+ COMPLETE = 5
- self.transport = transport
- self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
+ def __init__(self, transport, host, service, mechanism='GSSAPI',
+ **sasl_kwargs):
+ """
+ transport: an underlying transport to use, typically just a TSocket
+ host: the name of the server, from a SASL perspective
+ service: the name of the server's service, from a SASL perspective
+ mechanism: the name of the preferred mechanism to use
- self.__wbuf = BufferIO()
- self.__rbuf = BufferIO(b'')
+ All other kwargs will be passed to the puresasl.client.SASLClient
+ constructor.
+ """
- def open(self):
- if not self.transport.isOpen():
- self.transport.open()
+ from puresasl.client import SASLClient
- self.send_sasl_msg(self.START, self.sasl.mechanism)
- self.send_sasl_msg(self.OK, self.sasl.process())
+ self.transport = transport
+ self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
- while True:
- status, challenge = self.recv_sasl_msg()
- if status == self.OK:
- self.send_sasl_msg(self.OK, self.sasl.process(challenge))
- elif status == self.COMPLETE:
- if not self.sasl.complete:
- raise TTransportException("The server erroneously indicated "
- "that SASL negotiation was complete")
+ self.__wbuf = BufferIO()
+ self.__rbuf = BufferIO(b'')
+
+ def open(self):
+ if not self.transport.isOpen():
+ self.transport.open()
+
+ self.send_sasl_msg(self.START, self.sasl.mechanism)
+ self.send_sasl_msg(self.OK, self.sasl.process())
+
+ while True:
+ status, challenge = self.recv_sasl_msg()
+ if status == self.OK:
+ self.send_sasl_msg(self.OK, self.sasl.process(challenge))
+ elif status == self.COMPLETE:
+ if not self.sasl.complete:
+ raise TTransportException("The server erroneously indicated "
+ "that SASL negotiation was complete")
+ else:
+ break
+ else:
+ raise TTransportException("Bad SASL negotiation status: %d (%s)"
+ % (status, challenge))
+
+ def send_sasl_msg(self, status, body):
+ header = pack(">BI", status, len(body))
+ self.transport.write(header + body)
+ self.transport.flush()
+
+ def recv_sasl_msg(self):
+ header = self.transport.readAll(5)
+ status, length = unpack(">BI", header)
+ if length > 0:
+ payload = self.transport.readAll(length)
else:
- break
- else:
- raise TTransportException("Bad SASL negotiation status: %d (%s)"
- % (status, challenge))
+ payload = ""
+ return status, payload
- def send_sasl_msg(self, status, body):
- header = pack(">BI", status, len(body))
- self.transport.write(header + body)
- self.transport.flush()
+ def write(self, data):
+ self.__wbuf.write(data)
- def recv_sasl_msg(self):
- header = self.transport.readAll(5)
- status, length = unpack(">BI", header)
- if length > 0:
- payload = self.transport.readAll(length)
- else:
- payload = ""
- return status, payload
+ def flush(self):
+ data = self.__wbuf.getvalue()
+ encoded = self.sasl.wrap(data)
+ self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
+ self.transport.flush()
+ self.__wbuf = BufferIO()
- def write(self, data):
- self.__wbuf.write(data)
+ def read(self, sz):
+ ret = self.__rbuf.read(sz)
+ if len(ret) != 0:
+ return ret
- def flush(self):
- data = self.__wbuf.getvalue()
- encoded = self.sasl.wrap(data)
- self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
- self.transport.flush()
- self.__wbuf = BufferIO()
+ self._read_frame()
+ return self.__rbuf.read(sz)
- def read(self, sz):
- ret = self.__rbuf.read(sz)
- if len(ret) != 0:
- return ret
+ def _read_frame(self):
+ header = self.transport.readAll(4)
+ length, = unpack('!i', header)
+ encoded = self.transport.readAll(length)
+ self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
- self._read_frame()
- return self.__rbuf.read(sz)
+ def close(self):
+ self.sasl.dispose()
+ self.transport.close()
- def _read_frame(self):
- header = self.transport.readAll(4)
- length, = unpack('!i', header)
- encoded = self.transport.readAll(length)
- self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
+ # based on TFramedTransport
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
- def close(self):
- self.sasl.dispose()
- self.transport.close()
-
- # based on TFramedTransport
- @property
- def cstringio_buf(self):
- return self.__rbuf
-
- def cstringio_refill(self, prefix, reqlen):
- # self.__rbuf will already be empty here because fastbinary doesn't
- # ask for a refill until the previous buffer is empty. Therefore,
- # we can start reading new frames immediately.
- while len(prefix) < reqlen:
- self._read_frame()
- prefix += self.__rbuf.getvalue()
- self.__rbuf = BufferIO(prefix)
- return self.__rbuf
-
+ def cstringio_refill(self, prefix, reqlen):
+ # self.__rbuf will already be empty here because fastbinary doesn't
+ # ask for a refill until the previous buffer is empty. Therefore,
+ # we can start reading new frames immediately.
+ while len(prefix) < reqlen:
+ self._read_frame()
+ prefix += self.__rbuf.getvalue()
+ self.__rbuf = BufferIO(prefix)
+ return self.__rbuf
diff --git a/lib/py/src/transport/TTwisted.py b/lib/py/src/transport/TTwisted.py
index 6149a6c..5710b57 100644
--- a/lib/py/src/transport/TTwisted.py
+++ b/lib/py/src/transport/TTwisted.py
@@ -120,7 +120,7 @@
MAX_LENGTH = 2 ** 31 - 1
def __init__(self, client_class, iprot_factory, oprot_factory=None,
- host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
+ host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
"""
host: the name of the server, from a SASL perspective
service: the name of the server's service, from a SASL perspective
@@ -236,7 +236,7 @@
d = self.factory.processor.process(iprot, oprot)
d.addCallbacks(self.processOk, self.processError,
- callbackArgs=(tmo,))
+ callbackArgs=(tmo,))
class IThriftServerFactory(Interface):
@@ -288,7 +288,7 @@
def buildProtocol(self, addr):
p = self.protocol(self.client_class, self.iprot_factory,
- self.oprot_factory)
+ self.oprot_factory)
p.factory = self
return p
@@ -298,7 +298,7 @@
allowedMethods = ('POST',)
def __init__(self, processor, inputProtocolFactory,
- outputProtocolFactory=None):
+ outputProtocolFactory=None):
resource.Resource.__init__(self)
self.inputProtocolFactory = inputProtocolFactory
if outputProtocolFactory is None:
diff --git a/lib/py/src/transport/TZlibTransport.py b/lib/py/src/transport/TZlibTransport.py
index 7fe5853..e848579 100644
--- a/lib/py/src/transport/TZlibTransport.py
+++ b/lib/py/src/transport/TZlibTransport.py
@@ -29,220 +29,220 @@
class TZlibTransportFactory(object):
- """Factory transport that builds zlib compressed transports.
+ """Factory transport that builds zlib compressed transports.
- This factory caches the last single client/transport that it was passed
- and returns the same TZlibTransport object that was created.
+ This factory caches the last single client/transport that it was passed
+ and returns the same TZlibTransport object that was created.
- This caching means the TServer class will get the _same_ transport
- object for both input and output transports from this factory.
- (For non-threaded scenarios only, since the cache only holds one object)
+ This caching means the TServer class will get the _same_ transport
+ object for both input and output transports from this factory.
+ (For non-threaded scenarios only, since the cache only holds one object)
- The purpose of this caching is to allocate only one TZlibTransport where
- only one is really needed (since it must have separate read/write buffers),
- and makes the statistics from getCompSavings() and getCompRatio()
- easier to understand.
- """
- # class scoped cache of last transport given and zlibtransport returned
- _last_trans = None
- _last_z = None
-
- def getTransport(self, trans, compresslevel=9):
- """Wrap a transport, trans, with the TZlibTransport
- compressed transport class, returning a new
- transport to the caller.
-
- @param compresslevel: The zlib compression level, ranging
- from 0 (no compression) to 9 (best compression). Defaults to 9.
- @type compresslevel: int
-
- This method returns a TZlibTransport which wraps the
- passed C{trans} TTransport derived instance.
+ The purpose of this caching is to allocate only one TZlibTransport where
+ only one is really needed (since it must have separate read/write buffers),
+ and makes the statistics from getCompSavings() and getCompRatio()
+ easier to understand.
"""
- if trans == self._last_trans:
- return self._last_z
- ztrans = TZlibTransport(trans, compresslevel)
- self._last_trans = trans
- self._last_z = ztrans
- return ztrans
+ # class scoped cache of last transport given and zlibtransport returned
+ _last_trans = None
+ _last_z = None
+
+ def getTransport(self, trans, compresslevel=9):
+ """Wrap a transport, trans, with the TZlibTransport
+ compressed transport class, returning a new
+ transport to the caller.
+
+ @param compresslevel: The zlib compression level, ranging
+ from 0 (no compression) to 9 (best compression). Defaults to 9.
+ @type compresslevel: int
+
+ This method returns a TZlibTransport which wraps the
+ passed C{trans} TTransport derived instance.
+ """
+ if trans == self._last_trans:
+ return self._last_z
+ ztrans = TZlibTransport(trans, compresslevel)
+ self._last_trans = trans
+ self._last_z = ztrans
+ return ztrans
class TZlibTransport(TTransportBase, CReadableTransport):
- """Class that wraps a transport with zlib, compressing writes
- and decompresses reads, using the python standard
- library zlib module.
- """
- # Read buffer size for the python fastbinary C extension,
- # the TBinaryProtocolAccelerated class.
- DEFAULT_BUFFSIZE = 4096
-
- def __init__(self, trans, compresslevel=9):
- """Create a new TZlibTransport, wrapping C{trans}, another
- TTransport derived object.
-
- @param trans: A thrift transport object, i.e. a TSocket() object.
- @type trans: TTransport
- @param compresslevel: The zlib compression level, ranging
- from 0 (no compression) to 9 (best compression). Default is 9.
- @type compresslevel: int
+ """Class that wraps a transport with zlib, compressing writes
+ and decompresses reads, using the python standard
+ library zlib module.
"""
- self.__trans = trans
- self.compresslevel = compresslevel
- self.__rbuf = BufferIO()
- self.__wbuf = BufferIO()
- self._init_zlib()
- self._init_stats()
+ # Read buffer size for the python fastbinary C extension,
+ # the TBinaryProtocolAccelerated class.
+ DEFAULT_BUFFSIZE = 4096
- def _reinit_buffers(self):
- """Internal method to initialize/reset the internal StringIO objects
- for read and write buffers.
- """
- self.__rbuf = BufferIO()
- self.__wbuf = BufferIO()
+ def __init__(self, trans, compresslevel=9):
+ """Create a new TZlibTransport, wrapping C{trans}, another
+ TTransport derived object.
- def _init_stats(self):
- """Internal method to reset the internal statistics counters
- for compression ratios and bandwidth savings.
- """
- self.bytes_in = 0
- self.bytes_out = 0
- self.bytes_in_comp = 0
- self.bytes_out_comp = 0
+ @param trans: A thrift transport object, i.e. a TSocket() object.
+ @type trans: TTransport
+ @param compresslevel: The zlib compression level, ranging
+ from 0 (no compression) to 9 (best compression). Default is 9.
+ @type compresslevel: int
+ """
+ self.__trans = trans
+ self.compresslevel = compresslevel
+ self.__rbuf = BufferIO()
+ self.__wbuf = BufferIO()
+ self._init_zlib()
+ self._init_stats()
- def _init_zlib(self):
- """Internal method for setting up the zlib compression and
- decompression objects.
- """
- self._zcomp_read = zlib.decompressobj()
- self._zcomp_write = zlib.compressobj(self.compresslevel)
+ def _reinit_buffers(self):
+ """Internal method to initialize/reset the internal StringIO objects
+ for read and write buffers.
+ """
+ self.__rbuf = BufferIO()
+ self.__wbuf = BufferIO()
- def getCompRatio(self):
- """Get the current measured compression ratios (in,out) from
- this transport.
+ def _init_stats(self):
+ """Internal method to reset the internal statistics counters
+ for compression ratios and bandwidth savings.
+ """
+ self.bytes_in = 0
+ self.bytes_out = 0
+ self.bytes_in_comp = 0
+ self.bytes_out_comp = 0
- Returns a tuple of:
- (inbound_compression_ratio, outbound_compression_ratio)
+ def _init_zlib(self):
+ """Internal method for setting up the zlib compression and
+ decompression objects.
+ """
+ self._zcomp_read = zlib.decompressobj()
+ self._zcomp_write = zlib.compressobj(self.compresslevel)
- The compression ratios are computed as:
- compressed / uncompressed
+ def getCompRatio(self):
+ """Get the current measured compression ratios (in,out) from
+ this transport.
- E.g., data that compresses by 10x will have a ratio of: 0.10
- and data that compresses to half of ts original size will
- have a ratio of 0.5
+ Returns a tuple of:
+ (inbound_compression_ratio, outbound_compression_ratio)
- None is returned if no bytes have yet been processed in
- a particular direction.
- """
- r_percent, w_percent = (None, None)
- if self.bytes_in > 0:
- r_percent = self.bytes_in_comp / self.bytes_in
- if self.bytes_out > 0:
- w_percent = self.bytes_out_comp / self.bytes_out
- return (r_percent, w_percent)
+ The compression ratios are computed as:
+ compressed / uncompressed
- def getCompSavings(self):
- """Get the current count of saved bytes due to data
- compression.
+ E.g., data that compresses by 10x will have a ratio of: 0.10
+ and data that compresses to half of ts original size will
+ have a ratio of 0.5
- Returns a tuple of:
- (inbound_saved_bytes, outbound_saved_bytes)
+ None is returned if no bytes have yet been processed in
+ a particular direction.
+ """
+ r_percent, w_percent = (None, None)
+ if self.bytes_in > 0:
+ r_percent = self.bytes_in_comp / self.bytes_in
+ if self.bytes_out > 0:
+ w_percent = self.bytes_out_comp / self.bytes_out
+ return (r_percent, w_percent)
- Note: if compression is actually expanding your
- data (only likely with very tiny thrift objects), then
- the values returned will be negative.
- """
- r_saved = self.bytes_in - self.bytes_in_comp
- w_saved = self.bytes_out - self.bytes_out_comp
- return (r_saved, w_saved)
+ def getCompSavings(self):
+ """Get the current count of saved bytes due to data
+ compression.
- def isOpen(self):
- """Return the underlying transport's open status"""
- return self.__trans.isOpen()
+ Returns a tuple of:
+ (inbound_saved_bytes, outbound_saved_bytes)
- def open(self):
- """Open the underlying transport"""
- self._init_stats()
- return self.__trans.open()
+ Note: if compression is actually expanding your
+ data (only likely with very tiny thrift objects), then
+ the values returned will be negative.
+ """
+ r_saved = self.bytes_in - self.bytes_in_comp
+ w_saved = self.bytes_out - self.bytes_out_comp
+ return (r_saved, w_saved)
- def listen(self):
- """Invoke the underlying transport's listen() method"""
- self.__trans.listen()
+ def isOpen(self):
+ """Return the underlying transport's open status"""
+ return self.__trans.isOpen()
- def accept(self):
- """Accept connections on the underlying transport"""
- return self.__trans.accept()
+ def open(self):
+ """Open the underlying transport"""
+ self._init_stats()
+ return self.__trans.open()
- def close(self):
- """Close the underlying transport,"""
- self._reinit_buffers()
- self._init_zlib()
- return self.__trans.close()
+ def listen(self):
+ """Invoke the underlying transport's listen() method"""
+ self.__trans.listen()
- def read(self, sz):
- """Read up to sz bytes from the decompressed bytes buffer, and
- read from the underlying transport if the decompression
- buffer is empty.
- """
- ret = self.__rbuf.read(sz)
- if len(ret) > 0:
- return ret
- # keep reading from transport until something comes back
- while True:
- if self.readComp(sz):
- break
- ret = self.__rbuf.read(sz)
- return ret
+ def accept(self):
+ """Accept connections on the underlying transport"""
+ return self.__trans.accept()
- def readComp(self, sz):
- """Read compressed data from the underlying transport, then
- decompress it and append it to the internal StringIO read buffer
- """
- zbuf = self.__trans.read(sz)
- zbuf = self._zcomp_read.unconsumed_tail + zbuf
- buf = self._zcomp_read.decompress(zbuf)
- self.bytes_in += len(zbuf)
- self.bytes_in_comp += len(buf)
- old = self.__rbuf.read()
- self.__rbuf = BufferIO(old + buf)
- if len(old) + len(buf) == 0:
- return False
- return True
+ def close(self):
+ """Close the underlying transport,"""
+ self._reinit_buffers()
+ self._init_zlib()
+ return self.__trans.close()
- def write(self, buf):
- """Write some bytes, putting them into the internal write
- buffer for eventual compression.
- """
- self.__wbuf.write(buf)
+ def read(self, sz):
+ """Read up to sz bytes from the decompressed bytes buffer, and
+ read from the underlying transport if the decompression
+ buffer is empty.
+ """
+ ret = self.__rbuf.read(sz)
+ if len(ret) > 0:
+ return ret
+ # keep reading from transport until something comes back
+ while True:
+ if self.readComp(sz):
+ break
+ ret = self.__rbuf.read(sz)
+ return ret
- def flush(self):
- """Flush any queued up data in the write buffer and ensure the
- compression buffer is flushed out to the underlying transport
- """
- wout = self.__wbuf.getvalue()
- if len(wout) > 0:
- zbuf = self._zcomp_write.compress(wout)
- self.bytes_out += len(wout)
- self.bytes_out_comp += len(zbuf)
- else:
- zbuf = ''
- ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
- self.bytes_out_comp += len(ztail)
- if (len(zbuf) + len(ztail)) > 0:
- self.__wbuf = BufferIO()
- self.__trans.write(zbuf + ztail)
- self.__trans.flush()
+ def readComp(self, sz):
+ """Read compressed data from the underlying transport, then
+ decompress it and append it to the internal StringIO read buffer
+ """
+ zbuf = self.__trans.read(sz)
+ zbuf = self._zcomp_read.unconsumed_tail + zbuf
+ buf = self._zcomp_read.decompress(zbuf)
+ self.bytes_in += len(zbuf)
+ self.bytes_in_comp += len(buf)
+ old = self.__rbuf.read()
+ self.__rbuf = BufferIO(old + buf)
+ if len(old) + len(buf) == 0:
+ return False
+ return True
- @property
- def cstringio_buf(self):
- """Implement the CReadableTransport interface"""
- return self.__rbuf
+ def write(self, buf):
+ """Write some bytes, putting them into the internal write
+ buffer for eventual compression.
+ """
+ self.__wbuf.write(buf)
- def cstringio_refill(self, partialread, reqlen):
- """Implement the CReadableTransport interface for refill"""
- retstring = partialread
- if reqlen < self.DEFAULT_BUFFSIZE:
- retstring += self.read(self.DEFAULT_BUFFSIZE)
- while len(retstring) < reqlen:
- retstring += self.read(reqlen - len(retstring))
- self.__rbuf = BufferIO(retstring)
- return self.__rbuf
+ def flush(self):
+ """Flush any queued up data in the write buffer and ensure the
+ compression buffer is flushed out to the underlying transport
+ """
+ wout = self.__wbuf.getvalue()
+ if len(wout) > 0:
+ zbuf = self._zcomp_write.compress(wout)
+ self.bytes_out += len(wout)
+ self.bytes_out_comp += len(zbuf)
+ else:
+ zbuf = ''
+ ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
+ self.bytes_out_comp += len(ztail)
+ if (len(zbuf) + len(ztail)) > 0:
+ self.__wbuf = BufferIO()
+ self.__trans.write(zbuf + ztail)
+ self.__trans.flush()
+
+ @property
+ def cstringio_buf(self):
+ """Implement the CReadableTransport interface"""
+ return self.__rbuf
+
+ def cstringio_refill(self, partialread, reqlen):
+ """Implement the CReadableTransport interface for refill"""
+ retstring = partialread
+ if reqlen < self.DEFAULT_BUFFSIZE:
+ retstring += self.read(self.DEFAULT_BUFFSIZE)
+ while len(retstring) < reqlen:
+ retstring += self.read(reqlen - len(retstring))
+ self.__rbuf = BufferIO(retstring)
+ return self.__rbuf