THRIFT-3235 C#: Limit recursion depth to 64
Client: C#
Patch: Jens Geyer
diff --git a/lib/csharp/src/Protocol/TProtocol.cs b/lib/csharp/src/Protocol/TProtocol.cs
index 1f5bd81..bf481ab 100644
--- a/lib/csharp/src/Protocol/TProtocol.cs
+++ b/lib/csharp/src/Protocol/TProtocol.cs
@@ -29,11 +29,17 @@
 {
     public abstract class TProtocol : IDisposable
     {
+        private const int DEFAULT_RECURSION_DEPTH = 64;
+
         protected TTransport trans;
+        protected int recursionLimit;
+        protected int recursionDepth;
 
         protected TProtocol(TTransport trans)
         {
             this.trans = trans;
+            this.recursionLimit = DEFAULT_RECURSION_DEPTH;
+            this.recursionDepth = 0;
         }
 
         public TTransport Transport
@@ -41,6 +47,25 @@
             get { return trans; }
         }
 
+        public int RecursionLimit
+        {
+            get { return recursionLimit; }
+            set { recursionLimit = value; }
+        }
+
+        public void IncrementRecursionDepth()
+        {
+            if (recursionDepth < recursionLimit)
+                ++recursionDepth;
+            else
+                throw new TProtocolException(TProtocolException.DEPTH_LIMIT, "Depth limit exceeded");
+        }
+
+        public void DecrementRecursionDepth()
+        {
+            --recursionDepth;
+        }
+
         #region " IDisposable Support "
         private bool _IsDisposed;
 
diff --git a/lib/csharp/src/Protocol/TProtocolUtil.cs b/lib/csharp/src/Protocol/TProtocolUtil.cs
index 91140d3..0932a7f 100644
--- a/lib/csharp/src/Protocol/TProtocolUtil.cs
+++ b/lib/csharp/src/Protocol/TProtocolUtil.cs
@@ -29,69 +29,78 @@
     {
         public static void Skip(TProtocol prot, TType type)
         {
-            switch (type)
+            prot.IncrementRecursionDepth();
+            try
             {
-                case TType.Bool:
-                    prot.ReadBool();
-                    break;
-                case TType.Byte:
-                    prot.ReadByte();
-                    break;
-                case TType.I16:
-                    prot.ReadI16();
-                    break;
-                case TType.I32:
-                    prot.ReadI32();
-                    break;
-                case TType.I64:
-                    prot.ReadI64();
-                    break;
-                case TType.Double:
-                    prot.ReadDouble();
-                    break;
-                case TType.String:
-                    // Don't try to decode the string, just skip it.
-                    prot.ReadBinary();
-                    break;
-                case TType.Struct:
-                    prot.ReadStructBegin();
-                    while (true)
-                    {
-                        TField field = prot.ReadFieldBegin();
-                        if (field.Type == TType.Stop)
+                switch (type)
+                {
+                    case TType.Bool:
+                        prot.ReadBool();
+                        break;
+                    case TType.Byte:
+                        prot.ReadByte();
+                        break;
+                    case TType.I16:
+                        prot.ReadI16();
+                        break;
+                    case TType.I32:
+                        prot.ReadI32();
+                        break;
+                    case TType.I64:
+                        prot.ReadI64();
+                        break;
+                    case TType.Double:
+                        prot.ReadDouble();
+                        break;
+                    case TType.String:
+                        // Don't try to decode the string, just skip it.
+                        prot.ReadBinary();
+                        break;
+                    case TType.Struct:
+                        prot.ReadStructBegin();
+                        while (true)
                         {
-                            break;
+                            TField field = prot.ReadFieldBegin();
+                            if (field.Type == TType.Stop)
+                            {
+                                break;
+                            }
+                            Skip(prot, field.Type);
+                            prot.ReadFieldEnd();
                         }
-                        Skip(prot, field.Type);
-                        prot.ReadFieldEnd();
-                    }
-                    prot.ReadStructEnd();
-                    break;
-                case TType.Map:
-                    TMap map = prot.ReadMapBegin();
-                    for (int i = 0; i < map.Count; i++)
-                    {
-                        Skip(prot, map.KeyType);
-                        Skip(prot, map.ValueType);
-                    }
-                    prot.ReadMapEnd();
-                    break;
-                case TType.Set:
-                    TSet set = prot.ReadSetBegin();
-                    for (int i = 0; i < set.Count; i++)
-                    {
-                        Skip(prot, set.ElementType);
-                    }
-                    prot.ReadSetEnd();
-                    break;
-                case TType.List:
-                    TList list = prot.ReadListBegin();
-                    for (int i = 0; i < list.Count; i++)
-                    {
-                        Skip(prot, list.ElementType);
-                    }
-                    prot.ReadListEnd();
-                    break;
+                        prot.ReadStructEnd();
+                        break;
+                    case TType.Map:
+                        TMap map = prot.ReadMapBegin();
+                        for (int i = 0; i < map.Count; i++)
+                        {
+                            Skip(prot, map.KeyType);
+                            Skip(prot, map.ValueType);
+                        }
+                        prot.ReadMapEnd();
+                        break;
+                    case TType.Set:
+                        TSet set = prot.ReadSetBegin();
+                        for (int i = 0; i < set.Count; i++)
+                        {
+                            Skip(prot, set.ElementType);
+                        }
+                        prot.ReadSetEnd();
+                        break;
+                    case TType.List:
+                        TList list = prot.ReadListBegin();
+                        for (int i = 0; i < list.Count; i++)
+                        {
+                            Skip(prot, list.ElementType);
+                        }
+                        prot.ReadListEnd();
+                        break;
+                }
+
+            }
+            finally
+            {
+                prot.DecrementRecursionDepth();
             }
         }
     }