Thrift: Merging external patch.

Summary:
Merging a patch from Andy Lutomirsky.
- Allow fields to be marked "required" or "optional" (only affects C++).
- Thrift structs now have operator ==.

Reviewed By: mcslee

Test Plan: test/OptionalRequiredTest.cpp

Revert Plan: ok


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665202 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/cpp/src/generate/t_cpp_generator.cc b/compiler/cpp/src/generate/t_cpp_generator.cc
index e0313f7..ec83c89 100644
--- a/compiler/cpp/src/generate/t_cpp_generator.cc
+++ b/compiler/cpp/src/generate/t_cpp_generator.cc
@@ -439,17 +439,26 @@
       declare_field(*m_iter, false, pointers && !(*m_iter)->get_type()->is_xception(), !read) << endl;
   }
   
-  // Isset struct has boolean fields
-  if (members.size() > 0 && (!pointers || read)) {
+  // Isset struct has boolean fields, but only for non-required fields.
+  bool has_nonrequired_fields = false;
+  for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+    if ((*m_iter)->get_req() != t_field::REQUIRED)
+      has_nonrequired_fields = true;
+  }
+
+  if (has_nonrequired_fields && (!pointers || read)) {
     out <<
       endl <<
-      indent() <<"struct __isset {" << endl;
+      indent() << "struct __isset {" << endl;
     indent_up();
       
       indent(out) <<
         "__isset() : ";
       bool first = true;
       for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+        if ((*m_iter)->get_req() == t_field::REQUIRED) {
+          continue;
+        }
         if (first) {
           first = false;
           out <<
@@ -462,8 +471,10 @@
       out << " {}" << endl;
       
       for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
-        indent(out) <<
-          "bool " << (*m_iter)->get_name() << ";" << endl;
+        if ((*m_iter)->get_req() != t_field::REQUIRED) {
+          indent(out) <<
+            "bool " << (*m_iter)->get_name() << ";" << endl;
+        }
       }
 
     indent_down();
@@ -471,8 +482,40 @@
       "} __isset;" << endl;  
   }
 
-  out <<
-    endl;
+  out << endl;
+
+  if (!pointers) {
+    // Generate an equality testing operator.  Make it inline since the compiler
+    // will do a better job than we would when deciding whether to inline it.
+    out <<
+      indent() << "bool operator == (const " << tstruct->get_name() << " & rhs) const" << endl;
+    scope_up(out);
+    for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+      // Most existing Thrift code does not use isset or optional/required,
+      // so we treat "default" fields as required.
+      if ((*m_iter)->get_req() != t_field::OPTIONAL) {
+        out <<
+          indent() << "if (!(" << (*m_iter)->get_name()
+                   << " == rhs." << (*m_iter)->get_name() << "))" << endl <<
+          indent() << "  return false;" << endl;
+      } else {
+        out <<
+          indent() << "if (__isset." << (*m_iter)->get_name()
+                   << " != rhs.__isset." << (*m_iter)->get_name() << ")" << endl <<
+          indent() << "  return false;" << endl <<
+          indent() << "else if (__isset." << (*m_iter)->get_name() << " && !("
+                   << (*m_iter)->get_name() << " == rhs." << (*m_iter)->get_name()
+                   << "))" << endl <<
+          indent() << "  return false;" << endl;
+      }
+    }
+    indent(out) << "return true;" << endl;
+    scope_down(out);
+    out <<
+      indent() << "bool operator != (const " << tstruct->get_name() << " &rhs) const {" << endl <<
+      indent() << "  return !(*this == rhs);" << endl <<
+      indent() << "}" << endl << endl;
+  }
   if (read) {
     out <<
       indent() << "uint32_t read(facebook::thrift::protocol::TProtocol* iprot);" << endl;
@@ -481,8 +524,7 @@
     out <<
       indent() << "uint32_t write(facebook::thrift::protocol::TProtocol* oprot) const;" << endl;
   }
-  out <<
-    endl;
+  out << endl;
 
   indent_down(); 
   indent(out) <<
@@ -515,7 +557,17 @@
     indent() << "int16_t fid;" << endl <<
     endl <<
     indent() << "xfer += iprot->readStructBegin(fname);" << endl <<
