Thrift: Merging external patch.
Summary:
Merging a patch from Andy Lutomirsky.
- Allow fields to be marked "required" or "optional" (only affects C++).
- Thrift structs now have operator ==.
Reviewed By: mcslee
Test Plan: test/OptionalRequiredTest.cpp
Revert Plan: ok
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665202 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/cpp/src/generate/t_cpp_generator.cc b/compiler/cpp/src/generate/t_cpp_generator.cc
index e0313f7..ec83c89 100644
--- a/compiler/cpp/src/generate/t_cpp_generator.cc
+++ b/compiler/cpp/src/generate/t_cpp_generator.cc
@@ -439,17 +439,26 @@
declare_field(*m_iter, false, pointers && !(*m_iter)->get_type()->is_xception(), !read) << endl;
}
- // Isset struct has boolean fields
- if (members.size() > 0 && (!pointers || read)) {
+ // Isset struct has boolean fields, but only for non-required fields.
+ bool has_nonrequired_fields = false;
+ for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+ if ((*m_iter)->get_req() != t_field::REQUIRED)
+ has_nonrequired_fields = true;
+ }
+
+ if (has_nonrequired_fields && (!pointers || read)) {
out <<
endl <<
- indent() <<"struct __isset {" << endl;
+ indent() << "struct __isset {" << endl;
indent_up();
indent(out) <<
"__isset() : ";
bool first = true;
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+ if ((*m_iter)->get_req() == t_field::REQUIRED) {
+ continue;
+ }
if (first) {
first = false;
out <<
@@ -462,8 +471,10 @@
out << " {}" << endl;
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
- indent(out) <<
- "bool " << (*m_iter)->get_name() << ";" << endl;
+ if ((*m_iter)->get_req() != t_field::REQUIRED) {
+ indent(out) <<
+ "bool " << (*m_iter)->get_name() << ";" << endl;
+ }
}
indent_down();
@@ -471,8 +482,40 @@
"} __isset;" << endl;
}
- out <<
- endl;
+ out << endl;
+
+ if (!pointers) {
+ // Generate an equality testing operator. Make it inline since the compiler
+ // will do a better job than we would when deciding whether to inline it.
+ out <<
+ indent() << "bool operator == (const " << tstruct->get_name() << " & rhs) const" << endl;
+ scope_up(out);
+ for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+ // Most existing Thrift code does not use isset or optional/required,
+ // so we treat "default" fields as required.
+ if ((*m_iter)->get_req() != t_field::OPTIONAL) {
+ out <<
+ indent() << "if (!(" << (*m_iter)->get_name()
+ << " == rhs." << (*m_iter)->get_name() << "))" << endl <<
+ indent() << " return false;" << endl;
+ } else {
+ out <<
+ indent() << "if (__isset." << (*m_iter)->get_name()
+ << " != rhs.__isset." << (*m_iter)->get_name() << ")" << endl <<
+ indent() << " return false;" << endl <<
+ indent() << "else if (__isset." << (*m_iter)->get_name() << " && !("
+ << (*m_iter)->get_name() << " == rhs." << (*m_iter)->get_name()
+ << "))" << endl <<
+ indent() << " return false;" << endl;
+ }
+ }
+ indent(out) << "return true;" << endl;
+ scope_down(out);
+ out <<
+ indent() << "bool operator != (const " << tstruct->get_name() << " &rhs) const {" << endl <<
+ indent() << " return !(*this == rhs);" << endl <<
+ indent() << "}" << endl << endl;
+ }
if (read) {
out <<
indent() << "uint32_t read(facebook::thrift::protocol::TProtocol* iprot);" << endl;
@@ -481,8 +524,7 @@
out <<
indent() << "uint32_t write(facebook::thrift::protocol::TProtocol* oprot) const;" << endl;
}
- out <<
- endl;
+ out << endl;
indent_down();
indent(out) <<
@@ -515,7 +557,17 @@
indent() << "int16_t fid;" << endl <<
endl <<
indent() << "xfer += iprot->readStructBegin(fname);" << endl <<
+ endl <<
+ indent() << "using facebook::thrift::protocol::TProtocolException;" << endl <<
endl;
+
+ // Required variables aren't in __isset, so we need tmp vars to check them.
+ for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+ if ((*f_iter)->get_req() == t_field::REQUIRED)
+ indent(out) << "bool isset_" << (*f_iter)->get_name() << " = false;" << endl;
+ }
+ out << endl;
+
// Loop over reading in fields
indent(out) <<
@@ -547,17 +599,33 @@
"if (ftype == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl;
indent_up();
+ const char *isset_prefix =
+ ((*f_iter)->get_req() != t_field::REQUIRED) ? "this->__isset." : "isset_";
+
+#if 0
+ // This code throws an exception if the same field is encountered twice.
+ // We've decided to leave it out for performance reasons.
+ // TODO(dreiss): Generate this code and "if" it out to make it easier
+ // for people recompiling thrift to include it.
+ out <<
+ indent() << "if (" << isset_prefix << (*f_iter)->get_name() << ")" << endl <<
+ indent() << " throw TProtocolException(TProtocolException::INVALID_DATA);" << endl;
+#endif
+
if (pointers && !(*f_iter)->get_type()->is_xception()) {
generate_deserialize_field(out, *f_iter, "(*(this->", "))");
} else {
generate_deserialize_field(out, *f_iter, "this->");
}
out <<
- indent() << "this->__isset." << (*f_iter)->get_name() << " = true;" << endl;
+ indent() << isset_prefix << (*f_iter)->get_name() << " = true;" << endl;
indent_down();
out <<
indent() << "} else {" << endl <<
indent() << " xfer += iprot->skip(ftype);" << endl <<
+ // TODO(dreiss): Make this an option when thrift structs
+ // have a common base class.
+ // indent() << " throw TProtocolException(TProtocolException::INVALID_DATA);" << endl <<
indent() << "}" << endl <<
indent() << "break;" << endl;
indent_down();
@@ -579,8 +647,20 @@
out <<
endl <<
- indent() << "xfer += iprot->readStructEnd();" << endl <<
- indent() <<"return xfer;" << endl;
+ indent() << "xfer += iprot->readStructEnd();" << endl;
+
+ // Throw if any required fields are missing.
+ // We do this after reading the struct end so that
+ // there might possibly be a chance of continuing.
+ out << endl;
+ for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+ if ((*f_iter)->get_req() == t_field::REQUIRED)
+ out <<
+ indent() << "if (!isset_" << (*f_iter)->get_name() << ')' << endl <<
+ indent() << " throw TProtocolException(TProtocolException::INVALID_DATA);" << endl;
+ }
+
+ indent(out) << "return xfer;" << endl;
indent_down();
indent(out) <<
@@ -610,6 +690,10 @@
indent(out) <<
"xfer += oprot->writeStructBegin(\"" << name << "\");" << endl;
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+ if ((*f_iter)->get_req() == t_field::OPTIONAL) {
+ indent(out) << "if (this->__isset." << (*f_iter)->get_name() << ") {" << endl;
+ indent_up();
+ }
// Write field header
out <<
indent() << "xfer += oprot->writeFieldBegin(" <<
@@ -625,7 +709,12 @@
// Write field closer
indent(out) <<
"xfer += oprot->writeFieldEnd();" << endl;
+ if ((*f_iter)->get_req() == t_field::OPTIONAL) {
+ indent_down();
+ indent(out) << '}' << endl;
+ }
}
+
// Write the struct map
out <<
indent() << "xfer += oprot->writeFieldStop();" << endl <<
diff --git a/compiler/cpp/src/parse/t_field.h b/compiler/cpp/src/parse/t_field.h
index d342bff..a4a2cc7 100644
--- a/compiler/cpp/src/parse/t_field.h
+++ b/compiler/cpp/src/parse/t_field.h
@@ -33,6 +33,7 @@
type_(type),
name_(name),
key_(key),
+ req_(OPT_IN_REQ_OUT),
value_(NULL),
xsd_optional_(false),
xsd_nillable_(false),
@@ -52,6 +53,20 @@
return key_;
}
+ enum e_req {
+ REQUIRED,
+ OPTIONAL,
+ OPT_IN_REQ_OUT
+ };
+
+ void set_req(e_req req) {
+ req_ = req;
+ }
+
+ e_req get_req() const {
+ return req_;
+ }
+
void set_value(t_const_value* value) {
value_ = value;
}
@@ -101,6 +116,7 @@
t_type* type_;
std::string name_;
int32_t key_;
+ e_req req_;
t_const_value* value_;
bool xsd_optional_;
diff --git a/compiler/cpp/src/thriftl.ll b/compiler/cpp/src/thriftl.ll
index 4612858..7693f33 100644
--- a/compiler/cpp/src/thriftl.ll
+++ b/compiler/cpp/src/thriftl.ll
@@ -103,6 +103,8 @@
"service" { return tok_service; }
"enum" { return tok_enum; }
"const" { return tok_const; }
+"required" { return tok_required; }
+"optional" { return tok_optional; }
"abstract" { thrift_reserved_keyword(yytext); }
"and" { thrift_reserved_keyword(yytext); }
diff --git a/compiler/cpp/src/thrifty.yy b/compiler/cpp/src/thrifty.yy
index c8be3f0..e159ef0 100644
--- a/compiler/cpp/src/thrifty.yy
+++ b/compiler/cpp/src/thrifty.yy
@@ -50,6 +50,7 @@
t_function* tfunction;
t_field* tfield;
char* dtext;
+ t_field::e_req ereq;
}
/**
@@ -120,6 +121,8 @@
%token tok_service
%token tok_enum
%token tok_const
+%token tok_required
+%token tok_optional
/**
* Grammar nodes
@@ -139,6 +142,7 @@
%type<tfield> Field
%type<iconst> FieldIdentifier
+%type<ereq> FieldRequiredness
%type<ttype> FieldType
%type<tconstv> FieldValue
%type<tstruct> FieldList
@@ -723,24 +727,25 @@
}
Field:
- CaptureDocText FieldIdentifier FieldType tok_identifier FieldValue XsdOptional XsdNillable XsdAttributes CommaOrSemicolonOptional
+ CaptureDocText FieldIdentifier FieldRequiredness FieldType tok_identifier FieldValue XsdOptional XsdNillable XsdAttributes CommaOrSemicolonOptional
{
pdebug("tok_int_constant : Field -> FieldType tok_identifier");
if ($2 < 0) {
- pwarning(2, "No field key specified for %s, resulting protocol may have conflicts or not be backwards compatible!\n", $4);
+ pwarning(2, "No field key specified for %s, resulting protocol may have conflicts or not be backwards compatible!\n", $5);
}
- $$ = new t_field($3, $4, $2);
- if ($5 != NULL) {
- validate_field_value($$, $5);
- $$->set_value($5);
+ $$ = new t_field($4, $5, $2);
+ $$->set_req($3);
+ if ($6 != NULL) {
+ validate_field_value($$, $6);
+ $$->set_value($6);
}
- $$->set_xsd_optional($6);
- $$->set_xsd_nillable($7);
+ $$->set_xsd_optional($7);
+ $$->set_xsd_nillable($8);
if ($1 != NULL) {
$$->set_doc($1);
}
- if ($8 != NULL) {
- $$->set_xsd_attrs($8);
+ if ($9 != NULL) {
+ $$->set_xsd_attrs($9);
}
}
@@ -758,6 +763,20 @@
$$ = y_field_val--;
}
+FieldRequiredness:
+ tok_required
+ {
+ $$ = t_field::REQUIRED;
+ }
+| tok_optional
+ {
+ $$ = t_field::OPTIONAL;
+ }
+|
+ {
+ $$ = t_field::OPT_IN_REQ_OUT;
+ }
+
FieldValue:
'=' ConstValue
{
diff --git a/test/OptionalRequiredTest.cpp b/test/OptionalRequiredTest.cpp
new file mode 100644
index 0000000..73574ee
--- /dev/null
+++ b/test/OptionalRequiredTest.cpp
@@ -0,0 +1,231 @@
+/*
+../compiler/cpp/thrift -cpp OptionalRequiredTest.thrift
+g++ -Wall -I../lib/cpp/src -I/usr/local/include/boost-1_33_1 \
+ OptionalRequiredTest.cpp gen-cpp/OptionalRequiredTest_types.cpp \
+ ../lib/cpp/.libs/libthrift.a -o OptionalRequiredTest
+./OptionalRequiredTest
+*/
+
+#include <cassert>
+#include <map>
+#include <iostream>
+#include <protocol/TDebugProtocol.h>
+#include <protocol/TBinaryProtocol.h>
+#include <transport/TTransportUtils.h>
+#include "gen-cpp/OptionalRequiredTest_types.h"
+
+using std::cout;
+using std::endl;
+using std::map;
+using std::string;
+using namespace thrift::test;
+using namespace facebook::thrift;
+using namespace facebook::thrift::transport;
+using namespace facebook::thrift::protocol;
+
+
+/*
+template<typename Struct>
+void trywrite(const Struct& s, bool should_work) {
+ bool worked;
+ try {
+ TBinaryProtocol protocol(boost::shared_ptr<TTransport>(new TMemoryBuffer));
+ s.write(&protocol);
+ worked = true;
+ } catch (TProtocolException & ex) {
+ worked = false;
+ }
+ assert(worked == should_work);
+}
+*/
+
+template <typename Struct1, typename Struct2>
+void write_to_read(const Struct1 & w, Struct2 & r) {
+ TBinaryProtocol protocol(boost::shared_ptr<TTransport>(new TMemoryBuffer));
+ w.write(&protocol);
+ r.read(&protocol);
+}
+
+
+int main() {
+
+ cout << "This old school struct should have three fields." << endl;
+ {
+ OldSchool o;
+ cout << ThriftDebugString(o) << endl;
+ }
+ cout << endl;
+
+ cout << "Setting a value before setting isset." << endl;
+ {
+ Simple s;
+ cout << ThriftDebugString(s) << endl;
+ s.im_optional = 10;
+ cout << ThriftDebugString(s) << endl;
+ s.__isset.im_optional = true;
+ cout << ThriftDebugString(s) << endl;
+ }
+ cout << endl;
+
+ cout << "Setting isset before setting a value." << endl;
+ {
+ Simple s;
+ cout << ThriftDebugString(s) << endl;
+ s.__isset.im_optional = true;
+ cout << ThriftDebugString(s) << endl;
+ s.im_optional = 10;
+ cout << ThriftDebugString(s) << endl;
+ }
+ cout << endl;
+
+ // Write-to-read with optional fields.
+ {
+ Simple s1, s2, s3;
+ s1.im_optional = 10;
+ assert(!s1.__isset.im_default);
+ //assert(!s1.__isset.im_required); // Compile error.
+ assert(!s1.__isset.im_optional);
+
+ write_to_read(s1, s2);
+
+ assert( s2.__isset.im_default);
+ //assert( s2.__isset.im_required); // Compile error.
+ assert(!s2.__isset.im_optional);
+ assert(s3.im_optional == 0);
+
+ s1.__isset.im_optional = true;
+ write_to_read(s1, s3);
+
+ assert( s3.__isset.im_default);
+ //assert( s3.__isset.im_required); // Compile error.
+ assert( s3.__isset.im_optional);
+ assert(s3.im_optional == 10);
+ }
+
+ // Writing between optional and default.
+ {
+ Tricky1 t1;
+ Tricky2 t2;
+
+ t2.im_optional = 10;
+ write_to_read(t2, t1);
+ write_to_read(t1, t2);
+ assert(!t1.__isset.im_default);
+ assert( t2.__isset.im_optional);
+ assert(t1.im_default == t2.im_optional);
+ assert(t1.im_default == 0);
+ }
+
+ // Writing between default and required.
+ {
+ Tricky1 t1;
+ Tricky3 t3;
+ write_to_read(t1, t3);
+ write_to_read(t3, t1);
+ assert(t1.__isset.im_default);
+ }
+
+ // Writing between optional and required.
+ {
+ Tricky2 t2;
+ Tricky3 t3;
+ t2.__isset.im_optional = true;
+ write_to_read(t2, t3);
+ write_to_read(t3, t2);
+ }
+
+ // Mu-hu-ha-ha-ha!
+ {
+ Tricky2 t2;
+ Tricky3 t3;
+ try {
+ write_to_read(t2, t3);
+ abort();
+ }
+ catch (TProtocolException& ex) {}
+
+ write_to_read(t3, t2);
+ assert(t2.__isset.im_optional);
+ }
+
+ cout << "Complex struct, simple test." << endl;
+ {
+ Complex c;
+ cout << ThriftDebugString(c) << endl;
+ }
+
+
+ {
+ Tricky1 t1;
+ Tricky2 t2;
+ // Compile error.
+ //(void)(t1 == t2);
+ }
+
+ {
+ OldSchool o1, o2, o3;
+ assert(o1 == o2);
+ o1.im_int = o2.im_int = 10;
+ assert(o1 == o2);
+ o1.__isset.im_int = true;
+ o2.__isset.im_int = false;
+ assert(o1 == o2);
+ o1.im_int = 20;
+ o1.__isset.im_int = false;
+ assert(o1 != o2);
+ o1.im_int = 10;
+ assert(o1 == o2);
+ o1.im_str = o2.im_str = "foo";
+ assert(o1 == o2);
+ o1.__isset.im_str = o2.__isset.im_str = true;
+ assert(o1 == o2);
+ map<int32_t,string> mymap;
+ mymap[1] = "bar";
+ mymap[2] = "baz";
+ o1.im_big.push_back(map<int32_t,string>());
+ assert(o1 != o2);
+ o2.im_big.push_back(map<int32_t,string>());
+ assert(o1 == o2);
+ o2.im_big.push_back(mymap);
+ assert(o1 != o2);
+ o1.im_big.push_back(mymap);
+ assert(o1 == o2);
+
+ TBinaryProtocol protocol(boost::shared_ptr<TTransport>(new TMemoryBuffer));
+ o1.write(&protocol);
+
+ o1.im_big.push_back(mymap);
+ mymap[3] = "qux";
+ o2.im_big.push_back(mymap);
+ assert(o1 != o2);
+ o1.im_big.back()[3] = "qux";
+ assert(o1 == o2);
+
+ o3.read(&protocol);
+ o3.im_big.push_back(mymap);
+ assert(o1 == o3);
+
+ //cout << ThriftDebugString(o3) << endl;
+ }
+
+ {
+ Tricky2 t1, t2;
+ assert(t1.__isset.im_optional == false);
+ assert(t2.__isset.im_optional == false);
+ assert(t1 == t2);
+ t1.im_optional = 5;
+ assert(t1 == t2);
+ t2.im_optional = 5;
+ assert(t1 == t2);
+ t1.__isset.im_optional = true;
+ assert(t1 != t2);
+ t2.__isset.im_optional = true;
+ assert(t1 == t2);
+ t1.im_optional = 10;
+ assert(t1 != t2);
+ t2.__isset.im_optional = false;
+ assert(t1 != t2);
+ }
+
+ return 0;
+}
diff --git a/test/OptionalRequiredTest.thrift b/test/OptionalRequiredTest.thrift
new file mode 100644
index 0000000..9738bd6
--- /dev/null
+++ b/test/OptionalRequiredTest.thrift
@@ -0,0 +1,42 @@
+/*
+../compiler/cpp/thrift -cpp OptionalRequiredTest.thrift
+g++ -Wall -I../lib/cpp/src -I/usr/local/include/boost-1_33_1 \
+ OptionalRequiredTest.cpp gen-cpp/OptionalRequiredTest_types.cpp \
+ ../lib/cpp/.libs/libthrift.a -o OptionalRequiredTest
+./OptionalRequiredTest
+*/
+
+cpp_namespace thrift.test
+
+struct OldSchool {
+ 1: i16 im_int;
+ 2: string im_str;
+ 3: list<map<i32,string>> im_big;
+}
+
+struct Simple {
+ 1: /* :) */ i16 im_default;
+ 2: required i16 im_required;
+ 3: optional i16 im_optional;
+}
+
+struct Tricky1 {
+ 1: /* :) */ i16 im_default;
+}
+
+struct Tricky2 {
+ 1: optional i16 im_optional;
+}
+
+struct Tricky3 {
+ 1: required i16 im_required;
+}
+
+struct Complex {
+ 1: i16 cp_default;
+ 2: required i16 cp_required;
+ 3: optional i16 cp_optional;
+ 4: map<i16,Simple> the_map;
+ 5: required Simple req_simp;
+ 6: optional Simple opt_simp;
+}
diff --git a/thrift.el b/thrift.el
index 20ee915..4abb04e 100644
--- a/thrift.el
+++ b/thrift.el
@@ -10,7 +10,7 @@
(defconst thrift-font-lock-keywords
(list
'("#.*$" . font-lock-comment-face) ;; perl style comments
- '("\\<\\(include\\|struct\\|exception\\|typedef\\|cpp_namespace\\|java_package\\|php_namespace\\|const\\|enum\\|service\\|extends\\|void\\|async\\|throws\\)\\>" . font-lock-keyword-face) ;; keywords
+ '("\\<\\(include\\|struct\\|exception\\|typedef\\|cpp_namespace\\|java_package\\|php_namespace\\|const\\|enum\\|service\\|extends\\|void\\|async\\|throws\\|optional\\|required\\)\\>" . font-lock-keyword-face) ;; keywords
'("\\<\\(bool\\|byte\\|i16\\|i32\\|i64\\|double\\|string\\|binary\\|map\\|list\\|set\\)\\>" . font-lock-type-face) ;; built-in types
'("\\<\\([0-9]+\\)\\>" . font-lock-variable-name-face) ;; ordinals
'("\\<\\(\\w+\\)\\s-*(" (1 font-lock-function-name-face)) ;; functions
diff --git a/thrift.vim b/thrift.vim
index d81be13..daec30e 100644
--- a/thrift.vim
+++ b/thrift.vim
@@ -30,10 +30,10 @@
syn match thriftNumber "-\=\<\d\+\>" contained
" Keywords
-syn keyword thriftKeyword namespace cpp_namespace cpp_include
-syn keyword thriftKeyword cpp_type java_package include const
-syn keyword thriftBasicTypes void bool byte i16 i32 i64 double string
-
+syn keyword thriftKeyword namespace cpp_namespace java_package php_namespace ruby_namespace
+syn keyword thriftKeyword xsd_all xsd_optional xsd_nillable xsd_namespace xsd_attrs
+syn keyword thriftKeyword include cpp_include cpp_type const optional required
+syn keyword thriftBasicTypes void bool byte i16 i32 i64 double string binary
syn keyword thriftStructure map list set struct typedef exception enum throws
" Special