THRIFT-2878 Go validation support of required fields
Client: Go
Patch: Chi Vinh Le <cvl@chinet.info>
This closes #304
diff --git a/compiler/cpp/src/generate/t_go_generator.cc b/compiler/cpp/src/generate/t_go_generator.cc
index e2efcd3..c3cca60 100644
--- a/compiler/cpp/src/generate/t_go_generator.cc
+++ b/compiler/cpp/src/generate/t_go_generator.cc
@@ -1250,7 +1250,18 @@
out << indent() << "if _, err := iprot.ReadStructBegin(); err != nil {" << endl;
out << indent() << " return thrift.PrependError(fmt.Sprintf(\"%T read error: \", p), err)"
<< endl;
- out << indent() << "}" << endl;
+ out << indent() << "}" << endl << endl;
+
+ // Required variables does not have IsSet functions, so we need tmp vars to check them.
+ for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+ if ((*f_iter)->get_req() == t_field::T_REQUIRED) {
+ const string field_name(
+ publicize(variable_name_to_go_name(escape_string((*f_iter)->get_name()))));
+ indent(out) << "var isset" << field_name << " bool = false;" << endl;
+ }
+ }
+ out << endl;
+
// Loop over reading in fields
indent(out) << "for {" << endl;
indent_up();
@@ -1297,6 +1308,14 @@
<< endl;
out << indent() << " return err" << endl;
out << indent() << "}" << endl;
+
+ // Mark required field as read
+ if ((*f_iter)->get_req() == t_field::T_REQUIRED) {
+ const string field_name(
+ publicize(variable_name_to_go_name(escape_string((*f_iter)->get_name()))));
+ out << indent() << "isset" << field_name << " = true" << endl;
+ }
+
indent_down();
}
@@ -1327,6 +1346,19 @@
out << indent() << " return thrift.PrependError(fmt.Sprintf("
"\"%T read struct end error: \", p), err)" << endl;
out << indent() << "}" << endl;
+
+ // Return error if any required fields are missing.
+ for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+ if ((*f_iter)->get_req() == t_field::T_REQUIRED) {
+ const string field_name(
+ publicize(variable_name_to_go_name(escape_string((*f_iter)->get_name()))));
+ out << indent() << "if !isset" << field_name << "{" << endl;
+ out << indent() << " return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, "
+ "fmt.Errorf(\"Required field " << field_name << " is not set\"));" << endl;
+ out << indent() << "}" << endl;
+ }
+ }
+
out << indent() << "return nil" << endl;
indent_down();
out << indent() << "}" << endl << endl;
diff --git a/lib/go/test/OptionalFieldsTest.thrift b/lib/go/test/OptionalFieldsTest.thrift
index 25b0ef6..2afc157 100644
--- a/lib/go/test/OptionalFieldsTest.thrift
+++ b/lib/go/test/OptionalFieldsTest.thrift
@@ -41,3 +41,10 @@
1: required structA required_struct_thing
2: optional structA optional_struct_thing
}
+
+struct structC {
+ 1: string s,
+ 2: required i32 i,
+ 3: optional bool b,
+ 4: required string s2,
+}
diff --git a/lib/go/test/tests/required_fields_test.go b/lib/go/test/tests/required_fields_test.go
new file mode 100644
index 0000000..fcc7f25
--- /dev/null
+++ b/lib/go/test/tests/required_fields_test.go
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+ "code.google.com/p/gomock/gomock"
+ "optionalfieldstest"
+ "testing"
+ "thrift"
+)
+
+func TestStructReadRequiredFields(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ protocol := NewMockTProtocol(mockCtrl)
+ testStruct := optionalfieldstest.NewStructC()
+
+ // None of required fields are set
+ gomock.InOrder(
+ protocol.EXPECT().ReadStructBegin().Return("StructC", nil),
+ protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), nil),
+ protocol.EXPECT().ReadStructEnd().Return(nil),
+ )
+
+ err := testStruct.Read(protocol)
+ mockCtrl.Finish()
+ if err == nil {
+ t.Fatal("Expected read to fail")
+ }
+ err2, ok := err.(thrift.TProtocolException)
+ if !ok {
+ t.Fatal("Expected a TProtocolException")
+ }
+ if err2.TypeId() != thrift.INVALID_DATA {
+ t.Fatal("Expected INVALID_DATA TProtocolException")
+ }
+
+ // One of the required fields is set
+ gomock.InOrder(
+ protocol.EXPECT().ReadStructBegin().Return("StructC", nil),
+ protocol.EXPECT().ReadFieldBegin().Return("I", thrift.TType(thrift.I32), int16(2), nil),
+ protocol.EXPECT().ReadI32().Return(int32(1), nil),
+ protocol.EXPECT().ReadFieldEnd().Return(nil),
+ protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), nil),
+ protocol.EXPECT().ReadStructEnd().Return(nil),
+ )
+
+ err = testStruct.Read(protocol)
+ mockCtrl.Finish()
+ if err == nil {
+ t.Fatal("Expected read to fail")
+ }
+ err2, ok = err.(thrift.TProtocolException)
+ if !ok {
+ t.Fatal("Expected a TProtocolException")
+ }
+ if err2.TypeId() != thrift.INVALID_DATA {
+ t.Fatal("Expected INVALID_DATA TProtocolException")
+ }
+
+ // Both of the required fields are set
+ gomock.InOrder(
+ protocol.EXPECT().ReadStructBegin().Return("StructC", nil),
+ protocol.EXPECT().ReadFieldBegin().Return("i", thrift.TType(thrift.I32), int16(2), nil),
+ protocol.EXPECT().ReadI32().Return(int32(1), nil),
+ protocol.EXPECT().ReadFieldEnd().Return(nil),
+ protocol.EXPECT().ReadFieldBegin().Return("s2", thrift.TType(thrift.STRING), int16(4), nil),
+ protocol.EXPECT().ReadString().Return("test", nil),
+ protocol.EXPECT().ReadFieldEnd().Return(nil),
+ protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), nil),
+ protocol.EXPECT().ReadStructEnd().Return(nil),
+ )
+
+ err = testStruct.Read(protocol)
+ mockCtrl.Finish()
+ if err != nil {
+ t.Fatal("Expected read to succeed")
+ }
+}