THRIFT-3239 Limit recursion depth
Client: Haxe
Patch: Jens Geyer

This closes #547
diff --git a/compiler/cpp/src/generate/t_haxe_generator.cc b/compiler/cpp/src/generate/t_haxe_generator.cc
index a0e2f28..dfa36c5 100644
--- a/compiler/cpp/src/generate/t_haxe_generator.cc
+++ b/compiler/cpp/src/generate/t_haxe_generator.cc
@@ -818,6 +818,10 @@
   const vector<t_field*>& fields = tstruct->get_members();
   vector<t_field*>::const_iterator f_iter;
 
+  indent(out) << "iprot.IncrementRecursionDepth();" << endl;
+  indent(out) << "try" << endl;
+  scope_up(out);
+
   // Declare stack tmp variables and read struct header
   out << indent() << "var field : TField;" << endl << indent() << "iprot.readStructBegin();"
       << endl;
@@ -869,6 +873,14 @@
 
   out << indent() << "iprot.readStructEnd();" << endl << endl;
 
+  indent(out) << "iprot.DecrementRecursionDepth();" << endl;
+  scope_down(out);
+  indent(out) << "catch(e:Dynamic)" << endl;
+  scope_up(out);
+  indent(out) << "iprot.DecrementRecursionDepth();" << endl;
+  indent(out) << "throw e;" << endl;
+  scope_down(out);
+
   // check for required fields of primitive type
   // (which can be checked here but not in the general validate method)
   out << endl << indent() << "// check for required fields of primitive type, which can't be "
@@ -952,7 +964,10 @@
   vector<t_field*>::const_iterator f_iter;
 
   // performs various checks (e.g. check that all required fields are set)
-  indent(out) << "validate();" << endl << endl;
+  indent(out) << "validate();" << endl;
+  indent(out) << "oprot.IncrementRecursionDepth();" << endl;
+  indent(out) << "try" << endl;
+  scope_up(out);
 
   indent(out) << "oprot.writeStructBegin(STRUCT_DESC);" << endl;
 
@@ -977,10 +992,18 @@
       indent(out) << "}" << endl;
     }
   }
-  // Write the struct map
-  out << indent() << "oprot.writeFieldStop();" << endl << indent() << "oprot.writeStructEnd();"
-      << endl;
-
+  
+  indent(out) << "oprot.writeFieldStop();" << endl;
+  indent(out) << "oprot.writeStructEnd();" << endl;
+  
+  indent(out) << "oprot.DecrementRecursionDepth();" << endl;
+  scope_down(out);
+  indent(out) << "catch(e:Dynamic)" << endl;
+  scope_up(out);
+  indent(out) << "oprot.DecrementRecursionDepth();" << endl;
+  indent(out) << "throw e;" << endl;
+  scope_down(out);
+  
   indent_down();
   out << indent() << "}" << endl << endl;
 }
@@ -1001,6 +1024,10 @@
   const vector<t_field*>& fields = tstruct->get_sorted_members();
   vector<t_field*>::const_iterator f_iter;
 
+  indent(out) << "oprot.IncrementRecursionDepth();" << endl;
+  indent(out) << "try" << endl;
+  scope_up(out);
+  
   indent(out) << "oprot.writeStructBegin(STRUCT_DESC);" << endl;
 
   bool first = true;
@@ -1028,10 +1055,19 @@
     indent_down();
     indent(out) << "}";
   }
-  // Write the struct map
-  out << endl << indent() << "oprot.writeFieldStop();" << endl << indent()
-      << "oprot.writeStructEnd();" << endl;
-
+  
+  indent(out) << endl;
+  indent(out) << "oprot.writeFieldStop();" << endl;
+  indent(out) << "oprot.writeStructEnd();" << endl;
+  
+  indent(out) << "oprot.DecrementRecursionDepth();" << endl;
+  scope_down(out);
+  indent(out) << "catch(e:Dynamic)" << endl;
+  scope_up(out);
+  indent(out) << "oprot.DecrementRecursionDepth();" << endl;
+  indent(out) << "throw e;" << endl;
+  scope_down(out);
+  
   indent_down();
   out << indent() << "}" << endl << endl;
 }
diff --git a/lib/haxe/src/org/apache/thrift/protocol/TBinaryProtocol.hx b/lib/haxe/src/org/apache/thrift/protocol/TBinaryProtocol.hx
index 377e7ef..7ef291c 100644
--- a/lib/haxe/src/org/apache/thrift/protocol/TBinaryProtocol.hx
+++ b/lib/haxe/src/org/apache/thrift/protocol/TBinaryProtocol.hx
@@ -31,7 +31,7 @@
 /**
 * Binary protocol implementation for thrift.
 */
