THRIFT-3231 CPP: Limit recursion depth to 64
Client: cpp
Patch: Ben Craig <bencraig@apache.org>
diff --git a/compiler/cpp/src/generate/t_cpp_generator.cc b/compiler/cpp/src/generate/t_cpp_generator.cc
index 426434f..aed3935 100644
--- a/compiler/cpp/src/generate/t_cpp_generator.cc
+++ b/compiler/cpp/src/generate/t_cpp_generator.cc
@@ -1367,10 +1367,16 @@
   vector<t_field*>::const_iterator f_iter;
 
   // Declare stack tmp variables
-  out << endl << indent() << "uint32_t xfer = 0;" << endl << indent() << "std::string fname;"
-      << endl << indent() << "::apache::thrift::protocol::TType ftype;" << endl << indent()
-      << "int16_t fid;" << endl << endl << indent() << "xfer += iprot->readStructBegin(fname);"
-      << endl << endl << indent() << "using ::apache::thrift::protocol::TProtocolException;" << endl
+  out << endl
+      << indent() << "apache::thrift::protocol::TRecursionTracker tracker(*iprot);" << endl
+      << indent() << "uint32_t xfer = 0;" << endl
+      << indent() << "std::string fname;" << endl
+      << indent() << "::apache::thrift::protocol::TType ftype;" << endl
+      << indent() << "int16_t fid;" << endl
+      << endl
+      << indent() << "xfer += iprot->readStructBegin(fname);" << endl
+      << endl
+      << indent() << "using ::apache::thrift::protocol::TProtocolException;" << endl
       << endl;
 
   // Required variables aren't in __isset, so we need tmp vars to check them.
@@ -1486,7 +1492,7 @@
 
   out << indent() << "uint32_t xfer = 0;" << endl;
 
-  indent(out) << "oprot->incrementRecursionDepth();" << endl;
+  indent(out) << "apache::thrift::protocol::TRecursionTracker tracker(*oprot);" << endl;
   indent(out) << "xfer += oprot->writeStructBegin(\"" << name << "\");" << endl;
 
   for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
@@ -1522,7 +1528,7 @@
   // Write the struct map
   out << indent() << "xfer += oprot->writeFieldStop();" << endl << indent()
       << "xfer += oprot->writeStructEnd();" << endl << indent()
-      << "oprot->decrementRecursionDepth();" << endl << indent() << "return xfer;" << endl;
+      << "return xfer;" << endl;
 
   indent_down();
   indent(out) << "}" << endl << endl;
diff --git a/lib/cpp/CMakeLists.txt b/lib/cpp/CMakeLists.txt
index b97e356..bab2e84 100755
--- a/lib/cpp/CMakeLists.txt
+++ b/lib/cpp/CMakeLists.txt
@@ -39,11 +39,12 @@
    src/thrift/concurrency/TimerManager.cpp
    src/thrift/concurrency/Util.cpp
    src/thrift/processor/PeekProcessor.cpp
+   src/thrift/protocol/TBase64Utils.cpp
    src/thrift/protocol/TDebugProtocol.cpp
    src/thrift/protocol/TDenseProtocol.cpp
    src/thrift/protocol/TJSONProtocol.cpp
-   src/thrift/protocol/TBase64Utils.cpp
    src/thrift/protocol/TMultiplexedProtocol.cpp
+   src/thrift/protocol/TProtocol.cpp
    src/thrift/transport/TTransportException.cpp
    src/thrift/transport/TFDTransport.cpp
    src/thrift/transport/TSimpleFileTransport.cpp
diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am
index 9156577..0ecbeee 100755
--- a/lib/cpp/Makefile.am
+++ b/lib/cpp/Makefile.am
@@ -75,6 +75,7 @@
                        src/thrift/protocol/TJSONProtocol.cpp \
                        src/thrift/protocol/TBase64Utils.cpp \
                        src/thrift/protocol/TMultiplexedProtocol.cpp \
+                       src/thrift/protocol/TProtocol.cpp \
                        src/thrift/transport/TTransportException.cpp \
                        src/thrift/transport/TFDTransport.cpp \
                        src/thrift/transport/TFileTransport.cpp \
