THRIFT-2189 add union support for go generator
Client: Go
Patch: Anton Lindström <lindztr@gmail.com> and Jens Geyer
This closes #369
This makes it possible to check if a field is set and skips adding new ones.
Fields in unions are pointers by setting them as `t_field::T_OPTIONAL`.
To be sure that exactly one and only one field is set in a union, we count the
number of fields set and return an error if not exactly one field is set.
This is a breaking change and will require fields in unions to be passed in as
pointers.
diff --git a/compiler/cpp/src/generate/t_go_generator.cc b/compiler/cpp/src/generate/t_go_generator.cc
index c09b8a3..2bae622 100644
--- a/compiler/cpp/src/generate/t_go_generator.cc
+++ b/compiler/cpp/src/generate/t_go_generator.cc
@@ -132,6 +132,10 @@
t_struct* tstruct,
const string& tstruct_name,
bool is_result = false);
+ void generate_countsetfields_helper(std::ofstream& out,
+ t_struct* tstruct,
+ const string& tstruct_name,
+ bool is_result = false);
void generate_go_struct_reader(std::ofstream& out,
t_struct* tstruct,
const string& tstruct_name,
@@ -139,7 +143,8 @@
void generate_go_struct_writer(std::ofstream& out,
t_struct* tstruct,
const string& tstruct_name,
- bool is_result = false);
+ bool is_result = false,
+ bool uses_countsetfields = false);
void generate_go_function_helpers(t_function* tfunction);
void get_publicized_name_and_def_value(t_field* tfield,
string* OUT_pub_name,
@@ -1092,10 +1097,15 @@
// don't have thrift_spec.
indent_up();
+ int num_setable = 0;
if (sorted_members.empty() || (sorted_members[0]->get_key() >= 0)) {
int sorted_keys_pos = 0;
for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) {
+ // Set field to optional if field is union, this is so we can get a
+ // pointer to the field.
+ if (tstruct->is_union())
+ (*m_iter)->set_req(t_field::T_OPTIONAL);
if (sorted_keys_pos != (*m_iter)->get_key()) {
int first_unused = std::max(1, sorted_keys_pos++);
while (sorted_keys_pos != (*m_iter)->get_key()) {
@@ -1166,6 +1176,7 @@
out << indent() << " }" << endl;
out << indent() << "return " << maybepointer << "p." << publicized_name << endl;
out << indent() << "}" << endl;
+ num_setable += 1;
} else {
out << endl;
out << indent() << "func (p *" << tstruct_name << ") Get" << publicized_name << "() "
@@ -1175,9 +1186,14 @@
}
}
+
+ if (tstruct->is_union() && num_setable > 0) {
+ generate_countsetfields_helper(out, tstruct, tstruct_name, is_result);
+ }
+
generate_isset_helpers(out, tstruct, tstruct_name, is_result);
generate_go_struct_reader(out, tstruct, tstruct_name, is_result);
- generate_go_struct_writer(out, tstruct, tstruct_name, is_result);
+ generate_go_struct_writer(out, tstruct, tstruct_name, is_result, num_setable > 0);
out << indent() << "func (p *" << tstruct_name << ") String() string {" << endl;
out << indent() << " if p == nil {" << endl;
@@ -1235,6 +1251,46 @@
}
/**
+ * Generates the CountSetFields helper method for a struct
+ */
+void t_go_generator::generate_countsetfields_helper(ofstream& out,
+ t_struct* tstruct,
+ const string& tstruct_name,
+ bool is_result) {
+ (void)is_result;
+ const vector<t_field*>& fields = tstruct->get_members();
+ vector<t_field*>::const_iterator f_iter;
+ const string escaped_tstruct_name(escape_string(tstruct->get_name()));
+
+ out << indent() << "func (p *" << tstruct_name << ") CountSetFields" << tstruct_name
+ << "() int {"
+ << endl;
+ indent_up();
+ out << indent() << "count := 0" << endl;
+ for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+ if ((*f_iter)->get_req() == t_field::T_REQUIRED)
+ continue;
+
+ if (!is_pointer_field(*f_iter))
+ continue;
+
+ const string field_name(
+ publicize(variable_name_to_go_name(escape_string((*f_iter)->get_name()))));
+
+ out << indent() << "if (p.IsSet" << field_name << "()) {" << endl;
+ indent_up();
+ out << indent() << "count++" << endl;
+ indent_down();
+ out << indent() << "}" << endl;
+ }
+
+ out << indent() << "return count" << endl << endl;
+ indent_down();
+ out << indent() << "}" << endl << endl;
+}
+
+
+/**
* Generates the read method for a struct
*/
void t_go_generator::generate_go_struct_reader(ofstream& out,
@@ -1388,13 +1444,20 @@
void t_go_generator::generate_go_struct_writer(ofstream& out,
t_struct* tstruct,
const string& tstruct_name,
- bool is_result) {
+ bool is_result,
+ bool uses_countsetfields) {
(void)is_result;
string name(tstruct->get_name());
const vector<t_field*>& fields = tstruct->get_sorted_members();
vector<t_field*>::const_iterator f_iter;
indent(out) << "func (p *" << tstruct_name << ") Write(oprot thrift.TProtocol) error {" << endl;
indent_up();
+ if (tstruct->is_union() && uses_countsetfields) {
+ std::string tstruct_name(publicize(tstruct->get_name()));
+ out << indent() << "if c := p.CountSetFields" << tstruct_name << "(); c != 1 {" << endl
+ << indent() << " return fmt.Errorf(\"%T write union: exactly one field must be set (%d set).\", p, c)" << endl
+ << indent() << "}" << endl;
+ }
out << indent() << "if err := oprot.WriteStructBegin(\"" << name << "\"); err != nil {" << endl;
out << indent() << " return thrift.PrependError(fmt.Sprintf("
"\"%T write struct begin error: \", p), err) }" << endl;