Thrift: Merging external patch.

Summary:
Merging a patch from Andy Lutomirsky.
- Allow fields to be marked "required" or "optional" (only affects C++).
- Thrift structs now have operator ==.

Reviewed By: mcslee

Test Plan: test/OptionalRequiredTest.cpp

Revert Plan: ok


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665202 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/cpp/src/generate/t_cpp_generator.cc b/compiler/cpp/src/generate/t_cpp_generator.cc
index e0313f7..ec83c89 100644
--- a/compiler/cpp/src/generate/t_cpp_generator.cc
+++ b/compiler/cpp/src/generate/t_cpp_generator.cc
@@ -439,17 +439,26 @@
       declare_field(*m_iter, false, pointers && !(*m_iter)->get_type()->is_xception(), !read) << endl;
   }
   
-  // Isset struct has boolean fields
-  if (members.size() > 0 && (!pointers || read)) {
+  // Isset struct has boolean fields, but only for non-required fields.
+  bool has_nonrequired_fields = false;
+  for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+    if ((*m_iter)->get_req() != t_field::REQUIRED)
+      has_nonrequired_fields = true;
+  }
+
+  if (has_nonrequired_fields && (!pointers || read)) {
     out <<
       endl <<
-      indent() <<"struct __isset {" << endl;
+      indent() << "struct __isset {" << endl;
     indent_up();
       
       indent(out) <<
         "__isset() : ";
       bool first = true;
       for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+        if ((*m_iter)->get_req() == t_field::REQUIRED) {
+          continue;
+        }
         if (first) {
           first = false;
           out <<
@@ -462,8 +471,10 @@
       out << " {}" << endl;
       
       for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
-        indent(out) <<
-          "bool " << (*m_iter)->get_name() << ";" << endl;
+        if ((*m_iter)->get_req() != t_field::REQUIRED) {
+          indent(out) <<
+            "bool " << (*m_iter)->get_name() << ";" << endl;
+        }
       }
 
     indent_down();
@@ -471,8 +482,40 @@
       "} __isset;" << endl;  
   }
 
-  out <<
-    endl;
+  out << endl;
+
+  if (!pointers) {
+    // Generate an equality testing operator.  Make it inline since the compiler
+    // will do a better job than we would when deciding whether to inline it.
+    out <<
+      indent() << "bool operator == (const " << tstruct->get_name() << " & rhs) const" << endl;
+    scope_up(out);
+    for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+      // Most existing Thrift code does not use isset or optional/required,
+      // so we treat "default" fields as required.
+      if ((*m_iter)->get_req() != t_field::OPTIONAL) {
+        out <<
+          indent() << "if (!(" << (*m_iter)->get_name()
+                   << " == rhs." << (*m_iter)->get_name() << "))" << endl <<
+          indent() << "  return false;" << endl;
+      } else {
+        out <<
+          indent() << "if (__isset." << (*m_iter)->get_name()
+                   << " != rhs.__isset." << (*m_iter)->get_name() << ")" << endl <<
+          indent() << "  return false;" << endl <<
+          indent() << "else if (__isset." << (*m_iter)->get_name() << " && !("
+                   << (*m_iter)->get_name() << " == rhs." << (*m_iter)->get_name()
+                   << "))" << endl <<
+          indent() << "  return false;" << endl;
+      }
+    }
+    indent(out) << "return true;" << endl;
+    scope_down(out);
+    out <<
+      indent() << "bool operator != (const " << tstruct->get_name() << " &rhs) const {" << endl <<
+      indent() << "  return !(*this == rhs);" << endl <<
+      indent() << "}" << endl << endl;
+  }
   if (read) {
     out <<
       indent() << "uint32_t read(facebook::thrift::protocol::TProtocol* iprot);" << endl;
@@ -481,8 +524,7 @@
     out <<
       indent() << "uint32_t write(facebook::thrift::protocol::TProtocol* oprot) const;" << endl;
   }
-  out <<
-    endl;
+  out << endl;
 
   indent_down(); 
   indent(out) <<
