THRIFT-2471 Make cpp.ref annotation language agnostic
Client: compiler general
Patch: Dave Watson
This closes #113
commit 52b99af4ee1574253dcb77933d76a7ebb2d830df
Author: Dave Watson <davejwatson@fb.com>
Date: 2014-04-23T20:05:56Z
change cpp.ref to &
commit 3f9d31cc6140367529fd8f7b1b67056ec321786f
Author: Dave Watson <davejwatson@fb.com>
Date: 2014-04-23T21:50:29Z
Recursion depth limit
commit 61468e4534ce9e6a4f4f643bfd00542d13600d83
Author: Dave Watson <davejwatson@fb.com>
Date: 2014-04-25T19:59:18Z
shared_ptr for reference type
diff --git a/compiler/cpp/src/generate/t_cpp_generator.cc b/compiler/cpp/src/generate/t_cpp_generator.cc
index 92eab2f..f985492 100755
--- a/compiler/cpp/src/generate/t_cpp_generator.cc
+++ b/compiler/cpp/src/generate/t_cpp_generator.cc
@@ -235,7 +235,7 @@
void generate_local_reflection_pointer(std::ofstream& out, t_type* ttype);
bool is_reference(t_field* tfield) {
- return tfield->annotations_.count("cpp.ref") != 0;
+ return tfield->get_reference();
}
bool is_complex_type(t_type* ttype) {
@@ -832,14 +832,8 @@
const vector<t_field*>& members = tstruct->get_members();
vector<t_field*>::const_iterator f_iter;
for (f_iter = members.begin(); f_iter != members.end(); ++f_iter) {
- if (is_reference(*f_iter)) {
- std::string type = type_name((*f_iter)->get_type());
- indent(out) << (*f_iter)->get_name() << " = new " << type << "(*" << tmp_name << "." <<
- (*f_iter)->get_name() << ");" << endl;
- } else {
- indent(out) << (*f_iter)->get_name() << " = " << tmp_name << "." <<
- (*f_iter)->get_name() << ";" << endl;
- }
+ indent(out) << (*f_iter)->get_name() << " = " << tmp_name << "." <<
+ (*f_iter)->get_name() << ";" << endl;
}
indent_down();
@@ -859,20 +853,8 @@
const vector<t_field*>& members = tstruct->get_members();
vector<t_field*>::const_iterator f_iter;
for (f_iter = members.begin(); f_iter != members.end(); ++f_iter) {
- if (is_reference(*f_iter)) {
- std::string type = type_name((*f_iter)->get_type());
- indent(out) << "if (this == &" << tmp_name << ") return *this;" << endl;
- indent(out) << "if (" << (*f_iter)->get_name() << ") {" << endl;
- indent(out) << " *" << (*f_iter)->get_name() << " = *" << tmp_name << "." <<
- (*f_iter)->get_name() << ";" << endl;
- indent(out) << "} else {" << endl;
- indent(out) << " " << (*f_iter)->get_name() << " = new " << type << "(*" << tmp_name << "." <<
- (*f_iter)->get_name() << ");" << endl;
- indent(out) << "}" << endl;
- } else {
- indent(out) << (*f_iter)->get_name() << " = " << tmp_name << "." <<
- (*f_iter)->get_name() << ";" << endl;
- }
+ indent(out) << (*f_iter)->get_name() << " = " << tmp_name << "." <<
+ (*f_iter)->get_name() << ";" << endl;
}
indent(out) << "return *this;" << endl;
@@ -1031,7 +1013,7 @@
// Declare all fields
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
indent(out) <<
- declare_field(*m_iter, false, (pointers && !(*m_iter)->get_type()->is_xception()) || is_reference(*m_iter), !read) << endl;
+ declare_field(*m_iter, false, (pointers && !(*m_iter)->get_type()->is_xception()), !read) << endl;
}
// Add the __isset data member if we need it, using the definition from above
@@ -1046,11 +1028,19 @@
if (pointers) {
continue;
}
- out <<
- endl <<
- indent() << "void __set_" << (*m_iter)->get_name() <<
+ if (is_reference((*m_iter))) {
+ out <<
+ endl <<
+ indent() << "void __set_" << (*m_iter)->get_name() <<
+ "(boost::shared_ptr<" << type_name((*m_iter)->get_type(), false, false) << ">";
+ out << " val);" << endl;
+ } else {
+ out <<
+ endl <<
+ indent() << "void __set_" << (*m_iter)->get_name() <<
"(" << type_name((*m_iter)->get_type(), false, true);
- out << " val);" << endl;
+ out << " val);" << endl;
+ }
}
out << endl;
@@ -1151,13 +1141,6 @@
indent() << tstruct->get_name() << "::~" << tstruct->get_name() << "() throw() {" << endl;
indent_up();
- for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
- if (is_reference(*m_iter)) {
- out << indent() <<
- "delete " << (*m_iter)->get_name() << ";" << endl;
- }
- }
-
indent_down();
out << indent() << "}" << endl << endl;
}
@@ -1165,22 +1148,22 @@
// Create a setter function for each field
if (setters) {
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
- out <<
- endl <<
- indent() << "void " << tstruct->get_name() << "::__set_" << (*m_iter)->get_name() <<
- "(" << type_name((*m_iter)->get_type(), false, true);
- out << " val) {" << endl;
- indent_up();
if (is_reference((*m_iter))) {
std::string type = type_name((*m_iter)->get_type());
- indent(out) << "if (" << (*m_iter)->get_name() << ") {" << endl;
- indent(out) << " *" << (*m_iter)->get_name() << " = val;" << endl;
- indent(out) << "} else {" << endl;
- indent(out) << " " << (*m_iter)->get_name() << " = new " << type << "(val);" << endl;
- indent(out) << "}" << endl;
+ out <<
+ endl <<
+ indent() << "void " << tstruct->get_name() << "::__set_" << (*m_iter)->get_name() <<
+ "(boost::shared_ptr<" << type_name((*m_iter)->get_type(), false, false) << ">";
+ out << " val) {" << endl;
} else {
- out << indent() << (*m_iter)->get_name() << " = val;" << endl;
+ out <<
+ endl <<
+ indent() << "void " << tstruct->get_name() << "::__set_" << (*m_iter)->get_name() <<
+ "(" << type_name((*m_iter)->get_type(), false, true);
+ out << " val) {" << endl;
}
+ indent_up();
+ out << indent() << (*m_iter)->get_name() << " = val;" << endl;
indent_down();
// assume all fields are required except optional fields.
@@ -1545,6 +1528,7 @@
out <<
indent() << "uint32_t xfer = 0;" << endl;
+ indent(out) << "oprot->incrementRecursionDepth();" << endl;
indent(out) <<
"xfer += oprot->writeStructBegin(\"" << name << "\");" << endl;
@@ -1585,6 +1569,7 @@
out <<
indent() << "xfer += oprot->writeFieldStop();" << endl <<
indent() << "xfer += oprot->writeStructEnd();" << endl <<
+ indent() << "oprot->decrementRecursionDepth();" << endl <<
indent() << "return xfer;" << endl;
indent_down();
@@ -4078,7 +4063,7 @@
bool pointer) {
if (pointer) {
indent(out) << "if (!" << prefix << ") { " << endl;
- indent(out) << " " << prefix << " = new " << type_name(tstruct) << ";" << endl;
+ indent(out) << " " << prefix << " = boost::shared_ptr<" << type_name(tstruct) << ">(new " << type_name(tstruct) << ");" << endl;
indent(out) << "}" << endl;
indent(out) <<
"xfer += " << prefix << "->read(iprot);" << endl;
@@ -4614,6 +4599,9 @@
result += "const ";
}
result += type_name(tfield->get_type());
+ if (is_reference(tfield)) {
+ result = "boost::shared_ptr<" + result + ">";
+ }
if (pointer) {
result += "*";
}
diff --git a/compiler/cpp/src/parse/t_field.h b/compiler/cpp/src/parse/t_field.h
index 7bbcc0f..c05fdf3 100644
--- a/compiler/cpp/src/parse/t_field.h
+++ b/compiler/cpp/src/parse/t_field.h
@@ -42,7 +42,8 @@
value_(NULL),
xsd_optional_(false),
xsd_nillable_(false),
- xsd_attrs_(NULL) {}
+ xsd_attrs_(NULL),
+ reference_(false) {}
t_field(t_type* type, std::string name, int32_t key) :
type_(type),
@@ -52,7 +53,8 @@
value_(NULL),
xsd_optional_(false),
xsd_nillable_(false),
- xsd_attrs_(NULL) {}
+ xsd_attrs_(NULL),
+ reference_(false) {}
~t_field() {}
@@ -137,6 +139,14 @@
std::map<std::string, std::string> annotations_;
+ bool get_reference() {
+ return reference_;
+ }
+
+ void set_reference(bool reference) {
+ reference_ = reference;
+ }
+
private:
t_type* type_;
std::string name_;
@@ -147,7 +157,7 @@
bool xsd_optional_;
bool xsd_nillable_;
t_struct* xsd_attrs_;
-
+ bool reference_;
};
/**
diff --git a/compiler/cpp/src/thriftl.ll b/compiler/cpp/src/thriftl.ll
index 685bb54..aee4406 100644
--- a/compiler/cpp/src/thriftl.ll
+++ b/compiler/cpp/src/thriftl.ll
@@ -183,6 +183,7 @@
pwarning(0, "\"async\" is deprecated. It is called \"oneway\" now.\n");
return tok_oneway;
}
+"&" { return tok_reference; }
"BEGIN" { thrift_reserved_keyword(yytext); }
diff --git a/compiler/cpp/src/thrifty.yy b/compiler/cpp/src/thrifty.yy
index edb05f9..62e13ba 100644
--- a/compiler/cpp/src/thrifty.yy
+++ b/compiler/cpp/src/thrifty.yy
@@ -165,6 +165,7 @@
%token tok_required
%token tok_optional
%token tok_union
+%token tok_reference
/**
* Grammar nodes
@@ -193,6 +194,7 @@
%type<ttype> FieldType
%type<tconstv> FieldValue
%type<tstruct> FieldList
+%type<tbool> FieldReference
%type<tenum> Enum
%type<tenum> EnumDefList
@@ -955,35 +957,36 @@
}
Field:
- CaptureDocText FieldIdentifier FieldRequiredness FieldType tok_identifier FieldValue XsdOptional XsdNillable XsdAttributes TypeAnnotations CommaOrSemicolonOptional
+ CaptureDocText FieldIdentifier FieldRequiredness FieldType FieldReference tok_identifier FieldValue XsdOptional XsdNillable XsdAttributes TypeAnnotations CommaOrSemicolonOptional
{
pdebug("tok_int_constant : Field -> FieldType tok_identifier");
if ($2.auto_assigned) {
- pwarning(1, "No field key specified for %s, resulting protocol may have conflicts or not be backwards compatible!\n", $5);
+ pwarning(1, "No field key specified for %s, resulting protocol may have conflicts or not be backwards compatible!\n", $6);
if (g_strict >= 192) {
yyerror("Implicit field keys are deprecated and not allowed with -strict");
exit(1);
}
}
- validate_simple_identifier($5);
- $$ = new t_field($4, $5, $2.value);
+ validate_simple_identifier($6);
+ $$ = new t_field($4, $6, $2.value);
+ $$->set_reference($5);
$$->set_req($3);
- if ($6 != NULL) {
- g_scope->resolve_const_value($6, $4);
- validate_field_value($$, $6);
- $$->set_value($6);
+ if ($7 != NULL) {
+ g_scope->resolve_const_value($7, $4);
+ validate_field_value($$, $7);
+ $$->set_value($7);
}
- $$->set_xsd_optional($7);
- $$->set_xsd_nillable($8);
+ $$->set_xsd_optional($8);
+ $$->set_xsd_nillable($9);
if ($1 != NULL) {
$$->set_doc($1);
}
- if ($9 != NULL) {
- $$->set_xsd_attrs($9);
- }
if ($10 != NULL) {
- $$->annotations_ = $10->annotations_;
- delete $10;
+ $$->set_xsd_attrs($10);
+ }
+ if ($11 != NULL) {
+ $$->annotations_ = $11->annotations_;
+ delete $11;
}
}
@@ -1029,6 +1032,16 @@
$$.auto_assigned = true;
}
+FieldReference:
+ tok_reference
+ {
+ $$ = true;
+ }
+|
+ {
+ $$ = false;
+ }
+
FieldRequiredness:
tok_required
{
diff --git a/lib/cpp/src/thrift/protocol/TProtocol.h b/lib/cpp/src/thrift/protocol/TProtocol.h
index e72033a..e8ba429 100644
--- a/lib/cpp/src/thrift/protocol/TProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TProtocol.h
@@ -283,6 +283,8 @@
return 0;
}
+static const uint32_t DEFAULT_RECURSION_LIMIT = 64;
+
/**
* Abstract class for a thrift protocol driver. These are all the methods that
* a protocol must implement. Essentially, there must be some way of reading
@@ -660,15 +662,28 @@
return ptrans_;
}
- protected:
- TProtocol(boost::shared_ptr<TTransport> ptrans):
- ptrans_(ptrans) {
+ void incrementRecursionDepth() {
+ if (recursion_limit_ < ++recursion_depth_) {
+ throw TProtocolException(TProtocolException::DEPTH_LIMIT);
+ }
}
+ void decrementRecursionDepth() {
+ --recursion_depth_;
+ }
+
+ protected:
+ TProtocol(boost::shared_ptr<TTransport> ptrans)
+ : ptrans_(ptrans)
+ , recursion_depth_(0)
+ , recursion_limit_(DEFAULT_RECURSION_LIMIT) {}
+
boost::shared_ptr<TTransport> ptrans_;
private:
TProtocol() {}
+ uint32_t recursion_depth_;
+ uint32_t recursion_limit_;
};
/**
diff --git a/lib/cpp/src/thrift/protocol/TProtocolException.h b/lib/cpp/src/thrift/protocol/TProtocolException.h
index a03d3c8..4ddb81e 100644
--- a/lib/cpp/src/thrift/protocol/TProtocolException.h
+++ b/lib/cpp/src/thrift/protocol/TProtocolException.h
@@ -45,6 +45,7 @@
, SIZE_LIMIT = 3
, BAD_VERSION = 4
, NOT_IMPLEMENTED = 5
+ , DEPTH_LIMIT = 6
};
TProtocolException() :
diff --git a/lib/cpp/test/RecursiveTest.cpp b/lib/cpp/test/RecursiveTest.cpp
index 00610c6..24c0f7c 100644
--- a/lib/cpp/test/RecursiveTest.cpp
+++ b/lib/cpp/test/RecursiveTest.cpp
@@ -44,7 +44,7 @@
assert(tree == result);
RecList l;
- RecList* l2(new RecList);
+ boost::shared_ptr<RecList> l2(new RecList);
l.nextitem = l2;
l.write(prot.get());
@@ -55,7 +55,7 @@
assert(resultlist.nextitem->nextitem == NULL);
CoRec c;
- CoRec2* r(new CoRec2);
+ boost::shared_ptr<CoRec2> r(new CoRec2);
c.other = r;
c.write(prot.get());
@@ -64,4 +64,12 @@
assert(c.other != NULL);
assert(c.other->other.other == NULL);
+ boost::shared_ptr<RecList> depthLimit(new RecList);
+ depthLimit->nextitem = depthLimit;
+ try {
+ depthLimit->write(prot.get());
+ assert(false);
+ } catch (const apache::thrift::protocol::TProtocolException& e) {
+ }
+
}
diff --git a/test/Recursive.thrift b/test/Recursive.thrift
index c55541b..9c29983 100644
--- a/test/Recursive.thrift
+++ b/test/Recursive.thrift
@@ -23,12 +23,12 @@
}
struct RecList {
- 1: RecList nextitem (cpp.ref = "true")
+ 1: RecList & nextitem
3: i16 item
}
struct CoRec {
- 1: CoRec2 other (cpp.ref = "true")
+ 1: CoRec2 & other
}
struct CoRec2 {