THRIFT-3532 Add configurable string and container read size limit to Python protocols
This closes #787
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index 4f71e11..d106f4e 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -53,7 +53,9 @@
fastbinary is not None):
fastbinary.decode_binary(self,
iprot.trans,
- (self.__class__, self.thrift_spec))
+ (self.__class__, self.thrift_spec),
+ iprot.string_length_limit,
+ iprot.container_length_limit)
return
iprot.readStruct(self, self.thrift_spec)
@@ -90,5 +92,7 @@
self = cls()
return fastbinary.decode_binary(None,
iprot.trans,
- (self.__class__, self.thrift_spec))
+ (self.__class__, self.thrift_spec),
+ iprot.string_length_limit,
+ iprot.container_length_limit)
return iprot.readStruct(cls, cls.thrift_spec, True)
diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py
index 43cb5a4..db4ea31 100644
--- a/lib/py/src/protocol/TBinaryProtocol.py
+++ b/lib/py/src/protocol/TBinaryProtocol.py
@@ -36,10 +36,18 @@
TYPE_MASK = 0x000000ff
- def __init__(self, trans, strictRead=False, strictWrite=True):
+ def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs):
TProtocolBase.__init__(self, trans)
self.strictRead = strictRead
self.strictWrite = strictWrite
+ self.string_length_limit = kwargs.get('string_length_limit', None)
+ self.container_length_limit = kwargs.get('container_length_limit', None)
+
+ def _check_string_length(self, length):
+ self._check_length(self.string_length_limit, length)
+
+ def _check_container_length(self, length):
+ self._check_length(self.container_length_limit, length)
def writeMessageBegin(self, name, type, seqid):
if self.strictWrite:
@@ -165,6 +173,7 @@
ktype = self.readByte()
vtype = self.readByte()
size = self.readI32()
+ self._check_container_length(size)
return (ktype, vtype, size)
def readMapEnd(self):
@@ -173,6 +182,7 @@
def readListBegin(self):
etype = self.readByte()
size = self.readI32()
+ self._check_container_length(size)
return (etype, size)
def readListEnd(self):
@@ -181,6 +191,7 @@
def readSetBegin(self):
etype = self.readByte()
size = self.readI32()
+ self._check_container_length(size)
return (etype, size)
def readSetEnd(self):
@@ -218,18 +229,23 @@
return val
def readBinary(self):
- len = self.readI32()
- s = self.trans.readAll(len)
+ size = self.readI32()
+ self._check_string_length(size)
+ s = self.trans.readAll(size)
return s
class TBinaryProtocolFactory(object):
- def __init__(self, strictRead=False, strictWrite=True):
+ def __init__(self, strictRead=False, strictWrite=True, **kwargs):
self.strictRead = strictRead
self.strictWrite = strictWrite
+ self.string_length_limit = kwargs.get('string_length_limit', None)
+ self.container_length_limit = kwargs.get('container_length_limit', None)
def getProtocol(self, trans):
- prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite)
+ prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite,
+ string_length_limit=self.string_length_limit,
+ container_length_limit=self.container_length_limit)
return prot
@@ -256,5 +272,14 @@
class TBinaryProtocolAcceleratedFactory(object):
+ def __init__(self,
+ string_length_limit=None,
+ container_length_limit=None):
+ self.string_length_limit = string_length_limit
+ self.container_length_limit = container_length_limit
+
def getProtocol(self, trans):
- return TBinaryProtocolAccelerated(trans)
+ return TBinaryProtocolAccelerated(
+ trans,
+ string_length_limit=self.string_length_limit,
+ container_length_limit=self.container_length_limit)
diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py
index 6023066..3d9c0e6 100644
--- a/lib/py/src/protocol/TCompactProtocol.py
+++ b/lib/py/src/protocol/TCompactProtocol.py
@@ -126,7 +126,9 @@
TYPE_BITS = 0x07
TYPE_SHIFT_AMOUNT = 5
- def __init__(self, trans):
+ def __init__(self, trans,
+ string_length_limit=None,
+ container_length_limit=None):
TProtocolBase.__init__(self, trans)
self.state = CLEAR
self.__last_fid = 0
@@ -134,6 +136,14 @@
self.__bool_value = None
self.__structs = []
self.__containers = []
+ self.string_length_limit = string_length_limit
+ self.container_length_limit = container_length_limit
+
+ def _check_string_length(self, length):
+ self._check_length(self.string_length_limit, length)
+
+ def _check_container_length(self, length):
+ self._check_length(self.container_length_limit, length)
def __writeVarint(self, n):
writeVarint(self.trans, n)
@@ -344,6 +354,7 @@
type = self.__getTType(size_type)
if size == 15:
size = self.__readSize()
+ self._check_container_length(size)
self.__containers.append(self.state)
self.state = CONTAINER_READ
return type, size
@@ -353,6 +364,7 @@
def readMapBegin(self):
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
size = self.__readSize()
+ self._check_container_length(size)
types = 0
if size > 0:
types = self.__readUByte()
@@ -391,8 +403,9 @@
return val
def __readBinary(self):
- len = self.__readSize()
- return self.trans.readAll(len)
+ size = self.__readSize()
+ self._check_string_length(size)
+ return self.trans.readAll(size)
readBinary = reader(__readBinary)
def __getTType(self, byte):
@@ -400,8 +413,13 @@
class TCompactProtocolFactory(object):
- def __init__(self):
- pass
+ def __init__(self,
+ string_length_limit=None,
+ container_length_limit=None):
+ self.string_length_limit = string_length_limit
+ self.container_length_limit = container_length_limit
def getProtocol(self, trans):
- return TCompactProtocol(trans)
+ return TCompactProtocol(trans,
+ self.string_length_limit,
+ self.container_length_limit)
diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/protocol/TJSONProtocol.py
index 3612e91..f9e65fb 100644
--- a/lib/py/src/protocol/TJSONProtocol.py
+++ b/lib/py/src/protocol/TJSONProtocol.py
@@ -175,6 +175,15 @@
self.resetWriteContext()
self.resetReadContext()
+ # We don't have length limit implementation for JSON protocols
+ @property
+ def string_length_limit(senf):
+ return None
+
+ @property
+ def container_length_limit(senf):
+ return None
+
def resetWriteContext(self):
self.context = JSONBaseContext(self)
self.contextStack = [self.context]
@@ -560,10 +569,17 @@
class TJSONProtocolFactory(object):
-
def getProtocol(self, trans):
return TJSONProtocol(trans)
+ @property
+ def string_length_limit(senf):
+ return None
+
+ @property
+ def container_length_limit(senf):
+ return None
+
class TSimpleJSONProtocol(TJSONProtocolBase):
"""Simple, readable, write-only JSON protocol.
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index 450e0fa..d9aa2e8 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -18,6 +18,7 @@
#
from thrift.Thrift import TException, TType, TFrozenDict
+from thrift.transport.TTransport import TTransportException
from ..compat import binary_to_str, str_to_binary
import six
@@ -48,6 +49,15 @@
def __init__(self, trans):
self.trans = trans
+ @staticmethod
+ def _check_length(limit, length):
+ if length < 0:
+ raise TTransportException(TTransportException.NEGATIVE_SIZE,
+ 'Negative length: %d' % length)
+ if limit is not None and length > limit:
+ raise TTransportException(TTransportException.SIZE_LIMIT,
+ 'Length exceeded max allowed: %d' % limit)
+
def writeMessageBegin(self, name, ttype, seqid):
pass
diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c
index 337201b..da57c85 100644
--- a/lib/py/src/protocol/fastbinary.c
+++ b/lib/py/src/protocol/fastbinary.c
@@ -189,22 +189,19 @@
return false;
}
if (!CHECK_RANGE(len, 0, INT32_MAX)) {
- PyErr_SetString(PyExc_OverflowError, "string size out of range");
+ PyErr_SetString(PyExc_OverflowError, "size out of range: exceeded INT32_MAX");
return false;
}
return true;
}
-#define MAX_LIST_SIZE (10000)
-
static inline bool
-check_list_length(Py_ssize_t len) {
- // error from getting the int
- if (INT_CONV_ERROR_OCCURRED(len)) {
+check_length_limit(Py_ssize_t len, long limit) {
+ if (!check_ssize_t_32(len)) {
return false;
}
- if (!CHECK_RANGE(len, 0, MAX_LIST_SIZE)) {
- PyErr_SetString(PyExc_OverflowError, "list size out of the sanity limit (10000 items max)");
+ if (len > limit) {
+ PyErr_Format(PyExc_OverflowError, "size exceeded specified limit: %d", limit);
return false;
}
return true;
@@ -891,10 +888,10 @@
/* --- HELPER FUNCTION FOR DECODE_VAL --- */
static PyObject*
-decode_val(DecodeBuffer* input, TType type, PyObject* typeargs);
+decode_val(DecodeBuffer* input, TType type, PyObject* typeargs, long string_limit, long container_limit);
static PyObject*
-decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* spec_seq) {
+decode_struct(DecodeBuffer* input, PyObject* output, PyObject* klass, PyObject* spec_seq, long string_limit, long container_limit) {
int spec_seq_len = PyTuple_Size(spec_seq);
bool immutable = output == Py_None;
PyObject* kwargs = NULL;
@@ -954,7 +951,7 @@
}
}
- fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs);
+ fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs, string_limit, container_limit);
if (fieldval == NULL) {
goto error;
}
@@ -991,7 +988,7 @@
// Returns a new reference.
static PyObject*
-decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) {
+decode_val(DecodeBuffer* input, TType type, PyObject* typeargs, long string_limit, long container_limit) {
switch (type) {
case T_BOOL: {
@@ -1059,6 +1056,9 @@
if (!readBytes(input, &buf, len)) {
return NULL;
}
+ if (!check_length_limit(len, string_limit)) {
+ return NULL;
+ }
if (is_utf8(typeargs))
return PyUnicode_DecodeUTF8(buf, len, 0);
@@ -1083,7 +1083,7 @@
}
len = readI32(input);
- if (!check_list_length(len)) {
+ if (!check_length_limit(len, container_limit)) {
return NULL;
}
@@ -1094,7 +1094,7 @@
}
for (i = 0; i < len; i++) {
- PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs);
+ PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs, string_limit, container_limit);
if (!item) {
Py_DECREF(ret);
return NULL;
@@ -1135,8 +1135,8 @@
}
len = readI32(input);
- if (!check_ssize_t_32(len)) {
- return false;
+ if (!check_length_limit(len, container_limit)) {
+ return NULL;
}
ret = PyDict_New();
@@ -1147,11 +1147,11 @@
for (i = 0; i < len; i++) {
PyObject* k = NULL;
PyObject* v = NULL;
- k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs);
+ k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs, string_limit, container_limit);
if (k == NULL) {
goto loop_error;
}
- v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs);
+ v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs, string_limit, container_limit);
if (v == NULL) {
goto loop_error;
}
@@ -1199,7 +1199,7 @@
return NULL;
}
- return decode_struct(input, Py_None, parsedargs.klass, parsedargs.spec);
+ return decode_struct(input, Py_None, parsedargs.klass, parsedargs.spec, string_limit, container_limit);
}
case T_STOP:
@@ -1213,6 +1213,15 @@
}
}
+static long as_long_or(PyObject* value, long default_value) {
+ long v = PyInt_AsLong(value);
+ if (INT_CONV_ERROR_OCCURRED(v)) {
+ PyErr_Clear();
+ return default_value;
+ }
+ return v;
+}
+
/* --- TOP-LEVEL WRAPPER FOR INPUT -- */
@@ -1222,12 +1231,18 @@
PyObject* transport = NULL;
PyObject* typeargs = NULL;
StructTypeArgs parsedargs;
+ PyObject* string_limit_obj = NULL;
+ PyObject* container_limit_obj = NULL;
+ long string_limit = 0;
+ long container_limit = 0;
DecodeBuffer input = {0, 0};
PyObject* ret = NULL;
- if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) {
+ if (!PyArg_ParseTuple(args, "OOOOO", &output_obj, &transport, &typeargs, &string_limit_obj, &container_limit_obj)) {
return NULL;
}
+ string_limit = as_long_or(string_limit_obj, INT32_MAX);
+ container_limit = as_long_or(container_limit_obj, INT32_MAX);
if (!parse_struct_args(&parsedargs, typeargs)) {
return NULL;
@@ -1237,7 +1252,7 @@
return NULL;
}
- ret = decode_struct(&input, output_obj, parsedargs.klass, parsedargs.spec);
+ ret = decode_struct(&input, output_obj, parsedargs.klass, parsedargs.spec, string_limit, container_limit);
free_decodebuf(&input);
return ret;
}
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index 8e2da8d..f99b3b9 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -30,6 +30,8 @@
ALREADY_OPEN = 2
TIMED_OUT = 3
END_OF_FILE = 4
+ NEGATIVE_SIZE = 5
+ SIZE_LIMIT = 6
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)