THRIFT-2471 Make cpp.ref annotation language agnostic
Client: compiler general
Patch: Dave Watson

This closes #113

commit 52b99af4ee1574253dcb77933d76a7ebb2d830df
 Author: Dave Watson <davejwatson@fb.com>
 Date: 2014-04-23T20:05:56Z

change cpp.ref to &

commit 3f9d31cc6140367529fd8f7b1b67056ec321786f
 Author: Dave Watson <davejwatson@fb.com>
 Date: 2014-04-23T21:50:29Z

Recursion depth limit

commit 61468e4534ce9e6a4f4f643bfd00542d13600d83
 Author: Dave Watson <davejwatson@fb.com>
 Date: 2014-04-25T19:59:18Z

shared_ptr for reference type
diff --git a/compiler/cpp/src/generate/t_cpp_generator.cc b/compiler/cpp/src/generate/t_cpp_generator.cc
index 92eab2f..f985492 100755
--- a/compiler/cpp/src/generate/t_cpp_generator.cc
+++ b/compiler/cpp/src/generate/t_cpp_generator.cc
@@ -235,7 +235,7 @@
   void generate_local_reflection_pointer(std::ofstream& out, t_type* ttype);
 
   bool is_reference(t_field* tfield) {
-    return tfield->annotations_.count("cpp.ref") != 0;
+    return tfield->get_reference();
   }
 
   bool is_complex_type(t_type* ttype) {
@@ -832,14 +832,8 @@
   const vector<t_field*>& members = tstruct->get_members();
   vector<t_field*>::const_iterator f_iter;
   for (f_iter = members.begin(); f_iter != members.end(); ++f_iter) {
-    if (is_reference(*f_iter)) {
-      std::string type = type_name((*f_iter)->get_type());
-      indent(out) << (*f_iter)->get_name() << " = new " << type << "(*" << tmp_name << "." <<
-        (*f_iter)->get_name() << ");" << endl;
-    } else {
-      indent(out) << (*f_iter)->get_name() << " = " << tmp_name << "." <<
-        (*f_iter)->get_name() << ";" << endl;
-    }
+    indent(out) << (*f_iter)->get_name() << " = " << tmp_name << "." <<
+      (*f_iter)->get_name() << ";" << endl;
   }
 
   indent_down();
@@ -859,20 +853,8 @@
   const vector<t_field*>& members = tstruct->get_members();
   vector<t_field*>::const_iterator f_iter;
   for (f_iter = members.begin(); f_iter != members.end(); ++f_iter) {
-    if (is_reference(*f_iter)) {
-      std::string type = type_name((*f_iter)->get_type());
-      indent(out) << "if (this == &" << tmp_name << ") return *this;" << endl;
-      indent(out) << "if (" << (*f_iter)->get_name() << ") {" << endl;
-      indent(out) << "  *" << (*f_iter)->get_name() << " = *" << tmp_name << "." << 
-        (*f_iter)->get_name() << ";" << endl;
-      indent(out) << "} else {" << endl;
-      indent(out) << "  " << (*f_iter)->get_name() << " = new " << type << "(*" << tmp_name << "." <<
-        (*f_iter)->get_name() << ");" << endl;
-      indent(out) << "}" << endl;
-    } else {
-      indent(out) << (*f_iter)->get_name() << " = " << tmp_name << "." <<
-        (*f_iter)->get_name() << ";" << endl;
-    }
+    indent(out) << (*f_iter)->get_name() << " = " << tmp_name << "." <<
+      (*f_iter)->get_name() << ";" << endl;
   }
 
   indent(out) << "return *this;" << endl;
@@ -1031,7 +1013,7 @@
   // Declare all fields
   for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
     indent(out) <<
-      declare_field(*m_iter, false, (pointers && !(*m_iter)->get_type()->is_xception()) || is_reference(*m_iter), !read) << endl;
+      declare_field(*m_iter, false, (pointers && !(*m_iter)->get_type()->is_xception()), !read) << endl;
   }
 
   // Add the __isset data member if we need it, using the definition from above
@@ -1046,11 +1028,19 @@
     if (pointers) {
       continue;
     }
