THRIFT-3235 C#: Limit recursion depth to 64
Client: C#
Patch: Jens Geyer
diff --git a/compiler/cpp/src/generate/t_csharp_generator.cc b/compiler/cpp/src/generate/t_csharp_generator.cc
index d6aad75..6a21801 100644
--- a/compiler/cpp/src/generate/t_csharp_generator.cc
+++ b/compiler/cpp/src/generate/t_csharp_generator.cc
@@ -909,6 +909,10 @@
   indent(out) << "public void Read (TProtocol iprot)" << endl;
   scope_up(out);
 
+  out << indent() << "iprot.IncrementRecursionDepth();" << endl;
+  out << indent() << "try" << endl;
+  scope_up(out);
+
   const vector<t_field*>& fields = tstruct->get_members();
   vector<t_field*>::const_iterator f_iter;
 
@@ -977,6 +981,12 @@
     }
   }
 
+  scope_down(out);
+  out << indent() << "finally" << endl;
+  scope_up(out);
+  out << indent() << "iprot.DecrementRecursionDepth();" << endl;
+  scope_down(out);
+
   indent_down();
 
   indent(out) << "}" << endl << endl;
@@ -985,6 +995,10 @@
 void t_csharp_generator::generate_csharp_struct_writer(ofstream& out, t_struct* tstruct) {
   out << indent() << "public void Write(TProtocol oprot) {" << endl;
   indent_up();
+  
+  out << indent() << "oprot.IncrementRecursionDepth();" << endl;
+  out << indent() << "try" << endl;
+  scope_up(out);
 
   string name = tstruct->get_name();
   const vector<t_field*>& fields = tstruct->get_sorted_members();
@@ -1030,8 +1044,14 @@
   indent(out) << "oprot.WriteFieldStop();" << endl;
   indent(out) << "oprot.WriteStructEnd();" << endl;
 
-  indent_down();
+  scope_down(out);
+  out << indent() << "finally" << endl;
+  scope_up(out);
+  out << indent() << "oprot.DecrementRecursionDepth();" << endl;
+  scope_down(out);
 
+  indent_down();
+  
   indent(out) << "}" << endl << endl;
 }
 
@@ -1039,6 +1059,10 @@
   indent(out) << "public void Write(TProtocol oprot) {" << endl;
   indent_up();
 
+  out << indent() << "oprot.IncrementRecursionDepth();" << endl;
+  out << indent() << "try" << endl;
+  scope_up(out);
+
   string name = tstruct->get_name();
   const vector<t_field*>& fields = tstruct->get_sorted_members();
   vector<t_field*>::const_iterator f_iter;
@@ -1092,6 +1116,12 @@
   out << endl << indent() << "oprot.WriteFieldStop();" << endl << indent()
       << "oprot.WriteStructEnd();" << endl;
 
+  scope_down(out);
+  out << indent() << "finally" << endl;
+  scope_up(out);
+  out << indent() << "oprot.DecrementRecursionDepth();" << endl;
+  scope_down(out);
+
   indent_down();
 
   indent(out) << "}" << endl << endl;
@@ -1249,6 +1279,11 @@
   indent(out) << "}" << endl;
   indent(out) << "public override void Write(TProtocol oprot) {" << endl;
   indent_up();
+
+  out << indent() << "oprot.IncrementRecursionDepth();" << endl;
+  out << indent() << "try" << endl;
+  scope_up(out);
+
   indent(out) << "TStruct struc = new TStruct(\"" << tunion->get_name() << "\");" << endl;
   indent(out) << "oprot.WriteStructBegin(struc);" << endl;
 
@@ -1264,6 +1299,13 @@
   indent(out) << "oprot.WriteFieldStop();" << endl;
   indent(out) << "oprot.WriteStructEnd();" << endl;
   indent_down();
+
+  scope_down(out);
+  out << indent() << "finally" << endl;
+  scope_up(out);
+  out << indent() << "oprot.DecrementRecursionDepth();" << endl;
+  scope_down(out);
+
   indent(out) << "}" << endl;
 
   indent_down();
@@ -1987,6 +2029,11 @@
 
   indent(out) << "public static " << tunion->get_name() << " Read(TProtocol iprot)" << endl;
   scope_up(out);
+
+  out << indent() << "iprot.IncrementRecursionDepth();" << endl;
+  out << indent() << "try" << endl;
+  scope_up(out);
+
   indent(out) << tunion->get_name() << " retval;" << endl;
   indent(out) << "iprot.ReadStructBegin();" << endl;
   indent(out) << "TField field = iprot.ReadFieldBegin();" << endl;
@@ -2036,13 +2083,16 @@
 
   // end of else for TStop
   scope_down(out);
-
   indent(out) << "iprot.ReadStructEnd();" << endl;
-
   indent(out) << "return retval;" << endl;
-
   indent_down();
 
+  scope_down(out);
+  out << indent() << "finally" << endl;
+  scope_up(out);
+  out << indent() << "iprot.DecrementRecursionDepth();" << endl;
+  scope_down(out);
+
   indent(out) << "}" << endl << endl;
 }
 