+    endl <<
+    indent() << "using facebook::thrift::protocol::TProtocolException;" << endl <<
     endl;
+
+  // Required variables aren't in __isset, so we need tmp vars to check them.
+  for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+    if ((*f_iter)->get_req() == t_field::REQUIRED)
+      indent(out) << "bool isset_" << (*f_iter)->get_name() << " = false;" << endl;
+  }
+  out << endl;
+  
   
   // Loop over reading in fields
   indent(out) <<
@@ -547,17 +599,33 @@
           "if (ftype == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl;
         indent_up();
 
+        const char *isset_prefix =
+          ((*f_iter)->get_req() != t_field::REQUIRED) ? "this->__isset." : "isset_";
+
+#if 0
+        // This code throws an exception if the same field is encountered twice.
+        // We've decided to leave it out for performance reasons.
+        // TODO(dreiss): Generate this code and "if" it out to make it easier
+        // for people recompiling thrift to include it.
+        out <<
+          indent() << "if (" << isset_prefix << (*f_iter)->get_name() << ")" << endl <<
+          indent() << "  throw TProtocolException(TProtocolException::INVALID_DATA);" << endl;
+#endif
+
         if (pointers && !(*f_iter)->get_type()->is_xception()) {
           generate_deserialize_field(out, *f_iter, "(*(this->", "))");
         } else {
           generate_deserialize_field(out, *f_iter, "this->");
         }
         out <<
-          indent() << "this->__isset." << (*f_iter)->get_name() << " = true;" << endl;
+          indent() << isset_prefix << (*f_iter)->get_name() << " = true;" << endl;
         indent_down();
         out <<
           indent() << "} else {" << endl <<
           indent() << "  xfer += iprot->skip(ftype);" << endl <<
+          // TODO(dreiss): Make this an option when thrift structs
+          // have a common base class.
+          // indent() << "  throw TProtocolException(TProtocolException::INVALID_DATA);" << endl <<
           indent() << "}" << endl <<
           indent() << "break;" << endl;
         indent_down();
@@ -579,8 +647,20 @@
 
   out <<
     endl <<
-    indent() << "xfer += iprot->readStructEnd();" << endl <<
-    indent() <<"return xfer;" << endl;
+    indent() << "xfer += iprot->readStructEnd();" << endl;
+
+  // Throw if any required fields are missing.
+  // We do this after reading the struct end so that
+  // there might possibly be a chance of continuing.
+  out << endl;
+  for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+    if ((*f_iter)->get_req() == t_field::REQUIRED)
+      out <<
+        indent() << "if (!isset_" << (*f_iter)->get_name() << ')' << endl <<
+        indent() << "  throw TProtocolException(TProtocolException::INVALID_DATA);" << endl;
+  }
+
+  indent(out) << "return xfer;" << endl;
 
   indent_down();
   indent(out) <<
@@ -610,6 +690,10 @@
   indent(out) <<
     "xfer += oprot->writeStructBegin(\"" << name << "\");" << endl;
   for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+    if ((*f_iter)->get_req() == t_field::OPTIONAL) {
+      indent(out) << "if (this->__isset." << (*f_iter)->get_name() << ") {" << endl;
+      indent_up();
+    }
     // Write field header
     out <<
       indent() << "xfer += oprot->writeFieldBegin(" <<
@@ -625,7 +709,12 @@
     // Write field closer
     indent(out) <<
       "xfer += oprot->writeFieldEnd();" << endl;
+    if ((*f_iter)->get_req() == t_field::OPTIONAL) {
+      indent_down();
+      indent(out) << '}' << endl;
+    }
   }
+
   // Write the struct map
   out <<
     indent() << "xfer += oprot->writeFieldStop();" << endl <<
diff --git a/compiler/cpp/src/parse/t_field.h b/compiler/cpp/src/parse/t_field.h
index d342bff..a4a2cc7 100644
--- a/compiler/cpp/src/parse/t_field.h
+++ b/compiler/cpp/src/parse/t_field.h
@@ -33,6 +33,7 @@
     type_(type),
     name_(name),
     key_(key),
+    req_(OPT_IN_REQ_OUT),
     value_(NULL),
     xsd_optional_(false),
     xsd_nillable_(false),
