THRIFT-4650: fix required fields incorrectly being marked as set
This closes #1610.
Client: go
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc
index b5742f6..0807efb 100644
--- a/compiler/cpp/src/thrift/generate/t_go_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc
@@ -1591,17 +1591,19 @@
<< endl;
out << indent() << " return err" << endl;
out << indent() << " }" << endl;
+
+ // Mark required field as read
+ if ((*f_iter)->get_req() == t_field::T_REQUIRED) {
+ const string field_name(publicize(escape_string((*f_iter)->get_name())));
+ out << indent() << " isset" << field_name << " = true" << endl;
+ }
+
out << indent() << "} else {" << endl;
out << indent() << " if err := iprot.Skip(fieldTypeId); err != nil {" << endl;
out << indent() << " return err" << endl;
out << indent() << " }" << endl;
out << indent() << "}" << endl;
- // Mark required field as read
- if ((*f_iter)->get_req() == t_field::T_REQUIRED) {
- const string field_name(publicize(escape_string((*f_iter)->get_name())));
- out << indent() << "isset" << field_name << " = true" << endl;
- }
indent_down();
}
diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am
index b7ba870..78d4681 100644
--- a/lib/go/test/Makefile.am
+++ b/lib/go/test/Makefile.am
@@ -27,6 +27,7 @@
MultiplexedProtocolTest.thrift \
OnewayTest.thrift \
OptionalFieldsTest.thrift \
+ RequiredFieldTest.thrift \
ServicesTest.thrift \
GoTagTest.thrift \
TypedefFieldTest.thrift \
@@ -46,6 +47,7 @@
$(THRIFT) $(THRIFTARGS) MultiplexedProtocolTest.thrift
$(THRIFT) $(THRIFTARGS) OnewayTest.thrift
$(THRIFT) $(THRIFTARGS) OptionalFieldsTest.thrift
+ $(THRIFT) $(THRIFTARGS) RequiredFieldTest.thrift
$(THRIFT) $(THRIFTARGS) ServicesTest.thrift
$(THRIFT) $(THRIFTARGS) GoTagTest.thrift
$(THRIFT) $(THRIFTARGS) TypedefFieldTest.thrift
@@ -96,6 +98,7 @@
NamespacedTest.thrift \
OnewayTest.thrift \
OptionalFieldsTest.thrift \
+ RequiredFieldTest.thrift \
RefAnnotationFieldsTest.thrift \
UnionDefaultValueTest.thrift \
UnionBinaryTest.thrift \
diff --git a/lib/go/test/RequiredFieldTest.thrift b/lib/go/test/RequiredFieldTest.thrift
new file mode 100644
index 0000000..4a2dcae
--- /dev/null
+++ b/lib/go/test/RequiredFieldTest.thrift
@@ -0,0 +1,7 @@
+struct RequiredField {
+ 1: required string name
+}
+
+struct OtherThing {
+ 1: required i16 value
+}
diff --git a/lib/go/test/tests/required_fields_test.go b/lib/go/test/tests/required_fields_test.go
index 287ef60..3fa414a 100644
--- a/lib/go/test/tests/required_fields_test.go
+++ b/lib/go/test/tests/required_fields_test.go
@@ -20,12 +20,45 @@
package tests
import (
+ "context"
"github.com/golang/mock/gomock"
"optionalfieldstest"
+ "requiredfieldtest"
"testing"
"thrift"
)
+func TestRequiredField_SucecssWhenSet(t *testing.T) {
+ // create a new RequiredField instance with the required field set
+ source := &requiredfieldtest.RequiredField{Name: "this is a test"}
+ sourceData, err := thrift.NewTSerializer().Write(context.Background(), source)
+ if err != nil {
+ t.Fatalf("failed to serialize %T: %v", source, err)
+ }
+
+ d := thrift.NewTDeserializer()
+ err = d.Read(&requiredfieldtest.RequiredField{}, sourceData)
+ if err != nil {
+ t.Fatalf("Did not expect an error when trying to deserialize the requiredfieldtest.RequiredField: %v", err)
+ }
+}
+
+func TestRequiredField_ErrorWhenMissing(t *testing.T) {
+ // create a new OtherThing instance, without setting the required field
+ source := &requiredfieldtest.OtherThing{}
+ sourceData, err := thrift.NewTSerializer().Write(context.Background(), source)
+ if err != nil {
+ t.Fatalf("failed to serialize %T: %v", source, err)
+ }
+
+ // attempt to deserialize into a different type (which should fail)
+ d := thrift.NewTDeserializer()
+ err = d.Read(&requiredfieldtest.RequiredField{}, sourceData)
+ if err == nil {
+ t.Fatal("Expected an error when trying to deserialize an object which is missing a required field")
+ }
+}
+
func TestStructReadRequiredFields(t *testing.T) {
mockCtrl := gomock.NewController(t)
protocol := NewMockTProtocol(mockCtrl)