@@ -515,7 +557,17 @@
     indent() << "int16_t fid;" << endl <<
     endl <<
     indent() << "xfer += iprot->readStructBegin(fname);" << endl <<
+    endl <<
+    indent() << "using facebook::thrift::protocol::TProtocolException;" << endl <<
     endl;
+
+  // Required variables aren't in __isset, so we need tmp vars to check them.
+  for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+    if ((*f_iter)->get_req() == t_field::REQUIRED)
+      indent(out) << "bool isset_" << (*f_iter)->get_name() << " = false;" << endl;
+  }
+  out << endl;
+  
   
   // Loop over reading in fields
   indent(out) <<
@@ -547,17 +599,33 @@
           "if (ftype == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl;
         indent_up();
 
+        const char *isset_prefix =
+          ((*f_iter)->get_req() != t_field::REQUIRED) ? "this->__isset." : "isset_";
+
+#if 0
+        // This code throws an exception if the same field is encountered twice.
+        // We've decided to leave it out for performance reasons.
+        // TODO(dreiss): Generate this code and "if" it out to make it easier
+        // for people recompiling thrift to include it.
+        out <<
+          indent() << "if (" << isset_prefix << (*f_iter)->get_name() << ")" << endl <<
+          indent() << "  throw TProtocolException(TProtocolException::INVALID_DATA);" << endl;
+#endif
+
         if (pointers && !(*f_iter)->get_type()->is_xception()) {
           generate_deserialize_field(out, *f_iter, "(*(this->", "))");
         } else {
           generate_deserialize_field(out, *f_iter, "this->");
         }
         out <<
-          indent() << "this->__isset." << (*f_iter)->get_name() << " = true;" << endl;
+          indent() << isset_prefix << (*f_iter)->get_name() << " = true;" << endl;
         indent_down();
         out <<
           indent() << "} else {" << endl <<
           indent() << "  xfer += iprot->skip(ftype);" << endl <<
+          // TODO(dreiss): Make this an option when thrift structs
+          // have a common base class.
+          // indent() << "  throw TProtocolException(TProtocolException::INVALID_DATA);" << endl <<
           indent() << "}" << endl <<
           indent() << "break;" << endl;
         indent_down();
@@ -579,8 +647,20 @@
 
   out <<
     endl <<
-    indent() << "xfer += iprot->readStructEnd();" << endl <<
-    indent() <<"return xfer;" << endl;
+    indent() << "xfer += iprot->readStructEnd();" << endl;
+
+  // Throw if any required fields are missing.
+  // We do this after reading the struct end so that
+  // there might possibly be a chance of continuing.
+  out << endl;
+  for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+    if ((*f_iter)->get_req() == t_field::REQUIRED)
+      out <<
+        indent() << "if (!isset_" << (*f_iter)->get_name() << ')' << endl <<
+        indent() << "  throw TProtocolException(TProtocolException::INVALID_DATA);" << endl;
+  }
+
+  indent(out) << "return xfer;" << endl;
 
   indent_down();
   indent(out) <<
@@ -610,6 +690,10 @@
   indent(out) <<
     "xfer += oprot->writeStructBegin(\"" << name << "\");" << endl;
   for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+    if ((*f_iter)->get_req() == t_field::OPTIONAL) {
+      indent(out) << "if (this->__isset." << (*f_iter)->get_name() << ") {" << endl;
+      indent_up();
+    }
     // Write field header
     out <<
       indent() << "xfer += oprot->writeFieldBegin(" <<
@@ -625,7 +709,12 @@
     // Write field closer
     indent(out) <<
       "xfer += oprot->writeFieldEnd();" << endl;
+    if ((*f_iter)->get_req() == t_field::OPTIONAL) {
+      indent_down();
+      indent(out) << '}' << endl;
+    }
   }
+
   // Write the struct map
   out <<
     indent() << "xfer += oprot->writeFieldStop();" << endl <<