diff --git a/lib/csharp/src/Protocol/TProtocol.cs b/lib/csharp/src/Protocol/TProtocol.cs
index 1f5bd81..bf481ab 100644
--- a/lib/csharp/src/Protocol/TProtocol.cs
+++ b/lib/csharp/src/Protocol/TProtocol.cs
@@ -29,11 +29,17 @@
 {
     public abstract class TProtocol : IDisposable
     {
+        private const int DEFAULT_RECURSION_DEPTH = 64;
+
         protected TTransport trans;
+        protected int recursionLimit;
+        protected int recursionDepth;
 
         protected TProtocol(TTransport trans)
         {
             this.trans = trans;
+            this.recursionLimit = DEFAULT_RECURSION_DEPTH;
+            this.recursionDepth = 0;
         }
 
         public TTransport Transport
@@ -41,6 +47,25 @@
             get { return trans; }
         }
 
+        public int RecursionLimit
+        {
+            get { return recursionLimit; }
+            set { recursionLimit = value; }
+        }
+
+        public void IncrementRecursionDepth()
+        {
+            if (recursionDepth < recursionLimit)
+                ++recursionDepth;
+            else
+                throw new TProtocolException(TProtocolException.DEPTH_LIMIT, "Depth limit exceeded");
+        }
+
+        public void DecrementRecursionDepth()
+        {
+            --recursionDepth;
+        }
+
         #region " IDisposable Support "
         private bool _IsDisposed;
 
diff --git a/lib/csharp/src/Protocol/TProtocolUtil.cs b/lib/csharp/src/Protocol/TProtocolUtil.cs
index 91140d3..0932a7f 100644
--- a/lib/csharp/src/Protocol/TProtocolUtil.cs
+++ b/lib/csharp/src/Protocol/TProtocolUtil.cs
@@ -29,69 +29,78 @@
     {
         public static void Skip(TProtocol prot, TType type)
         {
-            switch (type)
+            prot.IncrementRecursionDepth();
+            try
             {
-                case TType.Bool:
-                    prot.ReadBool();
-                    break;
-                case TType.Byte:
-                    prot.ReadByte();
-                    break;
-                case TType.I16:
-                    prot.ReadI16();
-                    break;
-                case TType.I32:
-                    prot.ReadI32();
-                    break;
-                case TType.I64:
-                    prot.ReadI64();
-                    break;
-                case TType.Double:
-                    prot.ReadDouble();
-                    break;
-                case TType.String:
-                    // Don't try to decode the string, just skip it.
-                    prot.ReadBinary();
-                    break;
-                case TType.Struct:
-                    prot.ReadStructBegin();
-                    while (true)
-                    {
-                        TField field = prot.ReadFieldBegin();
-                        if (field.Type == TType.Stop)
+                switch (type)
+                {
+                    case TType.Bool:
+                        prot.ReadBool();
+                        break;
+                    case TType.Byte:
+                        prot.ReadByte();
+                        break;
+                    case TType.I16:
+                        prot.ReadI16();
+                        break;
+                    case TType.I32:
+                        prot.ReadI32();
+                        break;
+                    case TType.I64:
+                        prot.ReadI64();
+                        break;
+                    case TType.Double:
+                        prot.ReadDouble();
+                        break;
+                    case TType.String:
+                        // Don't try to decode the string, just skip it.
+                        prot.ReadBinary();
+                        break;
+                    case TType.Struct:
+                        prot.ReadStructBegin();
+                        while (true)
                         {
-                            break;
+                            TField field = prot.ReadFieldBegin();
+                            if (field.Type == TType.Stop)
+                            {
+                                break;
+                            }
+                            Skip(prot, field.Type);
+                            prot.ReadFieldEnd();
                         }
-                        Skip(prot, field.Type);
-                        prot.ReadFieldEnd();
-                    }
-                    prot.ReadStructEnd();
-                    break;
-                case TType.Map:
-                    TMap map = prot.ReadMapBegin();
-                    for (int i = 0; i < map.Count; i++)
-                    {
-                        Skip(prot, map.KeyType);
-                        Skip(prot, map.ValueType);
-                    }
-                    prot.ReadMapEnd();
-                    break;
-                case TType.Set:
-                    TSet set = prot.ReadSetBegin();
-                    for (int i = 0; i < set.Count; i++)
-                    {
-                        Skip(prot, set.ElementType);
-                    }
-                    prot.ReadSetEnd();
-                    break;
-                case TType.List:
-                    TList list = prot.ReadListBegin();
-                    for (int i = 0; i < list.Count; i++)
-                    {
-                        Skip(prot, list.ElementType);
-                    }
-                    prot.ReadListEnd();
-                    break;
+                        prot.ReadStructEnd();
+                        break;
+                    case TType.Map:
+                        TMap map = prot.ReadMapBegin();
+                        for (int i = 0; i < map.Count; i++)
+                        {
+                            Skip(prot, map.KeyType);
+                            Skip(prot, map.ValueType);
+                        }
+                        prot.ReadMapEnd();
+                        break;
+                    case TType.Set:
+                        TSet set = prot.ReadSetBegin();
+                        for (int i = 0; i < set.Count; i++)
+                        {
+                            Skip(prot, set.ElementType);
+                        }
+                        prot.ReadSetEnd();
+                        break;
+                    case TType.List:
+                        TList list = prot.ReadListBegin();
+                        for (int i = 0; i < list.Count; i++)
+                        {
+                            Skip(prot, list.ElementType);
+                        }
+                        prot.ReadListEnd();
+                        break;
+                }
+
+            }
+            finally
+            {
+                prot.DecrementRecursionDepth();
             }
         }
     }