THRIFT-3235 C#: Limit recursion depth to 64
Client: C#
Patch: Jens Geyer
diff --git a/compiler/cpp/src/generate/t_csharp_generator.cc b/compiler/cpp/src/generate/t_csharp_generator.cc
index d6aad75..6a21801 100644
--- a/compiler/cpp/src/generate/t_csharp_generator.cc
+++ b/compiler/cpp/src/generate/t_csharp_generator.cc
@@ -909,6 +909,10 @@
indent(out) << "public void Read (TProtocol iprot)" << endl;
scope_up(out);
+ out << indent() << "iprot.IncrementRecursionDepth();" << endl;
+ out << indent() << "try" << endl;
+ scope_up(out);
+
const vector<t_field*>& fields = tstruct->get_members();
vector<t_field*>::const_iterator f_iter;
@@ -977,6 +981,12 @@
}
}
+ scope_down(out);
+ out << indent() << "finally" << endl;
+ scope_up(out);
+ out << indent() << "iprot.DecrementRecursionDepth();" << endl;
+ scope_down(out);
+
indent_down();
indent(out) << "}" << endl << endl;
@@ -985,6 +995,10 @@
void t_csharp_generator::generate_csharp_struct_writer(ofstream& out, t_struct* tstruct) {
out << indent() << "public void Write(TProtocol oprot) {" << endl;
indent_up();
+
+ out << indent() << "oprot.IncrementRecursionDepth();" << endl;
+ out << indent() << "try" << endl;
+ scope_up(out);
string name = tstruct->get_name();
const vector<t_field*>& fields = tstruct->get_sorted_members();
@@ -1030,8 +1044,14 @@
indent(out) << "oprot.WriteFieldStop();" << endl;
indent(out) << "oprot.WriteStructEnd();" << endl;
- indent_down();
+ scope_down(out);
+ out << indent() << "finally" << endl;
+ scope_up(out);
+ out << indent() << "oprot.DecrementRecursionDepth();" << endl;
+ scope_down(out);
+ indent_down();
+
indent(out) << "}" << endl << endl;
}
@@ -1039,6 +1059,10 @@
indent(out) << "public void Write(TProtocol oprot) {" << endl;
indent_up();
+ out << indent() << "oprot.IncrementRecursionDepth();" << endl;
+ out << indent() << "try" << endl;
+ scope_up(out);
+
string name = tstruct->get_name();
const vector<t_field*>& fields = tstruct->get_sorted_members();
vector<t_field*>::const_iterator f_iter;
@@ -1092,6 +1116,12 @@
out << endl << indent() << "oprot.WriteFieldStop();" << endl << indent()
<< "oprot.WriteStructEnd();" << endl;
+ scope_down(out);
+ out << indent() << "finally" << endl;
+ scope_up(out);
+ out << indent() << "oprot.DecrementRecursionDepth();" << endl;
+ scope_down(out);
+
indent_down();
indent(out) << "}" << endl << endl;
@@ -1249,6 +1279,11 @@
indent(out) << "}" << endl;
indent(out) << "public override void Write(TProtocol oprot) {" << endl;
indent_up();
+
+ out << indent() << "oprot.IncrementRecursionDepth();" << endl;
+ out << indent() << "try" << endl;
+ scope_up(out);
+
indent(out) << "TStruct struc = new TStruct(\"" << tunion->get_name() << "\");" << endl;
indent(out) << "oprot.WriteStructBegin(struc);" << endl;
@@ -1264,6 +1299,13 @@
indent(out) << "oprot.WriteFieldStop();" << endl;
indent(out) << "oprot.WriteStructEnd();" << endl;
indent_down();
+
+ scope_down(out);
+ out << indent() << "finally" << endl;
+ scope_up(out);
+ out << indent() << "oprot.DecrementRecursionDepth();" << endl;
+ scope_down(out);
+
indent(out) << "}" << endl;
indent_down();
@@ -1987,6 +2029,11 @@
indent(out) << "public static " << tunion->get_name() << " Read(TProtocol iprot)" << endl;
scope_up(out);
+
+ out << indent() << "iprot.IncrementRecursionDepth();" << endl;
+ out << indent() << "try" << endl;
+ scope_up(out);
+
indent(out) << tunion->get_name() << " retval;" << endl;
indent(out) << "iprot.ReadStructBegin();" << endl;
indent(out) << "TField field = iprot.ReadFieldBegin();" << endl;
@@ -2036,13 +2083,16 @@
// end of else for TStop
scope_down(out);
-
indent(out) << "iprot.ReadStructEnd();" << endl;
-
indent(out) << "return retval;" << endl;
-
indent_down();
+ scope_down(out);
+ out << indent() << "finally" << endl;
+ scope_up(out);
+ out << indent() << "iprot.DecrementRecursionDepth();" << endl;
+ scope_down(out);
+
indent(out) << "}" << endl << endl;
}
diff --git a/lib/csharp/src/Protocol/TProtocol.cs b/lib/csharp/src/Protocol/TProtocol.cs
index 1f5bd81..bf481ab 100644
--- a/lib/csharp/src/Protocol/TProtocol.cs
+++ b/lib/csharp/src/Protocol/TProtocol.cs
@@ -29,11 +29,17 @@
{
public abstract class TProtocol : IDisposable
{
+ private const int DEFAULT_RECURSION_DEPTH = 64;
+
protected TTransport trans;
+ protected int recursionLimit;
+ protected int recursionDepth;
protected TProtocol(TTransport trans)
{
this.trans = trans;
+ this.recursionLimit = DEFAULT_RECURSION_DEPTH;
+ this.recursionDepth = 0;
}
public TTransport Transport
@@ -41,6 +47,25 @@
get { return trans; }
}
+ public int RecursionLimit
+ {
+ get { return recursionLimit; }
+ set { recursionLimit = value; }
+ }
+
+ public void IncrementRecursionDepth()
+ {
+ if (recursionDepth < recursionLimit)
+ ++recursionDepth;
+ else
+ throw new TProtocolException(TProtocolException.DEPTH_LIMIT, "Depth limit exceeded");
+ }
+
+ public void DecrementRecursionDepth()
+ {
+ --recursionDepth;
+ }
+
#region " IDisposable Support "
private bool _IsDisposed;
diff --git a/lib/csharp/src/Protocol/TProtocolUtil.cs b/lib/csharp/src/Protocol/TProtocolUtil.cs
index 91140d3..0932a7f 100644
--- a/lib/csharp/src/Protocol/TProtocolUtil.cs
+++ b/lib/csharp/src/Protocol/TProtocolUtil.cs
@@ -29,69 +29,78 @@
{
public static void Skip(TProtocol prot, TType type)
{
- switch (type)
+ prot.IncrementRecursionDepth();
+ try
{
- case TType.Bool:
- prot.ReadBool();
- break;
- case TType.Byte:
- prot.ReadByte();
- break;
- case TType.I16:
- prot.ReadI16();
- break;
- case TType.I32:
- prot.ReadI32();
- break;
- case TType.I64:
- prot.ReadI64();
- break;
- case TType.Double:
- prot.ReadDouble();
- break;
- case TType.String:
- // Don't try to decode the string, just skip it.
- prot.ReadBinary();
- break;
- case TType.Struct:
- prot.ReadStructBegin();
- while (true)
- {
- TField field = prot.ReadFieldBegin();
- if (field.Type == TType.Stop)
+ switch (type)
+ {
+ case TType.Bool:
+ prot.ReadBool();
+ break;
+ case TType.Byte:
+ prot.ReadByte();
+ break;
+ case TType.I16:
+ prot.ReadI16();
+ break;
+ case TType.I32:
+ prot.ReadI32();
+ break;
+ case TType.I64:
+ prot.ReadI64();
+ break;
+ case TType.Double:
+ prot.ReadDouble();
+ break;
+ case TType.String:
+ // Don't try to decode the string, just skip it.
+ prot.ReadBinary();
+ break;
+ case TType.Struct:
+ prot.ReadStructBegin();
+ while (true)
{
- break;
+ TField field = prot.ReadFieldBegin();
+ if (field.Type == TType.Stop)
+ {
+ break;
+ }
+ Skip(prot, field.Type);
+ prot.ReadFieldEnd();
}
- Skip(prot, field.Type);
- prot.ReadFieldEnd();
- }
- prot.ReadStructEnd();
- break;
- case TType.Map:
- TMap map = prot.ReadMapBegin();
- for (int i = 0; i < map.Count; i++)
- {
- Skip(prot, map.KeyType);
- Skip(prot, map.ValueType);
- }
- prot.ReadMapEnd();
- break;
- case TType.Set:
- TSet set = prot.ReadSetBegin();
- for (int i = 0; i < set.Count; i++)
- {
- Skip(prot, set.ElementType);
- }
- prot.ReadSetEnd();
- break;
- case TType.List:
- TList list = prot.ReadListBegin();
- for (int i = 0; i < list.Count; i++)
- {
- Skip(prot, list.ElementType);
- }
- prot.ReadListEnd();
- break;
+ prot.ReadStructEnd();
+ break;
+ case TType.Map:
+ TMap map = prot.ReadMapBegin();
+ for (int i = 0; i < map.Count; i++)
+ {
+ Skip(prot, map.KeyType);
+ Skip(prot, map.ValueType);
+ }
+ prot.ReadMapEnd();
+ break;
+ case TType.Set:
+ TSet set = prot.ReadSetBegin();
+ for (int i = 0; i < set.Count; i++)
+ {
+ Skip(prot, set.ElementType);
+ }
+ prot.ReadSetEnd();
+ break;
+ case TType.List:
+ TList list = prot.ReadListBegin();
+ for (int i = 0; i < list.Count; i++)
+ {
+ Skip(prot, list.ElementType);
+ }
+ prot.ReadListEnd();
+ break;
+ }
+
+ }
+ finally
+ {
+ prot.DecrementRecursionDepth();
}
}
}