THRIFT-1857 Python 3 Support
Client: Python
Patch: Thomas Bartelmess, Eevee (Alex Munroe), helgridly, Christian Verkerk, Jeroen Vlek, Nobuaki Sukegawa
This closes #213 and closes #680
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index 6cbd5f3..118a679 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -17,7 +17,6 @@
# under the License.
#
-from thrift.Thrift import *
from thrift.protocol import TBinaryProtocol
from thrift.transport import TTransport
@@ -31,8 +30,7 @@
__slots__ = []
def __repr__(self):
- L = ['%s=%r' % (key, getattr(self, key))
- for key in self.__slots__]
+ L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
def __eq__(self, other):
@@ -50,9 +48,9 @@
def read(self, iprot):
if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
- isinstance(iprot.trans, TTransport.CReadableTransport) and
- self.thrift_spec is not None and
- fastbinary is not None):
+ isinstance(iprot.trans, TTransport.CReadableTransport) and
+ self.thrift_spec is not None and
+ fastbinary is not None):
fastbinary.decode_binary(self,
iprot.trans,
(self.__class__, self.thrift_spec))
@@ -61,21 +59,13 @@
def write(self, oprot):
if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
- self.thrift_spec is not None and
- fastbinary is not None):
+ self.thrift_spec is not None and
+ fastbinary is not None):
oprot.trans.write(
fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))
return
oprot.writeStruct(self, self.thrift_spec)
-class TExceptionBase(Exception):
- # old style class so python2.4 can raise exceptions derived from this
- # This can't inherit from TBase because of that limitation.
+class TExceptionBase(TBase, Exception):
__slots__ = []
-
- __repr__ = TBase.__repr__.im_func
- __eq__ = TBase.__eq__.im_func
- __ne__ = TBase.__ne__.im_func
- read = TBase.read.im_func
- write = TBase.write.im_func
diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py
index 6fdd08c..f92f558 100644
--- a/lib/py/src/protocol/TBinaryProtocol.py
+++ b/lib/py/src/protocol/TBinaryProtocol.py
@@ -17,7 +17,7 @@
# under the License.
#
-from TProtocol import *
+from .TProtocol import TType, TProtocolBase, TProtocolException
from struct import pack, unpack
@@ -118,7 +118,7 @@
buff = pack("!d", dub)
self.trans.write(buff)
- def writeString(self, str):
+ def writeBinary(self, str):
self.writeI32(len(str))
self.trans.write(str)
@@ -217,10 +217,10 @@
val, = unpack('!d', buff)
return val
- def readString(self):
+ def readBinary(self):
len = self.readI32()
- str = self.trans.readAll(len)
- return str
+ s = self.trans.readAll(len)
+ return s
class TBinaryProtocolFactory:
diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py
index 7054ab0..b8d171e 100644
--- a/lib/py/src/protocol/TCompactProtocol.py
+++ b/lib/py/src/protocol/TCompactProtocol.py
@@ -17,9 +17,11 @@
# under the License.
#
-from TProtocol import *
+from .TProtocol import TType, TProtocolBase, TProtocolException, checkIntegerLimits
from struct import pack, unpack
+from ..compat import binary_to_str, str_to_binary
+
__all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
CLEAR = 0
@@ -62,7 +64,7 @@
else:
out.append((n & 0xff) | 0x80)
n = n >> 7
- trans.write(''.join(map(chr, out)))
+ trans.write(bytearray(out))
def readVarint(trans):
@@ -141,7 +143,7 @@
self.__writeUByte(self.PROTOCOL_ID)
self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
self.__writeVarint(seqid)
- self.__writeString(name)
+ self.__writeBinary(str_to_binary(name))
self.state = VALUE_WRITE
def writeMessageEnd(self):
@@ -254,10 +256,10 @@
def writeDouble(self, dub):
self.trans.write(pack('<d', dub))
- def __writeString(self, s):
+ def __writeBinary(self, s):
self.__writeSize(len(s))
self.trans.write(s)
- writeString = writer(__writeString)
+ writeBinary = writer(__writeBinary)
def readFieldBegin(self):
assert self.state == FIELD_READ, self.state
@@ -302,7 +304,7 @@
def __readSize(self):
result = self.__readVarint()
if result < 0:
- raise TException("Length < 0")
+ raise TProtocolException("Length < 0")
return result
def readMessageBegin(self):
@@ -310,15 +312,15 @@
proto_id = self.__readUByte()
if proto_id != self.PROTOCOL_ID:
raise TProtocolException(TProtocolException.BAD_VERSION,
- 'Bad protocol id in the message: %d' % proto_id)
+ 'Bad protocol id in the message: %d' % proto_id)
ver_type = self.__readUByte()
type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
version = ver_type & self.VERSION_MASK
if version != self.VERSION:
raise TProtocolException(TProtocolException.BAD_VERSION,
- 'Bad version: %d (expect %d)' % (version, self.VERSION))
+ 'Bad version: %d (expect %d)' % (version, self.VERSION))
seqid = self.__readVarint()
- name = self.__readString()
+ name = binary_to_str(self.__readBinary())
return (name, type, seqid)
def readMessageEnd(self):
@@ -388,10 +390,10 @@
val, = unpack('<d', buff)
return val
- def __readString(self):
+ def __readBinary(self):
len = self.__readSize()
return self.trans.readAll(len)
- readString = reader(__readString)
+ readBinary = reader(__readBinary)
def __getTType(self, byte):
return TTYPES[byte & 0x0f]
diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/protocol/TJSONProtocol.py
index 7807a6c..3ed8bcb 100644
--- a/lib/py/src/protocol/TJSONProtocol.py
+++ b/lib/py/src/protocol/TJSONProtocol.py
@@ -17,11 +17,13 @@
# under the License.
#
-from TProtocol import TType, TProtocolBase, TProtocolException, \
- checkIntegerLimits
+from .TProtocol import TType, TProtocolBase, TProtocolException, checkIntegerLimits
import base64
-import json
import math
+import sys
+
+from ..compat import str_to_binary
+
__all__ = ['TJSONProtocol',
'TJSONProtocolFactory',
@@ -30,20 +32,39 @@
VERSION = 1
-COMMA = ','
-COLON = ':'
-LBRACE = '{'
-RBRACE = '}'
-LBRACKET = '['
-RBRACKET = ']'
-QUOTE = '"'
-BACKSLASH = '\\'
-ZERO = '0'
+COMMA = b','
+COLON = b':'
+LBRACE = b'{'
+RBRACE = b'}'
+LBRACKET = b'['
+RBRACKET = b']'
+QUOTE = b'"'
+BACKSLASH = b'\\'
+ZERO = b'0'
-ESCSEQ = '\\u00'
-ESCAPE_CHAR = '"\\bfnrt/'
-ESCAPE_CHAR_VALS = ['"', '\\', '\b', '\f', '\n', '\r', '\t', '/']
-NUMERIC_CHAR = '+-.0123456789Ee'
+ESCSEQ0 = ord('\\')
+ESCSEQ1 = ord('u')
+ESCAPE_CHAR_VALS = {
+ '"': '\\"',
+ '\\': '\\\\',
+ '\b': '\\b',
+ '\f': '\\f',
+ '\n': '\\n',
+ '\r': '\\r',
+ '\t': '\\t',
+ # '/': '\\/',
+}
+ESCAPE_CHARS = {
+ b'"': '"',
+ b'\\': '\\',
+ b'b': '\b',
+ b'f': '\f',
+ b'n': '\n',
+ b'r': '\r',
+ b't': '\t',
+ b'/': '/',
+}
+NUMERIC_CHAR = b'+-.0123456789Ee'
CTYPES = {TType.BOOL: 'tf',
TType.BYTE: 'i8',
@@ -70,7 +91,7 @@
def doIO(self, function):
pass
-
+
def write(self):
pass
@@ -85,7 +106,7 @@
class JSONListContext(JSONBaseContext):
-
+
def doIO(self, function):
if self.first is True:
self.first = False
@@ -100,7 +121,7 @@
class JSONPairContext(JSONBaseContext):
-
+
def __init__(self, protocol):
super(JSONPairContext, self).__init__(protocol)
self.colon = True
@@ -146,6 +167,7 @@
self.hasData = True
return self.data
+
class TJSONProtocolBase(TProtocolBase):
def __init__(self, trans):
@@ -174,14 +196,22 @@
def writeJSONString(self, string):
self.context.write()
- self.trans.write(json.dumps(string, ensure_ascii=False))
+ json_str = ['"']
+ for s in string:
+ escaped = ESCAPE_CHAR_VALS.get(s, s)
+ json_str.append(escaped)
+ json_str.append('"')
+ self.trans.write(str_to_binary(''.join(json_str)))
def writeJSONNumber(self, number, formatter='{}'):
self.context.write()
- jsNumber = formatter.format(number)
+ jsNumber = str(formatter.format(number)).encode('ascii')
if self.context.escapeNum():
- jsNumber = "%s%s%s" % (QUOTE, jsNumber, QUOTE)
- self.trans.write(jsNumber)
+ self.trans.write(QUOTE)
+ self.trans.write(jsNumber)
+ self.trans.write(QUOTE)
+ else:
+ self.trans.write(jsNumber)
def writeJSONBase64(self, binary):
self.context.write()
@@ -222,18 +252,23 @@
character = self.reader.read()
if character == QUOTE:
break
- if character == ESCSEQ[0]:
+ if ord(character) == ESCSEQ0:
character = self.reader.read()
- if character == ESCSEQ[1]:
- self.readJSONSyntaxChar(ZERO)
- self.readJSONSyntaxChar(ZERO)
- character = json.JSONDecoder().decode('"\u00%s"' % self.trans.read(2))
+ if ord(character) == ESCSEQ1:
+ character = chr(int(self.trans.read(4)))
else:
- off = ESCAPE_CHAR.find(character)
- if off == -1:
+ if character not in ESCAPE_CHARS:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Expected control char")
- character = ESCAPE_CHAR_VALS[off]
+ character = ESCAPE_CHARS[character]
+ elif character in ESCAPE_CHAR_VALS:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Unescaped control char")
+ elif sys.version_info[0] > 2:
+ utf8_bytes = bytearray([ord(character)])
+ while ord(self.reader.peek()) >= 0x80:
+ utf8_bytes.append(ord(self.reader.read()))
+ character = utf8_bytes.decode('utf8')
string.append(character)
return ''.join(string)
@@ -251,7 +286,7 @@
if self.isJSONNumeric(character) is False:
break
numeric.append(self.reader.read())
- return ''.join(numeric)
+ return b''.join(numeric).decode('ascii')
def readJSONInteger(self):
self.context.read()
@@ -267,12 +302,12 @@
def readJSONDouble(self):
self.context.read()
if self.reader.peek() == QUOTE:
- string = self.readJSONString(True)
+ string = self.readJSONString(True)
try:
double = float(string)
if (self.context.escapeNum is False and
- not math.isinf(double) and
- not math.isnan(double)):
+ not math.isinf(double) and
+ not math.isnan(double)):
raise TProtocolException(TProtocolException.INVALID_DATA,
"Numeric data unexpectedly quoted")
return double
@@ -430,12 +465,12 @@
def writeMapEnd(self):
self.writeJSONObjectEnd()
self.writeJSONArrayEnd()
-
+
def writeListBegin(self, etype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[etype])
self.writeJSONNumber(size)
-
+
def writeListEnd(self):
self.writeJSONArrayEnd()
@@ -443,7 +478,7 @@
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[etype])
self.writeJSONNumber(size)
-
+
def writeSetEnd(self):
self.writeJSONArrayEnd()
@@ -472,7 +507,7 @@
def writeString(self, string):
self.writeJSONString(string)
-
+
def writeBinary(self, binary):
self.writeJSONBase64(binary)
@@ -485,49 +520,49 @@
class TSimpleJSONProtocol(TJSONProtocolBase):
"""Simple, readable, write-only JSON protocol.
-
+
Useful for interacting with scripting languages.
"""
def readMessageBegin(self):
raise NotImplementedError()
-
+
def readMessageEnd(self):
raise NotImplementedError()
-
+
def readStructBegin(self):
raise NotImplementedError()
-
+
def readStructEnd(self):
raise NotImplementedError()
-
+
def writeMessageBegin(self, name, request_type, seqid):
self.resetWriteContext()
-
+
def writeMessageEnd(self):
pass
-
+
def writeStructBegin(self, name):
self.writeJSONObjectStart()
-
+
def writeStructEnd(self):
self.writeJSONObjectEnd()
-
+
def writeFieldBegin(self, name, ttype, fid):
self.writeJSONString(name)
-
+
def writeFieldEnd(self):
pass
-
+
def writeMapBegin(self, ktype, vtype, size):
self.writeJSONObjectStart()
-
+
def writeMapEnd(self):
self.writeJSONObjectEnd()
-
+
def _writeCollectionBegin(self, etype, size):
self.writeJSONArrayStart()
-
+
def _writeCollectionEnd(self):
self.writeJSONArrayEnd()
writeListBegin = _writeCollectionBegin
@@ -550,16 +585,16 @@
def writeI64(self, i64):
checkIntegerLimits(i64, 64)
self.writeJSONNumber(i64)
-
+
def writeBool(self, boolean):
self.writeJSONNumber(1 if boolean is True else 0)
def writeDouble(self, dbl):
self.writeJSONNumber(dbl)
-
+
def writeString(self, string):
self.writeJSONString(string)
-
+
def writeBinary(self, binary):
self.writeJSONBase64(binary)
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index 311a635..22339c0 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -17,7 +17,10 @@
# under the License.
#
-from thrift.Thrift import *
+from thrift.Thrift import TException, TType
+import six
+
+from ..compat import binary_to_str, str_to_binary
class TProtocolException(TException):
@@ -100,6 +103,9 @@
pass
def writeString(self, str_val):
+ self.writeBinary(str_to_binary(str_val))
+
+ def writeBinary(self, str_val):
pass
def readMessageBegin(self):
@@ -157,6 +163,9 @@
pass
def readString(self):
+ return binary_to_str(self.readBinary())
+
+ def readBinary(self):
pass
def skip(self, ttype):
@@ -187,18 +196,18 @@
self.readStructEnd()
elif ttype == TType.MAP:
(ktype, vtype, size) = self.readMapBegin()
- for i in xrange(size):
+ for i in range(size):
self.skip(ktype)
self.skip(vtype)
self.readMapEnd()
elif ttype == TType.SET:
(etype, size) = self.readSetBegin()
- for i in xrange(size):
+ for i in range(size):
self.skip(etype)
self.readSetEnd()
elif ttype == TType.LIST:
(etype, size) = self.readListBegin()
- for i in xrange(size):
+ for i in range(size):
self.skip(etype)
self.readListEnd()
@@ -246,13 +255,13 @@
(list_type, list_len) = self.readListBegin()
if tspec is None:
# list values are simple types
- for idx in xrange(list_len):
+ for idx in range(list_len):
results.append(reader())
else:
# this is like an inlined readFieldByTType
container_reader = self._TTYPE_HANDLERS[list_type][0]
val_reader = getattr(self, container_reader)
- for idx in xrange(list_len):
+ for idx in range(list_len):
val = val_reader(tspec)
results.append(val)
self.readListEnd()
@@ -266,12 +275,12 @@
(set_type, set_len) = self.readSetBegin()
if tspec is None:
# set members are simple types
- for idx in xrange(set_len):
+ for idx in range(set_len):
results.add(reader())
else:
container_reader = self._TTYPE_HANDLERS[set_type][0]
val_reader = getattr(self, container_reader)
- for idx in xrange(set_len):
+ for idx in range(set_len):
results.add(val_reader(tspec))
self.readSetEnd()
return results
@@ -292,7 +301,7 @@
key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0])
val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0])
# list values are simple types
- for idx in xrange(map_len):
+ for idx in range(map_len):
if key_spec is None:
k_val = key_reader()
else:
@@ -363,7 +372,7 @@
k_writer = getattr(self, ktype_name)
v_writer = getattr(self, vtype_name)
self.writeMapBegin(k_type, v_type, len(val))
- for m_key, m_val in val.iteritems():
+ for m_key, m_val in six.iteritems(val):
if not k_is_container:
k_writer(m_key)
else:
@@ -402,6 +411,7 @@
else:
writer(val)
+
def checkIntegerLimits(i, bits):
if bits == 8 and (i < -128 or i > 127):
raise TProtocolException(TProtocolException.INVALID_DATA,
@@ -416,6 +426,7 @@
raise TProtocolException(TProtocolException.INVALID_DATA,
"i64 requires -9223372036854775808 <= number <= 9223372036854775807")
+
class TProtocolFactory:
def getProtocol(self, trans):
pass