-class TBinaryProtocol implements TProtocol {
+class TBinaryProtocol extends TRecursionTracker implements TProtocol {
 
     private static var ANONYMOUS_STRUCT:TStruct = new TStruct();
 
diff --git a/lib/haxe/src/org/apache/thrift/protocol/TCompactProtocol.hx b/lib/haxe/src/org/apache/thrift/protocol/TCompactProtocol.hx
index c4d0ced..03b13e2 100644
--- a/lib/haxe/src/org/apache/thrift/protocol/TCompactProtocol.hx
+++ b/lib/haxe/src/org/apache/thrift/protocol/TCompactProtocol.hx
@@ -37,7 +37,7 @@
 /**
 * Compact protocol implementation for thrift.
 */
-class TCompactProtocol implements TProtocol {
+class TCompactProtocol extends TRecursionTracker implements TProtocol {
 
     private static var ANONYMOUS_STRUCT : TStruct = new TStruct("");
     private static var TSTOP : TField = new TField("", TType.STOP, 0);
diff --git a/lib/haxe/src/org/apache/thrift/protocol/TJSONProtocol.hx b/lib/haxe/src/org/apache/thrift/protocol/TJSONProtocol.hx
index aeed8f4..e20ff33 100644
--- a/lib/haxe/src/org/apache/thrift/protocol/TJSONProtocol.hx
+++ b/lib/haxe/src/org/apache/thrift/protocol/TJSONProtocol.hx
@@ -45,7 +45,7 @@
 *
 *  Adapted from the Java version.
 */
-class TJSONProtocol implements TProtocol {
+class TJSONProtocol extends TRecursionTracker implements TProtocol {
 
     public var trans(default,null) : TTransport;
 
diff --git a/lib/haxe/src/org/apache/thrift/protocol/TProtocol.hx b/lib/haxe/src/org/apache/thrift/protocol/TProtocol.hx
index 0998e92..22e88e4 100644
--- a/lib/haxe/src/org/apache/thrift/protocol/TProtocol.hx
+++ b/lib/haxe/src/org/apache/thrift/protocol/TProtocol.hx
@@ -79,4 +79,7 @@
     function readString() : String;
     function readBinary() : Bytes;
 
+	// recursion tracking
+	function IncrementRecursionDepth() : Void;
+	function DecrementRecursionDepth() : Void;
 }
diff --git a/lib/haxe/src/org/apache/thrift/protocol/TProtocolUtil.hx b/lib/haxe/src/org/apache/thrift/protocol/TProtocolUtil.hx
index 794e397..71ed4ba 100644
--- a/lib/haxe/src/org/apache/thrift/protocol/TProtocolUtil.hx
+++ b/lib/haxe/src/org/apache/thrift/protocol/TProtocolUtil.hx
@@ -29,107 +29,82 @@
 class TProtocolUtil {
 
     /**
-     * The maximum recursive depth the skip() function will traverse before
-     * throwing a TException.
-     */
-    private static var maxSkipDepth : Int = Limits.I32_MAX;
-
-    /**
-     * Specifies the maximum recursive depth that the skip function will
-     * traverse before throwing a TException.  This is a global setting, so
-     * any call to skip in this JVM will enforce this value.
-     *
-     * @param depth  the maximum recursive depth.  A value of 2 would allow
-     *    the skip function to skip a structure or collection with basic children,
-     *    but it would not permit skipping a struct that had a field containing
-     *    a child struct.  A value of 1 would only allow skipping of simple
-     *    types and empty structs/collections.
-     */
-    public function setMaxSkipDepth(depth : Int) : Void {
-      maxSkipDepth = depth;
-    }
-
-    /**
      * Skips over the next data element from the provided input TProtocol object.
      *
      * @param prot  the protocol object to read from
      * @param type  the next value will be intepreted as this TType value.
      */
     public static function skip(prot:TProtocol, type : Int) : Void {
-      skipMaxDepth(prot, type, maxSkipDepth);
-    }
+		prot.IncrementRecursionDepth();
+		try
+		{
+			switch (type) {
+				case TType.BOOL: 
+					prot.readBool();
 
-     /**
-     * Skips over the next data element from the provided input TProtocol object.
-     *
-     * @param prot  the protocol object to read from
-     * @param type  the next value will be intepreted as this TType value.
-     * @param maxDepth  this function will only skip complex objects to this
-     *   recursive depth, to prevent Java stack overflow.
-     */
-    public static function skipMaxDepth(prot:TProtocol, type : Int, maxDepth : Int) : Void {
-      if (maxDepth <= 0) {
-        throw new TException("Maximum skip depth exceeded");
-      }
-      switch (type) {
-        case TType.BOOL: {
-          prot.readBool();
-        }
-        case TType.BYTE: {
-          prot.readByte();
-        }
-        case TType.I16: {
-          prot.readI16();
-        }
-        case TType.I32: {
-          prot.readI32();
-        }
-        case TType.I64: {
-          prot.readI64();
-        }
-        case TType.DOUBLE: {
-          prot.readDouble();
-        }
-        case TType.STRING: {
-          prot.readBinary();
-        }
-        case TType.STRUCT: {
-          prot.readStructBegin();
-          while (true) {
-            var field:TField = prot.readFieldBegin();
-            if (field.type == TType.STOP) {
-              break;
-            }
-            skipMaxDepth(prot, field.type, maxDepth - 1);
-            prot.readFieldEnd();
-          }
-          prot.readStructEnd();
-        }
-        case TType.MAP: {
-          var map:TMap = prot.readMapBegin();
-          for (i in 0 ... map.size) {
-            skipMaxDepth(prot, map.keyType, maxDepth - 1);
-            skipMaxDepth(prot, map.valueType, maxDepth - 1);
-          }
-          prot.readMapEnd();
-        }
-        case TType.SET: {
-          var set:TSet = prot.readSetBegin();
-          for (j in 0 ... set.size) {
-            skipMaxDepth(prot, set.elemType, maxDepth - 1);
-          }
-          prot.readSetEnd();
-        }
-        case TType.LIST: {
-          var list:TList = prot.readListBegin();
-          for (k in 0 ... list.size) {
-            skipMaxDepth(prot, list.elemType, maxDepth - 1);
-          }
-          prot.readListEnd();
-        }
-        default:
-          trace("Unknown field type ",type," in skipMaxDepth()");
-      }
+				case TType.BYTE: 
+					prot.readByte();
+
+				case TType.I16: 
+					prot.readI16();
+
+				case TType.I32: 
+					prot.readI32();
+
+				case TType.I64: 
+					prot.readI64();
+
+				case TType.DOUBLE: 
+					prot.readDouble();
+
+				case TType.STRING: 
+					prot.readBinary();
+
+				case TType.STRUCT: 
+					prot.readStructBegin();
+					while (true) {
+						var field:TField = prot.readFieldBegin();
+						if (field.type == TType.STOP) {
+						  break;
+						}
+						skip(prot, field.type);
+						prot.readFieldEnd();
+					}
+					prot.readStructEnd();
+
+				case TType.MAP: 
+					var map:TMap = prot.readMapBegin();
+					for (i in 0 ... map.size) {
+						skip(prot, map.keyType);
+						skip(prot, map.valueType);
+					}
+					prot.readMapEnd();
+
+				case TType.SET: 
+					var set:TSet = prot.readSetBegin();
+					for (j in 0 ... set.size) {
+						skip(prot, set.elemType);
+					}
+					prot.readSetEnd();
+
+				case TType.LIST: 
+					var list:TList = prot.readListBegin();
+					for (k in 0 ... list.size) {
+						skip(prot, list.elemType);
+					}
+					prot.readListEnd();
+
+				default:
+					trace("Unknown field type ",type," in skipMaxDepth()");
+			}
+			
+			prot.DecrementRecursionDepth();
+		}
+		catch(e:Dynamic)
+		{
+			prot.DecrementRecursionDepth();
+			throw e;
+		}
     }
 
 }
diff --git a/lib/haxe/src/org/apache/thrift/protocol/TRecursionTracker.hx b/lib/haxe/src/org/apache/thrift/protocol/TRecursionTracker.hx
new file mode 100644
index 0000000..b882cf2
--- /dev/null
+++ b/lib/haxe/src/org/apache/thrift/protocol/TRecursionTracker.hx
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.protocol;
+
+import org.apache.thrift.*;
+
+
+class TRecursionTracker {
+
+	// default 
+    private static inline var DEFAULT_RECURSION_DEPTH : Int = 64; 
+
+	// limit and actual value
+	public var recursionLimit : Int = DEFAULT_RECURSION_DEPTH;
+    private var recursionDepth : Int = 0;
+
+	public function IncrementRecursionDepth() : Void 
+	{
+		if (recursionDepth < recursionLimit)
+			++recursionDepth;
+		else
+			throw new TProtocolException(TProtocolException.DEPTH_LIMIT, "Depth limit exceeded");
+	}
+
+	public function DecrementRecursionDepth() : Void 
+	{
+		--recursionDepth;
+	}
+
+
+}
diff --git a/lib/haxe/src/org/apache/thrift/server/TSimpleServer.hx b/lib/haxe/src/org/apache/thrift/server/TSimpleServer.hx
index f3408e2..3b64b62 100644
--- a/lib/haxe/src/org/apache/thrift/server/TSimpleServer.hx
+++ b/lib/haxe/src/org/apache/thrift/server/TSimpleServer.hx
@@ -105,7 +105,7 @@
             }
             catch( pex : TProtocolException)
             {
-                logDelegate(pex); // Unexpected
+                logDelegate('$pex ${pex.errorID} ${pex.errorMsg}'); // Unexpected
             }
             catch( e : Dynamic)
             {
diff --git a/test/haxe/src/TestServer.hx b/test/haxe/src/TestServer.hx
index 4490a8c..bff5a47 100644
--- a/test/haxe/src/TestServer.hx
+++ b/test/haxe/src/TestServer.hx
@@ -106,7 +106,7 @@
         }
         catch (x : TException)
         {
-            trace('$x');
+			trace('$x ${x.errorID} ${x.errorMsg}');
         }
         catch (x : Dynamic)
         {