diff --git a/compiler/cpp/src/parse/t_field.h b/compiler/cpp/src/parse/t_field.h
index d342bff..a4a2cc7 100644
--- a/compiler/cpp/src/parse/t_field.h
+++ b/compiler/cpp/src/parse/t_field.h
@@ -33,6 +33,7 @@
     type_(type),
     name_(name),
     key_(key),
+    req_(OPT_IN_REQ_OUT),
     value_(NULL),
     xsd_optional_(false),
     xsd_nillable_(false),
@@ -52,6 +53,20 @@
     return key_;
   }
 
+  enum e_req {
+    REQUIRED,
+    OPTIONAL,
+    OPT_IN_REQ_OUT
+  };
+
+  void set_req(e_req req) {
+    req_ = req;
+  }
+
+  e_req get_req() const {
+    return req_;
+  }
+
   void set_value(t_const_value* value) {
     value_ = value;
   }
@@ -101,6 +116,7 @@
   t_type* type_;
   std::string name_;
   int32_t key_;
+  e_req req_;
   t_const_value* value_;
 
   bool xsd_optional_;
diff --git a/compiler/cpp/src/thriftl.ll b/compiler/cpp/src/thriftl.ll
index 4612858..7693f33 100644
--- a/compiler/cpp/src/thriftl.ll
+++ b/compiler/cpp/src/thriftl.ll
@@ -103,6 +103,8 @@
 "service"        { return tok_service;        }
 "enum"           { return tok_enum;           }
 "const"          { return tok_const;          }
+"required"       { return tok_required;       }
+"optional"       { return tok_optional;       }
 
 "abstract" { thrift_reserved_keyword(yytext); }
 "and" { thrift_reserved_keyword(yytext); }
diff --git a/compiler/cpp/src/thrifty.yy b/compiler/cpp/src/thrifty.yy
index c8be3f0..e159ef0 100644
--- a/compiler/cpp/src/thrifty.yy
+++ b/compiler/cpp/src/thrifty.yy
@@ -50,6 +50,7 @@
   t_function*    tfunction;
   t_field*       tfield;
   char*          dtext;
