THRIFT-3231 CPP: Limit recursion depth to 64
Client: cpp
Patch: Ben Craig <bencraig@apache.org>
diff --git a/compiler/cpp/src/generate/t_cpp_generator.cc b/compiler/cpp/src/generate/t_cpp_generator.cc
index 426434f..aed3935 100644
--- a/compiler/cpp/src/generate/t_cpp_generator.cc
+++ b/compiler/cpp/src/generate/t_cpp_generator.cc
@@ -1367,10 +1367,16 @@
vector<t_field*>::const_iterator f_iter;
// Declare stack tmp variables
- out << endl << indent() << "uint32_t xfer = 0;" << endl << indent() << "std::string fname;"
- << endl << indent() << "::apache::thrift::protocol::TType ftype;" << endl << indent()
- << "int16_t fid;" << endl << endl << indent() << "xfer += iprot->readStructBegin(fname);"
- << endl << endl << indent() << "using ::apache::thrift::protocol::TProtocolException;" << endl
+ out << endl
+ << indent() << "apache::thrift::protocol::TRecursionTracker tracker(*iprot);" << endl
+ << indent() << "uint32_t xfer = 0;" << endl
+ << indent() << "std::string fname;" << endl
+ << indent() << "::apache::thrift::protocol::TType ftype;" << endl
+ << indent() << "int16_t fid;" << endl
+ << endl
+ << indent() << "xfer += iprot->readStructBegin(fname);" << endl
+ << endl
+ << indent() << "using ::apache::thrift::protocol::TProtocolException;" << endl
<< endl;
// Required variables aren't in __isset, so we need tmp vars to check them.
@@ -1486,7 +1492,7 @@
out << indent() << "uint32_t xfer = 0;" << endl;
- indent(out) << "oprot->incrementRecursionDepth();" << endl;
+ indent(out) << "apache::thrift::protocol::TRecursionTracker tracker(*oprot);" << endl;
indent(out) << "xfer += oprot->writeStructBegin(\"" << name << "\");" << endl;
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
@@ -1522,7 +1528,7 @@
// Write the struct map
out << indent() << "xfer += oprot->writeFieldStop();" << endl << indent()
<< "xfer += oprot->writeStructEnd();" << endl << indent()
- << "oprot->decrementRecursionDepth();" << endl << indent() << "return xfer;" << endl;
+ << "return xfer;" << endl;
indent_down();
indent(out) << "}" << endl << endl;
diff --git a/lib/cpp/CMakeLists.txt b/lib/cpp/CMakeLists.txt
index b97e356..bab2e84 100755
--- a/lib/cpp/CMakeLists.txt
+++ b/lib/cpp/CMakeLists.txt
@@ -39,11 +39,12 @@
src/thrift/concurrency/TimerManager.cpp
src/thrift/concurrency/Util.cpp
src/thrift/processor/PeekProcessor.cpp
+ src/thrift/protocol/TBase64Utils.cpp
src/thrift/protocol/TDebugProtocol.cpp
src/thrift/protocol/TDenseProtocol.cpp
src/thrift/protocol/TJSONProtocol.cpp
- src/thrift/protocol/TBase64Utils.cpp
src/thrift/protocol/TMultiplexedProtocol.cpp
+ src/thrift/protocol/TProtocol.cpp
src/thrift/transport/TTransportException.cpp
src/thrift/transport/TFDTransport.cpp
src/thrift/transport/TSimpleFileTransport.cpp
diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am
index 9156577..0ecbeee 100755
--- a/lib/cpp/Makefile.am
+++ b/lib/cpp/Makefile.am
@@ -75,6 +75,7 @@
src/thrift/protocol/TJSONProtocol.cpp \
src/thrift/protocol/TBase64Utils.cpp \
src/thrift/protocol/TMultiplexedProtocol.cpp \
+ src/thrift/protocol/TProtocol.cpp \
src/thrift/transport/TTransportException.cpp \
src/thrift/transport/TFDTransport.cpp \
src/thrift/transport/TFileTransport.cpp \
diff --git a/lib/cpp/src/thrift/protocol/TProtocol.cpp b/lib/cpp/src/thrift/protocol/TProtocol.cpp
new file mode 100644
index 0000000..c378aca
--- /dev/null
+++ b/lib/cpp/src/thrift/protocol/TProtocol.cpp
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <thrift/protocol/TProtocol.h>
+
+namespace apache {
+namespace thrift {
+namespace protocol {
+
+TProtocol::~TProtocol() {}
+uint32_t TProtocol::skip_virt(TType type) {
+ return ::apache::thrift::protocol::skip(*this, type);
+}
+
+TProtocolFactory::~TProtocolFactory() {}
+
+}}} // apache::thrift::protocol
diff --git a/lib/cpp/src/thrift/protocol/TProtocol.h b/lib/cpp/src/thrift/protocol/TProtocol.h
index 9eec1ee..1aa2122 100644
--- a/lib/cpp/src/thrift/protocol/TProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TProtocol.h
@@ -33,6 +33,7 @@
#include <string>
#include <map>
#include <vector>
+#include <climits>
// Use this to get around strict aliasing rules.
// For example, uint64_t i = bitwise_cast<uint64_t>(returns_double());
@@ -199,105 +200,6 @@
T_ONEWAY = 4
};
-
-/**
- * Helper template for implementing TProtocol::skip().
- *
- * Templatized to avoid having to make virtual function calls.
- */
-template <class Protocol_>
-uint32_t skip(Protocol_& prot, TType type) {
- switch (type) {
- case T_BOOL: {
- bool boolv;
- return prot.readBool(boolv);
- }
- case T_BYTE: {
- int8_t bytev;
- return prot.readByte(bytev);
- }
- case T_I16: {
- int16_t i16;
- return prot.readI16(i16);
- }
- case T_I32: {
- int32_t i32;
- return prot.readI32(i32);
- }
- case T_I64: {
- int64_t i64;
- return prot.readI64(i64);
- }
- case T_DOUBLE: {
- double dub;
- return prot.readDouble(dub);
- }
- case T_STRING: {
- std::string str;
- return prot.readBinary(str);
- }
- case T_STRUCT: {
- uint32_t result = 0;
- std::string name;
- int16_t fid;
- TType ftype;
- result += prot.readStructBegin(name);
- while (true) {
- result += prot.readFieldBegin(name, ftype, fid);
- if (ftype == T_STOP) {
- break;
- }
- result += skip(prot, ftype);
- result += prot.readFieldEnd();
- }
- result += prot.readStructEnd();
- return result;
- }
- case T_MAP: {
- uint32_t result = 0;
- TType keyType;
- TType valType;
- uint32_t i, size;
- result += prot.readMapBegin(keyType, valType, size);
- for (i = 0; i < size; i++) {
- result += skip(prot, keyType);
- result += skip(prot, valType);
- }
- result += prot.readMapEnd();
- return result;
- }
- case T_SET: {
- uint32_t result = 0;
- TType elemType;
- uint32_t i, size;
- result += prot.readSetBegin(elemType, size);
- for (i = 0; i < size; i++) {
- result += skip(prot, elemType);
- }
- result += prot.readSetEnd();
- return result;
- }
- case T_LIST: {
- uint32_t result = 0;
- TType elemType;
- uint32_t i, size;
- result += prot.readListBegin(elemType, size);
- for (i = 0; i < size; i++) {
- result += skip(prot, elemType);
- }
- result += prot.readListEnd();
- return result;
- }
- case T_STOP:
- case T_VOID:
- case T_U64:
- case T_UTF8:
- case T_UTF16:
- break;
- }
- return 0;
-}
-
static const uint32_t DEFAULT_RECURSION_LIMIT = 64;
/**
@@ -316,7 +218,7 @@
*/
class TProtocol {
public:
- virtual ~TProtocol() {}
+ virtual ~TProtocol();
/**
* Writing functions.
@@ -641,7 +543,7 @@
T_VIRTUAL_CALL();
return skip_virt(type);
}
- virtual uint32_t skip_virt(TType type) { return ::apache::thrift::protocol::skip(*this, type); }
+ virtual uint32_t skip_virt(TType type);
inline boost::shared_ptr<TTransport> getTransport() { return ptrans_; }
@@ -657,10 +559,13 @@
}
void decrementRecursionDepth() { --recursion_depth_; }
+ uint32_t getRecursionLimit() const {return recursion_limit_;}
+ void setRecurisionLimit(uint32_t depth) {recursion_limit_ = depth;}
protected:
TProtocol(boost::shared_ptr<TTransport> ptrans)
- : ptrans_(ptrans), recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT) {}
+ : ptrans_(ptrans), recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT)
+ {}
boost::shared_ptr<TTransport> ptrans_;
@@ -677,7 +582,7 @@
public:
TProtocolFactory() {}
- virtual ~TProtocolFactory() {}
+ virtual ~TProtocolFactory();
virtual boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) = 0;
};
@@ -712,8 +617,116 @@
static uint64_t fromWire64(uint64_t x) {return letohll(x);}
};
+struct TRecursionTracker {
+ TProtocol &prot_;
+ TRecursionTracker(TProtocol &prot) : prot_(prot) {
+ prot_.incrementRecursionDepth();
+ }
+ ~TRecursionTracker() {
+ prot_.decrementRecursionDepth();
+ }
+};
+
+/**
+ * Helper template for implementing TProtocol::skip().
+ *
+ * Templatized to avoid having to make virtual function calls.
+ */
+template <class Protocol_>
+uint32_t skip(Protocol_& prot, TType type) {
+ TRecursionTracker tracker(prot);
+
+ switch (type) {
+ case T_BOOL: {
+ bool boolv;
+ return prot.readBool(boolv);
+ }
+ case T_BYTE: {
+ int8_t bytev;
+ return prot.readByte(bytev);
+ }
+ case T_I16: {
+ int16_t i16;
+ return prot.readI16(i16);
+ }
+ case T_I32: {
+ int32_t i32;
+ return prot.readI32(i32);
+ }
+ case T_I64: {
+ int64_t i64;
+ return prot.readI64(i64);
+ }
+ case T_DOUBLE: {
+ double dub;
+ return prot.readDouble(dub);
+ }
+ case T_STRING: {
+ std::string str;
+ return prot.readBinary(str);
+ }
+ case T_STRUCT: {
+ uint32_t result = 0;
+ std::string name;
+ int16_t fid;
+ TType ftype;
+ result += prot.readStructBegin(name);
+ while (true) {
+ result += prot.readFieldBegin(name, ftype, fid);
+ if (ftype == T_STOP) {
+ break;
+ }
+ result += skip(prot, ftype);
+ result += prot.readFieldEnd();
+ }
+ result += prot.readStructEnd();
+ return result;
+ }
+ case T_MAP: {
+ uint32_t result = 0;
+ TType keyType;
+ TType valType;
+ uint32_t i, size;
+ result += prot.readMapBegin(keyType, valType, size);
+ for (i = 0; i < size; i++) {
+ result += skip(prot, keyType);
+ result += skip(prot, valType);
+ }
+ result += prot.readMapEnd();
+ return result;
+ }
+ case T_SET: {
+ uint32_t result = 0;
+ TType elemType;
+ uint32_t i, size;
+ result += prot.readSetBegin(elemType, size);
+ for (i = 0; i < size; i++) {
+ result += skip(prot, elemType);
+ }
+ result += prot.readSetEnd();
+ return result;
+ }
+ case T_LIST: {
+ uint32_t result = 0;
+ TType elemType;
+ uint32_t i, size;
+ result += prot.readListBegin(elemType, size);
+ for (i = 0; i < size; i++) {
+ result += skip(prot, elemType);
+ }
+ result += prot.readListEnd();
+ return result;
+ }
+ case T_STOP:
+ case T_VOID:
+ case T_U64:
+ case T_UTF8:
+ case T_UTF16:
+ break;
+ }
+ return 0;
}
-}
-} // apache::thrift::protocol
+
+}}} // apache::thrift::protocol
#endif // #define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1