-    out <<
-      endl <<
-      indent() << "void __set_" << (*m_iter)->get_name() <<
+    if (is_reference((*m_iter))) {
+      out <<
+	endl <<
+	indent() << "void __set_" << (*m_iter)->get_name() <<
+        "(boost::shared_ptr<" << type_name((*m_iter)->get_type(), false, false) << ">";
+      out << " val);" << endl;
+    } else {
+      out <<
+	endl <<
+	indent() << "void __set_" << (*m_iter)->get_name() <<
         "(" << type_name((*m_iter)->get_type(), false, true);
-    out << " val);" << endl;
+      out << " val);" << endl;
+    }
   }
   out << endl;
 
@@ -1151,13 +1141,6 @@
       indent() << tstruct->get_name() << "::~" << tstruct->get_name() << "() throw() {" << endl;
     indent_up();
 
-    for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
-      if (is_reference(*m_iter)) {
-        out << indent() <<
-          "delete " << (*m_iter)->get_name() << ";" << endl;
-      }
-    }    
-
     indent_down();
     out << indent() << "}" << endl << endl;
   }
@@ -1165,22 +1148,22 @@
   // Create a setter function for each field
   if (setters) {
     for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
-      out <<
-        endl <<
-        indent() << "void " << tstruct->get_name() << "::__set_" << (*m_iter)->get_name() <<
-        "(" << type_name((*m_iter)->get_type(), false, true);
-      out << " val) {" << endl;
-      indent_up();
       if (is_reference((*m_iter))) {
         std::string type = type_name((*m_iter)->get_type());
-        indent(out) << "if (" << (*m_iter)->get_name() << ") {" << endl;
-        indent(out) << "  *" << (*m_iter)->get_name() << " = val;" << endl;
-        indent(out) << "} else {" << endl;
-        indent(out) << "  " << (*m_iter)->get_name() << " = new " << type << "(val);" << endl;
-        indent(out) << "}" << endl;
+	out <<
+	  endl <<
+	  indent() << "void " << tstruct->get_name() << "::__set_" << (*m_iter)->get_name() <<
+	  "(boost::shared_ptr<" << type_name((*m_iter)->get_type(), false, false) << ">";
+	out << " val) {" << endl;
       } else {
-        out << indent() << (*m_iter)->get_name() << " = val;" << endl;
+	out <<
+	  endl <<
+	  indent() << "void " << tstruct->get_name() << "::__set_" << (*m_iter)->get_name() <<
+	  "(" << type_name((*m_iter)->get_type(), false, true);
+	out << " val) {" << endl;
       }
+      indent_up();
+      out << indent() << (*m_iter)->get_name() << " = val;" << endl;
       indent_down();
 
       // assume all fields are required except optional fields.
@@ -1545,6 +1528,7 @@
   out <<
     indent() << "uint32_t xfer = 0;" << endl;
 
+  indent(out) << "oprot->incrementRecursionDepth();" << endl;
   indent(out) <<
     "xfer += oprot->writeStructBegin(\"" << name << "\");" << endl;
 
@@ -1585,6 +1569,7 @@
   out <<
     indent() << "xfer += oprot->writeFieldStop();" << endl <<
     indent() << "xfer += oprot->writeStructEnd();" << endl <<
+    indent() << "oprot->decrementRecursionDepth();" << endl <<
     indent() << "return xfer;" << endl;
 
   indent_down();
@@ -4078,7 +4063,7 @@
                                                   bool pointer) {
   if (pointer) {
     indent(out) << "if (!" << prefix << ") { " << endl;
-    indent(out) << "  " << prefix << " = new " << type_name(tstruct) << ";" << endl;
+    indent(out) << "  " << prefix << " = boost::shared_ptr<" << type_name(tstruct) << ">(new " << type_name(tstruct) << ");" << endl;
     indent(out) << "}" << endl;
     indent(out) <<
       "xfer += " << prefix << "->read(iprot);" << endl;
@@ -4614,6 +4599,9 @@
     result += "const ";
   }
   result += type_name(tfield->get_type());
+  if (is_reference(tfield)) {
+    result = "boost::shared_ptr<" + result + ">";
+  }
   if (pointer) {
     result += "*";
   }
diff --git a/compiler/cpp/src/parse/t_field.h b/compiler/cpp/src/parse/t_field.h
index 7bbcc0f..c05fdf3 100644
--- a/compiler/cpp/src/parse/t_field.h
+++ b/compiler/cpp/src/parse/t_field.h
@@ -42,7 +42,8 @@
     value_(NULL),
     xsd_optional_(false),
     xsd_nillable_(false),
-    xsd_attrs_(NULL) {}
+    xsd_attrs_(NULL),
+    reference_(false) {}
 
   t_field(t_type* type, std::string name, int32_t key) :
     type_(type),
@@ -52,7 +53,8 @@
     value_(NULL),
     xsd_optional_(false),
     xsd_nillable_(false),
-    xsd_attrs_(NULL) {}
+    xsd_attrs_(NULL),
+    reference_(false) {}
 
   ~t_field() {}
 
@@ -137,6 +139,14 @@
 
   std::map<std::string, std::string> annotations_;
 
+  bool get_reference() {
+    return reference_;
+  }
+
+  void set_reference(bool reference) {
+    reference_ = reference;
+  }
+
  private:
   t_type* type_;
   std::string name_;
@@ -147,7 +157,7 @@
   bool xsd_optional_;
   bool xsd_nillable_;
   t_struct* xsd_attrs_;
-
+  bool reference_;
 };
 
 /**
diff --git a/compiler/cpp/src/thriftl.ll b/compiler/cpp/src/thriftl.ll
index 685bb54..aee4406 100644
--- a/compiler/cpp/src/thriftl.ll
+++ b/compiler/cpp/src/thriftl.ll
@@ -183,6 +183,7 @@
   pwarning(0, "\"async\" is deprecated.  It is called \"oneway\" now.\n");
   return tok_oneway;
 }
+"&"                  { return tok_reference;            }
 
 
 "BEGIN"              { thrift_reserved_keyword(yytext); }
diff --git a/compiler/cpp/src/thrifty.yy b/compiler/cpp/src/thrifty.yy
index edb05f9..62e13ba 100644
--- a/compiler/cpp/src/thrifty.yy
+++ b/compiler/cpp/src/thrifty.yy
@@ -165,6 +165,7 @@
 %token tok_required
 %token tok_optional
 %token tok_union
+%token tok_reference
 
 /**
  * Grammar nodes
@@ -193,6 +194,7 @@
 %type<ttype>     FieldType
 %type<tconstv>   FieldValue
 %type<tstruct>   FieldList
+%type<tbool>     FieldReference
 
 %type<tenum>     Enum
 %type<tenum>     EnumDefList
@@ -955,35 +957,36 @@
     }
 
 Field:
-  CaptureDocText FieldIdentifier FieldRequiredness FieldType tok_identifier FieldValue XsdOptional XsdNillable XsdAttributes TypeAnnotations CommaOrSemicolonOptional
+  CaptureDocText FieldIdentifier FieldRequiredness FieldType FieldReference tok_identifier FieldValue XsdOptional XsdNillable XsdAttributes TypeAnnotations CommaOrSemicolonOptional
     {
       pdebug("tok_int_constant : Field -> FieldType tok_identifier");
       if ($2.auto_assigned) {
-        pwarning(1, "No field key specified for %s, resulting protocol may have conflicts or not be backwards compatible!\n", $5);
+        pwarning(1, "No field key specified for %s, resulting protocol may have conflicts or not be backwards compatible!\n", $6);
         if (g_strict >= 192) {
           yyerror("Implicit field keys are deprecated and not allowed with -strict");
           exit(1);
         }
       }
-      validate_simple_identifier($5);
-      $$ = new t_field($4, $5, $2.value);
+      validate_simple_identifier($6);
+      $$ = new t_field($4, $6, $2.value);
+      $$->set_reference($5);
       $$->set_req($3);
-      if ($6 != NULL) {
-        g_scope->resolve_const_value($6, $4);
-        validate_field_value($$, $6);
-        $$->set_value($6);
+      if ($7 != NULL) {
+        g_scope->resolve_const_value($7, $4);
+        validate_field_value($$, $7);
+        $$->set_value($7);
       }
-      $$->set_xsd_optional($7);
-      $$->set_xsd_nillable($8);
+      $$->set_xsd_optional($8);
+      $$->set_xsd_nillable($9);
       if ($1 != NULL) {
         $$->set_doc($1);
       }
-      if ($9 != NULL) {
-        $$->set_xsd_attrs($9);
-      }
       if ($10 != NULL) {
-        $$->annotations_ = $10->annotations_;
-        delete $10;
+        $$->set_xsd_attrs($10);
+      }
+      if ($11 != NULL) {
+        $$->annotations_ = $11->annotations_;
+        delete $11;
       }
     }
 
@@ -1029,6 +1032,16 @@
       $$.auto_assigned = true;
     }
 
+FieldReference:
+  tok_reference
+    {
+      $$ = true;
+    }
+|
+   {
+     $$ = false;
+   }
+
 FieldRequiredness:
   tok_required
     {
diff --git a/lib/cpp/src/thrift/protocol/TProtocol.h b/lib/cpp/src/thrift/protocol/TProtocol.h
index e72033a..e8ba429 100644
--- a/lib/cpp/src/thrift/protocol/TProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TProtocol.h
@@ -283,6 +283,8 @@
   return 0;
 }
 
+static const uint32_t DEFAULT_RECURSION_LIMIT = 64;
+
 /**
  * Abstract class for a thrift protocol driver. These are all the methods that
  * a protocol must implement. Essentially, there must be some way of reading
@@ -660,15 +662,28 @@
     return ptrans_;
   }
 
- protected:
-  TProtocol(boost::shared_ptr<TTransport> ptrans):
-    ptrans_(ptrans) {
+  void incrementRecursionDepth() {
+    if (recursion_limit_ < ++recursion_depth_) {
+      throw TProtocolException(TProtocolException::DEPTH_LIMIT);
+    }
   }
 
+  void decrementRecursionDepth() {
+    --recursion_depth_;
+  }
+
+ protected:
+  TProtocol(boost::shared_ptr<TTransport> ptrans)
+    : ptrans_(ptrans) 
+    , recursion_depth_(0)
+    , recursion_limit_(DEFAULT_RECURSION_LIMIT) {}
+
   boost::shared_ptr<TTransport> ptrans_;
 
  private:
   TProtocol() {}
+  uint32_t recursion_depth_;
+  uint32_t recursion_limit_;
 };
 
 /**
diff --git a/lib/cpp/src/thrift/protocol/TProtocolException.h b/lib/cpp/src/thrift/protocol/TProtocolException.h
index a03d3c8..4ddb81e 100644
--- a/lib/cpp/src/thrift/protocol/TProtocolException.h
+++ b/lib/cpp/src/thrift/protocol/TProtocolException.h
@@ -45,6 +45,7 @@
   , SIZE_LIMIT = 3
   , BAD_VERSION = 4
   , NOT_IMPLEMENTED = 5
+  , DEPTH_LIMIT = 6
   };
 
   TProtocolException() :
diff --git a/lib/cpp/test/RecursiveTest.cpp b/lib/cpp/test/RecursiveTest.cpp
index 00610c6..24c0f7c 100644
--- a/lib/cpp/test/RecursiveTest.cpp
+++ b/lib/cpp/test/RecursiveTest.cpp
@@ -44,7 +44,7 @@
   assert(tree == result);
 
   RecList l;
-  RecList* l2(new RecList);
+  boost::shared_ptr<RecList> l2(new RecList);
   l.nextitem = l2;
 
   l.write(prot.get());
@@ -55,7 +55,7 @@
   assert(resultlist.nextitem->nextitem == NULL);
 
   CoRec c;
-  CoRec2* r(new CoRec2);
+  boost::shared_ptr<CoRec2> r(new CoRec2);
   c.other = r;
 
   c.write(prot.get());
@@ -64,4 +64,12 @@
   assert(c.other != NULL);
   assert(c.other->other.other == NULL);
 
+  boost::shared_ptr<RecList> depthLimit(new RecList);
+  depthLimit->nextitem = depthLimit;
+  try {
+    depthLimit->write(prot.get());
+    assert(false);
+  } catch (const apache::thrift::protocol::TProtocolException& e) {
+  }
+
 }
diff --git a/test/Recursive.thrift b/test/Recursive.thrift
index c55541b..9c29983 100644
--- a/test/Recursive.thrift
+++ b/test/Recursive.thrift
@@ -23,12 +23,12 @@
 }
 
 struct RecList {
-  1: RecList nextitem (cpp.ref = "true")
+  1: RecList & nextitem 
   3: i16 item
 }
 
 struct CoRec {
-  1:  CoRec2  other (cpp.ref = "true")
+  1:  CoRec2 & other 
 }
 
 struct CoRec2 {