THRIFT-3596 Better conformance to PEP8
This closes #832
diff --git a/lib/py/src/TMultiplexedProcessor.py b/lib/py/src/TMultiplexedProcessor.py
index a8d5565..581214b 100644
--- a/lib/py/src/TMultiplexedProcessor.py
+++ b/lib/py/src/TMultiplexedProcessor.py
@@ -20,39 +20,36 @@
from thrift.Thrift import TProcessor, TMessageType, TException
from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol
+
class TMultiplexedProcessor(TProcessor):
- def __init__(self):
- self.services = {}
+ def __init__(self):
+ self.services = {}
- def registerProcessor(self, serviceName, processor):
- self.services[serviceName] = processor
+ def registerProcessor(self, serviceName, processor):
+ self.services[serviceName] = processor
- def process(self, iprot, oprot):
- (name, type, seqid) = iprot.readMessageBegin();
- if type != TMessageType.CALL & type != TMessageType.ONEWAY:
- raise TException("TMultiplex protocol only supports CALL & ONEWAY")
+ def process(self, iprot, oprot):
+ (name, type, seqid) = iprot.readMessageBegin()
+ if type != TMessageType.CALL & type != TMessageType.ONEWAY:
+ raise TException("TMultiplex protocol only supports CALL & ONEWAY")
- index = name.find(TMultiplexedProtocol.SEPARATOR)
- if index < 0:
- raise TException("Service name not found in message name: " + name + ". Did you forget to use TMultiplexProtocol in your client?")
+ index = name.find(TMultiplexedProtocol.SEPARATOR)
+ if index < 0:
+ raise TException("Service name not found in message name: " + name + ". Did you forget to use TMultiplexProtocol in your client?")
- serviceName = name[0:index]
- call = name[index+len(TMultiplexedProtocol.SEPARATOR):]
- if not serviceName in self.services:
- raise TException("Service name not found: " + serviceName + ". Did you forget to call registerProcessor()?")
+ serviceName = name[0:index]
+ call = name[index + len(TMultiplexedProtocol.SEPARATOR):]
+ if serviceName not in self.services:
+ raise TException("Service name not found: " + serviceName + ". Did you forget to call registerProcessor()?")
- standardMessage = (
- call,
- type,
- seqid
- )
- return self.services[serviceName].process(StoredMessageProtocol(iprot, standardMessage), oprot)
+ standardMessage = (call, type, seqid)
+ return self.services[serviceName].process(StoredMessageProtocol(iprot, standardMessage), oprot)
class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator):
- def __init__(self, protocol, messageBegin):
- TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
- self.messageBegin = messageBegin
+ def __init__(self, protocol, messageBegin):
+ TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
+ self.messageBegin = messageBegin
- def readMessageBegin(self):
- return self.messageBegin
+ def readMessageBegin(self):
+ return self.messageBegin
diff --git a/lib/py/src/TSCons.py b/lib/py/src/TSCons.py
index ed2601a..bc67d70 100644
--- a/lib/py/src/TSCons.py
+++ b/lib/py/src/TSCons.py
@@ -20,18 +20,17 @@
from os import path
from SCons.Builder import Builder
from six.moves import map
-from six.moves import zip
def scons_env(env, add=''):
- opath = path.dirname(path.abspath('$TARGET'))
- lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE'
- cppbuild = Builder(action=lstr)
- env.Append(BUILDERS={'ThriftCpp': cppbuild})
+ opath = path.dirname(path.abspath('$TARGET'))
+ lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE'
+ cppbuild = Builder(action=lstr)
+ env.Append(BUILDERS={'ThriftCpp': cppbuild})
def gen_cpp(env, dir, file):
- scons_env(env)
- suffixes = ['_types.h', '_types.cpp']
- targets = map(lambda s: 'gen-cpp/' + file + s, suffixes)
- return env.ThriftCpp(targets, dir + file + '.thrift')
+ scons_env(env)
+ suffixes = ['_types.h', '_types.cpp']
+ targets = map(lambda s: 'gen-cpp/' + file + s, suffixes)
+ return env.ThriftCpp(targets, dir + file + '.thrift')
diff --git a/lib/py/src/TTornado.py b/lib/py/src/TTornado.py
index e3b4df7..e01a49f 100644
--- a/lib/py/src/TTornado.py
+++ b/lib/py/src/TTornado.py
@@ -18,10 +18,9 @@
#
from __future__ import absolute_import
+import logging
import socket
import struct
-import logging
-logger = logging.getLogger(__name__)
from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
@@ -32,6 +31,8 @@
__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
+logger = logging.getLogger(__name__)
+
class _Lock(object):
def __init__(self):
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index 11ee796..c4dabdc 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -21,170 +21,172 @@
class TType(object):
- STOP = 0
- VOID = 1
- BOOL = 2
- BYTE = 3
- I08 = 3
- DOUBLE = 4
- I16 = 6
- I32 = 8
- I64 = 10
- STRING = 11
- UTF7 = 11
- STRUCT = 12
- MAP = 13
- SET = 14
- LIST = 15
- UTF8 = 16
- UTF16 = 17
+ STOP = 0
+ VOID = 1
+ BOOL = 2
+ BYTE = 3
+ I08 = 3
+ DOUBLE = 4
+ I16 = 6
+ I32 = 8
+ I64 = 10
+ STRING = 11
+ UTF7 = 11
+ STRUCT = 12
+ MAP = 13
+ SET = 14
+ LIST = 15
+ UTF8 = 16
+ UTF16 = 17
- _VALUES_TO_NAMES = ('STOP',
- 'VOID',
- 'BOOL',
- 'BYTE',
- 'DOUBLE',
- None,
- 'I16',
- None,
- 'I32',
- None,
- 'I64',
- 'STRING',
- 'STRUCT',
- 'MAP',
- 'SET',
- 'LIST',
- 'UTF8',
- 'UTF16')
+ _VALUES_TO_NAMES = (
+ 'STOP',
+ 'VOID',
+ 'BOOL',
+ 'BYTE',
+ 'DOUBLE',
+ None,
+ 'I16',
+ None,
+ 'I32',
+ None,
+ 'I64',
+ 'STRING',
+ 'STRUCT',
+ 'MAP',
+ 'SET',
+ 'LIST',
+ 'UTF8',
+ 'UTF16',
+ )
class TMessageType(object):
- CALL = 1
- REPLY = 2
- EXCEPTION = 3
- ONEWAY = 4
+ CALL = 1
+ REPLY = 2
+ EXCEPTION = 3
+ ONEWAY = 4
class TProcessor(object):
- """Base class for procsessor, which works on two streams."""
+ """Base class for procsessor, which works on two streams."""
- def process(iprot, oprot):
- pass
+ def process(iprot, oprot):
+ pass
class TException(Exception):
- """Base class for all thrift exceptions."""
+ """Base class for all thrift exceptions."""
- # BaseException.message is deprecated in Python v[2.6,3.0)
- if (2, 6, 0) <= sys.version_info < (3, 0):
- def _get_message(self):
- return self._message
+ # BaseException.message is deprecated in Python v[2.6,3.0)
+ if (2, 6, 0) <= sys.version_info < (3, 0):
+ def _get_message(self):
+ return self._message
- def _set_message(self, message):
- self._message = message
- message = property(_get_message, _set_message)
+ def _set_message(self, message):
+ self._message = message
+ message = property(_get_message, _set_message)
- def __init__(self, message=None):
- Exception.__init__(self, message)
- self.message = message
+ def __init__(self, message=None):
+ Exception.__init__(self, message)
+ self.message = message
class TApplicationException(TException):
- """Application level thrift exceptions."""
+ """Application level thrift exceptions."""
- UNKNOWN = 0
- UNKNOWN_METHOD = 1
- INVALID_MESSAGE_TYPE = 2
- WRONG_METHOD_NAME = 3
- BAD_SEQUENCE_ID = 4
- MISSING_RESULT = 5
- INTERNAL_ERROR = 6
- PROTOCOL_ERROR = 7
- INVALID_TRANSFORM = 8
- INVALID_PROTOCOL = 9
- UNSUPPORTED_CLIENT_TYPE = 10
+ UNKNOWN = 0
+ UNKNOWN_METHOD = 1
+ INVALID_MESSAGE_TYPE = 2
+ WRONG_METHOD_NAME = 3
+ BAD_SEQUENCE_ID = 4
+ MISSING_RESULT = 5
+ INTERNAL_ERROR = 6
+ PROTOCOL_ERROR = 7
+ INVALID_TRANSFORM = 8
+ INVALID_PROTOCOL = 9
+ UNSUPPORTED_CLIENT_TYPE = 10
- def __init__(self, type=UNKNOWN, message=None):
- TException.__init__(self, message)
- self.type = type
+ def __init__(self, type=UNKNOWN, message=None):
+ TException.__init__(self, message)
+ self.type = type
- def __str__(self):
- if self.message:
- return self.message
- elif self.type == self.UNKNOWN_METHOD:
- return 'Unknown method'
- elif self.type == self.INVALID_MESSAGE_TYPE:
- return 'Invalid message type'
- elif self.type == self.WRONG_METHOD_NAME:
- return 'Wrong method name'
- elif self.type == self.BAD_SEQUENCE_ID:
- return 'Bad sequence ID'
- elif self.type == self.MISSING_RESULT:
- return 'Missing result'
- elif self.type == self.INTERNAL_ERROR:
- return 'Internal error'
- elif self.type == self.PROTOCOL_ERROR:
- return 'Protocol error'
- elif self.type == self.INVALID_TRANSFORM:
- return 'Invalid transform'
- elif self.type == self.INVALID_PROTOCOL:
- return 'Invalid protocol'
- elif self.type == self.UNSUPPORTED_CLIENT_TYPE:
- return 'Unsupported client type'
- else:
- return 'Default (unknown) TApplicationException'
-
- def read(self, iprot):
- iprot.readStructBegin()
- while True:
- (fname, ftype, fid) = iprot.readFieldBegin()
- if ftype == TType.STOP:
- break
- if fid == 1:
- if ftype == TType.STRING:
- self.message = iprot.readString()
+ def __str__(self):
+ if self.message:
+ return self.message
+ elif self.type == self.UNKNOWN_METHOD:
+ return 'Unknown method'
+ elif self.type == self.INVALID_MESSAGE_TYPE:
+ return 'Invalid message type'
+ elif self.type == self.WRONG_METHOD_NAME:
+ return 'Wrong method name'
+ elif self.type == self.BAD_SEQUENCE_ID:
+ return 'Bad sequence ID'
+ elif self.type == self.MISSING_RESULT:
+ return 'Missing result'
+ elif self.type == self.INTERNAL_ERROR:
+ return 'Internal error'
+ elif self.type == self.PROTOCOL_ERROR:
+ return 'Protocol error'
+ elif self.type == self.INVALID_TRANSFORM:
+ return 'Invalid transform'
+ elif self.type == self.INVALID_PROTOCOL:
+ return 'Invalid protocol'
+ elif self.type == self.UNSUPPORTED_CLIENT_TYPE:
+ return 'Unsupported client type'
else:
- iprot.skip(ftype)
- elif fid == 2:
- if ftype == TType.I32:
- self.type = iprot.readI32()
- else:
- iprot.skip(ftype)
- else:
- iprot.skip(ftype)
- iprot.readFieldEnd()
- iprot.readStructEnd()
+ return 'Default (unknown) TApplicationException'
- def write(self, oprot):
- oprot.writeStructBegin('TApplicationException')
- if self.message is not None:
- oprot.writeFieldBegin('message', TType.STRING, 1)
- oprot.writeString(self.message)
- oprot.writeFieldEnd()
- if self.type is not None:
- oprot.writeFieldBegin('type', TType.I32, 2)
- oprot.writeI32(self.type)
- oprot.writeFieldEnd()
- oprot.writeFieldStop()
- oprot.writeStructEnd()
+ def read(self, iprot):
+ iprot.readStructBegin()
+ while True:
+ (fname, ftype, fid) = iprot.readFieldBegin()
+ if ftype == TType.STOP:
+ break
+ if fid == 1:
+ if ftype == TType.STRING:
+ self.message = iprot.readString()
+ else:
+ iprot.skip(ftype)
+ elif fid == 2:
+ if ftype == TType.I32:
+ self.type = iprot.readI32()
+ else:
+ iprot.skip(ftype)
+ else:
+ iprot.skip(ftype)
+ iprot.readFieldEnd()
+ iprot.readStructEnd()
+
+ def write(self, oprot):
+ oprot.writeStructBegin('TApplicationException')
+ if self.message is not None:
+ oprot.writeFieldBegin('message', TType.STRING, 1)
+ oprot.writeString(self.message)
+ oprot.writeFieldEnd()
+ if self.type is not None:
+ oprot.writeFieldBegin('type', TType.I32, 2)
+ oprot.writeI32(self.type)
+ oprot.writeFieldEnd()
+ oprot.writeFieldStop()
+ oprot.writeStructEnd()
class TFrozenDict(dict):
- """A dictionary that is "frozen" like a frozenset"""
+ """A dictionary that is "frozen" like a frozenset"""
- def __init__(self, *args, **kwargs):
- super(TFrozenDict, self).__init__(*args, **kwargs)
- # Sort the items so they will be in a consistent order.
- # XOR in the hash of the class so we don't collide with
- # the hash of a list of tuples.
- self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
+ def __init__(self, *args, **kwargs):
+ super(TFrozenDict, self).__init__(*args, **kwargs)
+ # Sort the items so they will be in a consistent order.
+ # XOR in the hash of the class so we don't collide with
+ # the hash of a list of tuples.
+ self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
- def __setitem__(self, *args):
- raise TypeError("Can't modify frozen TFreezableDict")
+ def __setitem__(self, *args):
+ raise TypeError("Can't modify frozen TFreezableDict")
- def __delitem__(self, *args):
- raise TypeError("Can't modify frozen TFreezableDict")
+ def __delitem__(self, *args):
+ raise TypeError("Can't modify frozen TFreezableDict")
- def __hash__(self):
- return self.__hashval
+ def __hash__(self):
+ return self.__hashval
diff --git a/lib/py/src/compat.py b/lib/py/src/compat.py
index 06f672a..42403ea 100644
--- a/lib/py/src/compat.py
+++ b/lib/py/src/compat.py
@@ -1,27 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
import sys
if sys.version_info[0] == 2:
- from cStringIO import StringIO as BufferIO
+ from cStringIO import StringIO as BufferIO
- def binary_to_str(bin_val):
- return bin_val
+ def binary_to_str(bin_val):
+ return bin_val
- def str_to_binary(str_val):
- return str_val
+ def str_to_binary(str_val):
+ return str_val
else:
- from io import BytesIO as BufferIO
+ from io import BytesIO as BufferIO
- def binary_to_str(bin_val):
- try:
- return bin_val.decode('utf8')
- except:
- return bin_val
+ def binary_to_str(bin_val):
+ try:
+ return bin_val.decode('utf8')
+ except:
+ return bin_val
- def str_to_binary(str_val):
- try:
- return bytes(str_val, 'utf8')
- except:
- return str_val
+ def str_to_binary(str_val):
+ try:
+ return bytes(str_val, 'utf8')
+ except:
+ return str_val
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index d106f4e..87caf0d 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -21,78 +21,79 @@
from thrift.transport import TTransport
try:
- from thrift.protocol import fastbinary
+ from thrift.protocol import fastbinary
except:
- fastbinary = None
+ fastbinary = None
class TBase(object):
- __slots__ = ()
+ __slots__ = ()
- def __repr__(self):
- L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
- return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
+ def __repr__(self):
+ L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
+ return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
- def __eq__(self, other):
- if not isinstance(other, self.__class__):
- return False
- for attr in self.__slots__:
- my_val = getattr(self, attr)
- other_val = getattr(other, attr)
- if my_val != other_val:
- return False
- return True
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return False
+ for attr in self.__slots__:
+ my_val = getattr(self, attr)
+ other_val = getattr(other, attr)
+ if my_val != other_val:
+ return False
+ return True
- def __ne__(self, other):
- return not (self == other)
+ def __ne__(self, other):
+ return not (self == other)
- def read(self, iprot):
- if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
- isinstance(iprot.trans, TTransport.CReadableTransport) and
- self.thrift_spec is not None and
- fastbinary is not None):
- fastbinary.decode_binary(self,
- iprot.trans,
- (self.__class__, self.thrift_spec),
- iprot.string_length_limit,
- iprot.container_length_limit)
- return
- iprot.readStruct(self, self.thrift_spec)
+ def read(self, iprot):
+ if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
+ isinstance(iprot.trans, TTransport.CReadableTransport) and
+ self.thrift_spec is not None and
+ fastbinary is not None):
+ fastbinary.decode_binary(self,
+ iprot.trans,
+ (self.__class__, self.thrift_spec),
+ iprot.string_length_limit,
+ iprot.container_length_limit)
+ return
+ iprot.readStruct(self, self.thrift_spec)
- def write(self, oprot):
- if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
- self.thrift_spec is not None and
- fastbinary is not None):
- oprot.trans.write(
- fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))
- return
- oprot.writeStruct(self, self.thrift_spec)
+ def write(self, oprot):
+ if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
+ self.thrift_spec is not None and
+ fastbinary is not None):
+ oprot.trans.write(
+ fastbinary.encode_binary(
+ self, (self.__class__, self.thrift_spec)))
+ return
+ oprot.writeStruct(self, self.thrift_spec)
class TExceptionBase(TBase, Exception):
- pass
+ pass
class TFrozenBase(TBase):
- def __setitem__(self, *args):
- raise TypeError("Can't modify frozen struct")
+ def __setitem__(self, *args):
+ raise TypeError("Can't modify frozen struct")
- def __delitem__(self, *args):
- raise TypeError("Can't modify frozen struct")
+ def __delitem__(self, *args):
+ raise TypeError("Can't modify frozen struct")
- def __hash__(self, *args):
- return hash(self.__class__) ^ hash(self.__slots__)
+ def __hash__(self, *args):
+ return hash(self.__class__) ^ hash(self.__slots__)
- @classmethod
- def read(cls, iprot):
- if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
- isinstance(iprot.trans, TTransport.CReadableTransport) and
- cls.thrift_spec is not None and
- fastbinary is not None):
- self = cls()
- return fastbinary.decode_binary(None,
- iprot.trans,
- (self.__class__, self.thrift_spec),
- iprot.string_length_limit,
- iprot.container_length_limit)
- return iprot.readStruct(cls, cls.thrift_spec, True)
+ @classmethod
+ def read(cls, iprot):
+ if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
+ isinstance(iprot.trans, TTransport.CReadableTransport) and
+ cls.thrift_spec is not None and
+ fastbinary is not None):
+ self = cls()
+ return fastbinary.decode_binary(None,
+ iprot.trans,
+ (self.__class__, self.thrift_spec),
+ iprot.string_length_limit,
+ iprot.container_length_limit)
+ return iprot.readStruct(cls, cls.thrift_spec, True)
diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py
index db4ea31..7fce12f 100644
--- a/lib/py/src/protocol/TBinaryProtocol.py
+++ b/lib/py/src/protocol/TBinaryProtocol.py
@@ -22,264 +22,264 @@
class TBinaryProtocol(TProtocolBase):
- """Binary implementation of the Thrift protocol driver."""
+ """Binary implementation of the Thrift protocol driver."""
- # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be
- # positive, converting this into a long. If we hardcode the int value
- # instead it'll stay in 32 bit-land.
+ # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be
+ # positive, converting this into a long. If we hardcode the int value
+ # instead it'll stay in 32 bit-land.
- # VERSION_MASK = 0xffff0000
- VERSION_MASK = -65536
+ # VERSION_MASK = 0xffff0000
+ VERSION_MASK = -65536
- # VERSION_1 = 0x80010000
- VERSION_1 = -2147418112
+ # VERSION_1 = 0x80010000
+ VERSION_1 = -2147418112
- TYPE_MASK = 0x000000ff
+ TYPE_MASK = 0x000000ff
- def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs):
- TProtocolBase.__init__(self, trans)
- self.strictRead = strictRead
- self.strictWrite = strictWrite
- self.string_length_limit = kwargs.get('string_length_limit', None)
- self.container_length_limit = kwargs.get('container_length_limit', None)
+ def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs):
+ TProtocolBase.__init__(self, trans)
+ self.strictRead = strictRead
+ self.strictWrite = strictWrite
+ self.string_length_limit = kwargs.get('string_length_limit', None)
+ self.container_length_limit = kwargs.get('container_length_limit', None)
- def _check_string_length(self, length):
- self._check_length(self.string_length_limit, length)
+ def _check_string_length(self, length):
+ self._check_length(self.string_length_limit, length)
- def _check_container_length(self, length):
- self._check_length(self.container_length_limit, length)
+ def _check_container_length(self, length):
+ self._check_length(self.container_length_limit, length)
- def writeMessageBegin(self, name, type, seqid):
- if self.strictWrite:
- self.writeI32(TBinaryProtocol.VERSION_1 | type)
- self.writeString(name)
- self.writeI32(seqid)
- else:
- self.writeString(name)
- self.writeByte(type)
- self.writeI32(seqid)
+ def writeMessageBegin(self, name, type, seqid):
+ if self.strictWrite:
+ self.writeI32(TBinaryProtocol.VERSION_1 | type)
+ self.writeString(name)
+ self.writeI32(seqid)
+ else:
+ self.writeString(name)
+ self.writeByte(type)
+ self.writeI32(seqid)
- def writeMessageEnd(self):
- pass
+ def writeMessageEnd(self):
+ pass
- def writeStructBegin(self, name):
- pass
+ def writeStructBegin(self, name):
+ pass
- def writeStructEnd(self):
- pass
+ def writeStructEnd(self):
+ pass
- def writeFieldBegin(self, name, type, id):
- self.writeByte(type)
- self.writeI16(id)
+ def writeFieldBegin(self, name, type, id):
+ self.writeByte(type)
+ self.writeI16(id)
- def writeFieldEnd(self):
- pass
+ def writeFieldEnd(self):
+ pass
- def writeFieldStop(self):
- self.writeByte(TType.STOP)
+ def writeFieldStop(self):
+ self.writeByte(TType.STOP)
- def writeMapBegin(self, ktype, vtype, size):
- self.writeByte(ktype)
- self.writeByte(vtype)
- self.writeI32(size)
+ def writeMapBegin(self, ktype, vtype, size):
+ self.writeByte(ktype)
+ self.writeByte(vtype)
+ self.writeI32(size)
- def writeMapEnd(self):
- pass
+ def writeMapEnd(self):
+ pass
- def writeListBegin(self, etype, size):
- self.writeByte(etype)
- self.writeI32(size)
+ def writeListBegin(self, etype, size):
+ self.writeByte(etype)
+ self.writeI32(size)
- def writeListEnd(self):
- pass
+ def writeListEnd(self):
+ pass
- def writeSetBegin(self, etype, size):
- self.writeByte(etype)
- self.writeI32(size)
+ def writeSetBegin(self, etype, size):
+ self.writeByte(etype)
+ self.writeI32(size)
- def writeSetEnd(self):
- pass
+ def writeSetEnd(self):
+ pass
- def writeBool(self, bool):
- if bool:
- self.writeByte(1)
- else:
- self.writeByte(0)
+ def writeBool(self, bool):
+ if bool:
+ self.writeByte(1)
+ else:
+ self.writeByte(0)
- def writeByte(self, byte):
- buff = pack("!b", byte)
- self.trans.write(buff)
+ def writeByte(self, byte):
+ buff = pack("!b", byte)
+ self.trans.write(buff)
- def writeI16(self, i16):
- buff = pack("!h", i16)
- self.trans.write(buff)
+ def writeI16(self, i16):
+ buff = pack("!h", i16)
+ self.trans.write(buff)
- def writeI32(self, i32):
- buff = pack("!i", i32)
- self.trans.write(buff)
+ def writeI32(self, i32):
+ buff = pack("!i", i32)
+ self.trans.write(buff)
- def writeI64(self, i64):
- buff = pack("!q", i64)
- self.trans.write(buff)
+ def writeI64(self, i64):
+ buff = pack("!q", i64)
+ self.trans.write(buff)
- def writeDouble(self, dub):
- buff = pack("!d", dub)
- self.trans.write(buff)
+ def writeDouble(self, dub):
+ buff = pack("!d", dub)
+ self.trans.write(buff)
- def writeBinary(self, str):
- self.writeI32(len(str))
- self.trans.write(str)
+ def writeBinary(self, str):
+ self.writeI32(len(str))
+ self.trans.write(str)
- def readMessageBegin(self):
- sz = self.readI32()
- if sz < 0:
- version = sz & TBinaryProtocol.VERSION_MASK
- if version != TBinaryProtocol.VERSION_1:
- raise TProtocolException(
- type=TProtocolException.BAD_VERSION,
- message='Bad version in readMessageBegin: %d' % (sz))
- type = sz & TBinaryProtocol.TYPE_MASK
- name = self.readString()
- seqid = self.readI32()
- else:
- if self.strictRead:
- raise TProtocolException(type=TProtocolException.BAD_VERSION,
- message='No protocol version header')
- name = self.trans.readAll(sz)
- type = self.readByte()
- seqid = self.readI32()
- return (name, type, seqid)
+ def readMessageBegin(self):
+ sz = self.readI32()
+ if sz < 0:
+ version = sz & TBinaryProtocol.VERSION_MASK
+ if version != TBinaryProtocol.VERSION_1:
+ raise TProtocolException(
+ type=TProtocolException.BAD_VERSION,
+ message='Bad version in readMessageBegin: %d' % (sz))
+ type = sz & TBinaryProtocol.TYPE_MASK
+ name = self.readString()
+ seqid = self.readI32()
+ else:
+ if self.strictRead:
+ raise TProtocolException(type=TProtocolException.BAD_VERSION,
+ message='No protocol version header')
+ name = self.trans.readAll(sz)
+ type = self.readByte()
+ seqid = self.readI32()
+ return (name, type, seqid)
- def readMessageEnd(self):
- pass
+ def readMessageEnd(self):
+ pass
- def readStructBegin(self):
- pass
+ def readStructBegin(self):
+ pass
- def readStructEnd(self):
- pass
+ def readStructEnd(self):
+ pass
- def readFieldBegin(self):
- type = self.readByte()
- if type == TType.STOP:
- return (None, type, 0)
- id = self.readI16()
- return (None, type, id)
+ def readFieldBegin(self):
+ type = self.readByte()
+ if type == TType.STOP:
+ return (None, type, 0)
+ id = self.readI16()
+ return (None, type, id)
- def readFieldEnd(self):
- pass
+ def readFieldEnd(self):
+ pass
- def readMapBegin(self):
- ktype = self.readByte()
- vtype = self.readByte()
- size = self.readI32()
- self._check_container_length(size)
- return (ktype, vtype, size)
+ def readMapBegin(self):
+ ktype = self.readByte()
+ vtype = self.readByte()
+ size = self.readI32()
+ self._check_container_length(size)
+ return (ktype, vtype, size)
- def readMapEnd(self):
- pass
+ def readMapEnd(self):
+ pass
- def readListBegin(self):
- etype = self.readByte()
- size = self.readI32()
- self._check_container_length(size)
- return (etype, size)
+ def readListBegin(self):
+ etype = self.readByte()
+ size = self.readI32()
+ self._check_container_length(size)
+ return (etype, size)
- def readListEnd(self):
- pass
+ def readListEnd(self):
+ pass
- def readSetBegin(self):
- etype = self.readByte()
- size = self.readI32()
- self._check_container_length(size)
- return (etype, size)
+ def readSetBegin(self):
+ etype = self.readByte()
+ size = self.readI32()
+ self._check_container_length(size)
+ return (etype, size)
- def readSetEnd(self):
- pass
+ def readSetEnd(self):
+ pass
- def readBool(self):
- byte = self.readByte()
- if byte == 0:
- return False
- return True
+ def readBool(self):
+ byte = self.readByte()
+ if byte == 0:
+ return False
+ return True
- def readByte(self):
- buff = self.trans.readAll(1)
- val, = unpack('!b', buff)
- return val
+ def readByte(self):
+ buff = self.trans.readAll(1)
+ val, = unpack('!b', buff)
+ return val
- def readI16(self):
- buff = self.trans.readAll(2)
- val, = unpack('!h', buff)
- return val
+ def readI16(self):
+ buff = self.trans.readAll(2)
+ val, = unpack('!h', buff)
+ return val
- def readI32(self):
- buff = self.trans.readAll(4)
- val, = unpack('!i', buff)
- return val
+ def readI32(self):
+ buff = self.trans.readAll(4)
+ val, = unpack('!i', buff)
+ return val
- def readI64(self):
- buff = self.trans.readAll(8)
- val, = unpack('!q', buff)
- return val
+ def readI64(self):
+ buff = self.trans.readAll(8)
+ val, = unpack('!q', buff)
+ return val
- def readDouble(self):
- buff = self.trans.readAll(8)
- val, = unpack('!d', buff)
- return val
+ def readDouble(self):
+ buff = self.trans.readAll(8)
+ val, = unpack('!d', buff)
+ return val
- def readBinary(self):
- size = self.readI32()
- self._check_string_length(size)
- s = self.trans.readAll(size)
- return s
+ def readBinary(self):
+ size = self.readI32()
+ self._check_string_length(size)
+ s = self.trans.readAll(size)
+ return s
class TBinaryProtocolFactory(object):
- def __init__(self, strictRead=False, strictWrite=True, **kwargs):
- self.strictRead = strictRead
- self.strictWrite = strictWrite
- self.string_length_limit = kwargs.get('string_length_limit', None)
- self.container_length_limit = kwargs.get('container_length_limit', None)
+ def __init__(self, strictRead=False, strictWrite=True, **kwargs):
+ self.strictRead = strictRead
+ self.strictWrite = strictWrite
+ self.string_length_limit = kwargs.get('string_length_limit', None)
+ self.container_length_limit = kwargs.get('container_length_limit', None)
- def getProtocol(self, trans):
- prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite,
- string_length_limit=self.string_length_limit,
- container_length_limit=self.container_length_limit)
- return prot
+ def getProtocol(self, trans):
+ prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite,
+ string_length_limit=self.string_length_limit,
+ container_length_limit=self.container_length_limit)
+ return prot
class TBinaryProtocolAccelerated(TBinaryProtocol):
- """C-Accelerated version of TBinaryProtocol.
+ """C-Accelerated version of TBinaryProtocol.
- This class does not override any of TBinaryProtocol's methods,
- but the generated code recognizes it directly and will call into
- our C module to do the encoding, bypassing this object entirely.
- We inherit from TBinaryProtocol so that the normal TBinaryProtocol
- encoding can happen if the fastbinary module doesn't work for some
- reason. (TODO(dreiss): Make this happen sanely in more cases.)
+ This class does not override any of TBinaryProtocol's methods,
+ but the generated code recognizes it directly and will call into
+ our C module to do the encoding, bypassing this object entirely.
+ We inherit from TBinaryProtocol so that the normal TBinaryProtocol
+ encoding can happen if the fastbinary module doesn't work for some
+ reason. (TODO(dreiss): Make this happen sanely in more cases.)
- In order to take advantage of the C module, just use
- TBinaryProtocolAccelerated instead of TBinaryProtocol.
+ In order to take advantage of the C module, just use
+ TBinaryProtocolAccelerated instead of TBinaryProtocol.
- NOTE: This code was contributed by an external developer.
- The internal Thrift team has reviewed and tested it,
- but we cannot guarantee that it is production-ready.
- Please feel free to report bugs and/or success stories
- to the public mailing list.
- """
- pass
+ NOTE: This code was contributed by an external developer.
+ The internal Thrift team has reviewed and tested it,
+ but we cannot guarantee that it is production-ready.
+ Please feel free to report bugs and/or success stories
+ to the public mailing list.
+ """
+ pass
class TBinaryProtocolAcceleratedFactory(object):
- def __init__(self,
- string_length_limit=None,
- container_length_limit=None):
- self.string_length_limit = string_length_limit
- self.container_length_limit = container_length_limit
+ def __init__(self,
+ string_length_limit=None,
+ container_length_limit=None):
+ self.string_length_limit = string_length_limit
+ self.container_length_limit = container_length_limit
- def getProtocol(self, trans):
- return TBinaryProtocolAccelerated(
- trans,
- string_length_limit=self.string_length_limit,
- container_length_limit=self.container_length_limit)
+ def getProtocol(self, trans):
+ return TBinaryProtocolAccelerated(
+ trans,
+ string_length_limit=self.string_length_limit,
+ container_length_limit=self.container_length_limit)
diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py
index 3d9c0e6..8d3db1a 100644
--- a/lib/py/src/protocol/TCompactProtocol.py
+++ b/lib/py/src/protocol/TCompactProtocol.py
@@ -36,390 +36,391 @@
def make_helper(v_from, container):
- def helper(func):
- def nested(self, *args, **kwargs):
- assert self.state in (v_from, container), (self.state, v_from, container)
- return func(self, *args, **kwargs)
- return nested
- return helper
+ def helper(func):
+ def nested(self, *args, **kwargs):
+ assert self.state in (v_from, container), (self.state, v_from, container)
+ return func(self, *args, **kwargs)
+ return nested
+ return helper
writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
reader = make_helper(VALUE_READ, CONTAINER_READ)
def makeZigZag(n, bits):
- checkIntegerLimits(n, bits)
- return (n << 1) ^ (n >> (bits - 1))
+ checkIntegerLimits(n, bits)
+ return (n << 1) ^ (n >> (bits - 1))
def fromZigZag(n):
- return (n >> 1) ^ -(n & 1)
+ return (n >> 1) ^ -(n & 1)
def writeVarint(trans, n):
- out = bytearray()
- while True:
- if n & ~0x7f == 0:
- out.append(n)
- break
- else:
- out.append((n & 0xff) | 0x80)
- n = n >> 7
- trans.write(bytes(out))
+ out = bytearray()
+ while True:
+ if n & ~0x7f == 0:
+ out.append(n)
+ break
+ else:
+ out.append((n & 0xff) | 0x80)
+ n = n >> 7
+ trans.write(bytes(out))
def readVarint(trans):
- result = 0
- shift = 0
- while True:
- x = trans.readAll(1)
- byte = ord(x)
- result |= (byte & 0x7f) << shift
- if byte >> 7 == 0:
- return result
- shift += 7
+ result = 0
+ shift = 0
+ while True:
+ x = trans.readAll(1)
+ byte = ord(x)
+ result |= (byte & 0x7f) << shift
+ if byte >> 7 == 0:
+ return result
+ shift += 7
class CompactType(object):
- STOP = 0x00
- TRUE = 0x01
- FALSE = 0x02
- BYTE = 0x03
- I16 = 0x04
- I32 = 0x05
- I64 = 0x06
- DOUBLE = 0x07
- BINARY = 0x08
- LIST = 0x09
- SET = 0x0A
- MAP = 0x0B
- STRUCT = 0x0C
+ STOP = 0x00
+ TRUE = 0x01
+ FALSE = 0x02
+ BYTE = 0x03
+ I16 = 0x04
+ I32 = 0x05
+ I64 = 0x06
+ DOUBLE = 0x07
+ BINARY = 0x08
+ LIST = 0x09
+ SET = 0x0A
+ MAP = 0x0B
+ STRUCT = 0x0C
-CTYPES = {TType.STOP: CompactType.STOP,
- TType.BOOL: CompactType.TRUE, # used for collection
- TType.BYTE: CompactType.BYTE,
- TType.I16: CompactType.I16,
- TType.I32: CompactType.I32,
- TType.I64: CompactType.I64,
- TType.DOUBLE: CompactType.DOUBLE,
- TType.STRING: CompactType.BINARY,
- TType.STRUCT: CompactType.STRUCT,
- TType.LIST: CompactType.LIST,
- TType.SET: CompactType.SET,
- TType.MAP: CompactType.MAP
- }
+CTYPES = {
+ TType.STOP: CompactType.STOP,
+ TType.BOOL: CompactType.TRUE, # used for collection
+ TType.BYTE: CompactType.BYTE,
+ TType.I16: CompactType.I16,
+ TType.I32: CompactType.I32,
+ TType.I64: CompactType.I64,
+ TType.DOUBLE: CompactType.DOUBLE,
+ TType.STRING: CompactType.BINARY,
+ TType.STRUCT: CompactType.STRUCT,
+ TType.LIST: CompactType.LIST,
+ TType.SET: CompactType.SET,
+ TType.MAP: CompactType.MAP,
+}
TTYPES = {}
for k, v in CTYPES.items():
- TTYPES[v] = k
+ TTYPES[v] = k
TTYPES[CompactType.FALSE] = TType.BOOL
del k
del v
class TCompactProtocol(TProtocolBase):
- """Compact implementation of the Thrift protocol driver."""
+ """Compact implementation of the Thrift protocol driver."""
- PROTOCOL_ID = 0x82
- VERSION = 1
- VERSION_MASK = 0x1f
- TYPE_MASK = 0xe0
- TYPE_BITS = 0x07
- TYPE_SHIFT_AMOUNT = 5
+ PROTOCOL_ID = 0x82
+ VERSION = 1
+ VERSION_MASK = 0x1f
+ TYPE_MASK = 0xe0
+ TYPE_BITS = 0x07
+ TYPE_SHIFT_AMOUNT = 5
- def __init__(self, trans,
- string_length_limit=None,
- container_length_limit=None):
- TProtocolBase.__init__(self, trans)
- self.state = CLEAR
- self.__last_fid = 0
- self.__bool_fid = None
- self.__bool_value = None
- self.__structs = []
- self.__containers = []
- self.string_length_limit = string_length_limit
- self.container_length_limit = container_length_limit
+ def __init__(self, trans,
+ string_length_limit=None,
+ container_length_limit=None):
+ TProtocolBase.__init__(self, trans)
+ self.state = CLEAR
+ self.__last_fid = 0
+ self.__bool_fid = None
+ self.__bool_value = None
+ self.__structs = []
+ self.__containers = []
+ self.string_length_limit = string_length_limit
+ self.container_length_limit = container_length_limit
- def _check_string_length(self, length):
- self._check_length(self.string_length_limit, length)
+ def _check_string_length(self, length):
+ self._check_length(self.string_length_limit, length)
- def _check_container_length(self, length):
- self._check_length(self.container_length_limit, length)
+ def _check_container_length(self, length):
+ self._check_length(self.container_length_limit, length)
- def __writeVarint(self, n):
- writeVarint(self.trans, n)
+ def __writeVarint(self, n):
+ writeVarint(self.trans, n)
- def writeMessageBegin(self, name, type, seqid):
- assert self.state == CLEAR
- self.__writeUByte(self.PROTOCOL_ID)
- self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
- self.__writeVarint(seqid)
- self.__writeBinary(str_to_binary(name))
- self.state = VALUE_WRITE
+ def writeMessageBegin(self, name, type, seqid):
+ assert self.state == CLEAR
+ self.__writeUByte(self.PROTOCOL_ID)
+ self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
+ self.__writeVarint(seqid)
+ self.__writeBinary(str_to_binary(name))
+ self.state = VALUE_WRITE
- def writeMessageEnd(self):
- assert self.state == VALUE_WRITE
- self.state = CLEAR
+ def writeMessageEnd(self):
+ assert self.state == VALUE_WRITE
+ self.state = CLEAR
- def writeStructBegin(self, name):
- assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
- self.__structs.append((self.state, self.__last_fid))
- self.state = FIELD_WRITE
- self.__last_fid = 0
+ def writeStructBegin(self, name):
+ assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
+ self.__structs.append((self.state, self.__last_fid))
+ self.state = FIELD_WRITE
+ self.__last_fid = 0
- def writeStructEnd(self):
- assert self.state == FIELD_WRITE
- self.state, self.__last_fid = self.__structs.pop()
+ def writeStructEnd(self):
+ assert self.state == FIELD_WRITE
+ self.state, self.__last_fid = self.__structs.pop()
- def writeFieldStop(self):
- self.__writeByte(0)
+ def writeFieldStop(self):
+ self.__writeByte(0)
- def __writeFieldHeader(self, type, fid):
- delta = fid - self.__last_fid
- if 0 < delta <= 15:
- self.__writeUByte(delta << 4 | type)
- else:
- self.__writeByte(type)
- self.__writeI16(fid)
- self.__last_fid = fid
+ def __writeFieldHeader(self, type, fid):
+ delta = fid - self.__last_fid
+ if 0 < delta <= 15:
+ self.__writeUByte(delta << 4 | type)
+ else:
+ self.__writeByte(type)
+ self.__writeI16(fid)
+ self.__last_fid = fid
- def writeFieldBegin(self, name, type, fid):
- assert self.state == FIELD_WRITE, self.state
- if type == TType.BOOL:
- self.state = BOOL_WRITE
- self.__bool_fid = fid
- else:
- self.state = VALUE_WRITE
- self.__writeFieldHeader(CTYPES[type], fid)
+ def writeFieldBegin(self, name, type, fid):
+ assert self.state == FIELD_WRITE, self.state
+ if type == TType.BOOL:
+ self.state = BOOL_WRITE
+ self.__bool_fid = fid
+ else:
+ self.state = VALUE_WRITE
+ self.__writeFieldHeader(CTYPES[type], fid)
- def writeFieldEnd(self):
- assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
- self.state = FIELD_WRITE
+ def writeFieldEnd(self):
+ assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
+ self.state = FIELD_WRITE
- def __writeUByte(self, byte):
- self.trans.write(pack('!B', byte))
+ def __writeUByte(self, byte):
+ self.trans.write(pack('!B', byte))
- def __writeByte(self, byte):
- self.trans.write(pack('!b', byte))
+ def __writeByte(self, byte):
+ self.trans.write(pack('!b', byte))
- def __writeI16(self, i16):
- self.__writeVarint(makeZigZag(i16, 16))
+ def __writeI16(self, i16):
+ self.__writeVarint(makeZigZag(i16, 16))
- def __writeSize(self, i32):
- self.__writeVarint(i32)
+ def __writeSize(self, i32):
+ self.__writeVarint(i32)
- def writeCollectionBegin(self, etype, size):
- assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
- if size <= 14:
- self.__writeUByte(size << 4 | CTYPES[etype])
- else:
- self.__writeUByte(0xf0 | CTYPES[etype])
- self.__writeSize(size)
- self.__containers.append(self.state)
- self.state = CONTAINER_WRITE
- writeSetBegin = writeCollectionBegin
- writeListBegin = writeCollectionBegin
+ def writeCollectionBegin(self, etype, size):
+ assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
+ if size <= 14:
+ self.__writeUByte(size << 4 | CTYPES[etype])
+ else:
+ self.__writeUByte(0xf0 | CTYPES[etype])
+ self.__writeSize(size)
+ self.__containers.append(self.state)
+ self.state = CONTAINER_WRITE
+ writeSetBegin = writeCollectionBegin
+ writeListBegin = writeCollectionBegin
- def writeMapBegin(self, ktype, vtype, size):
- assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
- if size == 0:
- self.__writeByte(0)
- else:
- self.__writeSize(size)
- self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
- self.__containers.append(self.state)
- self.state = CONTAINER_WRITE
+ def writeMapBegin(self, ktype, vtype, size):
+ assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
+ if size == 0:
+ self.__writeByte(0)
+ else:
+ self.__writeSize(size)
+ self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
+ self.__containers.append(self.state)
+ self.state = CONTAINER_WRITE
- def writeCollectionEnd(self):
- assert self.state == CONTAINER_WRITE, self.state
- self.state = self.__containers.pop()
- writeMapEnd = writeCollectionEnd
- writeSetEnd = writeCollectionEnd
- writeListEnd = writeCollectionEnd
+ def writeCollectionEnd(self):
+ assert self.state == CONTAINER_WRITE, self.state
+ self.state = self.__containers.pop()
+ writeMapEnd = writeCollectionEnd
+ writeSetEnd = writeCollectionEnd
+ writeListEnd = writeCollectionEnd
- def writeBool(self, bool):
- if self.state == BOOL_WRITE:
- if bool:
- ctype = CompactType.TRUE
- else:
- ctype = CompactType.FALSE
- self.__writeFieldHeader(ctype, self.__bool_fid)
- elif self.state == CONTAINER_WRITE:
- if bool:
- self.__writeByte(CompactType.TRUE)
- else:
- self.__writeByte(CompactType.FALSE)
- else:
- raise AssertionError("Invalid state in compact protocol")
+ def writeBool(self, bool):
+ if self.state == BOOL_WRITE:
+ if bool:
+ ctype = CompactType.TRUE
+ else:
+ ctype = CompactType.FALSE
+ self.__writeFieldHeader(ctype, self.__bool_fid)
+ elif self.state == CONTAINER_WRITE:
+ if bool:
+ self.__writeByte(CompactType.TRUE)
+ else:
+ self.__writeByte(CompactType.FALSE)
+ else:
+ raise AssertionError("Invalid state in compact protocol")
- writeByte = writer(__writeByte)
- writeI16 = writer(__writeI16)
+ writeByte = writer(__writeByte)
+ writeI16 = writer(__writeI16)
- @writer
- def writeI32(self, i32):
- self.__writeVarint(makeZigZag(i32, 32))
+ @writer
+ def writeI32(self, i32):
+ self.__writeVarint(makeZigZag(i32, 32))
- @writer
- def writeI64(self, i64):
- self.__writeVarint(makeZigZag(i64, 64))
+ @writer
+ def writeI64(self, i64):
+ self.__writeVarint(makeZigZag(i64, 64))
- @writer
- def writeDouble(self, dub):
- self.trans.write(pack('<d', dub))
+ @writer
+ def writeDouble(self, dub):
+ self.trans.write(pack('<d', dub))
- def __writeBinary(self, s):
- self.__writeSize(len(s))
- self.trans.write(s)
- writeBinary = writer(__writeBinary)
+ def __writeBinary(self, s):
+ self.__writeSize(len(s))
+ self.trans.write(s)
+ writeBinary = writer(__writeBinary)
- def readFieldBegin(self):
- assert self.state == FIELD_READ, self.state
- type = self.__readUByte()
- if type & 0x0f == TType.STOP:
- return (None, 0, 0)
- delta = type >> 4
- if delta == 0:
- fid = self.__readI16()
- else:
- fid = self.__last_fid + delta
- self.__last_fid = fid
- type = type & 0x0f
- if type == CompactType.TRUE:
- self.state = BOOL_READ
- self.__bool_value = True
- elif type == CompactType.FALSE:
- self.state = BOOL_READ
- self.__bool_value = False
- else:
- self.state = VALUE_READ
- return (None, self.__getTType(type), fid)
+ def readFieldBegin(self):
+ assert self.state == FIELD_READ, self.state
+ type = self.__readUByte()
+ if type & 0x0f == TType.STOP:
+ return (None, 0, 0)
+ delta = type >> 4
+ if delta == 0:
+ fid = self.__readI16()
+ else:
+ fid = self.__last_fid + delta
+ self.__last_fid = fid
+ type = type & 0x0f
+ if type == CompactType.TRUE:
+ self.state = BOOL_READ
+ self.__bool_value = True
+ elif type == CompactType.FALSE:
+ self.state = BOOL_READ
+ self.__bool_value = False
+ else:
+ self.state = VALUE_READ
+ return (None, self.__getTType(type), fid)
- def readFieldEnd(self):
- assert self.state in (VALUE_READ, BOOL_READ), self.state
- self.state = FIELD_READ
+ def readFieldEnd(self):
+ assert self.state in (VALUE_READ, BOOL_READ), self.state
+ self.state = FIELD_READ
- def __readUByte(self):
- result, = unpack('!B', self.trans.readAll(1))
- return result
+ def __readUByte(self):
+ result, = unpack('!B', self.trans.readAll(1))
+ return result
- def __readByte(self):
- result, = unpack('!b', self.trans.readAll(1))
- return result
+ def __readByte(self):
+ result, = unpack('!b', self.trans.readAll(1))
+ return result
- def __readVarint(self):
- return readVarint(self.trans)
+ def __readVarint(self):
+ return readVarint(self.trans)
- def __readZigZag(self):
- return fromZigZag(self.__readVarint())
+ def __readZigZag(self):
+ return fromZigZag(self.__readVarint())
- def __readSize(self):
- result = self.__readVarint()
- if result < 0:
- raise TProtocolException("Length < 0")
- return result
+ def __readSize(self):
+ result = self.__readVarint()
+ if result < 0:
+ raise TProtocolException("Length < 0")
+ return result
- def readMessageBegin(self):
- assert self.state == CLEAR
- proto_id = self.__readUByte()
- if proto_id != self.PROTOCOL_ID:
- raise TProtocolException(TProtocolException.BAD_VERSION,
- 'Bad protocol id in the message: %d' % proto_id)
- ver_type = self.__readUByte()
- type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
- version = ver_type & self.VERSION_MASK
- if version != self.VERSION:
- raise TProtocolException(TProtocolException.BAD_VERSION,
- 'Bad version: %d (expect %d)' % (version, self.VERSION))
- seqid = self.__readVarint()
- name = binary_to_str(self.__readBinary())
- return (name, type, seqid)
+ def readMessageBegin(self):
+ assert self.state == CLEAR
+ proto_id = self.__readUByte()
+ if proto_id != self.PROTOCOL_ID:
+ raise TProtocolException(TProtocolException.BAD_VERSION,
+ 'Bad protocol id in the message: %d' % proto_id)
+ ver_type = self.__readUByte()
+ type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
+ version = ver_type & self.VERSION_MASK
+ if version != self.VERSION:
+ raise TProtocolException(TProtocolException.BAD_VERSION,
+ 'Bad version: %d (expect %d)' % (version, self.VERSION))
+ seqid = self.__readVarint()
+ name = binary_to_str(self.__readBinary())
+ return (name, type, seqid)
- def readMessageEnd(self):
- assert self.state == CLEAR
- assert len(self.__structs) == 0
+ def readMessageEnd(self):
+ assert self.state == CLEAR
+ assert len(self.__structs) == 0
- def readStructBegin(self):
- assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
- self.__structs.append((self.state, self.__last_fid))
- self.state = FIELD_READ
- self.__last_fid = 0
+ def readStructBegin(self):
+ assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
+ self.__structs.append((self.state, self.__last_fid))
+ self.state = FIELD_READ
+ self.__last_fid = 0
- def readStructEnd(self):
- assert self.state == FIELD_READ
- self.state, self.__last_fid = self.__structs.pop()
+ def readStructEnd(self):
+ assert self.state == FIELD_READ
+ self.state, self.__last_fid = self.__structs.pop()
- def readCollectionBegin(self):
- assert self.state in (VALUE_READ, CONTAINER_READ), self.state
- size_type = self.__readUByte()
- size = size_type >> 4
- type = self.__getTType(size_type)
- if size == 15:
- size = self.__readSize()
- self._check_container_length(size)
- self.__containers.append(self.state)
- self.state = CONTAINER_READ
- return type, size
- readSetBegin = readCollectionBegin
- readListBegin = readCollectionBegin
+ def readCollectionBegin(self):
+ assert self.state in (VALUE_READ, CONTAINER_READ), self.state
+ size_type = self.__readUByte()
+ size = size_type >> 4
+ type = self.__getTType(size_type)
+ if size == 15:
+ size = self.__readSize()
+ self._check_container_length(size)
+ self.__containers.append(self.state)
+ self.state = CONTAINER_READ
+ return type, size
+ readSetBegin = readCollectionBegin
+ readListBegin = readCollectionBegin
- def readMapBegin(self):
- assert self.state in (VALUE_READ, CONTAINER_READ), self.state
- size = self.__readSize()
- self._check_container_length(size)
- types = 0
- if size > 0:
- types = self.__readUByte()
- vtype = self.__getTType(types)
- ktype = self.__getTType(types >> 4)
- self.__containers.append(self.state)
- self.state = CONTAINER_READ
- return (ktype, vtype, size)
+ def readMapBegin(self):
+ assert self.state in (VALUE_READ, CONTAINER_READ), self.state
+ size = self.__readSize()
+ self._check_container_length(size)
+ types = 0
+ if size > 0:
+ types = self.__readUByte()
+ vtype = self.__getTType(types)
+ ktype = self.__getTType(types >> 4)
+ self.__containers.append(self.state)
+ self.state = CONTAINER_READ
+ return (ktype, vtype, size)
- def readCollectionEnd(self):
- assert self.state == CONTAINER_READ, self.state
- self.state = self.__containers.pop()
- readSetEnd = readCollectionEnd
- readListEnd = readCollectionEnd
- readMapEnd = readCollectionEnd
+ def readCollectionEnd(self):
+ assert self.state == CONTAINER_READ, self.state
+ self.state = self.__containers.pop()
+ readSetEnd = readCollectionEnd
+ readListEnd = readCollectionEnd
+ readMapEnd = readCollectionEnd
- def readBool(self):
- if self.state == BOOL_READ:
- return self.__bool_value == CompactType.TRUE
- elif self.state == CONTAINER_READ:
- return self.__readByte() == CompactType.TRUE
- else:
- raise AssertionError("Invalid state in compact protocol: %d" %
- self.state)
+ def readBool(self):
+ if self.state == BOOL_READ:
+ return self.__bool_value == CompactType.TRUE
+ elif self.state == CONTAINER_READ:
+ return self.__readByte() == CompactType.TRUE
+ else:
+ raise AssertionError("Invalid state in compact protocol: %d" %
+ self.state)
- readByte = reader(__readByte)
- __readI16 = __readZigZag
- readI16 = reader(__readZigZag)
- readI32 = reader(__readZigZag)
- readI64 = reader(__readZigZag)
+ readByte = reader(__readByte)
+ __readI16 = __readZigZag
+ readI16 = reader(__readZigZag)
+ readI32 = reader(__readZigZag)
+ readI64 = reader(__readZigZag)
- @reader
- def readDouble(self):
- buff = self.trans.readAll(8)
- val, = unpack('<d', buff)
- return val
+ @reader
+ def readDouble(self):
+ buff = self.trans.readAll(8)
+ val, = unpack('<d', buff)
+ return val
- def __readBinary(self):
- size = self.__readSize()
- self._check_string_length(size)
- return self.trans.readAll(size)
- readBinary = reader(__readBinary)
+ def __readBinary(self):
+ size = self.__readSize()
+ self._check_string_length(size)
+ return self.trans.readAll(size)
+ readBinary = reader(__readBinary)
- def __getTType(self, byte):
- return TTYPES[byte & 0x0f]
+ def __getTType(self, byte):
+ return TTYPES[byte & 0x0f]
class TCompactProtocolFactory(object):
- def __init__(self,
- string_length_limit=None,
- container_length_limit=None):
- self.string_length_limit = string_length_limit
- self.container_length_limit = container_length_limit
+ def __init__(self,
+ string_length_limit=None,
+ container_length_limit=None):
+ self.string_length_limit = string_length_limit
+ self.container_length_limit = container_length_limit
- def getProtocol(self, trans):
- return TCompactProtocol(trans,
- self.string_length_limit,
- self.container_length_limit)
+ def getProtocol(self, trans):
+ return TCompactProtocol(trans,
+ self.string_length_limit,
+ self.container_length_limit)
diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/protocol/TJSONProtocol.py
index f9e65fb..db2099a 100644
--- a/lib/py/src/protocol/TJSONProtocol.py
+++ b/lib/py/src/protocol/TJSONProtocol.py
@@ -17,7 +17,8 @@
# under the License.
#
-from .TProtocol import TType, TProtocolBase, TProtocolException, checkIntegerLimits
+from .TProtocol import (TType, TProtocolBase, TProtocolException,
+ checkIntegerLimits)
import base64
import math
import sys
@@ -45,14 +46,14 @@
ESCSEQ0 = ord('\\')
ESCSEQ1 = ord('u')
ESCAPE_CHAR_VALS = {
- '"': '\\"',
- '\\': '\\\\',
- '\b': '\\b',
- '\f': '\\f',
- '\n': '\\n',
- '\r': '\\r',
- '\t': '\\t',
- # '/': '\\/',
+ '"': '\\"',
+ '\\': '\\\\',
+ '\b': '\\b',
+ '\f': '\\f',
+ '\n': '\\n',
+ '\r': '\\r',
+ '\t': '\\t',
+ # '/': '\\/',
}
ESCAPE_CHARS = {
b'"': '"',
@@ -66,519 +67,527 @@
}
NUMERIC_CHAR = b'+-.0123456789Ee'
-CTYPES = {TType.BOOL: 'tf',
- TType.BYTE: 'i8',
- TType.I16: 'i16',
- TType.I32: 'i32',
- TType.I64: 'i64',
- TType.DOUBLE: 'dbl',
- TType.STRING: 'str',
- TType.STRUCT: 'rec',
- TType.LIST: 'lst',
- TType.SET: 'set',
- TType.MAP: 'map'}
+CTYPES = {
+ TType.BOOL: 'tf',
+ TType.BYTE: 'i8',
+ TType.I16: 'i16',
+ TType.I32: 'i32',
+ TType.I64: 'i64',
+ TType.DOUBLE: 'dbl',
+ TType.STRING: 'str',
+ TType.STRUCT: 'rec',
+ TType.LIST: 'lst',
+ TType.SET: 'set',
+ TType.MAP: 'map',
+}
JTYPES = {}
for key in CTYPES.keys():
- JTYPES[CTYPES[key]] = key
+ JTYPES[CTYPES[key]] = key
class JSONBaseContext(object):
- def __init__(self, protocol):
- self.protocol = protocol
- self.first = True
+ def __init__(self, protocol):
+ self.protocol = protocol
+ self.first = True
- def doIO(self, function):
- pass
+ def doIO(self, function):
+ pass
- def write(self):
- pass
+ def write(self):
+ pass
- def read(self):
- pass
+ def read(self):
+ pass
- def escapeNum(self):
- return False
+ def escapeNum(self):
+ return False
- def __str__(self):
- return self.__class__.__name__
+ def __str__(self):
+ return self.__class__.__name__
class JSONListContext(JSONBaseContext):
- def doIO(self, function):
- if self.first is True:
- self.first = False
- else:
- function(COMMA)
+ def doIO(self, function):
+ if self.first is True:
+ self.first = False
+ else:
+ function(COMMA)
- def write(self):
- self.doIO(self.protocol.trans.write)
+ def write(self):
+ self.doIO(self.protocol.trans.write)
- def read(self):
- self.doIO(self.protocol.readJSONSyntaxChar)
+ def read(self):
+ self.doIO(self.protocol.readJSONSyntaxChar)
class JSONPairContext(JSONBaseContext):
- def __init__(self, protocol):
- super(JSONPairContext, self).__init__(protocol)
- self.colon = True
+ def __init__(self, protocol):
+ super(JSONPairContext, self).__init__(protocol)
+ self.colon = True
- def doIO(self, function):
- if self.first:
- self.first = False
- self.colon = True
- else:
- function(COLON if self.colon else COMMA)
- self.colon = not self.colon
+ def doIO(self, function):
+ if self.first:
+ self.first = False
+ self.colon = True
+ else:
+ function(COLON if self.colon else COMMA)
+ self.colon = not self.colon
- def write(self):
- self.doIO(self.protocol.trans.write)
+ def write(self):
+ self.doIO(self.protocol.trans.write)
- def read(self):
- self.doIO(self.protocol.readJSONSyntaxChar)
+ def read(self):
+ self.doIO(self.protocol.readJSONSyntaxChar)
- def escapeNum(self):
- return self.colon
+ def escapeNum(self):
+ return self.colon
- def __str__(self):
- return '%s, colon=%s' % (self.__class__.__name__, self.colon)
+ def __str__(self):
+ return '%s, colon=%s' % (self.__class__.__name__, self.colon)
class LookaheadReader():
- hasData = False
- data = ''
+ hasData = False
+ data = ''
- def __init__(self, protocol):
- self.protocol = protocol
+ def __init__(self, protocol):
+ self.protocol = protocol
- def read(self):
- if self.hasData is True:
- self.hasData = False
- else:
- self.data = self.protocol.trans.read(1)
- return self.data
+ def read(self):
+ if self.hasData is True:
+ self.hasData = False
+ else:
+ self.data = self.protocol.trans.read(1)
+ return self.data
- def peek(self):
- if self.hasData is False:
- self.data = self.protocol.trans.read(1)
- self.hasData = True
- return self.data
+ def peek(self):
+ if self.hasData is False:
+ self.data = self.protocol.trans.read(1)
+ self.hasData = True
+ return self.data
class TJSONProtocolBase(TProtocolBase):
- def __init__(self, trans):
- TProtocolBase.__init__(self, trans)
- self.resetWriteContext()
- self.resetReadContext()
+ def __init__(self, trans):
+ TProtocolBase.__init__(self, trans)
+ self.resetWriteContext()
+ self.resetReadContext()
- # We don't have length limit implementation for JSON protocols
- @property
- def string_length_limit(senf):
- return None
+ # We don't have length limit implementation for JSON protocols
+ @property
+ def string_length_limit(senf):
+ return None
- @property
- def container_length_limit(senf):
- return None
+ @property
+ def container_length_limit(senf):
+ return None
- def resetWriteContext(self):
- self.context = JSONBaseContext(self)
- self.contextStack = [self.context]
+ def resetWriteContext(self):
+ self.context = JSONBaseContext(self)
+ self.contextStack = [self.context]
- def resetReadContext(self):
- self.resetWriteContext()
- self.reader = LookaheadReader(self)
+ def resetReadContext(self):
+ self.resetWriteContext()
+ self.reader = LookaheadReader(self)
- def pushContext(self, ctx):
- self.contextStack.append(ctx)
- self.context = ctx
+ def pushContext(self, ctx):
+ self.contextStack.append(ctx)
+ self.context = ctx
- def popContext(self):
- self.contextStack.pop()
- if self.contextStack:
- self.context = self.contextStack[-1]
- else:
- self.context = JSONBaseContext(self)
-
- def writeJSONString(self, string):
- self.context.write()
- json_str = ['"']
- for s in string:
- escaped = ESCAPE_CHAR_VALS.get(s, s)
- json_str.append(escaped)
- json_str.append('"')
- self.trans.write(str_to_binary(''.join(json_str)))
-
- def writeJSONNumber(self, number, formatter='{0}'):
- self.context.write()
- jsNumber = str(formatter.format(number)).encode('ascii')
- if self.context.escapeNum():
- self.trans.write(QUOTE)
- self.trans.write(jsNumber)
- self.trans.write(QUOTE)
- else:
- self.trans.write(jsNumber)
-
- def writeJSONBase64(self, binary):
- self.context.write()
- self.trans.write(QUOTE)
- self.trans.write(base64.b64encode(binary))
- self.trans.write(QUOTE)
-
- def writeJSONObjectStart(self):
- self.context.write()
- self.trans.write(LBRACE)
- self.pushContext(JSONPairContext(self))
-
- def writeJSONObjectEnd(self):
- self.popContext()
- self.trans.write(RBRACE)
-
- def writeJSONArrayStart(self):
- self.context.write()
- self.trans.write(LBRACKET)
- self.pushContext(JSONListContext(self))
-
- def writeJSONArrayEnd(self):
- self.popContext()
- self.trans.write(RBRACKET)
-
- def readJSONSyntaxChar(self, character):
- current = self.reader.read()
- if character != current:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Unexpected character: %s" % current)
-
- def _isHighSurrogate(self, codeunit):
- return codeunit >= 0xd800 and codeunit <= 0xdbff
-
- def _isLowSurrogate(self, codeunit):
- return codeunit >= 0xdc00 and codeunit <= 0xdfff
-
- def _toChar(self, high, low=None):
- if not low:
- if sys.version_info[0] == 2:
- return ("\\u%04x" % high).decode('unicode-escape').encode('utf-8')
- else:
- return chr(high)
- else:
- codepoint = (1 << 16) + ((high & 0x3ff) << 10)
- codepoint += low & 0x3ff
- if sys.version_info[0] == 2:
- s = "\\U%08x" % codepoint
- return s.decode('unicode-escape').encode('utf-8')
- else:
- return chr(codepoint)
-
- def readJSONString(self, skipContext):
- highSurrogate = None
- string = []
- if skipContext is False:
- self.context.read()
- self.readJSONSyntaxChar(QUOTE)
- while True:
- character = self.reader.read()
- if character == QUOTE:
- break
- if ord(character) == ESCSEQ0:
- character = self.reader.read()
- if ord(character) == ESCSEQ1:
- character = self.trans.read(4).decode('ascii')
- codeunit = int(character, 16)
- if self._isHighSurrogate(codeunit):
- if highSurrogate:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Expected low surrogate char")
- highSurrogate = codeunit
- continue
- elif self._isLowSurrogate(codeunit):
- if not highSurrogate:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Expected high surrogate char")
- character = self._toChar(highSurrogate, codeunit)
- highSurrogate = None
- else:
- character = self._toChar(codeunit)
+ def popContext(self):
+ self.contextStack.pop()
+ if self.contextStack:
+ self.context = self.contextStack[-1]
else:
- if character not in ESCAPE_CHARS:
+ self.context = JSONBaseContext(self)
+
+ def writeJSONString(self, string):
+ self.context.write()
+ json_str = ['"']
+ for s in string:
+ escaped = ESCAPE_CHAR_VALS.get(s, s)
+ json_str.append(escaped)
+ json_str.append('"')
+ self.trans.write(str_to_binary(''.join(json_str)))
+
+ def writeJSONNumber(self, number, formatter='{0}'):
+ self.context.write()
+ jsNumber = str(formatter.format(number)).encode('ascii')
+ if self.context.escapeNum():
+ self.trans.write(QUOTE)
+ self.trans.write(jsNumber)
+ self.trans.write(QUOTE)
+ else:
+ self.trans.write(jsNumber)
+
+ def writeJSONBase64(self, binary):
+ self.context.write()
+ self.trans.write(QUOTE)
+ self.trans.write(base64.b64encode(binary))
+ self.trans.write(QUOTE)
+
+ def writeJSONObjectStart(self):
+ self.context.write()
+ self.trans.write(LBRACE)
+ self.pushContext(JSONPairContext(self))
+
+ def writeJSONObjectEnd(self):
+ self.popContext()
+ self.trans.write(RBRACE)
+
+ def writeJSONArrayStart(self):
+ self.context.write()
+ self.trans.write(LBRACKET)
+ self.pushContext(JSONListContext(self))
+
+ def writeJSONArrayEnd(self):
+ self.popContext()
+ self.trans.write(RBRACKET)
+
+ def readJSONSyntaxChar(self, character):
+ current = self.reader.read()
+ if character != current:
raise TProtocolException(TProtocolException.INVALID_DATA,
- "Expected control char")
- character = ESCAPE_CHARS[character]
- elif character in ESCAPE_CHAR_VALS:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Unescaped control char")
- elif sys.version_info[0] > 2:
- utf8_bytes = bytearray([ord(character)])
- while ord(self.reader.peek()) >= 0x80:
- utf8_bytes.append(ord(self.reader.read()))
- character = utf8_bytes.decode('utf8')
- string.append(character)
+ "Unexpected character: %s" % current)
- if highSurrogate:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Expected low surrogate char")
- return ''.join(string)
+ def _isHighSurrogate(self, codeunit):
+ return codeunit >= 0xd800 and codeunit <= 0xdbff
- def isJSONNumeric(self, character):
- return (True if NUMERIC_CHAR.find(character) != - 1 else False)
+ def _isLowSurrogate(self, codeunit):
+ return codeunit >= 0xdc00 and codeunit <= 0xdfff
- def readJSONQuotes(self):
- if (self.context.escapeNum()):
- self.readJSONSyntaxChar(QUOTE)
+ def _toChar(self, high, low=None):
+ if not low:
+ if sys.version_info[0] == 2:
+ return ("\\u%04x" % high).decode('unicode-escape') \
+ .encode('utf-8')
+ else:
+ return chr(high)
+ else:
+ codepoint = (1 << 16) + ((high & 0x3ff) << 10)
+ codepoint += low & 0x3ff
+ if sys.version_info[0] == 2:
+ s = "\\U%08x" % codepoint
+ return s.decode('unicode-escape').encode('utf-8')
+ else:
+ return chr(codepoint)
- def readJSONNumericChars(self):
- numeric = []
- while True:
- character = self.reader.peek()
- if self.isJSONNumeric(character) is False:
- break
- numeric.append(self.reader.read())
- return b''.join(numeric).decode('ascii')
-
- def readJSONInteger(self):
- self.context.read()
- self.readJSONQuotes()
- numeric = self.readJSONNumericChars()
- self.readJSONQuotes()
- try:
- return int(numeric)
- except ValueError:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Bad data encounted in numeric data")
-
- def readJSONDouble(self):
- self.context.read()
- if self.reader.peek() == QUOTE:
- string = self.readJSONString(True)
- try:
- double = float(string)
- if (self.context.escapeNum is False and
- not math.isinf(double) and
- not math.isnan(double)):
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Numeric data unexpectedly quoted")
- return double
- except ValueError:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Bad data encounted in numeric data")
- else:
- if self.context.escapeNum() is True:
+ def readJSONString(self, skipContext):
+ highSurrogate = None
+ string = []
+ if skipContext is False:
+ self.context.read()
self.readJSONSyntaxChar(QUOTE)
- try:
- return float(self.readJSONNumericChars())
- except ValueError:
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "Bad data encounted in numeric data")
+ while True:
+ character = self.reader.read()
+ if character == QUOTE:
+ break
+ if ord(character) == ESCSEQ0:
+ character = self.reader.read()
+ if ord(character) == ESCSEQ1:
+ character = self.trans.read(4).decode('ascii')
+ codeunit = int(character, 16)
+ if self._isHighSurrogate(codeunit):
+ if highSurrogate:
+ raise TProtocolException(
+ TProtocolException.INVALID_DATA,
+ "Expected low surrogate char")
+ highSurrogate = codeunit
+ continue
+ elif self._isLowSurrogate(codeunit):
+ if not highSurrogate:
+ raise TProtocolException(
+ TProtocolException.INVALID_DATA,
+ "Expected high surrogate char")
+ character = self._toChar(highSurrogate, codeunit)
+ highSurrogate = None
+ else:
+ character = self._toChar(codeunit)
+ else:
+ if character not in ESCAPE_CHARS:
+ raise TProtocolException(
+ TProtocolException.INVALID_DATA,
+ "Expected control char")
+ character = ESCAPE_CHARS[character]
+ elif character in ESCAPE_CHAR_VALS:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Unescaped control char")
+ elif sys.version_info[0] > 2:
+ utf8_bytes = bytearray([ord(character)])
+ while ord(self.reader.peek()) >= 0x80:
+ utf8_bytes.append(ord(self.reader.read()))
+ character = utf8_bytes.decode('utf8')
+ string.append(character)
- def readJSONBase64(self):
- string = self.readJSONString(False)
- size = len(string)
- m = size % 4
- # Force padding since b64encode method does not allow it
- if m != 0:
- for i in range(4 - m):
- string += '='
- return base64.b64decode(string)
+ if highSurrogate:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Expected low surrogate char")
+ return ''.join(string)
- def readJSONObjectStart(self):
- self.context.read()
- self.readJSONSyntaxChar(LBRACE)
- self.pushContext(JSONPairContext(self))
+ def isJSONNumeric(self, character):
+ return (True if NUMERIC_CHAR.find(character) != - 1 else False)
- def readJSONObjectEnd(self):
- self.readJSONSyntaxChar(RBRACE)
- self.popContext()
+ def readJSONQuotes(self):
+ if (self.context.escapeNum()):
+ self.readJSONSyntaxChar(QUOTE)
- def readJSONArrayStart(self):
- self.context.read()
- self.readJSONSyntaxChar(LBRACKET)
- self.pushContext(JSONListContext(self))
+ def readJSONNumericChars(self):
+ numeric = []
+ while True:
+ character = self.reader.peek()
+ if self.isJSONNumeric(character) is False:
+ break
+ numeric.append(self.reader.read())
+ return b''.join(numeric).decode('ascii')
- def readJSONArrayEnd(self):
- self.readJSONSyntaxChar(RBRACKET)
- self.popContext()
+ def readJSONInteger(self):
+ self.context.read()
+ self.readJSONQuotes()
+ numeric = self.readJSONNumericChars()
+ self.readJSONQuotes()
+ try:
+ return int(numeric)
+ except ValueError:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Bad data encounted in numeric data")
+
+ def readJSONDouble(self):
+ self.context.read()
+ if self.reader.peek() == QUOTE:
+ string = self.readJSONString(True)
+ try:
+ double = float(string)
+ if (self.context.escapeNum is False and
+ not math.isinf(double) and
+ not math.isnan(double)):
+ raise TProtocolException(
+ TProtocolException.INVALID_DATA,
+ "Numeric data unexpectedly quoted")
+ return double
+ except ValueError:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Bad data encounted in numeric data")
+ else:
+ if self.context.escapeNum() is True:
+ self.readJSONSyntaxChar(QUOTE)
+ try:
+ return float(self.readJSONNumericChars())
+ except ValueError:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Bad data encounted in numeric data")
+
+ def readJSONBase64(self):
+ string = self.readJSONString(False)
+ size = len(string)
+ m = size % 4
+ # Force padding since b64encode method does not allow it
+ if m != 0:
+ for i in range(4 - m):
+ string += '='
+ return base64.b64decode(string)
+
+ def readJSONObjectStart(self):
+ self.context.read()
+ self.readJSONSyntaxChar(LBRACE)
+ self.pushContext(JSONPairContext(self))
+
+ def readJSONObjectEnd(self):
+ self.readJSONSyntaxChar(RBRACE)
+ self.popContext()
+
+ def readJSONArrayStart(self):
+ self.context.read()
+ self.readJSONSyntaxChar(LBRACKET)
+ self.pushContext(JSONListContext(self))
+
+ def readJSONArrayEnd(self):
+ self.readJSONSyntaxChar(RBRACKET)
+ self.popContext()
class TJSONProtocol(TJSONProtocolBase):
- def readMessageBegin(self):
- self.resetReadContext()
- self.readJSONArrayStart()
- if self.readJSONInteger() != VERSION:
- raise TProtocolException(TProtocolException.BAD_VERSION,
- "Message contained bad version.")
- name = self.readJSONString(False)
- typen = self.readJSONInteger()
- seqid = self.readJSONInteger()
- return (name, typen, seqid)
+ def readMessageBegin(self):
+ self.resetReadContext()
+ self.readJSONArrayStart()
+ if self.readJSONInteger() != VERSION:
+ raise TProtocolException(TProtocolException.BAD_VERSION,
+ "Message contained bad version.")
+ name = self.readJSONString(False)
+ typen = self.readJSONInteger()
+ seqid = self.readJSONInteger()
+ return (name, typen, seqid)
- def readMessageEnd(self):
- self.readJSONArrayEnd()
+ def readMessageEnd(self):
+ self.readJSONArrayEnd()
- def readStructBegin(self):
- self.readJSONObjectStart()
+ def readStructBegin(self):
+ self.readJSONObjectStart()
- def readStructEnd(self):
- self.readJSONObjectEnd()
+ def readStructEnd(self):
+ self.readJSONObjectEnd()
- def readFieldBegin(self):
- character = self.reader.peek()
- ttype = 0
- id = 0
- if character == RBRACE:
- ttype = TType.STOP
- else:
- id = self.readJSONInteger()
- self.readJSONObjectStart()
- ttype = JTYPES[self.readJSONString(False)]
- return (None, ttype, id)
+ def readFieldBegin(self):
+ character = self.reader.peek()
+ ttype = 0
+ id = 0
+ if character == RBRACE:
+ ttype = TType.STOP
+ else:
+ id = self.readJSONInteger()
+ self.readJSONObjectStart()
+ ttype = JTYPES[self.readJSONString(False)]
+ return (None, ttype, id)
- def readFieldEnd(self):
- self.readJSONObjectEnd()
+ def readFieldEnd(self):
+ self.readJSONObjectEnd()
- def readMapBegin(self):
- self.readJSONArrayStart()
- keyType = JTYPES[self.readJSONString(False)]
- valueType = JTYPES[self.readJSONString(False)]
- size = self.readJSONInteger()
- self.readJSONObjectStart()
- return (keyType, valueType, size)
+ def readMapBegin(self):
+ self.readJSONArrayStart()
+ keyType = JTYPES[self.readJSONString(False)]
+ valueType = JTYPES[self.readJSONString(False)]
+ size = self.readJSONInteger()
+ self.readJSONObjectStart()
+ return (keyType, valueType, size)
- def readMapEnd(self):
- self.readJSONObjectEnd()
- self.readJSONArrayEnd()
+ def readMapEnd(self):
+ self.readJSONObjectEnd()
+ self.readJSONArrayEnd()
- def readCollectionBegin(self):
- self.readJSONArrayStart()
- elemType = JTYPES[self.readJSONString(False)]
- size = self.readJSONInteger()
- return (elemType, size)
- readListBegin = readCollectionBegin
- readSetBegin = readCollectionBegin
+ def readCollectionBegin(self):
+ self.readJSONArrayStart()
+ elemType = JTYPES[self.readJSONString(False)]
+ size = self.readJSONInteger()
+ return (elemType, size)
+ readListBegin = readCollectionBegin
+ readSetBegin = readCollectionBegin
- def readCollectionEnd(self):
- self.readJSONArrayEnd()
- readSetEnd = readCollectionEnd
- readListEnd = readCollectionEnd
+ def readCollectionEnd(self):
+ self.readJSONArrayEnd()
+ readSetEnd = readCollectionEnd
+ readListEnd = readCollectionEnd
- def readBool(self):
- return (False if self.readJSONInteger() == 0 else True)
+ def readBool(self):
+ return (False if self.readJSONInteger() == 0 else True)
- def readNumber(self):
- return self.readJSONInteger()
- readByte = readNumber
- readI16 = readNumber
- readI32 = readNumber
- readI64 = readNumber
+ def readNumber(self):
+ return self.readJSONInteger()
+ readByte = readNumber
+ readI16 = readNumber
+ readI32 = readNumber
+ readI64 = readNumber
- def readDouble(self):
- return self.readJSONDouble()
+ def readDouble(self):
+ return self.readJSONDouble()
- def readString(self):
- return self.readJSONString(False)
+ def readString(self):
+ return self.readJSONString(False)
- def readBinary(self):
- return self.readJSONBase64()
+ def readBinary(self):
+ return self.readJSONBase64()
- def writeMessageBegin(self, name, request_type, seqid):
- self.resetWriteContext()
- self.writeJSONArrayStart()
- self.writeJSONNumber(VERSION)
- self.writeJSONString(name)
- self.writeJSONNumber(request_type)
- self.writeJSONNumber(seqid)
+ def writeMessageBegin(self, name, request_type, seqid):
+ self.resetWriteContext()
+ self.writeJSONArrayStart()
+ self.writeJSONNumber(VERSION)
+ self.writeJSONString(name)
+ self.writeJSONNumber(request_type)
+ self.writeJSONNumber(seqid)
- def writeMessageEnd(self):
- self.writeJSONArrayEnd()
+ def writeMessageEnd(self):
+ self.writeJSONArrayEnd()
- def writeStructBegin(self, name):
- self.writeJSONObjectStart()
+ def writeStructBegin(self, name):
+ self.writeJSONObjectStart()
- def writeStructEnd(self):
- self.writeJSONObjectEnd()
+ def writeStructEnd(self):
+ self.writeJSONObjectEnd()
- def writeFieldBegin(self, name, ttype, id):
- self.writeJSONNumber(id)
- self.writeJSONObjectStart()
- self.writeJSONString(CTYPES[ttype])
+ def writeFieldBegin(self, name, ttype, id):
+ self.writeJSONNumber(id)
+ self.writeJSONObjectStart()
+ self.writeJSONString(CTYPES[ttype])
- def writeFieldEnd(self):
- self.writeJSONObjectEnd()
+ def writeFieldEnd(self):
+ self.writeJSONObjectEnd()
- def writeFieldStop(self):
- pass
+ def writeFieldStop(self):
+ pass
- def writeMapBegin(self, ktype, vtype, size):
- self.writeJSONArrayStart()
- self.writeJSONString(CTYPES[ktype])
- self.writeJSONString(CTYPES[vtype])
- self.writeJSONNumber(size)
- self.writeJSONObjectStart()
+ def writeMapBegin(self, ktype, vtype, size):
+ self.writeJSONArrayStart()
+ self.writeJSONString(CTYPES[ktype])
+ self.writeJSONString(CTYPES[vtype])
+ self.writeJSONNumber(size)
+ self.writeJSONObjectStart()
- def writeMapEnd(self):
- self.writeJSONObjectEnd()
- self.writeJSONArrayEnd()
+ def writeMapEnd(self):
+ self.writeJSONObjectEnd()
+ self.writeJSONArrayEnd()
- def writeListBegin(self, etype, size):
- self.writeJSONArrayStart()
- self.writeJSONString(CTYPES[etype])
- self.writeJSONNumber(size)
+ def writeListBegin(self, etype, size):
+ self.writeJSONArrayStart()
+ self.writeJSONString(CTYPES[etype])
+ self.writeJSONNumber(size)
- def writeListEnd(self):
- self.writeJSONArrayEnd()
+ def writeListEnd(self):
+ self.writeJSONArrayEnd()
- def writeSetBegin(self, etype, size):
- self.writeJSONArrayStart()
- self.writeJSONString(CTYPES[etype])
- self.writeJSONNumber(size)
+ def writeSetBegin(self, etype, size):
+ self.writeJSONArrayStart()
+ self.writeJSONString(CTYPES[etype])
+ self.writeJSONNumber(size)
- def writeSetEnd(self):
- self.writeJSONArrayEnd()
+ def writeSetEnd(self):
+ self.writeJSONArrayEnd()
- def writeBool(self, boolean):
- self.writeJSONNumber(1 if boolean is True else 0)
+ def writeBool(self, boolean):
+ self.writeJSONNumber(1 if boolean is True else 0)
- def writeByte(self, byte):
- checkIntegerLimits(byte, 8)
- self.writeJSONNumber(byte)
+ def writeByte(self, byte):
+ checkIntegerLimits(byte, 8)
+ self.writeJSONNumber(byte)
- def writeI16(self, i16):
- checkIntegerLimits(i16, 16)
- self.writeJSONNumber(i16)
+ def writeI16(self, i16):
+ checkIntegerLimits(i16, 16)
+ self.writeJSONNumber(i16)
- def writeI32(self, i32):
- checkIntegerLimits(i32, 32)
- self.writeJSONNumber(i32)
+ def writeI32(self, i32):
+ checkIntegerLimits(i32, 32)
+ self.writeJSONNumber(i32)
- def writeI64(self, i64):
- checkIntegerLimits(i64, 64)
- self.writeJSONNumber(i64)
+ def writeI64(self, i64):
+ checkIntegerLimits(i64, 64)
+ self.writeJSONNumber(i64)
- def writeDouble(self, dbl):
- # 17 significant digits should be just enough for any double precision value.
- self.writeJSONNumber(dbl, '{0:.17g}')
+ def writeDouble(self, dbl):
+ # 17 significant digits should be just enough for any double precision
+ # value.
+ self.writeJSONNumber(dbl, '{0:.17g}')
- def writeString(self, string):
- self.writeJSONString(string)
+ def writeString(self, string):
+ self.writeJSONString(string)
- def writeBinary(self, binary):
- self.writeJSONBase64(binary)
+ def writeBinary(self, binary):
+ self.writeJSONBase64(binary)
class TJSONProtocolFactory(object):
- def getProtocol(self, trans):
- return TJSONProtocol(trans)
+ def getProtocol(self, trans):
+ return TJSONProtocol(trans)
- @property
- def string_length_limit(senf):
- return None
+ @property
+ def string_length_limit(senf):
+ return None
- @property
- def container_length_limit(senf):
- return None
+ @property
+ def container_length_limit(senf):
+ return None
class TSimpleJSONProtocol(TJSONProtocolBase):
diff --git a/lib/py/src/protocol/TMultiplexedProtocol.py b/lib/py/src/protocol/TMultiplexedProtocol.py
index d25f367..309f896 100644
--- a/lib/py/src/protocol/TMultiplexedProtocol.py
+++ b/lib/py/src/protocol/TMultiplexedProtocol.py
@@ -22,18 +22,19 @@
SEPARATOR = ":"
-class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator):
- def __init__(self, protocol, serviceName):
- TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
- self.serviceName = serviceName
- def writeMessageBegin(self, name, type, seqid):
- if (type == TMessageType.CALL or
- type == TMessageType.ONEWAY):
- self.protocol.writeMessageBegin(
- self.serviceName + SEPARATOR + name,
- type,
- seqid
- )
- else:
- self.protocol.writeMessageBegin(name, type, seqid)
+class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator):
+ def __init__(self, protocol, serviceName):
+ TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
+ self.serviceName = serviceName
+
+ def writeMessageBegin(self, name, type, seqid):
+ if (type == TMessageType.CALL or
+ type == TMessageType.ONEWAY):
+ self.protocol.writeMessageBegin(
+ self.serviceName + SEPARATOR + name,
+ type,
+ seqid
+ )
+ else:
+ self.protocol.writeMessageBegin(name, type, seqid)
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index d9aa2e8..ed6938b 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -28,373 +28,373 @@
class TProtocolException(TException):
- """Custom Protocol Exception class"""
+ """Custom Protocol Exception class"""
- UNKNOWN = 0
- INVALID_DATA = 1
- NEGATIVE_SIZE = 2
- SIZE_LIMIT = 3
- BAD_VERSION = 4
- NOT_IMPLEMENTED = 5
- DEPTH_LIMIT = 6
+ UNKNOWN = 0
+ INVALID_DATA = 1
+ NEGATIVE_SIZE = 2
+ SIZE_LIMIT = 3
+ BAD_VERSION = 4
+ NOT_IMPLEMENTED = 5
+ DEPTH_LIMIT = 6
- def __init__(self, type=UNKNOWN, message=None):
- TException.__init__(self, message)
- self.type = type
+ def __init__(self, type=UNKNOWN, message=None):
+ TException.__init__(self, message)
+ self.type = type
class TProtocolBase(object):
- """Base class for Thrift protocol driver."""
+ """Base class for Thrift protocol driver."""
- def __init__(self, trans):
- self.trans = trans
+ def __init__(self, trans):
+ self.trans = trans
- @staticmethod
- def _check_length(limit, length):
- if length < 0:
- raise TTransportException(TTransportException.NEGATIVE_SIZE,
- 'Negative length: %d' % length)
- if limit is not None and length > limit:
- raise TTransportException(TTransportException.SIZE_LIMIT,
- 'Length exceeded max allowed: %d' % limit)
+ @staticmethod
+ def _check_length(limit, length):
+ if length < 0:
+ raise TTransportException(TTransportException.NEGATIVE_SIZE,
+ 'Negative length: %d' % length)
+ if limit is not None and length > limit:
+ raise TTransportException(TTransportException.SIZE_LIMIT,
+ 'Length exceeded max allowed: %d' % limit)
- def writeMessageBegin(self, name, ttype, seqid):
- pass
+ def writeMessageBegin(self, name, ttype, seqid):
+ pass
- def writeMessageEnd(self):
- pass
+ def writeMessageEnd(self):
+ pass
- def writeStructBegin(self, name):
- pass
+ def writeStructBegin(self, name):
+ pass
- def writeStructEnd(self):
- pass
+ def writeStructEnd(self):
+ pass
- def writeFieldBegin(self, name, ttype, fid):
- pass
+ def writeFieldBegin(self, name, ttype, fid):
+ pass
- def writeFieldEnd(self):
- pass
+ def writeFieldEnd(self):
+ pass
- def writeFieldStop(self):
- pass
+ def writeFieldStop(self):
+ pass
- def writeMapBegin(self, ktype, vtype, size):
- pass
+ def writeMapBegin(self, ktype, vtype, size):
+ pass
- def writeMapEnd(self):
- pass
+ def writeMapEnd(self):
+ pass
- def writeListBegin(self, etype, size):
- pass
+ def writeListBegin(self, etype, size):
+ pass
- def writeListEnd(self):
- pass
+ def writeListEnd(self):
+ pass
- def writeSetBegin(self, etype, size):
- pass
+ def writeSetBegin(self, etype, size):
+ pass
- def writeSetEnd(self):
- pass
+ def writeSetEnd(self):
+ pass
- def writeBool(self, bool_val):
- pass
+ def writeBool(self, bool_val):
+ pass
- def writeByte(self, byte):
- pass
+ def writeByte(self, byte):
+ pass
- def writeI16(self, i16):
- pass
+ def writeI16(self, i16):
+ pass
- def writeI32(self, i32):
- pass
+ def writeI32(self, i32):
+ pass
- def writeI64(self, i64):
- pass
+ def writeI64(self, i64):
+ pass
- def writeDouble(self, dub):
- pass
+ def writeDouble(self, dub):
+ pass
- def writeString(self, str_val):
- self.writeBinary(str_to_binary(str_val))
+ def writeString(self, str_val):
+ self.writeBinary(str_to_binary(str_val))
- def writeBinary(self, str_val):
- pass
+ def writeBinary(self, str_val):
+ pass
- def writeUtf8(self, str_val):
- self.writeString(str_val.encode('utf8'))
+ def writeUtf8(self, str_val):
+ self.writeString(str_val.encode('utf8'))
- def readMessageBegin(self):
- pass
+ def readMessageBegin(self):
+ pass
- def readMessageEnd(self):
- pass
+ def readMessageEnd(self):
+ pass
- def readStructBegin(self):
- pass
+ def readStructBegin(self):
+ pass
- def readStructEnd(self):
- pass
+ def readStructEnd(self):
+ pass
- def readFieldBegin(self):
- pass
+ def readFieldBegin(self):
+ pass
- def readFieldEnd(self):
- pass
+ def readFieldEnd(self):
+ pass
- def readMapBegin(self):
- pass
+ def readMapBegin(self):
+ pass
- def readMapEnd(self):
- pass
+ def readMapEnd(self):
+ pass
- def readListBegin(self):
- pass
+ def readListBegin(self):
+ pass
- def readListEnd(self):
- pass
+ def readListEnd(self):
+ pass
- def readSetBegin(self):
- pass
+ def readSetBegin(self):
+ pass
- def readSetEnd(self):
- pass
+ def readSetEnd(self):
+ pass
- def readBool(self):
- pass
+ def readBool(self):
+ pass
- def readByte(self):
- pass
+ def readByte(self):
+ pass
- def readI16(self):
- pass
+ def readI16(self):
+ pass
- def readI32(self):
- pass
+ def readI32(self):
+ pass
- def readI64(self):
- pass
+ def readI64(self):
+ pass
- def readDouble(self):
- pass
+ def readDouble(self):
+ pass
- def readString(self):
- return binary_to_str(self.readBinary())
+ def readString(self):
+ return binary_to_str(self.readBinary())
- def readBinary(self):
- pass
+ def readBinary(self):
+ pass
- def readUtf8(self):
- return self.readString().decode('utf8')
+ def readUtf8(self):
+ return self.readString().decode('utf8')
- def skip(self, ttype):
- if ttype == TType.STOP:
- return
- elif ttype == TType.BOOL:
- self.readBool()
- elif ttype == TType.BYTE:
- self.readByte()
- elif ttype == TType.I16:
- self.readI16()
- elif ttype == TType.I32:
- self.readI32()
- elif ttype == TType.I64:
- self.readI64()
- elif ttype == TType.DOUBLE:
- self.readDouble()
- elif ttype == TType.STRING:
- self.readString()
- elif ttype == TType.STRUCT:
- name = self.readStructBegin()
- while True:
- (name, ttype, id) = self.readFieldBegin()
+ def skip(self, ttype):
if ttype == TType.STOP:
- break
- self.skip(ttype)
- self.readFieldEnd()
- self.readStructEnd()
- elif ttype == TType.MAP:
- (ktype, vtype, size) = self.readMapBegin()
- for i in range(size):
- self.skip(ktype)
- self.skip(vtype)
- self.readMapEnd()
- elif ttype == TType.SET:
- (etype, size) = self.readSetBegin()
- for i in range(size):
- self.skip(etype)
- self.readSetEnd()
- elif ttype == TType.LIST:
- (etype, size) = self.readListBegin()
- for i in range(size):
- self.skip(etype)
- self.readListEnd()
+ return
+ elif ttype == TType.BOOL:
+ self.readBool()
+ elif ttype == TType.BYTE:
+ self.readByte()
+ elif ttype == TType.I16:
+ self.readI16()
+ elif ttype == TType.I32:
+ self.readI32()
+ elif ttype == TType.I64:
+ self.readI64()
+ elif ttype == TType.DOUBLE:
+ self.readDouble()
+ elif ttype == TType.STRING:
+ self.readString()
+ elif ttype == TType.STRUCT:
+ name = self.readStructBegin()
+ while True:
+ (name, ttype, id) = self.readFieldBegin()
+ if ttype == TType.STOP:
+ break
+ self.skip(ttype)
+ self.readFieldEnd()
+ self.readStructEnd()
+ elif ttype == TType.MAP:
+ (ktype, vtype, size) = self.readMapBegin()
+ for i in range(size):
+ self.skip(ktype)
+ self.skip(vtype)
+ self.readMapEnd()
+ elif ttype == TType.SET:
+ (etype, size) = self.readSetBegin()
+ for i in range(size):
+ self.skip(etype)
+ self.readSetEnd()
+ elif ttype == TType.LIST:
+ (etype, size) = self.readListBegin()
+ for i in range(size):
+ self.skip(etype)
+ self.readListEnd()
- # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
- _TTYPE_HANDLERS = (
- (None, None, False), # 0 TType.STOP
- (None, None, False), # 1 TType.VOID # TODO: handle void?
- ('readBool', 'writeBool', False), # 2 TType.BOOL
- ('readByte', 'writeByte', False), # 3 TType.BYTE and I08
- ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
- (None, None, False), # 5 undefined
- ('readI16', 'writeI16', False), # 6 TType.I16
- (None, None, False), # 7 undefined
- ('readI32', 'writeI32', False), # 8 TType.I32
- (None, None, False), # 9 undefined
- ('readI64', 'writeI64', False), # 10 TType.I64
- ('readString', 'writeString', False), # 11 TType.STRING and UTF7
- ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
- ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
- ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
- ('readContainerList', 'writeContainerList', True), # 15 TType.LIST
- (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
- (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
- )
+ # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
+ _TTYPE_HANDLERS = (
+ (None, None, False), # 0 TType.STOP
+ (None, None, False), # 1 TType.VOID # TODO: handle void?
+ ('readBool', 'writeBool', False), # 2 TType.BOOL
+ ('readByte', 'writeByte', False), # 3 TType.BYTE and I08
+ ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
+ (None, None, False), # 5 undefined
+ ('readI16', 'writeI16', False), # 6 TType.I16
+ (None, None, False), # 7 undefined
+ ('readI32', 'writeI32', False), # 8 TType.I32
+ (None, None, False), # 9 undefined
+ ('readI64', 'writeI64', False), # 10 TType.I64
+ ('readString', 'writeString', False), # 11 TType.STRING and UTF7
+ ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
+ ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
+ ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
+ ('readContainerList', 'writeContainerList', True), # 15 TType.LIST
+ (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
+ (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
+ )
- def _ttype_handlers(self, ttype, spec):
- if spec == 'BINARY':
- if ttype != TType.STRING:
- raise TProtocolException(type=TProtocolException.INVALID_DATA,
- message='Invalid binary field type %d' % ttype)
- return ('readBinary', 'writeBinary', False)
- if sys.version_info[0] == 2 and spec == 'UTF8':
- if ttype != TType.STRING:
- raise TProtocolException(type=TProtocolException.INVALID_DATA,
- message='Invalid string field type %d' % ttype)
- return ('readUtf8', 'writeUtf8', False)
- return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
+ def _ttype_handlers(self, ttype, spec):
+ if spec == 'BINARY':
+ if ttype != TType.STRING:
+ raise TProtocolException(type=TProtocolException.INVALID_DATA,
+ message='Invalid binary field type %d' % ttype)
+ return ('readBinary', 'writeBinary', False)
+ if sys.version_info[0] == 2 and spec == 'UTF8':
+ if ttype != TType.STRING:
+ raise TProtocolException(type=TProtocolException.INVALID_DATA,
+ message='Invalid string field type %d' % ttype)
+ return ('readUtf8', 'writeUtf8', False)
+ return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
- def _read_by_ttype(self, ttype, spec, espec):
- reader_name, _, is_container = self._ttype_handlers(ttype, spec)
- if reader_name is None:
- raise TProtocolException(type=TProtocolException.INVALID_DATA,
- message='Invalid type %d' % (ttype))
- reader_func = getattr(self, reader_name)
- read = (lambda: reader_func(espec)) if is_container else reader_func
- while True:
- yield read()
+ def _read_by_ttype(self, ttype, spec, espec):
+ reader_name, _, is_container = self._ttype_handlers(ttype, spec)
+ if reader_name is None:
+ raise TProtocolException(type=TProtocolException.INVALID_DATA,
+ message='Invalid type %d' % (ttype))
+ reader_func = getattr(self, reader_name)
+ read = (lambda: reader_func(espec)) if is_container else reader_func
+ while True:
+ yield read()
- def readFieldByTType(self, ttype, spec):
- return self._read_by_ttype(ttype, spec, spec).next()
+ def readFieldByTType(self, ttype, spec):
+ return self._read_by_ttype(ttype, spec, spec).next()
- def readContainerList(self, spec):
- ttype, tspec, is_immutable = spec
- (list_type, list_len) = self.readListBegin()
- # TODO: compare types we just decoded with thrift_spec
- elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len)
- results = (tuple if is_immutable else list)(elems)
- self.readListEnd()
- return results
+ def readContainerList(self, spec):
+ ttype, tspec, is_immutable = spec
+ (list_type, list_len) = self.readListBegin()
+ # TODO: compare types we just decoded with thrift_spec
+ elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len)
+ results = (tuple if is_immutable else list)(elems)
+ self.readListEnd()
+ return results
- def readContainerSet(self, spec):
- ttype, tspec, is_immutable = spec
- (set_type, set_len) = self.readSetBegin()
- # TODO: compare types we just decoded with thrift_spec
- elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len)
- results = (frozenset if is_immutable else set)(elems)
- self.readSetEnd()
- return results
+ def readContainerSet(self, spec):
+ ttype, tspec, is_immutable = spec
+ (set_type, set_len) = self.readSetBegin()
+ # TODO: compare types we just decoded with thrift_spec
+ elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len)
+ results = (frozenset if is_immutable else set)(elems)
+ self.readSetEnd()
+ return results
- def readContainerStruct(self, spec):
- (obj_class, obj_spec) = spec
- obj = obj_class()
- obj.read(self)
- return obj
+ def readContainerStruct(self, spec):
+ (obj_class, obj_spec) = spec
+ obj = obj_class()
+ obj.read(self)
+ return obj
- def readContainerMap(self, spec):
- ktype, kspec, vtype, vspec, is_immutable = spec
- (map_ktype, map_vtype, map_len) = self.readMapBegin()
- # TODO: compare types we just decoded with thrift_spec and
- # abort/skip if types disagree
- keys = self._read_by_ttype(ktype, spec, kspec)
- vals = self._read_by_ttype(vtype, spec, vspec)
- keyvals = islice(zip(keys, vals), map_len)
- results = (TFrozenDict if is_immutable else dict)(keyvals)
- self.readMapEnd()
- return results
+ def readContainerMap(self, spec):
+ ktype, kspec, vtype, vspec, is_immutable = spec
+ (map_ktype, map_vtype, map_len) = self.readMapBegin()
+ # TODO: compare types we just decoded with thrift_spec and
+ # abort/skip if types disagree
+ keys = self._read_by_ttype(ktype, spec, kspec)
+ vals = self._read_by_ttype(vtype, spec, vspec)
+ keyvals = islice(zip(keys, vals), map_len)
+ results = (TFrozenDict if is_immutable else dict)(keyvals)
+ self.readMapEnd()
+ return results
- def readStruct(self, obj, thrift_spec, is_immutable=False):
- if is_immutable:
- fields = {}
- self.readStructBegin()
- while True:
- (fname, ftype, fid) = self.readFieldBegin()
- if ftype == TType.STOP:
- break
- try:
- field = thrift_spec[fid]
- except IndexError:
- self.skip(ftype)
- else:
- if field is not None and ftype == field[1]:
- fname = field[2]
- fspec = field[3]
- val = self.readFieldByTType(ftype, fspec)
- if is_immutable:
- fields[fname] = val
- else:
- setattr(obj, fname, val)
- else:
- self.skip(ftype)
- self.readFieldEnd()
- self.readStructEnd()
- if is_immutable:
- return obj(**fields)
+ def readStruct(self, obj, thrift_spec, is_immutable=False):
+ if is_immutable:
+ fields = {}
+ self.readStructBegin()
+ while True:
+ (fname, ftype, fid) = self.readFieldBegin()
+ if ftype == TType.STOP:
+ break
+ try:
+ field = thrift_spec[fid]
+ except IndexError:
+ self.skip(ftype)
+ else:
+ if field is not None and ftype == field[1]:
+ fname = field[2]
+ fspec = field[3]
+ val = self.readFieldByTType(ftype, fspec)
+ if is_immutable:
+ fields[fname] = val
+ else:
+ setattr(obj, fname, val)
+ else:
+ self.skip(ftype)
+ self.readFieldEnd()
+ self.readStructEnd()
+ if is_immutable:
+ return obj(**fields)
- def writeContainerStruct(self, val, spec):
- val.write(self)
+ def writeContainerStruct(self, val, spec):
+ val.write(self)
- def writeContainerList(self, val, spec):
- ttype, tspec, _ = spec
- self.writeListBegin(ttype, len(val))
- for _ in self._write_by_ttype(ttype, val, spec, tspec):
- pass
- self.writeListEnd()
+ def writeContainerList(self, val, spec):
+ ttype, tspec, _ = spec
+ self.writeListBegin(ttype, len(val))
+ for _ in self._write_by_ttype(ttype, val, spec, tspec):
+ pass
+ self.writeListEnd()
- def writeContainerSet(self, val, spec):
- ttype, tspec, _ = spec
- self.writeSetBegin(ttype, len(val))
- for _ in self._write_by_ttype(ttype, val, spec, tspec):
- pass
- self.writeSetEnd()
+ def writeContainerSet(self, val, spec):
+ ttype, tspec, _ = spec
+ self.writeSetBegin(ttype, len(val))
+ for _ in self._write_by_ttype(ttype, val, spec, tspec):
+ pass
+ self.writeSetEnd()
- def writeContainerMap(self, val, spec):
- ktype, kspec, vtype, vspec, _ = spec
- self.writeMapBegin(ktype, vtype, len(val))
- for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec),
- self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)):
- pass
- self.writeMapEnd()
+ def writeContainerMap(self, val, spec):
+ ktype, kspec, vtype, vspec, _ = spec
+ self.writeMapBegin(ktype, vtype, len(val))
+ for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec),
+ self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)):
+ pass
+ self.writeMapEnd()
- def writeStruct(self, obj, thrift_spec):
- self.writeStructBegin(obj.__class__.__name__)
- for field in thrift_spec:
- if field is None:
- continue
- fname = field[2]
- val = getattr(obj, fname)
- if val is None:
- # skip writing out unset fields
- continue
- fid = field[0]
- ftype = field[1]
- fspec = field[3]
- self.writeFieldBegin(fname, ftype, fid)
- self.writeFieldByTType(ftype, val, fspec)
- self.writeFieldEnd()
- self.writeFieldStop()
- self.writeStructEnd()
+ def writeStruct(self, obj, thrift_spec):
+ self.writeStructBegin(obj.__class__.__name__)
+ for field in thrift_spec:
+ if field is None:
+ continue
+ fname = field[2]
+ val = getattr(obj, fname)
+ if val is None:
+ # skip writing out unset fields
+ continue
+ fid = field[0]
+ ftype = field[1]
+ fspec = field[3]
+ self.writeFieldBegin(fname, ftype, fid)
+ self.writeFieldByTType(ftype, val, fspec)
+ self.writeFieldEnd()
+ self.writeFieldStop()
+ self.writeStructEnd()
- def _write_by_ttype(self, ttype, vals, spec, espec):
- _, writer_name, is_container = self._ttype_handlers(ttype, spec)
- writer_func = getattr(self, writer_name)
- write = (lambda v: writer_func(v, espec)) if is_container else writer_func
- for v in vals:
- yield write(v)
+ def _write_by_ttype(self, ttype, vals, spec, espec):
+ _, writer_name, is_container = self._ttype_handlers(ttype, spec)
+ writer_func = getattr(self, writer_name)
+ write = (lambda v: writer_func(v, espec)) if is_container else writer_func
+ for v in vals:
+ yield write(v)
- def writeFieldByTType(self, ttype, val, spec):
- self._write_by_ttype(ttype, [val], spec, spec).next()
+ def writeFieldByTType(self, ttype, val, spec):
+ self._write_by_ttype(ttype, [val], spec, spec).next()
def checkIntegerLimits(i, bits):
@@ -408,10 +408,10 @@
raise TProtocolException(TProtocolException.INVALID_DATA,
"i32 requires -2147483648 <= number <= 2147483647")
elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807):
- raise TProtocolException(TProtocolException.INVALID_DATA,
- "i64 requires -9223372036854775808 <= number <= 9223372036854775807")
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "i64 requires -9223372036854775808 <= number <= 9223372036854775807")
class TProtocolFactory(object):
- def getProtocol(self, trans):
- pass
+ def getProtocol(self, trans):
+ pass
diff --git a/lib/py/src/protocol/TProtocolDecorator.py b/lib/py/src/protocol/TProtocolDecorator.py
index bf50bfa..8b270a4 100644
--- a/lib/py/src/protocol/TProtocolDecorator.py
+++ b/lib/py/src/protocol/TProtocolDecorator.py
@@ -17,26 +17,34 @@
# under the License.
#
+import types
+
from thrift.protocol.TProtocol import TProtocolBase
-from types import *
+
class TProtocolDecorator():
- def __init__(self, protocol):
- TProtocolBase(protocol)
- self.protocol = protocol
+ def __init__(self, protocol):
+ TProtocolBase(protocol)
+ self.protocol = protocol
- def __getattr__(self, name):
- if hasattr(self.protocol, name):
- member = getattr(self.protocol, name)
- if type(member) in [MethodType, FunctionType, LambdaType, BuiltinFunctionType, BuiltinMethodType]:
- return lambda *args, **kwargs: self._wrap(member, args, kwargs)
- else:
- return member
- raise AttributeError(name)
+ def __getattr__(self, name):
+ if hasattr(self.protocol, name):
+ member = getattr(self.protocol, name)
+ if type(member) in [
+ types.MethodType,
+ types.FunctionType,
+ types.LambdaType,
+ types.BuiltinFunctionType,
+ types.BuiltinMethodType,
+ ]:
+ return lambda *args, **kwargs: self._wrap(member, args, kwargs)
+ else:
+ return member
+ raise AttributeError(name)
- def _wrap(self, func, args, kwargs):
- if type(func) == MethodType:
- result = func(*args, **kwargs)
- else:
- result = func(self.protocol, *args, **kwargs)
- return result
+ def _wrap(self, func, args, kwargs):
+ if isinstance(func, types.MethodType):
+ result = func(*args, **kwargs)
+ else:
+ result = func(self.protocol, *args, **kwargs)
+ return result
diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/protocol/__init__.py
index 7eefb45..7148f66 100644
--- a/lib/py/src/protocol/__init__.py
+++ b/lib/py/src/protocol/__init__.py
@@ -17,4 +17,5 @@
# under the License.
#
-__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol', 'TJSONProtocol', 'TProtocol']
+__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol',
+ 'TJSONProtocol', 'TProtocol']
diff --git a/lib/py/src/server/THttpServer.py b/lib/py/src/server/THttpServer.py
index bf3b0e3..1b501a7 100644
--- a/lib/py/src/server/THttpServer.py
+++ b/lib/py/src/server/THttpServer.py
@@ -24,64 +24,64 @@
class ResponseException(Exception):
- """Allows handlers to override the HTTP response
+ """Allows handlers to override the HTTP response
- Normally, THttpServer always sends a 200 response. If a handler wants
- to override this behavior (e.g., to simulate a misconfigured or
- overloaded web server during testing), it can raise a ResponseException.
- The function passed to the constructor will be called with the
- RequestHandler as its only argument.
- """
- def __init__(self, handler):
- self.handler = handler
+ Normally, THttpServer always sends a 200 response. If a handler wants
+ to override this behavior (e.g., to simulate a misconfigured or
+ overloaded web server during testing), it can raise a ResponseException.
+ The function passed to the constructor will be called with the
+ RequestHandler as its only argument.
+ """
+ def __init__(self, handler):
+ self.handler = handler
class THttpServer(TServer.TServer):
- """A simple HTTP-based Thrift server
+ """A simple HTTP-based Thrift server
- This class is not very performant, but it is useful (for example) for
- acting as a mock version of an Apache-based PHP Thrift endpoint.
- """
- def __init__(self,
- processor,
- server_address,
- inputProtocolFactory,
- outputProtocolFactory=None,
- server_class=BaseHTTPServer.HTTPServer):
- """Set up protocol factories and HTTP server.
-
- See BaseHTTPServer for server_address.
- See TServer for protocol factories.
+ This class is not very performant, but it is useful (for example) for
+ acting as a mock version of an Apache-based PHP Thrift endpoint.
"""
- if outputProtocolFactory is None:
- outputProtocolFactory = inputProtocolFactory
+ def __init__(self,
+ processor,
+ server_address,
+ inputProtocolFactory,
+ outputProtocolFactory=None,
+ server_class=BaseHTTPServer.HTTPServer):
+ """Set up protocol factories and HTTP server.
- TServer.TServer.__init__(self, processor, None, None, None,
- inputProtocolFactory, outputProtocolFactory)
+ See BaseHTTPServer for server_address.
+ See TServer for protocol factories.
+ """
+ if outputProtocolFactory is None:
+ outputProtocolFactory = inputProtocolFactory
- thttpserver = self
+ TServer.TServer.__init__(self, processor, None, None, None,
+ inputProtocolFactory, outputProtocolFactory)
- class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
- def do_POST(self):
- # Don't care about the request path.
- itrans = TTransport.TFileObjectTransport(self.rfile)
- otrans = TTransport.TFileObjectTransport(self.wfile)
- itrans = TTransport.TBufferedTransport(
- itrans, int(self.headers['Content-Length']))
- otrans = TTransport.TMemoryBuffer()
- iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
- oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
- try:
- thttpserver.processor.process(iprot, oprot)
- except ResponseException as exn:
- exn.handler(self)
- else:
- self.send_response(200)
- self.send_header("content-type", "application/x-thrift")
- self.end_headers()
- self.wfile.write(otrans.getvalue())
+ thttpserver = self
- self.httpd = server_class(server_address, RequestHander)
+ class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
+ def do_POST(self):
+ # Don't care about the request path.
+ itrans = TTransport.TFileObjectTransport(self.rfile)
+ otrans = TTransport.TFileObjectTransport(self.wfile)
+ itrans = TTransport.TBufferedTransport(
+ itrans, int(self.headers['Content-Length']))
+ otrans = TTransport.TMemoryBuffer()
+ iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
+ oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
+ try:
+ thttpserver.processor.process(iprot, oprot)
+ except ResponseException as exn:
+ exn.handler(self)
+ else:
+ self.send_response(200)
+ self.send_header("content-type", "application/x-thrift")
+ self.end_headers()
+ self.wfile.write(otrans.getvalue())
- def serve(self):
- self.httpd.serve_forever()
+ self.httpd = server_class(server_address, RequestHander)
+
+ def serve(self):
+ self.httpd.serve_forever()
diff --git a/lib/py/src/server/TNonblockingServer.py b/lib/py/src/server/TNonblockingServer.py
index a930a80..87031c1 100644
--- a/lib/py/src/server/TNonblockingServer.py
+++ b/lib/py/src/server/TNonblockingServer.py
@@ -24,13 +24,12 @@
The thread poool should be sized for concurrent tasks, not
maximum connections
"""
-import threading
-import socket
-import select
-import struct
import logging
-logger = logging.getLogger(__name__)
+import select
+import socket
+import struct
+import threading
from six.moves import queue
@@ -39,6 +38,8 @@
__all__ = ['TNonblockingServer']
+logger = logging.getLogger(__name__)
+
class Worker(threading.Thread):
"""Worker is a small helper to process incoming connection."""
@@ -127,7 +128,7 @@
self.len, = struct.unpack('!i', self.message)
if self.len < 0:
logger.error("negative frame size, it seems client "
- "doesn't use FramedTransport")
+ "doesn't use FramedTransport")
self.close()
elif self.len == 0:
logger.error("empty frame, it's really strange")
@@ -149,7 +150,7 @@
read = self.socket.recv(self.len - len(self.message))
if len(read) == 0:
logger.error("can't read frame from socket (get %d of "
- "%d bytes)" % (len(self.message), self.len))
+ "%d bytes)" % (len(self.message), self.len))
self.close()
return
self.message += read
diff --git a/lib/py/src/server/TProcessPoolServer.py b/lib/py/src/server/TProcessPoolServer.py
index b2c2308..fe6dc81 100644
--- a/lib/py/src/server/TProcessPoolServer.py
+++ b/lib/py/src/server/TProcessPoolServer.py
@@ -19,13 +19,14 @@
import logging
-logger = logging.getLogger(__name__)
-from multiprocessing import Process, Value, Condition, reduction
+from multiprocessing import Process, Value, Condition
from .TServer import TServer
from thrift.transport.TTransport import TTransportException
+logger = logging.getLogger(__name__)
+
class TProcessPoolServer(TServer):
"""Server with a fixed size pool of worker subprocesses to service requests
@@ -59,7 +60,7 @@
try:
client = self.serverTransport.accept()
if not client:
- continue
+ continue
self.serveClient(client)
except (KeyboardInterrupt, SystemExit):
return 0
@@ -76,7 +77,7 @@
try:
while True:
self.processor.process(iprot, oprot)
- except TTransportException as tx:
+ except TTransportException:
pass
except Exception as x:
logger.exception(x)
diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py
index 30f063b..d5d9c98 100644
--- a/lib/py/src/server/TServer.py
+++ b/lib/py/src/server/TServer.py
@@ -18,262 +18,259 @@
#
from six.moves import queue
-import os
-import sys
-import threading
-import traceback
-
import logging
-logger = logging.getLogger(__name__)
+import os
+import threading
-from thrift.Thrift import TProcessor
from thrift.protocol import TBinaryProtocol
from thrift.transport import TTransport
+logger = logging.getLogger(__name__)
+
class TServer(object):
- """Base interface for a server, which must have a serve() method.
+ """Base interface for a server, which must have a serve() method.
- Three constructors for all servers:
- 1) (processor, serverTransport)
- 2) (processor, serverTransport, transportFactory, protocolFactory)
- 3) (processor, serverTransport,
- inputTransportFactory, outputTransportFactory,
- inputProtocolFactory, outputProtocolFactory)
- """
- def __init__(self, *args):
- if (len(args) == 2):
- self.__initArgs__(args[0], args[1],
- TTransport.TTransportFactoryBase(),
- TTransport.TTransportFactoryBase(),
- TBinaryProtocol.TBinaryProtocolFactory(),
- TBinaryProtocol.TBinaryProtocolFactory())
- elif (len(args) == 4):
- self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
- elif (len(args) == 6):
- self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
+ Three constructors for all servers:
+ 1) (processor, serverTransport)
+ 2) (processor, serverTransport, transportFactory, protocolFactory)
+ 3) (processor, serverTransport,
+ inputTransportFactory, outputTransportFactory,
+ inputProtocolFactory, outputProtocolFactory)
+ """
+ def __init__(self, *args):
+ if (len(args) == 2):
+ self.__initArgs__(args[0], args[1],
+ TTransport.TTransportFactoryBase(),
+ TTransport.TTransportFactoryBase(),
+ TBinaryProtocol.TBinaryProtocolFactory(),
+ TBinaryProtocol.TBinaryProtocolFactory())
+ elif (len(args) == 4):
+ self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
+ elif (len(args) == 6):
+ self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
- def __initArgs__(self, processor, serverTransport,
- inputTransportFactory, outputTransportFactory,
- inputProtocolFactory, outputProtocolFactory):
- self.processor = processor
- self.serverTransport = serverTransport
- self.inputTransportFactory = inputTransportFactory
- self.outputTransportFactory = outputTransportFactory
- self.inputProtocolFactory = inputProtocolFactory
- self.outputProtocolFactory = outputProtocolFactory
+ def __initArgs__(self, processor, serverTransport,
+ inputTransportFactory, outputTransportFactory,
+ inputProtocolFactory, outputProtocolFactory):
+ self.processor = processor
+ self.serverTransport = serverTransport
+ self.inputTransportFactory = inputTransportFactory
+ self.outputTransportFactory = outputTransportFactory
+ self.inputProtocolFactory = inputProtocolFactory
+ self.outputProtocolFactory = outputProtocolFactory
- def serve(self):
- pass
+ def serve(self):
+ pass
class TSimpleServer(TServer):
- """Simple single-threaded server that just pumps around one transport."""
+ """Simple single-threaded server that just pumps around one transport."""
- def __init__(self, *args):
- TServer.__init__(self, *args)
+ def __init__(self, *args):
+ TServer.__init__(self, *args)
- def serve(self):
- self.serverTransport.listen()
- while True:
- client = self.serverTransport.accept()
- if not client:
- continue
- itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
- iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
- try:
+ def serve(self):
+ self.serverTransport.listen()
while True:
- self.processor.process(iprot, oprot)
- except TTransport.TTransportException as tx:
- pass
- except Exception as x:
- logger.exception(x)
+ client = self.serverTransport.accept()
+ if not client:
+ continue
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+ try:
+ while True:
+ self.processor.process(iprot, oprot)
+ except TTransport.TTransportException:
+ pass
+ except Exception as x:
+ logger.exception(x)
- itrans.close()
- otrans.close()
+ itrans.close()
+ otrans.close()
class TThreadedServer(TServer):
- """Threaded server that spawns a new thread per each connection."""
+ """Threaded server that spawns a new thread per each connection."""
- def __init__(self, *args, **kwargs):
- TServer.__init__(self, *args)
- self.daemon = kwargs.get("daemon", False)
+ def __init__(self, *args, **kwargs):
+ TServer.__init__(self, *args)
+ self.daemon = kwargs.get("daemon", False)
- def serve(self):
- self.serverTransport.listen()
- while True:
- try:
- client = self.serverTransport.accept()
- if not client:
- continue
- t = threading.Thread(target=self.handle, args=(client,))
- t.setDaemon(self.daemon)
- t.start()
- except KeyboardInterrupt:
- raise
- except Exception as x:
- logger.exception(x)
+ def serve(self):
+ self.serverTransport.listen()
+ while True:
+ try:
+ client = self.serverTransport.accept()
+ if not client:
+ continue
+ t = threading.Thread(target=self.handle, args=(client,))
+ t.setDaemon(self.daemon)
+ t.start()
+ except KeyboardInterrupt:
+ raise
+ except Exception as x:
+ logger.exception(x)
- def handle(self, client):
- itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
- iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
- try:
- while True:
- self.processor.process(iprot, oprot)
- except TTransport.TTransportException as tx:
- pass
- except Exception as x:
- logger.exception(x)
+ def handle(self, client):
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+ try:
+ while True:
+ self.processor.process(iprot, oprot)
+ except TTransport.TTransportException:
+ pass
+ except Exception as x:
+ logger.exception(x)
- itrans.close()
- otrans.close()
+ itrans.close()
+ otrans.close()
class TThreadPoolServer(TServer):
- """Server with a fixed size pool of threads which service requests."""
+ """Server with a fixed size pool of threads which service requests."""
- def __init__(self, *args, **kwargs):
- TServer.__init__(self, *args)
- self.clients = queue.Queue()
- self.threads = 10
- self.daemon = kwargs.get("daemon", False)
+ def __init__(self, *args, **kwargs):
+ TServer.__init__(self, *args)
+ self.clients = queue.Queue()
+ self.threads = 10
+ self.daemon = kwargs.get("daemon", False)
- def setNumThreads(self, num):
- """Set the number of worker threads that should be created"""
- self.threads = num
+ def setNumThreads(self, num):
+ """Set the number of worker threads that should be created"""
+ self.threads = num
- def serveThread(self):
- """Loop around getting clients from the shared queue and process them."""
- while True:
- try:
- client = self.clients.get()
- self.serveClient(client)
- except Exception as x:
- logger.exception(x)
+ def serveThread(self):
+ """Loop around getting clients from the shared queue and process them."""
+ while True:
+ try:
+ client = self.clients.get()
+ self.serveClient(client)
+ except Exception as x:
+ logger.exception(x)
- def serveClient(self, client):
- """Process input/output from a client for as long as possible"""
- itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
- iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
- try:
- while True:
- self.processor.process(iprot, oprot)
- except TTransport.TTransportException as tx:
- pass
- except Exception as x:
- logger.exception(x)
+ def serveClient(self, client):
+ """Process input/output from a client for as long as possible"""
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+ try:
+ while True:
+ self.processor.process(iprot, oprot)
+ except TTransport.TTransportException:
+ pass
+ except Exception as x:
+ logger.exception(x)
- itrans.close()
- otrans.close()
+ itrans.close()
+ otrans.close()
- def serve(self):
- """Start a fixed number of worker threads and put client into a queue"""
- for i in range(self.threads):
- try:
- t = threading.Thread(target=self.serveThread)
- t.setDaemon(self.daemon)
- t.start()
- except Exception as x:
- logger.exception(x)
+ def serve(self):
+ """Start a fixed number of worker threads and put client into a queue"""
+ for i in range(self.threads):
+ try:
+ t = threading.Thread(target=self.serveThread)
+ t.setDaemon(self.daemon)
+ t.start()
+ except Exception as x:
+ logger.exception(x)
- # Pump the socket for clients
- self.serverTransport.listen()
- while True:
- try:
- client = self.serverTransport.accept()
- if not client:
- continue
- self.clients.put(client)
- except Exception as x:
- logger.exception(x)
+ # Pump the socket for clients
+ self.serverTransport.listen()
+ while True:
+ try:
+ client = self.serverTransport.accept()
+ if not client:
+ continue
+ self.clients.put(client)
+ except Exception as x:
+ logger.exception(x)
class TForkingServer(TServer):
- """A Thrift server that forks a new process for each request
+ """A Thrift server that forks a new process for each request
- This is more scalable than the threaded server as it does not cause
- GIL contention.
+ This is more scalable than the threaded server as it does not cause
+ GIL contention.
- Note that this has different semantics from the threading server.
- Specifically, updates to shared variables will no longer be shared.
- It will also not work on windows.
+ Note that this has different semantics from the threading server.
+ Specifically, updates to shared variables will no longer be shared.
+ It will also not work on windows.
- This code is heavily inspired by SocketServer.ForkingMixIn in the
- Python stdlib.
- """
- def __init__(self, *args):
- TServer.__init__(self, *args)
- self.children = []
+ This code is heavily inspired by SocketServer.ForkingMixIn in the
+ Python stdlib.
+ """
+ def __init__(self, *args):
+ TServer.__init__(self, *args)
+ self.children = []
- def serve(self):
- def try_close(file):
- try:
- file.close()
- except IOError as e:
- logger.warning(e, exc_info=True)
-
- self.serverTransport.listen()
- while True:
- client = self.serverTransport.accept()
- if not client:
- continue
- try:
- pid = os.fork()
-
- if pid: # parent
- # add before collect, otherwise you race w/ waitpid
- self.children.append(pid)
- self.collect_children()
-
- # Parent must close socket or the connection may not get
- # closed promptly
- itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
- try_close(itrans)
- try_close(otrans)
- else:
- itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
-
- iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
-
- ecode = 0
- try:
+ def serve(self):
+ def try_close(file):
try:
- while True:
- self.processor.process(iprot, oprot)
+ file.close()
+ except IOError as e:
+ logger.warning(e, exc_info=True)
+
+ self.serverTransport.listen()
+ while True:
+ client = self.serverTransport.accept()
+ if not client:
+ continue
+ try:
+ pid = os.fork()
+
+ if pid: # parent
+ # add before collect, otherwise you race w/ waitpid
+ self.children.append(pid)
+ self.collect_children()
+
+ # Parent must close socket or the connection may not get
+ # closed promptly
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+ try_close(itrans)
+ try_close(otrans)
+ else:
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ ecode = 0
+ try:
+ try:
+ while True:
+ self.processor.process(iprot, oprot)
+ except TTransport.TTransportException:
+ pass
+ except Exception as e:
+ logger.exception(e)
+ ecode = 1
+ finally:
+ try_close(itrans)
+ try_close(otrans)
+
+ os._exit(ecode)
+
except TTransport.TTransportException:
- pass
- except Exception as e:
- logger.exception(e)
- ecode = 1
- finally:
- try_close(itrans)
- try_close(otrans)
+ pass
+ except Exception as x:
+ logger.exception(x)
- os._exit(ecode)
+ def collect_children(self):
+ while self.children:
+ try:
+ pid, status = os.waitpid(0, os.WNOHANG)
+ except os.error:
+ pid = None
- except TTransport.TTransportException:
- pass
- except Exception as x:
- logger.exception(x)
-
- def collect_children(self):
- while self.children:
- try:
- pid, status = os.waitpid(0, os.WNOHANG)
- except os.error:
- pid = None
-
- if pid:
- self.children.remove(pid)
- else:
- break
+ if pid:
+ self.children.remove(pid)
+ else:
+ break
diff --git a/lib/py/src/transport/THttpClient.py b/lib/py/src/transport/THttpClient.py
index 5abd41c..95f118c 100644
--- a/lib/py/src/transport/THttpClient.py
+++ b/lib/py/src/transport/THttpClient.py
@@ -26,130 +26,130 @@
from six.moves import urllib
from six.moves import http_client
-from .TTransport import *
+from .TTransport import TTransportBase
import six
class THttpClient(TTransportBase):
- """Http implementation of TTransport base."""
+ """Http implementation of TTransport base."""
- def __init__(self, uri_or_host, port=None, path=None):
- """THttpClient supports two different types constructor parameters.
+ def __init__(self, uri_or_host, port=None, path=None):
+ """THttpClient supports two different types constructor parameters.
- THttpClient(host, port, path) - deprecated
- THttpClient(uri)
+ THttpClient(host, port, path) - deprecated
+ THttpClient(uri)
- Only the second supports https.
- """
- if port is not None:
- warnings.warn(
- "Please use the THttpClient('http://host:port/path') syntax",
- DeprecationWarning,
- stacklevel=2)
- self.host = uri_or_host
- self.port = port
- assert path
- self.path = path
- self.scheme = 'http'
- else:
- parsed = urllib.parse.urlparse(uri_or_host)
- self.scheme = parsed.scheme
- assert self.scheme in ('http', 'https')
- if self.scheme == 'http':
- self.port = parsed.port or http_client.HTTP_PORT
- elif self.scheme == 'https':
- self.port = parsed.port or http_client.HTTPS_PORT
- self.host = parsed.hostname
- self.path = parsed.path
- if parsed.query:
- self.path += '?%s' % parsed.query
- self.__wbuf = BytesIO()
- self.__http = None
- self.__http_response = None
- self.__timeout = None
- self.__custom_headers = None
+ Only the second supports https.
+ """
+ if port is not None:
+ warnings.warn(
+ "Please use the THttpClient('http://host:port/path') syntax",
+ DeprecationWarning,
+ stacklevel=2)
+ self.host = uri_or_host
+ self.port = port
+ assert path
+ self.path = path
+ self.scheme = 'http'
+ else:
+ parsed = urllib.parse.urlparse(uri_or_host)
+ self.scheme = parsed.scheme
+ assert self.scheme in ('http', 'https')
+ if self.scheme == 'http':
+ self.port = parsed.port or http_client.HTTP_PORT
+ elif self.scheme == 'https':
+ self.port = parsed.port or http_client.HTTPS_PORT
+ self.host = parsed.hostname
+ self.path = parsed.path
+ if parsed.query:
+ self.path += '?%s' % parsed.query
+ self.__wbuf = BytesIO()
+ self.__http = None
+ self.__http_response = None
+ self.__timeout = None
+ self.__custom_headers = None
- def open(self):
- if self.scheme == 'http':
- self.__http = http_client.HTTPConnection(self.host, self.port)
- else:
- self.__http = http_client.HTTPSConnection(self.host, self.port)
+ def open(self):
+ if self.scheme == 'http':
+ self.__http = http_client.HTTPConnection(self.host, self.port)
+ else:
+ self.__http = http_client.HTTPSConnection(self.host, self.port)
- def close(self):
- self.__http.close()
- self.__http = None
- self.__http_response = None
+ def close(self):
+ self.__http.close()
+ self.__http = None
+ self.__http_response = None
- def isOpen(self):
- return self.__http is not None
+ def isOpen(self):
+ return self.__http is not None
- def setTimeout(self, ms):
- if not hasattr(socket, 'getdefaulttimeout'):
- raise NotImplementedError
+ def setTimeout(self, ms):
+ if not hasattr(socket, 'getdefaulttimeout'):
+ raise NotImplementedError
- if ms is None:
- self.__timeout = None
- else:
- self.__timeout = ms / 1000.0
+ if ms is None:
+ self.__timeout = None
+ else:
+ self.__timeout = ms / 1000.0
- def setCustomHeaders(self, headers):
- self.__custom_headers = headers
+ def setCustomHeaders(self, headers):
+ self.__custom_headers = headers
- def read(self, sz):
- return self.__http_response.read(sz)
+ def read(self, sz):
+ return self.__http_response.read(sz)
- def write(self, buf):
- self.__wbuf.write(buf)
+ def write(self, buf):
+ self.__wbuf.write(buf)
- def __withTimeout(f):
- def _f(*args, **kwargs):
- orig_timeout = socket.getdefaulttimeout()
- socket.setdefaulttimeout(args[0].__timeout)
- try:
- result = f(*args, **kwargs)
- finally:
- socket.setdefaulttimeout(orig_timeout)
- return result
- return _f
+ def __withTimeout(f):
+ def _f(*args, **kwargs):
+ orig_timeout = socket.getdefaulttimeout()
+ socket.setdefaulttimeout(args[0].__timeout)
+ try:
+ result = f(*args, **kwargs)
+ finally:
+ socket.setdefaulttimeout(orig_timeout)
+ return result
+ return _f
- def flush(self):
- if self.isOpen():
- self.close()
- self.open()
+ def flush(self):
+ if self.isOpen():
+ self.close()
+ self.open()
- # Pull data out of buffer
- data = self.__wbuf.getvalue()
- self.__wbuf = BytesIO()
+ # Pull data out of buffer
+ data = self.__wbuf.getvalue()
+ self.__wbuf = BytesIO()
- # HTTP request
- self.__http.putrequest('POST', self.path)
+ # HTTP request
+ self.__http.putrequest('POST', self.path)
- # Write headers
- self.__http.putheader('Content-Type', 'application/x-thrift')
- self.__http.putheader('Content-Length', str(len(data)))
+ # Write headers
+ self.__http.putheader('Content-Type', 'application/x-thrift')
+ self.__http.putheader('Content-Length', str(len(data)))
- if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
- user_agent = 'Python/THttpClient'
- script = os.path.basename(sys.argv[0])
- if script:
- user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script))
- self.__http.putheader('User-Agent', user_agent)
+ if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
+ user_agent = 'Python/THttpClient'
+ script = os.path.basename(sys.argv[0])
+ if script:
+ user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script))
+ self.__http.putheader('User-Agent', user_agent)
- if self.__custom_headers:
- for key, val in six.iteritems(self.__custom_headers):
- self.__http.putheader(key, val)
+ if self.__custom_headers:
+ for key, val in six.iteritems(self.__custom_headers):
+ self.__http.putheader(key, val)
- self.__http.endheaders()
+ self.__http.endheaders()
- # Write payload
- self.__http.send(data)
+ # Write payload
+ self.__http.send(data)
- # Get reply to flush the request
- self.__http_response = self.__http.getresponse()
- self.code = self.__http_response.status
- self.message = self.__http_response.reason
- self.headers = self.__http_response.msg
+ # Get reply to flush the request
+ self.__http_response = self.__http.getresponse()
+ self.code = self.__http_response.status
+ self.message = self.__http_response.reason
+ self.headers = self.__http_response.msg
- # Decorate if we know how to timeout
- if hasattr(socket, 'getdefaulttimeout'):
- flush = __withTimeout(flush)
+ # Decorate if we know how to timeout
+ if hasattr(socket, 'getdefaulttimeout'):
+ flush = __withTimeout(flush)
diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/transport/TSSLSocket.py
index 9be0912..3f1a909 100644
--- a/lib/py/src/transport/TSSLSocket.py
+++ b/lib/py/src/transport/TSSLSocket.py
@@ -32,345 +32,345 @@
class TSSLBase(object):
- # SSLContext is not available for Python < 2.7.9
- _has_ssl_context = sys.hexversion >= 0x020709F0
+ # SSLContext is not available for Python < 2.7.9
+ _has_ssl_context = sys.hexversion >= 0x020709F0
- # ciphers argument is not available for Python < 2.7.0
- _has_ciphers = sys.hexversion >= 0x020700F0
+ # ciphers argument is not available for Python < 2.7.0
+ _has_ciphers = sys.hexversion >= 0x020700F0
- # For pythoon >= 2.7.9, use latest TLS that both client and server supports.
- # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
- # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nare OP_NO_SSLvX are unavailable.
- _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else ssl.PROTOCOL_TLSv1
+ # For pythoon >= 2.7.9, use latest TLS that both client and server supports.
+ # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
+ # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nare OP_NO_SSLvX are unavailable.
+ _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else ssl.PROTOCOL_TLSv1
- def _init_context(self, ssl_version):
- if self._has_ssl_context:
- self._context = ssl.SSLContext(ssl_version)
- if self._context.protocol == ssl.PROTOCOL_SSLv23:
- self._context.options |= ssl.OP_NO_SSLv2
- self._context.options |= ssl.OP_NO_SSLv3
- else:
- self._context = None
- self._ssl_version = ssl_version
+ def _init_context(self, ssl_version):
+ if self._has_ssl_context:
+ self._context = ssl.SSLContext(ssl_version)
+ if self._context.protocol == ssl.PROTOCOL_SSLv23:
+ self._context.options |= ssl.OP_NO_SSLv2
+ self._context.options |= ssl.OP_NO_SSLv3
+ else:
+ self._context = None
+ self._ssl_version = ssl_version
- @property
- def ssl_version(self):
- if self._has_ssl_context:
- return self.ssl_context.protocol
- else:
- return self._ssl_version
+ @property
+ def ssl_version(self):
+ if self._has_ssl_context:
+ return self.ssl_context.protocol
+ else:
+ return self._ssl_version
- @property
- def ssl_context(self):
- return self._context
+ @property
+ def ssl_context(self):
+ return self._context
- SSL_VERSION = _default_protocol
- """
+ SSL_VERSION = _default_protocol
+ """
Default SSL version.
For backword compatibility, it can be modified.
Use __init__ keywoard argument "ssl_version" instead.
"""
- def _deprecated_arg(self, args, kwargs, pos, key):
- if len(args) <= pos:
- return
- real_pos = pos + 3
- warnings.warn(
- '%dth positional argument is deprecated. Use keyward argument insteand.' % real_pos,
- DeprecationWarning)
- if key in kwargs:
- raise TypeError('Duplicate argument: %dth argument and %s keyward argument.', (real_pos, key))
- kwargs[key] = args[pos]
+ def _deprecated_arg(self, args, kwargs, pos, key):
+ if len(args) <= pos:
+ return
+ real_pos = pos + 3
+ warnings.warn(
+ '%dth positional argument is deprecated. Use keyward argument insteand.' % real_pos,
+ DeprecationWarning)
+ if key in kwargs:
+ raise TypeError('Duplicate argument: %dth argument and %s keyward argument.', (real_pos, key))
+ kwargs[key] = args[pos]
- def _unix_socket_arg(self, host, port, args, kwargs):
- key = 'unix_socket'
- if host is None and port is None and len(args) == 1 and key not in kwargs:
- kwargs[key] = args[0]
- return True
- return False
+ def _unix_socket_arg(self, host, port, args, kwargs):
+ key = 'unix_socket'
+ if host is None and port is None and len(args) == 1 and key not in kwargs:
+ kwargs[key] = args[0]
+ return True
+ return False
- def __getattr__(self, key):
- if key == 'SSL_VERSION':
- warnings.warn('Use ssl_version attribute instead.', DeprecationWarning)
- return self.ssl_version
+ def __getattr__(self, key):
+ if key == 'SSL_VERSION':
+ warnings.warn('Use ssl_version attribute instead.', DeprecationWarning)
+ return self.ssl_version
- def __init__(self, server_side, host, ssl_opts):
- self._server_side = server_side
- if TSSLBase.SSL_VERSION != self._default_protocol:
- warnings.warn('SSL_VERSION is deprecated. Use ssl_version keyward argument instead.', DeprecationWarning)
- self._context = ssl_opts.pop('ssl_context', None)
- self._server_hostname = None
- if not self._server_side:
- self._server_hostname = ssl_opts.pop('server_hostname', host)
- if self._context:
- self._custom_context = True
- if ssl_opts:
- raise ValueError('Incompatible arguments: ssl_context and %s' % ' '.join(ssl_opts.keys()))
- if not self._has_ssl_context:
- raise ValueError('ssl_context is not available for this version of Python')
- else:
- self._custom_context = False
- ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
- self._init_context(ssl_version)
- self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
- self.ca_certs = ssl_opts.pop('ca_certs', None)
- self.keyfile = ssl_opts.pop('keyfile', None)
- self.certfile = ssl_opts.pop('certfile', None)
- self.ciphers = ssl_opts.pop('ciphers', None)
-
- if ssl_opts:
- raise ValueError('Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
-
- if self.cert_reqs != ssl.CERT_NONE:
- if not self.ca_certs:
- raise ValueError('ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
- if not os.access(self.ca_certs, os.R_OK):
- raise IOError('Certificate Authority ca_certs file "%s" '
- 'is not readable, cannot validate SSL '
- 'certificates.' % (self.ca_certs))
-
- @property
- def certfile(self):
- return self._certfile
-
- @certfile.setter
- def certfile(self, certfile):
- if self._server_side and not certfile:
- raise ValueError('certfile is needed for server-side')
- if certfile and not os.access(certfile, os.R_OK):
- raise IOError('No such certfile found: %s' % (certfile))
- self._certfile = certfile
-
- def _wrap_socket(self, sock):
- if self._has_ssl_context:
- if not self._custom_context:
- self.ssl_context.verify_mode = self.cert_reqs
- if self.certfile:
- self.ssl_context.load_cert_chain(self.certfile, self.keyfile)
- if self.ciphers:
- self.ssl_context.set_ciphers(self.ciphers)
- if self.ca_certs:
- self.ssl_context.load_verify_locations(self.ca_certs)
- return self.ssl_context.wrap_socket(sock, server_side=self._server_side,
- server_hostname=self._server_hostname)
- else:
- ssl_opts = {
- 'ssl_version': self._ssl_version,
- 'server_side': self._server_side,
- 'ca_certs': self.ca_certs,
- 'keyfile': self.keyfile,
- 'certfile': self.certfile,
- 'cert_reqs': self.cert_reqs,
- }
- if self.ciphers:
- if self._has_ciphers:
- ssl_opts['ciphers'] = self.ciphers
+ def __init__(self, server_side, host, ssl_opts):
+ self._server_side = server_side
+ if TSSLBase.SSL_VERSION != self._default_protocol:
+ warnings.warn('SSL_VERSION is deprecated. Use ssl_version keyward argument instead.', DeprecationWarning)
+ self._context = ssl_opts.pop('ssl_context', None)
+ self._server_hostname = None
+ if not self._server_side:
+ self._server_hostname = ssl_opts.pop('server_hostname', host)
+ if self._context:
+ self._custom_context = True
+ if ssl_opts:
+ raise ValueError('Incompatible arguments: ssl_context and %s' % ' '.join(ssl_opts.keys()))
+ if not self._has_ssl_context:
+ raise ValueError('ssl_context is not available for this version of Python')
else:
- logger.warning('ciphers is specified but ignored due to old Python version')
- return ssl.wrap_socket(sock, **ssl_opts)
+ self._custom_context = False
+ ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
+ self._init_context(ssl_version)
+ self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
+ self.ca_certs = ssl_opts.pop('ca_certs', None)
+ self.keyfile = ssl_opts.pop('keyfile', None)
+ self.certfile = ssl_opts.pop('certfile', None)
+ self.ciphers = ssl_opts.pop('ciphers', None)
+
+ if ssl_opts:
+ raise ValueError('Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
+
+ if self.cert_reqs != ssl.CERT_NONE:
+ if not self.ca_certs:
+ raise ValueError('ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
+ if not os.access(self.ca_certs, os.R_OK):
+ raise IOError('Certificate Authority ca_certs file "%s" '
+ 'is not readable, cannot validate SSL '
+ 'certificates.' % (self.ca_certs))
+
+ @property
+ def certfile(self):
+ return self._certfile
+
+ @certfile.setter
+ def certfile(self, certfile):
+ if self._server_side and not certfile:
+ raise ValueError('certfile is needed for server-side')
+ if certfile and not os.access(certfile, os.R_OK):
+ raise IOError('No such certfile found: %s' % (certfile))
+ self._certfile = certfile
+
+ def _wrap_socket(self, sock):
+ if self._has_ssl_context:
+ if not self._custom_context:
+ self.ssl_context.verify_mode = self.cert_reqs
+ if self.certfile:
+ self.ssl_context.load_cert_chain(self.certfile, self.keyfile)
+ if self.ciphers:
+ self.ssl_context.set_ciphers(self.ciphers)
+ if self.ca_certs:
+ self.ssl_context.load_verify_locations(self.ca_certs)
+ return self.ssl_context.wrap_socket(sock, server_side=self._server_side,
+ server_hostname=self._server_hostname)
+ else:
+ ssl_opts = {
+ 'ssl_version': self._ssl_version,
+ 'server_side': self._server_side,
+ 'ca_certs': self.ca_certs,
+ 'keyfile': self.keyfile,
+ 'certfile': self.certfile,
+ 'cert_reqs': self.cert_reqs,
+ }
+ if self.ciphers:
+ if self._has_ciphers:
+ ssl_opts['ciphers'] = self.ciphers
+ else:
+ logger.warning('ciphers is specified but ignored due to old Python version')
+ return ssl.wrap_socket(sock, **ssl_opts)
class TSSLSocket(TSocket.TSocket, TSSLBase):
- """
- SSL implementation of TSocket
-
- This class creates outbound sockets wrapped using the
- python standard ssl module for encrypted connections.
- """
-
- # New signature
- # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
- # Deprecated signature
- # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
- def __init__(self, host='localhost', port=9090, *args, **kwargs):
- """Positional arguments: ``host``, ``port``, ``unix_socket``
-
- Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
- ``ca_certs``, ``ciphers`` (Python 2.7.0 or later),
- ``server_hostname`` (Python 2.7.9 or later)
- Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
-
- Alternative keywoard arguments: (Python 2.7.9 or later)
- ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
- ``server_hostname``: Passed to SSLContext.wrap_socket
"""
- self.is_valid = False
- self.peercert = None
+ SSL implementation of TSocket
- if args:
- if len(args) > 6:
- raise TypeError('Too many positional argument')
- if not self._unix_socket_arg(host, port, args, kwargs):
- self._deprecated_arg(args, kwargs, 0, 'validate')
- self._deprecated_arg(args, kwargs, 1, 'ca_certs')
- self._deprecated_arg(args, kwargs, 2, 'keyfile')
- self._deprecated_arg(args, kwargs, 3, 'certfile')
- self._deprecated_arg(args, kwargs, 4, 'unix_socket')
- self._deprecated_arg(args, kwargs, 5, 'ciphers')
+ This class creates outbound sockets wrapped using the
+ python standard ssl module for encrypted connections.
+ """
- validate = kwargs.pop('validate', None)
- if validate is not None:
- cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
- warnings.warn(
- 'validate is deprecated. Use cert_reqs=ssl.%s instead' % cert_reqs_name,
- DeprecationWarning)
- if 'cert_reqs' in kwargs:
- raise TypeError('Cannot specify both validate and cert_reqs')
- kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
+ # New signature
+ # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
+ # Deprecated signature
+ # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
+ def __init__(self, host='localhost', port=9090, *args, **kwargs):
+ """Positional arguments: ``host``, ``port``, ``unix_socket``
- unix_socket = kwargs.pop('unix_socket', None)
- TSSLBase.__init__(self, False, host, kwargs)
- TSocket.TSocket.__init__(self, host, port, unix_socket)
+ Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
+ ``ca_certs``, ``ciphers`` (Python 2.7.0 or later),
+ ``server_hostname`` (Python 2.7.9 or later)
+ Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
- @property
- def validate(self):
- warnings.warn('Use cert_reqs instead', DeprecationWarning)
- return self.cert_reqs != ssl.CERT_NONE
+ Alternative keywoard arguments: (Python 2.7.9 or later)
+ ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
+ ``server_hostname``: Passed to SSLContext.wrap_socket
+ """
+ self.is_valid = False
+ self.peercert = None
- @validate.setter
- def validate(self, value):
- warnings.warn('Use cert_reqs instead', DeprecationWarning)
- self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE
+ if args:
+ if len(args) > 6:
+ raise TypeError('Too many positional argument')
+ if not self._unix_socket_arg(host, port, args, kwargs):
+ self._deprecated_arg(args, kwargs, 0, 'validate')
+ self._deprecated_arg(args, kwargs, 1, 'ca_certs')
+ self._deprecated_arg(args, kwargs, 2, 'keyfile')
+ self._deprecated_arg(args, kwargs, 3, 'certfile')
+ self._deprecated_arg(args, kwargs, 4, 'unix_socket')
+ self._deprecated_arg(args, kwargs, 5, 'ciphers')
- def open(self):
- try:
- res0 = self._resolveAddr()
- for res in res0:
- sock_family, sock_type = res[0:2]
- ip_port = res[4]
- plain_sock = socket.socket(sock_family, sock_type)
- self.handle = self._wrap_socket(plain_sock)
- self.handle.settimeout(self._timeout)
+ validate = kwargs.pop('validate', None)
+ if validate is not None:
+ cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
+ warnings.warn(
+ 'validate is deprecated. Use cert_reqs=ssl.%s instead' % cert_reqs_name,
+ DeprecationWarning)
+ if 'cert_reqs' in kwargs:
+ raise TypeError('Cannot specify both validate and cert_reqs')
+ kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
+
+ unix_socket = kwargs.pop('unix_socket', None)
+ TSSLBase.__init__(self, False, host, kwargs)
+ TSocket.TSocket.__init__(self, host, port, unix_socket)
+
+ @property
+ def validate(self):
+ warnings.warn('Use cert_reqs instead', DeprecationWarning)
+ return self.cert_reqs != ssl.CERT_NONE
+
+ @validate.setter
+ def validate(self, value):
+ warnings.warn('Use cert_reqs instead', DeprecationWarning)
+ self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE
+
+ def open(self):
try:
- self.handle.connect(ip_port)
+ res0 = self._resolveAddr()
+ for res in res0:
+ sock_family, sock_type = res[0:2]
+ ip_port = res[4]
+ plain_sock = socket.socket(sock_family, sock_type)
+ self.handle = self._wrap_socket(plain_sock)
+ self.handle.settimeout(self._timeout)
+ try:
+ self.handle.connect(ip_port)
+ except socket.error as e:
+ if res is not res0[-1]:
+ logger.warning('Error while connecting with %s. Trying next one.', ip_port, exc_info=True)
+ continue
+ else:
+ raise
+ break
except socket.error as e:
- if res is not res0[-1]:
- logger.warning('Error while connecting with %s. Trying next one.', ip_port, exc_info=True)
- continue
- else:
- raise
- break
- except socket.error as e:
- if self._unix_socket:
- message = 'Could not connect to secure socket %s: %s' \
- % (self._unix_socket, e)
- else:
- message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e)
- logger.error('Error while connecting with %s.', ip_port, exc_info=True)
- raise TTransportException(type=TTransportException.NOT_OPEN,
- message=message)
- if self.validate:
- self._validate_cert()
+ if self._unix_socket:
+ message = 'Could not connect to secure socket %s: %s' \
+ % (self._unix_socket, e)
+ else:
+ message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e)
+ logger.error('Error while connecting with %s.', ip_port, exc_info=True)
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message=message)
+ if self.validate:
+ self._validate_cert()
- def _validate_cert(self):
- """internal method to validate the peer's SSL certificate, and to check the
- commonName of the certificate to ensure it matches the hostname we
- used to make this connection. Does not support subjectAltName records
- in certificates.
+ def _validate_cert(self):
+ """internal method to validate the peer's SSL certificate, and to check the
+ commonName of the certificate to ensure it matches the hostname we
+ used to make this connection. Does not support subjectAltName records
+ in certificates.
- raises TTransportException if the certificate fails validation.
- """
- cert = self.handle.getpeercert()
- self.peercert = cert
- if 'subject' not in cert:
- raise TTransportException(
- type=TTransportException.NOT_OPEN,
- message='No SSL certificate found from %s:%s' % (self.host, self.port))
- fields = cert['subject']
- for field in fields:
- # ensure structure we get back is what we expect
- if not isinstance(field, tuple):
- continue
- cert_pair = field[0]
- if len(cert_pair) < 2:
- continue
- cert_key, cert_value = cert_pair[0:2]
- if cert_key != 'commonName':
- continue
- certhost = cert_value
- # this check should be performed by some sort of Access Manager
- if certhost == self.host:
- # success, cert commonName matches desired hostname
- self.is_valid = True
- return
- else:
+ raises TTransportException if the certificate fails validation.
+ """
+ cert = self.handle.getpeercert()
+ self.peercert = cert
+ if 'subject' not in cert:
+ raise TTransportException(
+ type=TTransportException.NOT_OPEN,
+ message='No SSL certificate found from %s:%s' % (self.host, self.port))
+ fields = cert['subject']
+ for field in fields:
+ # ensure structure we get back is what we expect
+ if not isinstance(field, tuple):
+ continue
+ cert_pair = field[0]
+ if len(cert_pair) < 2:
+ continue
+ cert_key, cert_value = cert_pair[0:2]
+ if cert_key != 'commonName':
+ continue
+ certhost = cert_value
+ # this check should be performed by some sort of Access Manager
+ if certhost == self.host:
+ # success, cert commonName matches desired hostname
+ self.is_valid = True
+ return
+ else:
+ raise TTransportException(
+ type=TTransportException.UNKNOWN,
+ message='Hostname we connected to "%s" doesn\'t match certificate '
+ 'provided commonName "%s"' % (self.host, certhost))
raise TTransportException(
- type=TTransportException.UNKNOWN,
- message='Hostname we connected to "%s" doesn\'t match certificate '
- 'provided commonName "%s"' % (self.host, certhost))
- raise TTransportException(
- type=TTransportException.UNKNOWN,
- message='Could not validate SSL certificate from '
- 'host "%s". Cert=%s' % (self.host, cert))
+ type=TTransportException.UNKNOWN,
+ message='Could not validate SSL certificate from '
+ 'host "%s". Cert=%s' % (self.host, cert))
class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
- """SSL implementation of TServerSocket
+ """SSL implementation of TServerSocket
- This uses the ssl module's wrap_socket() method to provide SSL
- negotiated encryption.
- """
-
- # New signature
- # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
- # Deprecated signature
- # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
- def __init__(self, host=None, port=9090, *args, **kwargs):
- """Positional arguments: ``host``, ``port``, ``unix_socket``
-
- Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
- ``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
- See ssl.wrap_socket documentation.
-
- Alternative keywoard arguments: (Python 2.7.9 or later)
- ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
- ``server_hostname``: Passed to SSLContext.wrap_socket
+ This uses the ssl module's wrap_socket() method to provide SSL
+ negotiated encryption.
"""
- if args:
- if len(args) > 3:
- raise TypeError('Too many positional argument')
- if not self._unix_socket_arg(host, port, args, kwargs):
- self._deprecated_arg(args, kwargs, 0, 'certfile')
- self._deprecated_arg(args, kwargs, 1, 'unix_socket')
- self._deprecated_arg(args, kwargs, 2, 'ciphers')
- if 'ssl_context' not in kwargs:
- # Preserve existing behaviors for default values
- if 'cert_reqs' not in kwargs:
- kwargs['cert_reqs'] = ssl.CERT_NONE
- if'certfile' not in kwargs:
- kwargs['certfile'] = 'cert.pem'
+ # New signature
+ # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
+ # Deprecated signature
+ # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
+ def __init__(self, host=None, port=9090, *args, **kwargs):
+ """Positional arguments: ``host``, ``port``, ``unix_socket``
- unix_socket = kwargs.pop('unix_socket', None)
- TSSLBase.__init__(self, True, None, kwargs)
- TSocket.TServerSocket.__init__(self, host, port, unix_socket)
+ Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
+ ``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
+ See ssl.wrap_socket documentation.
- def setCertfile(self, certfile):
- """Set or change the server certificate file used to wrap new connections.
+ Alternative keywoard arguments: (Python 2.7.9 or later)
+ ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
+ ``server_hostname``: Passed to SSLContext.wrap_socket
+ """
+ if args:
+ if len(args) > 3:
+ raise TypeError('Too many positional argument')
+ if not self._unix_socket_arg(host, port, args, kwargs):
+ self._deprecated_arg(args, kwargs, 0, 'certfile')
+ self._deprecated_arg(args, kwargs, 1, 'unix_socket')
+ self._deprecated_arg(args, kwargs, 2, 'ciphers')
- @param certfile: The filename of the server certificate,
- i.e. '/etc/certs/server.pem'
- @type certfile: str
+ if 'ssl_context' not in kwargs:
+ # Preserve existing behaviors for default values
+ if 'cert_reqs' not in kwargs:
+ kwargs['cert_reqs'] = ssl.CERT_NONE
+ if'certfile' not in kwargs:
+ kwargs['certfile'] = 'cert.pem'
- Raises an IOError exception if the certfile is not present or unreadable.
- """
- warnings.warn('Use certfile property instead.', DeprecationWarning)
- self.certfile = certfile
+ unix_socket = kwargs.pop('unix_socket', None)
+ TSSLBase.__init__(self, True, None, kwargs)
+ TSocket.TServerSocket.__init__(self, host, port, unix_socket)
- def accept(self):
- plain_client, addr = self.handle.accept()
- try:
- client = self._wrap_socket(plain_client)
- except ssl.SSLError:
- logger.error('Error while accepting from %s', addr, exc_info=True)
- # failed handshake/ssl wrap, close socket to client
- plain_client.close()
- # raise
- # We can't raise the exception, because it kills most TServer derived
- # serve() methods.
- # Instead, return None, and let the TServer instance deal with it in
- # other exception handling. (but TSimpleServer dies anyway)
- return None
- result = TSocket.TSocket()
- result.setHandle(client)
- return result
+ def setCertfile(self, certfile):
+ """Set or change the server certificate file used to wrap new connections.
+
+ @param certfile: The filename of the server certificate,
+ i.e. '/etc/certs/server.pem'
+ @type certfile: str
+
+ Raises an IOError exception if the certfile is not present or unreadable.
+ """
+ warnings.warn('Use certfile property instead.', DeprecationWarning)
+ self.certfile = certfile
+
+ def accept(self):
+ plain_client, addr = self.handle.accept()
+ try:
+ client = self._wrap_socket(plain_client)
+ except ssl.SSLError:
+ logger.error('Error while accepting from %s', addr, exc_info=True)
+ # failed handshake/ssl wrap, close socket to client
+ plain_client.close()
+ # raise
+ # We can't raise the exception, because it kills most TServer derived
+ # serve() methods.
+ # Instead, return None, and let the TServer instance deal with it in
+ # other exception handling. (but TSimpleServer dies anyway)
+ return None
+ result = TSocket.TSocket()
+ result.setHandle(client)
+ return result
diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py
index cb204a4..a8ed4b7 100644
--- a/lib/py/src/transport/TSocket.py
+++ b/lib/py/src/transport/TSocket.py
@@ -22,159 +22,159 @@
import socket
import sys
-from .TTransport import *
+from .TTransport import TTransportBase, TTransportException, TServerTransportBase
class TSocketBase(TTransportBase):
- def _resolveAddr(self):
- if self._unix_socket is not None:
- return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
- self._unix_socket)]
- else:
- return socket.getaddrinfo(self.host,
- self.port,
- self._socket_family,
- socket.SOCK_STREAM,
- 0,
- socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
+ def _resolveAddr(self):
+ if self._unix_socket is not None:
+ return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
+ self._unix_socket)]
+ else:
+ return socket.getaddrinfo(self.host,
+ self.port,
+ self._socket_family,
+ socket.SOCK_STREAM,
+ 0,
+ socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
- def close(self):
- if self.handle:
- self.handle.close()
- self.handle = None
+ def close(self):
+ if self.handle:
+ self.handle.close()
+ self.handle = None
class TSocket(TSocketBase):
- """Socket implementation of TTransport base."""
+ """Socket implementation of TTransport base."""
- def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
- """Initialize a TSocket
+ def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
+ """Initialize a TSocket
- @param host(str) The host to connect to.
- @param port(int) The (TCP) port to connect to.
- @param unix_socket(str) The filename of a unix socket to connect to.
- (host and port will be ignored.)
- @param socket_family(int) The socket family to use with this socket.
- """
- self.host = host
- self.port = port
- self.handle = None
- self._unix_socket = unix_socket
- self._timeout = None
- self._socket_family = socket_family
+ @param host(str) The host to connect to.
+ @param port(int) The (TCP) port to connect to.
+ @param unix_socket(str) The filename of a unix socket to connect to.
+ (host and port will be ignored.)
+ @param socket_family(int) The socket family to use with this socket.
+ """
+ self.host = host
+ self.port = port
+ self.handle = None
+ self._unix_socket = unix_socket
+ self._timeout = None
+ self._socket_family = socket_family
- def setHandle(self, h):
- self.handle = h
+ def setHandle(self, h):
+ self.handle = h
- def isOpen(self):
- return self.handle is not None
+ def isOpen(self):
+ return self.handle is not None
- def setTimeout(self, ms):
- if ms is None:
- self._timeout = None
- else:
- self._timeout = ms / 1000.0
+ def setTimeout(self, ms):
+ if ms is None:
+ self._timeout = None
+ else:
+ self._timeout = ms / 1000.0
- if self.handle is not None:
- self.handle.settimeout(self._timeout)
+ if self.handle is not None:
+ self.handle.settimeout(self._timeout)
- def open(self):
- try:
- res0 = self._resolveAddr()
- for res in res0:
- self.handle = socket.socket(res[0], res[1])
- self.handle.settimeout(self._timeout)
+ def open(self):
try:
- self.handle.connect(res[4])
+ res0 = self._resolveAddr()
+ for res in res0:
+ self.handle = socket.socket(res[0], res[1])
+ self.handle.settimeout(self._timeout)
+ try:
+ self.handle.connect(res[4])
+ except socket.error as e:
+ if res is not res0[-1]:
+ continue
+ else:
+ raise e
+ break
except socket.error as e:
- if res is not res0[-1]:
- continue
- else:
- raise e
- break
- except socket.error as e:
- if self._unix_socket:
- message = 'Could not connect to socket %s' % self._unix_socket
- else:
- message = 'Could not connect to %s:%d' % (self.host, self.port)
- raise TTransportException(type=TTransportException.NOT_OPEN,
- message=message)
+ if self._unix_socket:
+ message = 'Could not connect to socket %s' % self._unix_socket
+ else:
+ message = 'Could not connect to %s:%d' % (self.host, self.port)
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message=message)
- def read(self, sz):
- try:
- buff = self.handle.recv(sz)
- except socket.error as e:
- if (e.args[0] == errno.ECONNRESET and
- (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
- # freebsd and Mach don't follow POSIX semantic of recv
- # and fail with ECONNRESET if peer performed shutdown.
- # See corresponding comment and code in TSocket::read()
- # in lib/cpp/src/transport/TSocket.cpp.
- self.close()
- # Trigger the check to raise the END_OF_FILE exception below.
- buff = ''
- else:
- raise
- if len(buff) == 0:
- raise TTransportException(type=TTransportException.END_OF_FILE,
- message='TSocket read 0 bytes')
- return buff
+ def read(self, sz):
+ try:
+ buff = self.handle.recv(sz)
+ except socket.error as e:
+ if (e.args[0] == errno.ECONNRESET and
+ (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
+ # freebsd and Mach don't follow POSIX semantic of recv
+ # and fail with ECONNRESET if peer performed shutdown.
+ # See corresponding comment and code in TSocket::read()
+ # in lib/cpp/src/transport/TSocket.cpp.
+ self.close()
+ # Trigger the check to raise the END_OF_FILE exception below.
+ buff = ''
+ else:
+ raise
+ if len(buff) == 0:
+ raise TTransportException(type=TTransportException.END_OF_FILE,
+ message='TSocket read 0 bytes')
+ return buff
- def write(self, buff):
- if not self.handle:
- raise TTransportException(type=TTransportException.NOT_OPEN,
- message='Transport not open')
- sent = 0
- have = len(buff)
- while sent < have:
- plus = self.handle.send(buff)
- if plus == 0:
- raise TTransportException(type=TTransportException.END_OF_FILE,
- message='TSocket sent 0 bytes')
- sent += plus
- buff = buff[plus:]
+ def write(self, buff):
+ if not self.handle:
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message='Transport not open')
+ sent = 0
+ have = len(buff)
+ while sent < have:
+ plus = self.handle.send(buff)
+ if plus == 0:
+ raise TTransportException(type=TTransportException.END_OF_FILE,
+ message='TSocket sent 0 bytes')
+ sent += plus
+ buff = buff[plus:]
- def flush(self):
- pass
+ def flush(self):
+ pass
class TServerSocket(TSocketBase, TServerTransportBase):
- """Socket implementation of TServerTransport base."""
+ """Socket implementation of TServerTransport base."""
- def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
- self.host = host
- self.port = port
- self._unix_socket = unix_socket
- self._socket_family = socket_family
- self.handle = None
+ def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
+ self.host = host
+ self.port = port
+ self._unix_socket = unix_socket
+ self._socket_family = socket_family
+ self.handle = None
- def listen(self):
- res0 = self._resolveAddr()
- socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
- for res in res0:
- if res[0] is socket_family or res is res0[-1]:
- break
+ def listen(self):
+ res0 = self._resolveAddr()
+ socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
+ for res in res0:
+ if res[0] is socket_family or res is res0[-1]:
+ break
- # We need remove the old unix socket if the file exists and
- # nobody is listening on it.
- if self._unix_socket:
- tmp = socket.socket(res[0], res[1])
- try:
- tmp.connect(res[4])
- except socket.error as err:
- eno, message = err.args
- if eno == errno.ECONNREFUSED:
- os.unlink(res[4])
+ # We need remove the old unix socket if the file exists and
+ # nobody is listening on it.
+ if self._unix_socket:
+ tmp = socket.socket(res[0], res[1])
+ try:
+ tmp.connect(res[4])
+ except socket.error as err:
+ eno, message = err.args
+ if eno == errno.ECONNREFUSED:
+ os.unlink(res[4])
- self.handle = socket.socket(res[0], res[1])
- self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- if hasattr(self.handle, 'settimeout'):
- self.handle.settimeout(None)
- self.handle.bind(res[4])
- self.handle.listen(128)
+ self.handle = socket.socket(res[0], res[1])
+ self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ if hasattr(self.handle, 'settimeout'):
+ self.handle.settimeout(None)
+ self.handle.bind(res[4])
+ self.handle.listen(128)
- def accept(self):
- client, addr = self.handle.accept()
- result = TSocket()
- result.setHandle(client)
- return result
+ def accept(self):
+ client, addr = self.handle.accept()
+ result = TSocket()
+ result.setHandle(client)
+ return result
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index f99b3b9..6669891 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -23,427 +23,426 @@
class TTransportException(TException):
- """Custom Transport Exception class"""
+ """Custom Transport Exception class"""
- UNKNOWN = 0
- NOT_OPEN = 1
- ALREADY_OPEN = 2
- TIMED_OUT = 3
- END_OF_FILE = 4
- NEGATIVE_SIZE = 5
- SIZE_LIMIT = 6
+ UNKNOWN = 0
+ NOT_OPEN = 1
+ ALREADY_OPEN = 2
+ TIMED_OUT = 3
+ END_OF_FILE = 4
+ NEGATIVE_SIZE = 5
+ SIZE_LIMIT = 6
- def __init__(self, type=UNKNOWN, message=None):
- TException.__init__(self, message)
- self.type = type
+ def __init__(self, type=UNKNOWN, message=None):
+ TException.__init__(self, message)
+ self.type = type
class TTransportBase(object):
- """Base class for Thrift transport layer."""
+ """Base class for Thrift transport layer."""
- def isOpen(self):
- pass
+ def isOpen(self):
+ pass
- def open(self):
- pass
+ def open(self):
+ pass
- def close(self):
- pass
+ def close(self):
+ pass
- def read(self, sz):
- pass
+ def read(self, sz):
+ pass
- def readAll(self, sz):
- buff = b''
- have = 0
- while (have < sz):
- chunk = self.read(sz - have)
- have += len(chunk)
- buff += chunk
+ def readAll(self, sz):
+ buff = b''
+ have = 0
+ while (have < sz):
+ chunk = self.read(sz - have)
+ have += len(chunk)
+ buff += chunk
- if len(chunk) == 0:
- raise EOFError()
+ if len(chunk) == 0:
+ raise EOFError()
- return buff
+ return buff
- def write(self, buf):
- pass
+ def write(self, buf):
+ pass
- def flush(self):
- pass
+ def flush(self):
+ pass
# This class should be thought of as an interface.
class CReadableTransport(object):
- """base class for transports that are readable from C"""
+ """base class for transports that are readable from C"""
- # TODO(dreiss): Think about changing this interface to allow us to use
- # a (Python, not c) StringIO instead, because it allows
- # you to write after reading.
+ # TODO(dreiss): Think about changing this interface to allow us to use
+ # a (Python, not c) StringIO instead, because it allows
+ # you to write after reading.
- # NOTE: This is a classic class, so properties will NOT work
- # correctly for setting.
- @property
- def cstringio_buf(self):
- """A cStringIO buffer that contains the current chunk we are reading."""
- pass
+ # NOTE: This is a classic class, so properties will NOT work
+ # correctly for setting.
+ @property
+ def cstringio_buf(self):
+ """A cStringIO buffer that contains the current chunk we are reading."""
+ pass
- def cstringio_refill(self, partialread, reqlen):
- """Refills cstringio_buf.
+ def cstringio_refill(self, partialread, reqlen):
+ """Refills cstringio_buf.
- Returns the currently used buffer (which can but need not be the same as
- the old cstringio_buf). partialread is what the C code has read from the
- buffer, and should be inserted into the buffer before any more reads. The
- return value must be a new, not borrowed reference. Something along the
- lines of self._buf should be fine.
+ Returns the currently used buffer (which can but need not be the same as
+ the old cstringio_buf). partialread is what the C code has read from the
+ buffer, and should be inserted into the buffer before any more reads. The
+ return value must be a new, not borrowed reference. Something along the
+ lines of self._buf should be fine.
- If reqlen bytes can't be read, throw EOFError.
- """
- pass
+ If reqlen bytes can't be read, throw EOFError.
+ """
+ pass
class TServerTransportBase(object):
- """Base class for Thrift server transports."""
+ """Base class for Thrift server transports."""
- def listen(self):
- pass
+ def listen(self):
+ pass
- def accept(self):
- pass
+ def accept(self):
+ pass
- def close(self):
- pass
+ def close(self):
+ pass
class TTransportFactoryBase(object):
- """Base class for a Transport Factory"""
+ """Base class for a Transport Factory"""
- def getTransport(self, trans):
- return trans
+ def getTransport(self, trans):
+ return trans
class TBufferedTransportFactory(object):
- """Factory transport that builds buffered transports"""
+ """Factory transport that builds buffered transports"""
- def getTransport(self, trans):
- buffered = TBufferedTransport(trans)
- return buffered
+ def getTransport(self, trans):
+ buffered = TBufferedTransport(trans)
+ return buffered
class TBufferedTransport(TTransportBase, CReadableTransport):
- """Class that wraps another transport and buffers its I/O.
+ """Class that wraps another transport and buffers its I/O.
- The implementation uses a (configurable) fixed-size read buffer
- but buffers all writes until a flush is performed.
- """
- DEFAULT_BUFFER = 4096
+ The implementation uses a (configurable) fixed-size read buffer
+ but buffers all writes until a flush is performed.
+ """
+ DEFAULT_BUFFER = 4096
- def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
- self.__trans = trans
- self.__wbuf = BufferIO()
- # Pass string argument to initialize read buffer as cStringIO.InputType
- self.__rbuf = BufferIO(b'')
- self.__rbuf_size = rbuf_size
+ def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
+ self.__trans = trans
+ self.__wbuf = BufferIO()
+ # Pass string argument to initialize read buffer as cStringIO.InputType
+ self.__rbuf = BufferIO(b'')
+ self.__rbuf_size = rbuf_size
- def isOpen(self):
- return self.__trans.isOpen()
+ def isOpen(self):
+ return self.__trans.isOpen()
- def open(self):
- return self.__trans.open()
+ def open(self):
+ return self.__trans.open()
- def close(self):
- return self.__trans.close()
+ def close(self):
+ return self.__trans.close()
- def read(self, sz):
- ret = self.__rbuf.read(sz)
- if len(ret) != 0:
- return ret
- self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
- return self.__rbuf.read(sz)
+ def read(self, sz):
+ ret = self.__rbuf.read(sz)
+ if len(ret) != 0:
+ return ret
+ self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
+ return self.__rbuf.read(sz)
- def write(self, buf):
- try:
- self.__wbuf.write(buf)
- except Exception as e:
- # on exception reset wbuf so it doesn't contain a partial function call
- self.__wbuf = BufferIO()
- raise e
- self.__wbuf.getvalue()
+ def write(self, buf):
+ try:
+ self.__wbuf.write(buf)
+ except Exception as e:
+ # on exception reset wbuf so it doesn't contain a partial function call
+ self.__wbuf = BufferIO()
+ raise e
+ self.__wbuf.getvalue()
- def flush(self):
- out = self.__wbuf.getvalue()
- # reset wbuf before write/flush to preserve state on underlying failure
- self.__wbuf = BufferIO()
- self.__trans.write(out)
- self.__trans.flush()
+ def flush(self):
+ out = self.__wbuf.getvalue()
+ # reset wbuf before write/flush to preserve state on underlying failure
+ self.__wbuf = BufferIO()
+ self.__trans.write(out)
+ self.__trans.flush()
- # Implement the CReadableTransport interface.
- @property
- def cstringio_buf(self):
- return self.__rbuf
+ # Implement the CReadableTransport interface.
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
- def cstringio_refill(self, partialread, reqlen):
- retstring = partialread
- if reqlen < self.__rbuf_size:
- # try to make a read of as much as we can.
- retstring += self.__trans.read(self.__rbuf_size)
+ def cstringio_refill(self, partialread, reqlen):
+ retstring = partialread
+ if reqlen < self.__rbuf_size:
+ # try to make a read of as much as we can.
+ retstring += self.__trans.read(self.__rbuf_size)
- # but make sure we do read reqlen bytes.
- if len(retstring) < reqlen:
- retstring += self.__trans.readAll(reqlen - len(retstring))
+ # but make sure we do read reqlen bytes.
+ if len(retstring) < reqlen:
+ retstring += self.__trans.readAll(reqlen - len(retstring))
- self.__rbuf = BufferIO(retstring)
- return self.__rbuf
+ self.__rbuf = BufferIO(retstring)
+ return self.__rbuf
class TMemoryBuffer(TTransportBase, CReadableTransport):
- """Wraps a cBytesIO object as a TTransport.
+ """Wraps a cBytesIO object as a TTransport.
- NOTE: Unlike the C++ version of this class, you cannot write to it
- then immediately read from it. If you want to read from a
- TMemoryBuffer, you must either pass a string to the constructor.
- TODO(dreiss): Make this work like the C++ version.
- """
+ NOTE: Unlike the C++ version of this class, you cannot write to it
+ then immediately read from it. If you want to read from a
+ TMemoryBuffer, you must either pass a string to the constructor.
+ TODO(dreiss): Make this work like the C++ version.
+ """
- def __init__(self, value=None):
- """value -- a value to read from for stringio
+ def __init__(self, value=None):
+ """value -- a value to read from for stringio
- If value is set, this will be a transport for reading,
- otherwise, it is for writing"""
- if value is not None:
- self._buffer = BufferIO(value)
- else:
- self._buffer = BufferIO()
+ If value is set, this will be a transport for reading,
+ otherwise, it is for writing"""
+ if value is not None:
+ self._buffer = BufferIO(value)
+ else:
+ self._buffer = BufferIO()
- def isOpen(self):
- return not self._buffer.closed
+ def isOpen(self):
+ return not self._buffer.closed
- def open(self):
- pass
+ def open(self):
+ pass
- def close(self):
- self._buffer.close()
+ def close(self):
+ self._buffer.close()
- def read(self, sz):
- return self._buffer.read(sz)
+ def read(self, sz):
+ return self._buffer.read(sz)
- def write(self, buf):
- self._buffer.write(buf)
+ def write(self, buf):
+ self._buffer.write(buf)
- def flush(self):
- pass
+ def flush(self):
+ pass
- def getvalue(self):
- return self._buffer.getvalue()
+ def getvalue(self):
+ return self._buffer.getvalue()
- # Implement the CReadableTransport interface.
- @property
- def cstringio_buf(self):
- return self._buffer
+ # Implement the CReadableTransport interface.
+ @property
+ def cstringio_buf(self):
+ return self._buffer
- def cstringio_refill(self, partialread, reqlen):
- # only one shot at reading...
- raise EOFError()
+ def cstringio_refill(self, partialread, reqlen):
+ # only one shot at reading...
+ raise EOFError()
class TFramedTransportFactory(object):
- """Factory transport that builds framed transports"""
+ """Factory transport that builds framed transports"""
- def getTransport(self, trans):
- framed = TFramedTransport(trans)
- return framed
+ def getTransport(self, trans):
+ framed = TFramedTransport(trans)
+ return framed
class TFramedTransport(TTransportBase, CReadableTransport):
- """Class that wraps another transport and frames its I/O when writing."""
+ """Class that wraps another transport and frames its I/O when writing."""
- def __init__(self, trans,):
- self.__trans = trans
- self.__rbuf = BufferIO(b'')
- self.__wbuf = BufferIO()
+ def __init__(self, trans,):
+ self.__trans = trans
+ self.__rbuf = BufferIO(b'')
+ self.__wbuf = BufferIO()
- def isOpen(self):
- return self.__trans.isOpen()
+ def isOpen(self):
+ return self.__trans.isOpen()
- def open(self):
- return self.__trans.open()
+ def open(self):
+ return self.__trans.open()
- def close(self):
- return self.__trans.close()
+ def close(self):
+ return self.__trans.close()
- def read(self, sz):
- ret = self.__rbuf.read(sz)
- if len(ret) != 0:
- return ret
+ def read(self, sz):
+ ret = self.__rbuf.read(sz)
+ if len(ret) != 0:
+ return ret
- self.readFrame()
- return self.__rbuf.read(sz)
+ self.readFrame()
+ return self.__rbuf.read(sz)
- def readFrame(self):
- buff = self.__trans.readAll(4)
- sz, = unpack('!i', buff)
- self.__rbuf = BufferIO(self.__trans.readAll(sz))
+ def readFrame(self):
+ buff = self.__trans.readAll(4)
+ sz, = unpack('!i', buff)
+ self.__rbuf = BufferIO(self.__trans.readAll(sz))
- def write(self, buf):
- self.__wbuf.write(buf)
+ def write(self, buf):
+ self.__wbuf.write(buf)
- def flush(self):
- wout = self.__wbuf.getvalue()
- wsz = len(wout)
- # reset wbuf before write/flush to preserve state on underlying failure
- self.__wbuf = BufferIO()
- # N.B.: Doing this string concatenation is WAY cheaper than making
- # two separate calls to the underlying socket object. Socket writes in
- # Python turn out to be REALLY expensive, but it seems to do a pretty
- # good job of managing string buffer operations without excessive copies
- buf = pack("!i", wsz) + wout
- self.__trans.write(buf)
- self.__trans.flush()
+ def flush(self):
+ wout = self.__wbuf.getvalue()
+ wsz = len(wout)
+ # reset wbuf before write/flush to preserve state on underlying failure
+ self.__wbuf = BufferIO()
+ # N.B.: Doing this string concatenation is WAY cheaper than making
+ # two separate calls to the underlying socket object. Socket writes in
+ # Python turn out to be REALLY expensive, but it seems to do a pretty
+ # good job of managing string buffer operations without excessive copies
+ buf = pack("!i", wsz) + wout
+ self.__trans.write(buf)
+ self.__trans.flush()
- # Implement the CReadableTransport interface.
- @property
- def cstringio_buf(self):
- return self.__rbuf
+ # Implement the CReadableTransport interface.
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
- def cstringio_refill(self, prefix, reqlen):
- # self.__rbuf will already be empty here because fastbinary doesn't
- # ask for a refill until the previous buffer is empty. Therefore,
- # we can start reading new frames immediately.
- while len(prefix) < reqlen:
- self.readFrame()
- prefix += self.__rbuf.getvalue()
- self.__rbuf = BufferIO(prefix)
- return self.__rbuf
+ def cstringio_refill(self, prefix, reqlen):
+ # self.__rbuf will already be empty here because fastbinary doesn't
+ # ask for a refill until the previous buffer is empty. Therefore,
+ # we can start reading new frames immediately.
+ while len(prefix) < reqlen:
+ self.readFrame()
+ prefix += self.__rbuf.getvalue()
+ self.__rbuf = BufferIO(prefix)
+ return self.__rbuf
class TFileObjectTransport(TTransportBase):
- """Wraps a file-like object to make it work as a Thrift transport."""
+ """Wraps a file-like object to make it work as a Thrift transport."""
- def __init__(self, fileobj):
- self.fileobj = fileobj
+ def __init__(self, fileobj):
+ self.fileobj = fileobj
- def isOpen(self):
- return True
+ def isOpen(self):
+ return True
- def close(self):
- self.fileobj.close()
+ def close(self):
+ self.fileobj.close()
- def read(self, sz):
- return self.fileobj.read(sz)
+ def read(self, sz):
+ return self.fileobj.read(sz)
- def write(self, buf):
- self.fileobj.write(buf)
+ def write(self, buf):
+ self.fileobj.write(buf)
- def flush(self):
- self.fileobj.flush()
+ def flush(self):
+ self.fileobj.flush()
class TSaslClientTransport(TTransportBase, CReadableTransport):
- """
- SASL transport
- """
-
- START = 1
- OK = 2
- BAD = 3
- ERROR = 4
- COMPLETE = 5
-
- def __init__(self, transport, host, service, mechanism='GSSAPI',
- **sasl_kwargs):
"""
- transport: an underlying transport to use, typically just a TSocket
- host: the name of the server, from a SASL perspective
- service: the name of the server's service, from a SASL perspective
- mechanism: the name of the preferred mechanism to use
-
- All other kwargs will be passed to the puresasl.client.SASLClient
- constructor.
+ SASL transport
"""
- from puresasl.client import SASLClient
+ START = 1
+ OK = 2
+ BAD = 3
+ ERROR = 4
+ COMPLETE = 5
- self.transport = transport
- self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
+ def __init__(self, transport, host, service, mechanism='GSSAPI',
+ **sasl_kwargs):
+ """
+ transport: an underlying transport to use, typically just a TSocket
+ host: the name of the server, from a SASL perspective
+ service: the name of the server's service, from a SASL perspective
+ mechanism: the name of the preferred mechanism to use
- self.__wbuf = BufferIO()
- self.__rbuf = BufferIO(b'')
+ All other kwargs will be passed to the puresasl.client.SASLClient
+ constructor.
+ """
- def open(self):
- if not self.transport.isOpen():
- self.transport.open()
+ from puresasl.client import SASLClient
- self.send_sasl_msg(self.START, self.sasl.mechanism)
- self.send_sasl_msg(self.OK, self.sasl.process())
+ self.transport = transport
+ self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
- while True:
- status, challenge = self.recv_sasl_msg()
- if status == self.OK:
- self.send_sasl_msg(self.OK, self.sasl.process(challenge))
- elif status == self.COMPLETE:
- if not self.sasl.complete:
- raise TTransportException("The server erroneously indicated "
- "that SASL negotiation was complete")
+ self.__wbuf = BufferIO()
+ self.__rbuf = BufferIO(b'')
+
+ def open(self):
+ if not self.transport.isOpen():
+ self.transport.open()
+
+ self.send_sasl_msg(self.START, self.sasl.mechanism)
+ self.send_sasl_msg(self.OK, self.sasl.process())
+
+ while True:
+ status, challenge = self.recv_sasl_msg()
+ if status == self.OK:
+ self.send_sasl_msg(self.OK, self.sasl.process(challenge))
+ elif status == self.COMPLETE:
+ if not self.sasl.complete:
+ raise TTransportException("The server erroneously indicated "
+ "that SASL negotiation was complete")
+ else:
+ break
+ else:
+ raise TTransportException("Bad SASL negotiation status: %d (%s)"
+ % (status, challenge))
+
+ def send_sasl_msg(self, status, body):
+ header = pack(">BI", status, len(body))
+ self.transport.write(header + body)
+ self.transport.flush()
+
+ def recv_sasl_msg(self):
+ header = self.transport.readAll(5)
+ status, length = unpack(">BI", header)
+ if length > 0:
+ payload = self.transport.readAll(length)
else:
- break
- else:
- raise TTransportException("Bad SASL negotiation status: %d (%s)"
- % (status, challenge))
+ payload = ""
+ return status, payload
- def send_sasl_msg(self, status, body):
- header = pack(">BI", status, len(body))
- self.transport.write(header + body)
- self.transport.flush()
+ def write(self, data):
+ self.__wbuf.write(data)
- def recv_sasl_msg(self):
- header = self.transport.readAll(5)
- status, length = unpack(">BI", header)
- if length > 0:
- payload = self.transport.readAll(length)
- else:
- payload = ""
- return status, payload
+ def flush(self):
+ data = self.__wbuf.getvalue()
+ encoded = self.sasl.wrap(data)
+ self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
+ self.transport.flush()
+ self.__wbuf = BufferIO()
- def write(self, data):
- self.__wbuf.write(data)
+ def read(self, sz):
+ ret = self.__rbuf.read(sz)
+ if len(ret) != 0:
+ return ret
- def flush(self):
- data = self.__wbuf.getvalue()
- encoded = self.sasl.wrap(data)
- self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
- self.transport.flush()
- self.__wbuf = BufferIO()
+ self._read_frame()
+ return self.__rbuf.read(sz)
- def read(self, sz):
- ret = self.__rbuf.read(sz)
- if len(ret) != 0:
- return ret
+ def _read_frame(self):
+ header = self.transport.readAll(4)
+ length, = unpack('!i', header)
+ encoded = self.transport.readAll(length)
+ self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
- self._read_frame()
- return self.__rbuf.read(sz)
+ def close(self):
+ self.sasl.dispose()
+ self.transport.close()
- def _read_frame(self):
- header = self.transport.readAll(4)
- length, = unpack('!i', header)
- encoded = self.transport.readAll(length)
- self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
+ # based on TFramedTransport
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
- def close(self):
- self.sasl.dispose()
- self.transport.close()
-
- # based on TFramedTransport
- @property
- def cstringio_buf(self):
- return self.__rbuf
-
- def cstringio_refill(self, prefix, reqlen):
- # self.__rbuf will already be empty here because fastbinary doesn't
- # ask for a refill until the previous buffer is empty. Therefore,
- # we can start reading new frames immediately.
- while len(prefix) < reqlen:
- self._read_frame()
- prefix += self.__rbuf.getvalue()
- self.__rbuf = BufferIO(prefix)
- return self.__rbuf
-
+ def cstringio_refill(self, prefix, reqlen):
+ # self.__rbuf will already be empty here because fastbinary doesn't
+ # ask for a refill until the previous buffer is empty. Therefore,
+ # we can start reading new frames immediately.
+ while len(prefix) < reqlen:
+ self._read_frame()
+ prefix += self.__rbuf.getvalue()
+ self.__rbuf = BufferIO(prefix)
+ return self.__rbuf
diff --git a/lib/py/src/transport/TTwisted.py b/lib/py/src/transport/TTwisted.py
index 6149a6c..5710b57 100644
--- a/lib/py/src/transport/TTwisted.py
+++ b/lib/py/src/transport/TTwisted.py
@@ -120,7 +120,7 @@
MAX_LENGTH = 2 ** 31 - 1
def __init__(self, client_class, iprot_factory, oprot_factory=None,
- host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
+ host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
"""
host: the name of the server, from a SASL perspective
service: the name of the server's service, from a SASL perspective
@@ -236,7 +236,7 @@
d = self.factory.processor.process(iprot, oprot)
d.addCallbacks(self.processOk, self.processError,
- callbackArgs=(tmo,))
+ callbackArgs=(tmo,))
class IThriftServerFactory(Interface):
@@ -288,7 +288,7 @@
def buildProtocol(self, addr):
p = self.protocol(self.client_class, self.iprot_factory,
- self.oprot_factory)
+ self.oprot_factory)
p.factory = self
return p
@@ -298,7 +298,7 @@
allowedMethods = ('POST',)
def __init__(self, processor, inputProtocolFactory,
- outputProtocolFactory=None):
+ outputProtocolFactory=None):
resource.Resource.__init__(self)
self.inputProtocolFactory = inputProtocolFactory
if outputProtocolFactory is None:
diff --git a/lib/py/src/transport/TZlibTransport.py b/lib/py/src/transport/TZlibTransport.py
index 7fe5853..e848579 100644
--- a/lib/py/src/transport/TZlibTransport.py
+++ b/lib/py/src/transport/TZlibTransport.py
@@ -29,220 +29,220 @@
class TZlibTransportFactory(object):
- """Factory transport that builds zlib compressed transports.
+ """Factory transport that builds zlib compressed transports.
- This factory caches the last single client/transport that it was passed
- and returns the same TZlibTransport object that was created.
+ This factory caches the last single client/transport that it was passed
+ and returns the same TZlibTransport object that was created.
- This caching means the TServer class will get the _same_ transport
- object for both input and output transports from this factory.
- (For non-threaded scenarios only, since the cache only holds one object)
+ This caching means the TServer class will get the _same_ transport
+ object for both input and output transports from this factory.
+ (For non-threaded scenarios only, since the cache only holds one object)
- The purpose of this caching is to allocate only one TZlibTransport where
- only one is really needed (since it must have separate read/write buffers),
- and makes the statistics from getCompSavings() and getCompRatio()
- easier to understand.
- """
- # class scoped cache of last transport given and zlibtransport returned
- _last_trans = None
- _last_z = None
-
- def getTransport(self, trans, compresslevel=9):
- """Wrap a transport, trans, with the TZlibTransport
- compressed transport class, returning a new
- transport to the caller.
-
- @param compresslevel: The zlib compression level, ranging
- from 0 (no compression) to 9 (best compression). Defaults to 9.
- @type compresslevel: int
-
- This method returns a TZlibTransport which wraps the
- passed C{trans} TTransport derived instance.
+ The purpose of this caching is to allocate only one TZlibTransport where
+ only one is really needed (since it must have separate read/write buffers),
+ and makes the statistics from getCompSavings() and getCompRatio()
+ easier to understand.
"""
- if trans == self._last_trans:
- return self._last_z
- ztrans = TZlibTransport(trans, compresslevel)
- self._last_trans = trans
- self._last_z = ztrans
- return ztrans
+ # class scoped cache of last transport given and zlibtransport returned
+ _last_trans = None
+ _last_z = None
+
+ def getTransport(self, trans, compresslevel=9):
+ """Wrap a transport, trans, with the TZlibTransport
+ compressed transport class, returning a new
+ transport to the caller.
+
+ @param compresslevel: The zlib compression level, ranging
+ from 0 (no compression) to 9 (best compression). Defaults to 9.
+ @type compresslevel: int
+
+ This method returns a TZlibTransport which wraps the
+ passed C{trans} TTransport derived instance.
+ """
+ if trans == self._last_trans:
+ return self._last_z
+ ztrans = TZlibTransport(trans, compresslevel)
+ self._last_trans = trans
+ self._last_z = ztrans
+ return ztrans
class TZlibTransport(TTransportBase, CReadableTransport):
- """Class that wraps a transport with zlib, compressing writes
- and decompresses reads, using the python standard
- library zlib module.
- """
- # Read buffer size for the python fastbinary C extension,
- # the TBinaryProtocolAccelerated class.
- DEFAULT_BUFFSIZE = 4096
-
- def __init__(self, trans, compresslevel=9):
- """Create a new TZlibTransport, wrapping C{trans}, another
- TTransport derived object.
-
- @param trans: A thrift transport object, i.e. a TSocket() object.
- @type trans: TTransport
- @param compresslevel: The zlib compression level, ranging
- from 0 (no compression) to 9 (best compression). Default is 9.
- @type compresslevel: int
+ """Class that wraps a transport with zlib, compressing writes
+ and decompresses reads, using the python standard
+ library zlib module.
"""
- self.__trans = trans
- self.compresslevel = compresslevel
- self.__rbuf = BufferIO()
- self.__wbuf = BufferIO()
- self._init_zlib()
- self._init_stats()
+ # Read buffer size for the python fastbinary C extension,
+ # the TBinaryProtocolAccelerated class.
+ DEFAULT_BUFFSIZE = 4096
- def _reinit_buffers(self):
- """Internal method to initialize/reset the internal StringIO objects
- for read and write buffers.
- """
- self.__rbuf = BufferIO()
- self.__wbuf = BufferIO()
+ def __init__(self, trans, compresslevel=9):
+ """Create a new TZlibTransport, wrapping C{trans}, another
+ TTransport derived object.
- def _init_stats(self):
- """Internal method to reset the internal statistics counters
- for compression ratios and bandwidth savings.
- """
- self.bytes_in = 0
- self.bytes_out = 0
- self.bytes_in_comp = 0
- self.bytes_out_comp = 0
+ @param trans: A thrift transport object, i.e. a TSocket() object.
+ @type trans: TTransport
+ @param compresslevel: The zlib compression level, ranging
+ from 0 (no compression) to 9 (best compression). Default is 9.
+ @type compresslevel: int
+ """
+ self.__trans = trans
+ self.compresslevel = compresslevel
+ self.__rbuf = BufferIO()
+ self.__wbuf = BufferIO()
+ self._init_zlib()
+ self._init_stats()
- def _init_zlib(self):
- """Internal method for setting up the zlib compression and
- decompression objects.
- """
- self._zcomp_read = zlib.decompressobj()
- self._zcomp_write = zlib.compressobj(self.compresslevel)
+ def _reinit_buffers(self):
+ """Internal method to initialize/reset the internal StringIO objects
+ for read and write buffers.
+ """
+ self.__rbuf = BufferIO()
+ self.__wbuf = BufferIO()
- def getCompRatio(self):
- """Get the current measured compression ratios (in,out) from
- this transport.
+ def _init_stats(self):
+ """Internal method to reset the internal statistics counters
+ for compression ratios and bandwidth savings.
+ """
+ self.bytes_in = 0
+ self.bytes_out = 0
+ self.bytes_in_comp = 0
+ self.bytes_out_comp = 0
- Returns a tuple of:
- (inbound_compression_ratio, outbound_compression_ratio)
+ def _init_zlib(self):
+ """Internal method for setting up the zlib compression and
+ decompression objects.
+ """
+ self._zcomp_read = zlib.decompressobj()
+ self._zcomp_write = zlib.compressobj(self.compresslevel)
- The compression ratios are computed as:
- compressed / uncompressed
+ def getCompRatio(self):
+ """Get the current measured compression ratios (in,out) from
+ this transport.
- E.g., data that compresses by 10x will have a ratio of: 0.10
- and data that compresses to half of ts original size will
- have a ratio of 0.5
+ Returns a tuple of:
+ (inbound_compression_ratio, outbound_compression_ratio)
- None is returned if no bytes have yet been processed in
- a particular direction.
- """
- r_percent, w_percent = (None, None)
- if self.bytes_in > 0:
- r_percent = self.bytes_in_comp / self.bytes_in
- if self.bytes_out > 0:
- w_percent = self.bytes_out_comp / self.bytes_out
- return (r_percent, w_percent)
+ The compression ratios are computed as:
+ compressed / uncompressed
- def getCompSavings(self):
- """Get the current count of saved bytes due to data
- compression.
+ E.g., data that compresses by 10x will have a ratio of: 0.10
+ and data that compresses to half of ts original size will
+ have a ratio of 0.5
- Returns a tuple of:
- (inbound_saved_bytes, outbound_saved_bytes)
+ None is returned if no bytes have yet been processed in
+ a particular direction.
+ """
+ r_percent, w_percent = (None, None)
+ if self.bytes_in > 0:
+ r_percent = self.bytes_in_comp / self.bytes_in
+ if self.bytes_out > 0:
+ w_percent = self.bytes_out_comp / self.bytes_out
+ return (r_percent, w_percent)
- Note: if compression is actually expanding your
- data (only likely with very tiny thrift objects), then
- the values returned will be negative.
- """
- r_saved = self.bytes_in - self.bytes_in_comp
- w_saved = self.bytes_out - self.bytes_out_comp
- return (r_saved, w_saved)
+ def getCompSavings(self):
+ """Get the current count of saved bytes due to data
+ compression.
- def isOpen(self):
- """Return the underlying transport's open status"""
- return self.__trans.isOpen()
+ Returns a tuple of:
+ (inbound_saved_bytes, outbound_saved_bytes)
- def open(self):
- """Open the underlying transport"""
- self._init_stats()
- return self.__trans.open()
+ Note: if compression is actually expanding your
+ data (only likely with very tiny thrift objects), then
+ the values returned will be negative.
+ """
+ r_saved = self.bytes_in - self.bytes_in_comp
+ w_saved = self.bytes_out - self.bytes_out_comp
+ return (r_saved, w_saved)
- def listen(self):
- """Invoke the underlying transport's listen() method"""
- self.__trans.listen()
+ def isOpen(self):
+ """Return the underlying transport's open status"""
+ return self.__trans.isOpen()
- def accept(self):
- """Accept connections on the underlying transport"""
- return self.__trans.accept()
+ def open(self):
+ """Open the underlying transport"""
+ self._init_stats()
+ return self.__trans.open()
- def close(self):
- """Close the underlying transport,"""
- self._reinit_buffers()
- self._init_zlib()
- return self.__trans.close()
+ def listen(self):
+ """Invoke the underlying transport's listen() method"""
+ self.__trans.listen()
- def read(self, sz):
- """Read up to sz bytes from the decompressed bytes buffer, and
- read from the underlying transport if the decompression
- buffer is empty.
- """
- ret = self.__rbuf.read(sz)
- if len(ret) > 0:
- return ret
- # keep reading from transport until something comes back
- while True:
- if self.readComp(sz):
- break
- ret = self.__rbuf.read(sz)
- return ret
+ def accept(self):
+ """Accept connections on the underlying transport"""
+ return self.__trans.accept()
- def readComp(self, sz):
- """Read compressed data from the underlying transport, then
- decompress it and append it to the internal StringIO read buffer
- """
- zbuf = self.__trans.read(sz)
- zbuf = self._zcomp_read.unconsumed_tail + zbuf
- buf = self._zcomp_read.decompress(zbuf)
- self.bytes_in += len(zbuf)
- self.bytes_in_comp += len(buf)
- old = self.__rbuf.read()
- self.__rbuf = BufferIO(old + buf)
- if len(old) + len(buf) == 0:
- return False
- return True
+ def close(self):
+ """Close the underlying transport,"""
+ self._reinit_buffers()
+ self._init_zlib()
+ return self.__trans.close()
- def write(self, buf):
- """Write some bytes, putting them into the internal write
- buffer for eventual compression.
- """
- self.__wbuf.write(buf)
+ def read(self, sz):
+ """Read up to sz bytes from the decompressed bytes buffer, and
+ read from the underlying transport if the decompression
+ buffer is empty.
+ """
+ ret = self.__rbuf.read(sz)
+ if len(ret) > 0:
+ return ret
+ # keep reading from transport until something comes back
+ while True:
+ if self.readComp(sz):
+ break
+ ret = self.__rbuf.read(sz)
+ return ret
- def flush(self):
- """Flush any queued up data in the write buffer and ensure the
- compression buffer is flushed out to the underlying transport
- """
- wout = self.__wbuf.getvalue()
- if len(wout) > 0:
- zbuf = self._zcomp_write.compress(wout)
- self.bytes_out += len(wout)
- self.bytes_out_comp += len(zbuf)
- else:
- zbuf = ''
- ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
- self.bytes_out_comp += len(ztail)
- if (len(zbuf) + len(ztail)) > 0:
- self.__wbuf = BufferIO()
- self.__trans.write(zbuf + ztail)
- self.__trans.flush()
+ def readComp(self, sz):
+ """Read compressed data from the underlying transport, then
+ decompress it and append it to the internal StringIO read buffer
+ """
+ zbuf = self.__trans.read(sz)
+ zbuf = self._zcomp_read.unconsumed_tail + zbuf
+ buf = self._zcomp_read.decompress(zbuf)
+ self.bytes_in += len(zbuf)
+ self.bytes_in_comp += len(buf)
+ old = self.__rbuf.read()
+ self.__rbuf = BufferIO(old + buf)
+ if len(old) + len(buf) == 0:
+ return False
+ return True
- @property
- def cstringio_buf(self):
- """Implement the CReadableTransport interface"""
- return self.__rbuf
+ def write(self, buf):
+ """Write some bytes, putting them into the internal write
+ buffer for eventual compression.
+ """
+ self.__wbuf.write(buf)
- def cstringio_refill(self, partialread, reqlen):
- """Implement the CReadableTransport interface for refill"""
- retstring = partialread
- if reqlen < self.DEFAULT_BUFFSIZE:
- retstring += self.read(self.DEFAULT_BUFFSIZE)
- while len(retstring) < reqlen:
- retstring += self.read(reqlen - len(retstring))
- self.__rbuf = BufferIO(retstring)
- return self.__rbuf
+ def flush(self):
+ """Flush any queued up data in the write buffer and ensure the
+ compression buffer is flushed out to the underlying transport
+ """
+ wout = self.__wbuf.getvalue()
+ if len(wout) > 0:
+ zbuf = self._zcomp_write.compress(wout)
+ self.bytes_out += len(wout)
+ self.bytes_out_comp += len(zbuf)
+ else:
+ zbuf = ''
+ ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
+ self.bytes_out_comp += len(ztail)
+ if (len(zbuf) + len(ztail)) > 0:
+ self.__wbuf = BufferIO()
+ self.__trans.write(zbuf + ztail)
+ self.__trans.flush()
+
+ @property
+ def cstringio_buf(self):
+ """Implement the CReadableTransport interface"""
+ return self.__rbuf
+
+ def cstringio_refill(self, partialread, reqlen):
+ """Implement the CReadableTransport interface for refill"""
+ retstring = partialread
+ if reqlen < self.DEFAULT_BUFFSIZE:
+ retstring += self.read(self.DEFAULT_BUFFSIZE)
+ while len(retstring) < reqlen:
+ retstring += self.read(reqlen - len(retstring))
+ self.__rbuf = BufferIO(retstring)
+ return self.__rbuf