@@ -52,6 +53,20 @@
     return key_;
   }
 
+  enum e_req {
+    REQUIRED,
+    OPTIONAL,
+    OPT_IN_REQ_OUT
+  };
+
+  void set_req(e_req req) {
+    req_ = req;
+  }
+
+  e_req get_req() const {
+    return req_;
+  }
+
   void set_value(t_const_value* value) {
     value_ = value;
   }
@@ -101,6 +116,7 @@
   t_type* type_;
   std::string name_;
   int32_t key_;
+  e_req req_;
   t_const_value* value_;
 
   bool xsd_optional_;
diff --git a/compiler/cpp/src/thriftl.ll b/compiler/cpp/src/thriftl.ll
index 4612858..7693f33 100644
--- a/compiler/cpp/src/thriftl.ll
+++ b/compiler/cpp/src/thriftl.ll
@@ -103,6 +103,8 @@
 "service"        { return tok_service;        }
 "enum"           { return tok_enum;           }
 "const"          { return tok_const;          }
+"required"       { return tok_required;       }
+"optional"       { return tok_optional;       }
 
 "abstract" { thrift_reserved_keyword(yytext); }
 "and" { thrift_reserved_keyword(yytext); }
diff --git a/compiler/cpp/src/thrifty.yy b/compiler/cpp/src/thrifty.yy
index c8be3f0..e159ef0 100644
--- a/compiler/cpp/src/thrifty.yy
+++ b/compiler/cpp/src/thrifty.yy
@@ -50,6 +50,7 @@
   t_function*    tfunction;
   t_field*       tfield;
   char*          dtext;
+  t_field::e_req ereq;
 }
 
 /**
@@ -120,6 +121,8 @@
 %token tok_service
 %token tok_enum
 %token tok_const
+%token tok_required
+%token tok_optional
 
 /**
  * Grammar nodes
@@ -139,6 +142,7 @@
 
 %type<tfield>    Field
 %type<iconst>    FieldIdentifier
+%type<ereq>      FieldRequiredness
 %type<ttype>     FieldType
 %type<tconstv>   FieldValue
 %type<tstruct>   FieldList
@@ -723,24 +727,25 @@
     }
 
 Field:
-  CaptureDocText FieldIdentifier FieldType tok_identifier FieldValue XsdOptional XsdNillable XsdAttributes CommaOrSemicolonOptional
+  CaptureDocText FieldIdentifier FieldRequiredness FieldType tok_identifier FieldValue XsdOptional XsdNillable XsdAttributes CommaOrSemicolonOptional
     {
       pdebug("tok_int_constant : Field -> FieldType tok_identifier");
       if ($2 < 0) {
-        pwarning(2, "No field key specified for %s, resulting protocol may have conflicts or not be backwards compatible!\n", $4);
+        pwarning(2, "No field key specified for %s, resulting protocol may have conflicts or not be backwards compatible!\n", $5);
       }
-      $$ = new t_field($3, $4, $2);
-      if ($5 != NULL) {
-        validate_field_value($$, $5);
-        $$->set_value($5);
+      $$ = new t_field($4, $5, $2);
+      $$->set_req($3);
+      if ($6 != NULL) {
+        validate_field_value($$, $6);
+        $$->set_value($6);
       }
-      $$->set_xsd_optional($6);
-      $$->set_xsd_nillable($7);
+      $$->set_xsd_optional($7);
+      $$->set_xsd_nillable($8);
       if ($1 != NULL) {
         $$->set_doc($1);
       }
-      if ($8 != NULL) {
-        $$->set_xsd_attrs($8);
+      if ($9 != NULL) {
+        $$->set_xsd_attrs($9);
       }
     }
 
@@ -758,6 +763,20 @@
       $$ = y_field_val--;
     }
 
+FieldRequiredness:
+  tok_required
+    {
+      $$ = t_field::REQUIRED;
+    }
+| tok_optional
+    {
+      $$ = t_field::OPTIONAL;
+    }
+|
+    {
+      $$ = t_field::OPT_IN_REQ_OUT;
+    }
+
 FieldValue:
   '=' ConstValue
     {