+  t_field::e_req ereq;
 }
 
 /**
@@ -120,6 +121,8 @@
 %token tok_service
 %token tok_enum
 %token tok_const
+%token tok_required
+%token tok_optional
 
 /**
  * Grammar nodes
@@ -139,6 +142,7 @@
 
 %type<tfield>    Field
 %type<iconst>    FieldIdentifier
+%type<ereq>      FieldRequiredness
 %type<ttype>     FieldType
 %type<tconstv>   FieldValue
 %type<tstruct>   FieldList
@@ -723,24 +727,25 @@
     }
 
 Field:
-  CaptureDocText FieldIdentifier FieldType tok_identifier FieldValue XsdOptional XsdNillable XsdAttributes CommaOrSemicolonOptional
+  CaptureDocText FieldIdentifier FieldRequiredness FieldType tok_identifier FieldValue XsdOptional XsdNillable XsdAttributes CommaOrSemicolonOptional
     {
       pdebug("tok_int_constant : Field -> FieldType tok_identifier");
       if ($2 < 0) {
-        pwarning(2, "No field key specified for %s, resulting protocol may have conflicts or not be backwards compatible!\n", $4);
+        pwarning(2, "No field key specified for %s, resulting protocol may have conflicts or not be backwards compatible!\n", $5);
       }
-      $$ = new t_field($3, $4, $2);
-      if ($5 != NULL) {
-        validate_field_value($$, $5);
-        $$->set_value($5);
+      $$ = new t_field($4, $5, $2);
+      $$->set_req($3);
+      if ($6 != NULL) {
+        validate_field_value($$, $6);
+        $$->set_value($6);
       }
-      $$->set_xsd_optional($6);
-      $$->set_xsd_nillable($7);
+      $$->set_xsd_optional($7);
+      $$->set_xsd_nillable($8);
       if ($1 != NULL) {
         $$->set_doc($1);
       }
-      if ($8 != NULL) {
-        $$->set_xsd_attrs($8);
+      if ($9 != NULL) {
+        $$->set_xsd_attrs($9);
       }
     }
 
@@ -758,6 +763,20 @@
       $$ = y_field_val--;
     }
 
+FieldRequiredness:
+  tok_required
+    {
+      $$ = t_field::REQUIRED;
+    }
+| tok_optional
+    {
+      $$ = t_field::OPTIONAL;
+    }
+|
+    {
+      $$ = t_field::OPT_IN_REQ_OUT;
+    }
+
 FieldValue:
   '=' ConstValue
     {
diff --git a/test/OptionalRequiredTest.cpp b/test/OptionalRequiredTest.cpp
new file mode 100644
index 0000000..73574ee
--- /dev/null
+++ b/test/OptionalRequiredTest.cpp
@@ -0,0 +1,231 @@
+/*
+../compiler/cpp/thrift -cpp OptionalRequiredTest.thrift
+g++ -Wall -I../lib/cpp/src -I/usr/local/include/boost-1_33_1 \
+  OptionalRequiredTest.cpp gen-cpp/OptionalRequiredTest_types.cpp \
+  ../lib/cpp/.libs/libthrift.a -o OptionalRequiredTest
+./OptionalRequiredTest
+*/
+
+#include <cassert>
+#include <map>
+#include <iostream>
+#include <protocol/TDebugProtocol.h>
+#include <protocol/TBinaryProtocol.h>
+#include <transport/TTransportUtils.h>
+#include "gen-cpp/OptionalRequiredTest_types.h"
+
+using std::cout;
+using std::endl;
+using std::map;
+using std::string;
+using namespace thrift::test;
+using namespace facebook::thrift;
+using namespace facebook::thrift::transport;
+using namespace facebook::thrift::protocol;
+
+
+/*
+template<typename Struct>
+void trywrite(const Struct& s, bool should_work) {
+  bool worked;
+  try {
+    TBinaryProtocol protocol(boost::shared_ptr<TTransport>(new TMemoryBuffer));
+    s.write(&protocol);
+    worked = true;
+  } catch (TProtocolException & ex) {
+    worked = false;
+  }
+  assert(worked == should_work);
+}
+*/
+
+template <typename Struct1, typename Struct2>
+void write_to_read(const Struct1 & w, Struct2 & r) {
+  TBinaryProtocol protocol(boost::shared_ptr<TTransport>(new TMemoryBuffer));
+  w.write(&protocol);
+  r.read(&protocol);
+}
+
+
+int main() {
+
+  cout << "This old school struct should have three fields." << endl;
+  {
+    OldSchool o;
+    cout << ThriftDebugString(o) << endl;
+  }
+  cout << endl;
+
+  cout << "Setting a value before setting isset." << endl;
+  {
+    Simple s;
+    cout << ThriftDebugString(s) << endl;
+    s.im_optional = 10;
+    cout << ThriftDebugString(s) << endl;
+    s.__isset.im_optional = true;
+    cout << ThriftDebugString(s) << endl;
+  }
+  cout << endl;
+
+  cout << "Setting isset before setting a value." << endl;
+  {
+    Simple s;
+    cout << ThriftDebugString(s) << endl;
+    s.__isset.im_optional = true;
+    cout << ThriftDebugString(s) << endl;
+    s.im_optional = 10;
+    cout << ThriftDebugString(s) << endl;
+  }
+  cout << endl;
+
+  // Write-to-read with optional fields.
+  {
+    Simple s1, s2, s3;
+    s1.im_optional = 10;
+    assert(!s1.__isset.im_default);
+  //assert(!s1.__isset.im_required);  // Compile error.
+    assert(!s1.__isset.im_optional);
+
+    write_to_read(s1, s2);
+
+    assert( s2.__isset.im_default);
+  //assert( s2.__isset.im_required);  // Compile error.
+    assert(!s2.__isset.im_optional);
+    assert(s3.im_optional == 0);
+
+    s1.__isset.im_optional = true;
+    write_to_read(s1, s3);
+
+    assert( s3.__isset.im_default);
+  //assert( s3.__isset.im_required);  // Compile error.
+    assert( s3.__isset.im_optional);
+    assert(s3.im_optional == 10);
+  }
+
+  // Writing between optional and default.
+  {
+    Tricky1 t1;
+    Tricky2 t2;
+
+    t2.im_optional = 10;
+    write_to_read(t2, t1);
+    write_to_read(t1, t2);
+    assert(!t1.__isset.im_default);
+    assert( t2.__isset.im_optional);
+    assert(t1.im_default == t2.im_optional);
+    assert(t1.im_default == 0);
+  }
+
+  // Writing between default and required.
+  {
+    Tricky1 t1;
+    Tricky3 t3;
+    write_to_read(t1, t3);
+    write_to_read(t3, t1);
+    assert(t1.__isset.im_default);
+  }
+
+  // Writing between optional and required.
+  {
+    Tricky2 t2;
+    Tricky3 t3;
+    t2.__isset.im_optional = true;
+    write_to_read(t2, t3);
+    write_to_read(t3, t2);
+  }
+
+  // Mu-hu-ha-ha-ha!
+  {
+    Tricky2 t2;
+    Tricky3 t3;
+    try {
+      write_to_read(t2, t3);
+      abort();
+    }
+    catch (TProtocolException& ex) {}
+
+    write_to_read(t3, t2);
+    assert(t2.__isset.im_optional);
+  }
+
+  cout << "Complex struct, simple test." << endl;
+  {
+    Complex c;
+    cout << ThriftDebugString(c) << endl;
+  }
+
+
+  {
+    Tricky1 t1;
+    Tricky2 t2;
+    // Compile error.
+    //(void)(t1 == t2);
+  }
+
+  {
+    OldSchool o1, o2, o3;
+    assert(o1 == o2);
+    o1.im_int = o2.im_int = 10;
+    assert(o1 == o2);
+    o1.__isset.im_int = true;
+    o2.__isset.im_int = false;
+    assert(o1 == o2);
+    o1.im_int = 20;
+    o1.__isset.im_int = false;
+    assert(o1 != o2);
+    o1.im_int = 10;
+    assert(o1 == o2);
+    o1.im_str = o2.im_str = "foo";
+    assert(o1 == o2);
+    o1.__isset.im_str = o2.__isset.im_str = true;
+    assert(o1 == o2);
+    map<int32_t,string> mymap;
+    mymap[1] = "bar";
+    mymap[2] = "baz";
+    o1.im_big.push_back(map<int32_t,string>());
+    assert(o1 != o2);
+    o2.im_big.push_back(map<int32_t,string>());
+    assert(o1 == o2);
+    o2.im_big.push_back(mymap);
+    assert(o1 != o2);
+    o1.im_big.push_back(mymap);
+    assert(o1 == o2);
+
+    TBinaryProtocol protocol(boost::shared_ptr<TTransport>(new TMemoryBuffer));
+    o1.write(&protocol);
+
+    o1.im_big.push_back(mymap);
+    mymap[3] = "qux";
+    o2.im_big.push_back(mymap);
+    assert(o1 != o2);
+    o1.im_big.back()[3] = "qux";
+    assert(o1 == o2);
+
+    o3.read(&protocol);
+    o3.im_big.push_back(mymap);
+    assert(o1 == o3);
+
+    //cout << ThriftDebugString(o3) << endl;
+  }
+
+  {
+    Tricky2 t1, t2;
+    assert(t1.__isset.im_optional == false);
+    assert(t2.__isset.im_optional == false);
+    assert(t1 == t2);
+    t1.im_optional = 5;
+    assert(t1 == t2);
+    t2.im_optional = 5;
+    assert(t1 == t2);
+    t1.__isset.im_optional = true;
+    assert(t1 != t2);
+    t2.__isset.im_optional = true;
+    assert(t1 == t2);
+    t1.im_optional = 10;
+    assert(t1 != t2);
+    t2.__isset.im_optional = false;
+    assert(t1 != t2);
+  }
+
+  return 0;
+}
diff --git a/test/OptionalRequiredTest.thrift b/test/OptionalRequiredTest.thrift
new file mode 100644
index 0000000..9738bd6
--- /dev/null
+++ b/test/OptionalRequiredTest.thrift
@@ -0,0 +1,42 @@
+/*
+../compiler/cpp/thrift -cpp OptionalRequiredTest.thrift
+g++ -Wall -I../lib/cpp/src -I/usr/local/include/boost-1_33_1 \
+  OptionalRequiredTest.cpp gen-cpp/OptionalRequiredTest_types.cpp \
+  ../lib/cpp/.libs/libthrift.a -o OptionalRequiredTest
+./OptionalRequiredTest
+*/
+
+cpp_namespace thrift.test
+
+struct OldSchool {
+  1: i16    im_int;
+  2: string im_str;
+  3: list<map<i32,string>> im_big;
+}
+
+struct Simple {
+  1: /* :) */ i16 im_default;
+  2: required i16 im_required;
+  3: optional i16 im_optional;
+}
+
+struct Tricky1 {
+  1: /* :) */ i16 im_default;
+}
+
+struct Tricky2 {
+  1: optional i16 im_optional;
+}
+
+struct Tricky3 {
+  1: required i16 im_required;
+}
+
+struct Complex {
+  1:          i16 cp_default;
+  2: required i16 cp_required;
+  3: optional i16 cp_optional;
+  4:          map<i16,Simple> the_map;
+  5: required Simple req_simp;
+  6: optional Simple opt_simp;
+}
diff --git a/thrift.el b/thrift.el
index 20ee915..4abb04e 100644
--- a/thrift.el
+++ b/thrift.el
@@ -10,7 +10,7 @@
 (defconst thrift-font-lock-keywords
   (list
    '("#.*$" . font-lock-comment-face)  ;; perl style comments
-   '("\\<\\(include\\|struct\\|exception\\|typedef\\|cpp_namespace\\|java_package\\|php_namespace\\|const\\|enum\\|service\\|extends\\|void\\|async\\|throws\\)\\>" . font-lock-keyword-face)  ;; keywords
+   '("\\<\\(include\\|struct\\|exception\\|typedef\\|cpp_namespace\\|java_package\\|php_namespace\\|const\\|enum\\|service\\|extends\\|void\\|async\\|throws\\|optional\\|required\\)\\>" . font-lock-keyword-face)  ;; keywords
    '("\\<\\(bool\\|byte\\|i16\\|i32\\|i64\\|double\\|string\\|binary\\|map\\|list\\|set\\)\\>" . font-lock-type-face)  ;; built-in types
    '("\\<\\([0-9]+\\)\\>" . font-lock-variable-name-face)   ;; ordinals
    '("\\<\\(\\w+\\)\\s-*(" (1 font-lock-function-name-face))  ;; functions
diff --git a/thrift.vim b/thrift.vim
index d81be13..daec30e 100644
--- a/thrift.vim
+++ b/thrift.vim
@@ -30,10 +30,10 @@
 syn match thriftNumber "-\=\<\d\+\>" contained
 
 " Keywords
-syn keyword thriftKeyword namespace cpp_namespace cpp_include
-syn keyword thriftKeyword cpp_type java_package include const
-syn keyword thriftBasicTypes void bool byte i16 i32 i64 double string
-
+syn keyword thriftKeyword namespace cpp_namespace java_package php_namespace ruby_namespace
+syn keyword thriftKeyword xsd_all xsd_optional xsd_nillable xsd_namespace xsd_attrs
+syn keyword thriftKeyword include cpp_include cpp_type const optional required
+syn keyword thriftBasicTypes void bool byte i16 i32 i64 double string binary
 syn keyword thriftStructure map list set struct typedef exception enum throws
 
 " Special