diff --git a/lib/cpp/src/thrift/protocol/TProtocol.cpp b/lib/cpp/src/thrift/protocol/TProtocol.cpp
new file mode 100644
index 0000000..c378aca
--- /dev/null
+++ b/lib/cpp/src/thrift/protocol/TProtocol.cpp
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <thrift/protocol/TProtocol.h>
+
+namespace apache {
+namespace thrift {
+namespace protocol {
+
+TProtocol::~TProtocol() {}
+uint32_t TProtocol::skip_virt(TType type) {
+  return ::apache::thrift::protocol::skip(*this, type);
+}
+
+TProtocolFactory::~TProtocolFactory() {}
+
+}}} // apache::thrift::protocol
diff --git a/lib/cpp/src/thrift/protocol/TProtocol.h b/lib/cpp/src/thrift/protocol/TProtocol.h
index 9eec1ee..1aa2122 100644
--- a/lib/cpp/src/thrift/protocol/TProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TProtocol.h
@@ -33,6 +33,7 @@
 #include <string>
 #include <map>
 #include <vector>
+#include <climits>
 
 // Use this to get around strict aliasing rules.
 // For example, uint64_t i = bitwise_cast<uint64_t>(returns_double());
@@ -199,105 +200,6 @@
   T_ONEWAY     = 4
 };
 
