THRIFT-1783 C# doesn't handle required fields correctly
Patch: Carl Yeksigian
diff --git a/compiler/cpp/src/generate/t_csharp_generator.cc b/compiler/cpp/src/generate/t_csharp_generator.cc
index 3bba2b7..924d372 100644
--- a/compiler/cpp/src/generate/t_csharp_generator.cc
+++ b/compiler/cpp/src/generate/t_csharp_generator.cc
@@ -56,15 +56,15 @@
iter = parsed_options.find("serial");
serialize_ = (iter != parsed_options.end());
- if (serialize_) {
- wcf_namespace_ = iter->second; // since there can be only one namespace
- }
-
- iter = parsed_options.find("wcf");
- wcf_ = (iter != parsed_options.end());
- if (wcf_) {
- wcf_namespace_ = iter->second;
- }
+ if (serialize_) {
+ wcf_namespace_ = iter->second; // since there can be only one namespace
+ }
+
+ iter = parsed_options.find("wcf");
+ wcf_ = (iter != parsed_options.end());
+ if (wcf_) {
+ wcf_namespace_ = iter->second;
+ }
out_dir_base_ = "gen-csharp";
}
@@ -124,8 +124,8 @@
std::string csharp_type_usings();
std::string csharp_thrift_usings();
- std::string type_name(t_type* ttype, bool in_countainer=false, bool in_init=false, bool in_param=false);
- std::string base_type_name(t_base_type* tbase, bool in_container=false, bool in_param=false);
+ std::string type_name(t_type* ttype, bool in_countainer=false, bool in_init=false, bool in_param=false, bool is_required=false);
+ std::string base_type_name(t_base_type* tbase, bool in_container=false, bool in_param=false, bool is_required=false);
std::string declare_field(t_field* tfield, bool init=false, std::string prefix="");
std::string function_signature_async_begin(t_function* tfunction, std::string prefix = "");
std::string function_signature_async_end(t_function* tfunction, std::string prefix = "");
@@ -140,6 +140,10 @@
return tfield->get_value() != NULL;
}
+ bool field_is_required(t_field* tfield) {
+ return tfield->get_req() == t_field::T_REQUIRED;
+ }
+
bool type_can_be_null(t_type* ttype) {
while (ttype->is_typedef()) {
ttype = ((t_typedef*)ttype)->get_type();
@@ -463,7 +467,7 @@
indent(out) << "[Serializable]" << endl;
indent(out) << "#endif" << endl;
if ((serialize_||wcf_) &&!is_exception) {
- indent(out) << "[DataContract(Namespace=\"" << wcf_namespace_ << "\")]" << endl; // do not make exception classes directly WCF serializable, we provide a seperate "fault" for that
+ indent(out) << "[DataContract(Namespace=\"" << wcf_namespace_ << "\")]" << endl; // do not make exception classes directly WCF serializable, we provide a seperate "fault" for that
}
bool is_final = (tstruct->annotations_.find("final") != tstruct->annotations_.end());
@@ -483,21 +487,38 @@
//make private members with public Properties
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
- if (!nullable_ || field_has_default((*m_iter))) {
+ // if the field is requied, then we use auto-properties
+ if (!field_is_required((*m_iter)) && (!nullable_ || field_has_default((*m_iter)))) {
indent(out) << "private " << declare_field(*m_iter, false, "_") << endl;
}
}
out << endl;
- bool generate_isset = !nullable_;
+ bool has_non_required_fields = false;
+ bool has_non_required_default_value_fields = false;
+ bool has_required_fields = false;
+ bool has_default_values = false;
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
generate_csharp_doc(out, *m_iter);
generate_property(out, *m_iter, true, true);
- if (field_has_default((*m_iter))) {
- generate_isset = true;
+ bool is_required = field_is_required((*m_iter));
+ bool has_default = field_has_default((*m_iter));
+ if (is_required) {
+ has_required_fields = true;
+ } else {
+ if (has_default) {
+ has_non_required_default_value_fields = true;
+ }
+ has_non_required_fields = true;
+ }
+ if (has_default) {
+ has_default_values = true;
}
}
+ bool generate_isset =
+ (nullable_ && has_non_required_default_value_fields)
+ || (!nullable_ && has_non_required_fields);
if (generate_isset) {
out <<
endl <<
@@ -512,7 +533,12 @@
indent(out) << "public struct Isset {" << endl;
indent_up();
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
- if (!nullable_ || field_has_default((*m_iter))) {
+ bool is_required = field_is_required((*m_iter));
+ bool has_default = field_has_default((*m_iter));
+ // if it is required, don't need Isset for that variable
+ // if it is not required, if it has a default value, we need to generate Isset
+ // if we are not nullable, then we generate Isset
+ if (!is_required && (!nullable_ || has_default)) {
indent(out) << "public bool " << (*m_iter)->get_name() << ";" << endl;
}
}
@@ -521,8 +547,8 @@
indent(out) << "}" << endl << endl;
}
- indent(out) <<
- "public " << tstruct->get_name() << "() {" << endl;
+ // We always want a default, no argument constructor for Reading
+ indent(out) << "public " << tstruct->get_name() << "() {" << endl;
indent_up();
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
t_type* t = (*m_iter)->get_type();
@@ -533,10 +559,35 @@
print_const_value(out, "this._" + (*m_iter)->get_name(), t, (*m_iter)->get_value(), true, true);
}
}
-
indent_down();
indent(out) << "}" << endl << endl;
+
+ if (has_required_fields) {
+ indent(out) << "public " << tstruct->get_name() << "(";
+ bool first = true;
+ for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+ if (field_is_required((*m_iter))) {
+ if (first) {
+ first = false;
+ } else {
+ out << ", ";
+ }
+ out << type_name((*m_iter)->get_type()) << " " << (*m_iter)->get_name();
+ }
+ }
+ out << ") : this() {" << endl;
+ indent_up();
+ for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+ if (field_is_required((*m_iter))) {
+ indent(out) << "this." << prop_name((*m_iter)) << " = " << (*m_iter)->get_name() << ";" << endl;
+ }
+ }
+
+ indent_down();
+ indent(out) << "}" << endl << endl;
+ }
+
generate_csharp_struct_reader(out, tstruct);
if (is_result) {
generate_csharp_struct_result_writer(out, tstruct);
@@ -549,11 +600,10 @@
// generate a corresponding WCF fault to wrap the exception
if((serialize_||wcf_) && is_exception) {
- generate_csharp_wcffault(out, tstruct);
+ generate_csharp_wcffault(out, tstruct);
}
- if (!in_class)
- {
+ if (!in_class) {
end_csharp_namespace(out);
}
}
@@ -596,6 +646,13 @@
const vector<t_field*>& fields = tstruct->get_members();
vector<t_field*>::const_iterator f_iter;
+ // Required variables aren't in __isset, so we need tmp vars to check them
+ for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+ if (field_is_required((*f_iter))) {
+ indent(out) << "bool isset_" << (*f_iter)->get_name() << " = false;" << endl;
+ }
+ }
+
indent(out) <<
"TField field;" << endl <<
indent() << "iprot.ReadStructBegin();" << endl;
@@ -622,6 +679,7 @@
scope_up(out);
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+ bool is_required = field_is_required((*f_iter));
indent(out) <<
"case " << (*f_iter)->get_key() << ":" << endl;
indent_up();
@@ -630,7 +688,10 @@
indent_up();
generate_deserialize_field(out, *f_iter);
-
+ if (is_required) {
+ indent(out) << "isset_" << (*f_iter)->get_name() << " = true;" << endl;
+ }
+
indent_down();
out <<
indent() << "} else { " << endl <<
@@ -657,6 +718,15 @@
indent(out) <<
"iprot.ReadStructEnd();" << endl;
+ for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+ if (field_is_required((*f_iter))) {
+ indent(out) << "if (!isset_" << (*f_iter)->get_name() << ")" << endl;
+ indent_up();
+ indent(out) << "throw new TProtocolException(TProtocolException.INVALID_DATA);" << endl;
+ indent_down();
+ }
+ }
+
indent_down();
indent(out) << "}" << endl << endl;
@@ -680,53 +750,44 @@
if (fields.size() > 0) {
indent(out) << "TField field = new TField();" << endl;
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
- bool use_nullable = nullable_ && !field_has_default((*f_iter));
- if (use_nullable) {
- indent(out) <<
- "if (" << prop_name((*f_iter)) << " != null) {" << endl;
+ bool is_required = field_is_required((*f_iter));
+ bool has_default = field_has_default((*f_iter));
+ if (nullable_ && !has_default && !is_required) {
+ indent(out) << "if (" << prop_name((*f_iter)) << " != null) {" << endl;
indent_up();
- } else {
+ } else if (!is_required) {
bool null_allowed = type_can_be_null((*f_iter)->get_type());
if (null_allowed) {
indent(out) <<
"if (" << prop_name((*f_iter)) << " != null && __isset." << (*f_iter)->get_name() << ") {" << endl;
indent_up();
- }
- else
- {
+ } else {
indent(out) <<
"if (__isset." << (*f_iter)->get_name() << ") {" << endl;
indent_up();
}
}
- indent(out) <<
- "field.Name = \"" << (*f_iter)->get_name() << "\";" << endl;
- indent(out) <<
- "field.Type = " << type_to_enum((*f_iter)->get_type()) << ";" << endl;
- indent(out) <<
- "field.ID = " << (*f_iter)->get_key() << ";" << endl;
- indent(out) <<
- "oprot.WriteFieldBegin(field);" << endl;
+ indent(out) << "field.Name = \"" << (*f_iter)->get_name() << "\";" << endl;
+ indent(out) << "field.Type = " << type_to_enum((*f_iter)->get_type()) << ";" << endl;
+ indent(out) << "field.ID = " << (*f_iter)->get_key() << ";" << endl;
+ indent(out) << "oprot.WriteFieldBegin(field);" << endl;
generate_serialize_field(out, *f_iter);
- indent(out) <<
- "oprot.WriteFieldEnd();" << endl;
-
- indent_down();
- indent(out) << "}" << endl;
+ indent(out) << "oprot.WriteFieldEnd();" << endl;
+ if (!is_required) {
+ indent_down();
+ indent(out) << "}" << endl;
+ }
}
}
- indent(out) <<
- "oprot.WriteFieldStop();" << endl;
- indent(out) <<
- "oprot.WriteStructEnd();" << endl;
+ indent(out) << "oprot.WriteFieldStop();" << endl;
+ indent(out) << "oprot.WriteStructEnd();" << endl;
indent_down();
- indent(out) <<
- "}" << endl << endl;
+ indent(out) << "}" << endl << endl;
}
void t_csharp_generator::generate_csharp_struct_result_writer(ofstream& out, t_struct* tstruct) {
@@ -749,11 +810,9 @@
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
if (first) {
first = false;
- out <<
- endl << indent() << "if ";
+ out << endl << indent() << "if ";
} else {
- out <<
- " else if ";
+ out << " else if ";
}
if (nullable_) {
@@ -786,8 +845,8 @@
"oprot.WriteFieldEnd();" << endl;
if (null_allowed) {
- indent_down();
- indent(out) << "}" << endl;
+ indent_down();
+ indent(out) << "}" << endl;
}
indent_down();
@@ -1692,7 +1751,7 @@
indent(out) <<
"oprot.";
- string nullable_name = nullable_ && !is_element
+ string nullable_name = nullable_ && !is_element && !field_is_required(tfield)
? name + ".Value"
: name;
@@ -1839,8 +1898,10 @@
if((serialize_||wcf_) && isPublic) {
indent(out) << "[DataMember]" << endl;
}
- if (nullable_ && !field_has_default(tfield)) {
- indent(out) << (isPublic ? "public " : "private ") << type_name(tfield->get_type(), false, false, true)
+ bool has_default = field_has_default(tfield);
+ bool is_required = field_is_required(tfield);
+ if ((nullable_ && !has_default) || (is_required)) {
+ indent(out) << (isPublic ? "public " : "private ") << type_name(tfield->get_type(), false, false, true, is_required)
<< " " << prop_name(tfield) << " { get; set; }" << endl;
} else {
indent(out) << (isPublic ? "public " : "private ") << type_name(tfield->get_type(), false, false, true)
@@ -1885,14 +1946,14 @@
return name;
}
-string t_csharp_generator::type_name(t_type* ttype, bool in_container, bool in_init, bool in_param) {
+string t_csharp_generator::type_name(t_type* ttype, bool in_container, bool in_init, bool in_param, bool is_required) {
(void) in_init;
while (ttype->is_typedef()) {
ttype = ((t_typedef*)ttype)->get_type();
}
if (ttype->is_base_type()) {
- return base_type_name((t_base_type*)ttype, in_container, in_param);
+ return base_type_name((t_base_type*)ttype, in_container, in_param, is_required);
} else if (ttype->is_map()) {
t_map *tmap = (t_map*) ttype;
return "Dictionary<" + type_name(tmap->get_key_type(), true) +
@@ -1906,7 +1967,7 @@
}
t_program* program = ttype->get_program();
- string postfix = (nullable_ && ttype->is_enum()) ? "?" : "";
+ string postfix = (!is_required && nullable_ && in_param && ttype->is_enum()) ? "?" : "";
if (program != NULL && program != program_) {
string ns = program->get_namespace("csharp");
if (!ns.empty()) {
@@ -1917,8 +1978,9 @@
return ttype->get_name() + postfix;
}
-string t_csharp_generator::base_type_name(t_base_type* tbase, bool in_container, bool in_param) {
+string t_csharp_generator::base_type_name(t_base_type* tbase, bool in_container, bool in_param, bool is_required) {
(void) in_container;
+ string postfix = (!is_required && nullable_ && in_param) ? "?" : "";
switch (tbase->get_base()) {
case t_base_type::TYPE_VOID:
return "void";
@@ -1929,35 +1991,17 @@
return "string";
}
case t_base_type::TYPE_BOOL:
- if (nullable_ && in_param) {
- return "bool?";
- }
- return "bool";
+ return "bool" + postfix;
case t_base_type::TYPE_BYTE:
- if (nullable_ && in_param) {
- return "byte?";
- }
- return "byte";
+ return "byte" + postfix;
case t_base_type::TYPE_I16:
- if (nullable_ && in_param) {
- return "short?";
- }
- return "short";
+ return "short" + postfix;
case t_base_type::TYPE_I32:
- if (nullable_ && in_param) {
- return "int?";
- }
- return "int";
+ return "int" + postfix;
case t_base_type::TYPE_I64:
- if (nullable_ && in_param) {
- return "long?";
- }
- return "long";
+ return "long" + postfix;
case t_base_type::TYPE_DOUBLE:
- if (nullable_ && in_param) {
- return "double?";
- }
- return "double";
+ return "double" + postfix;
default:
throw "compiler error: no C# name for base type " + tbase->get_base();
}