-
-/**
- * Helper template for implementing TProtocol::skip().
- *
- * Templatized to avoid having to make virtual function calls.
- */
-template <class Protocol_>
-uint32_t skip(Protocol_& prot, TType type) {
-  switch (type) {
-  case T_BOOL: {
-    bool boolv;
-    return prot.readBool(boolv);
-  }
-  case T_BYTE: {
-    int8_t bytev;
-    return prot.readByte(bytev);
-  }
-  case T_I16: {
-    int16_t i16;
-    return prot.readI16(i16);
-  }
-  case T_I32: {
-    int32_t i32;
-    return prot.readI32(i32);
-  }
-  case T_I64: {
-    int64_t i64;
-    return prot.readI64(i64);
-  }
-  case T_DOUBLE: {
-    double dub;
-    return prot.readDouble(dub);
-  }
-  case T_STRING: {
-    std::string str;
-    return prot.readBinary(str);
-  }
-  case T_STRUCT: {
-    uint32_t result = 0;
-    std::string name;
-    int16_t fid;
-    TType ftype;
-    result += prot.readStructBegin(name);
-    while (true) {
-      result += prot.readFieldBegin(name, ftype, fid);
-      if (ftype == T_STOP) {
-        break;
-      }
-      result += skip(prot, ftype);
-      result += prot.readFieldEnd();
-    }
-    result += prot.readStructEnd();
-    return result;
-  }
-  case T_MAP: {
-    uint32_t result = 0;
-    TType keyType;
-    TType valType;
-    uint32_t i, size;
-    result += prot.readMapBegin(keyType, valType, size);
-    for (i = 0; i < size; i++) {
-      result += skip(prot, keyType);
-      result += skip(prot, valType);
-    }
-    result += prot.readMapEnd();
-    return result;
-  }
-  case T_SET: {
-    uint32_t result = 0;
-    TType elemType;
-    uint32_t i, size;
-    result += prot.readSetBegin(elemType, size);
-    for (i = 0; i < size; i++) {
-      result += skip(prot, elemType);
-    }
-    result += prot.readSetEnd();
-    return result;
-  }
-  case T_LIST: {
-    uint32_t result = 0;
-    TType elemType;
-    uint32_t i, size;
-    result += prot.readListBegin(elemType, size);
-    for (i = 0; i < size; i++) {
-      result += skip(prot, elemType);
-    }
-    result += prot.readListEnd();
-    return result;
-  }
-  case T_STOP:
-  case T_VOID:
-  case T_U64:
-  case T_UTF8:
-  case T_UTF16:
-    break;
-  }
-  return 0;
-}
-
 static const uint32_t DEFAULT_RECURSION_LIMIT = 64;
 
 /**
@@ -316,7 +218,7 @@
  */
 class TProtocol {
 public:
-  virtual ~TProtocol() {}
+  virtual ~TProtocol();
 
   /**
    * Writing functions.
@@ -641,7 +543,7 @@
     T_VIRTUAL_CALL();
     return skip_virt(type);
   }
-  virtual uint32_t skip_virt(TType type) { return ::apache::thrift::protocol::skip(*this, type); }
+  virtual uint32_t skip_virt(TType type);
 
   inline boost::shared_ptr<TTransport> getTransport() { return ptrans_; }
 
@@ -657,10 +559,13 @@
   }
 
   void decrementRecursionDepth() { --recursion_depth_; }
+  uint32_t getRecursionLimit() const {return recursion_limit_;}
+  void setRecurisionLimit(uint32_t depth) {recursion_limit_ = depth;}
 
 protected:
   TProtocol(boost::shared_ptr<TTransport> ptrans)
-    : ptrans_(ptrans), recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT) {}
+    : ptrans_(ptrans), recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT)
+  {}
 
   boost::shared_ptr<TTransport> ptrans_;
 
@@ -677,7 +582,7 @@
 public:
   TProtocolFactory() {}
 
-  virtual ~TProtocolFactory() {}
+  virtual ~TProtocolFactory();
 
   virtual boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) = 0;
 };
@@ -712,8 +617,116 @@
   static uint64_t fromWire64(uint64_t x) {return letohll(x);}
 };
 
+struct TRecursionTracker {
+  TProtocol &prot_;
+  TRecursionTracker(TProtocol &prot) : prot_(prot) {
+    prot_.incrementRecursionDepth();
+  }
+  ~TRecursionTracker() {
+    prot_.decrementRecursionDepth();
+  }
+};
+
+/**
+ * Helper template for implementing TProtocol::skip().
+ *
+ * Templatized to avoid having to make virtual function calls.
+ */
+template <class Protocol_>
+uint32_t skip(Protocol_& prot, TType type) {
+  TRecursionTracker tracker(prot);
+
+  switch (type) {
+  case T_BOOL: {
+    bool boolv;
+    return prot.readBool(boolv);
+  }
+  case T_BYTE: {
+    int8_t bytev;
+    return prot.readByte(bytev);
+  }
+  case T_I16: {
+    int16_t i16;
+    return prot.readI16(i16);
+  }
+  case T_I32: {
+    int32_t i32;
+    return prot.readI32(i32);
+  }
+  case T_I64: {
+    int64_t i64;
+    return prot.readI64(i64);
+  }
+  case T_DOUBLE: {
+    double dub;
+    return prot.readDouble(dub);
+  }
+  case T_STRING: {
+    std::string str;
+    return prot.readBinary(str);
+  }
+  case T_STRUCT: {
+    uint32_t result = 0;
+    std::string name;
+    int16_t fid;
+    TType ftype;
+    result += prot.readStructBegin(name);
+    while (true) {
+      result += prot.readFieldBegin(name, ftype, fid);
+      if (ftype == T_STOP) {
+        break;
+      }
+      result += skip(prot, ftype);
+      result += prot.readFieldEnd();
+    }
+    result += prot.readStructEnd();
+    return result;
+  }
+  case T_MAP: {
+    uint32_t result = 0;
+    TType keyType;
+    TType valType;
+    uint32_t i, size;
+    result += prot.readMapBegin(keyType, valType, size);
+    for (i = 0; i < size; i++) {
+      result += skip(prot, keyType);
+      result += skip(prot, valType);
+    }
+    result += prot.readMapEnd();
+    return result;
+  }
+  case T_SET: {
+    uint32_t result = 0;
+    TType elemType;
+    uint32_t i, size;
+    result += prot.readSetBegin(elemType, size);
+    for (i = 0; i < size; i++) {
+      result += skip(prot, elemType);
+    }
+    result += prot.readSetEnd();
+    return result;
+  }
+  case T_LIST: {
+    uint32_t result = 0;
+    TType elemType;
+    uint32_t i, size;
+    result += prot.readListBegin(elemType, size);
+    for (i = 0; i < size; i++) {
+      result += skip(prot, elemType);
+    }
+    result += prot.readListEnd();
+    return result;
+  }
+  case T_STOP:
+  case T_VOID:
+  case T_U64:
+  case T_UTF8:
+  case T_UTF16:
+    break;
+  }
+  return 0;
 }
-}
-} // apache::thrift::protocol
+
+}}} // apache::thrift::protocol
 
 #endif // #define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1