THRIFT-2878 Go validation support of required fields
Client: Go
Patch: Chi Vinh Le <cvl@chinet.info>

This closes #304
diff --git a/compiler/cpp/src/generate/t_go_generator.cc b/compiler/cpp/src/generate/t_go_generator.cc
index e2efcd3..c3cca60 100644
--- a/compiler/cpp/src/generate/t_go_generator.cc
+++ b/compiler/cpp/src/generate/t_go_generator.cc
@@ -1250,7 +1250,18 @@
   out << indent() << "if _, err := iprot.ReadStructBegin(); err != nil {" << endl;
   out << indent() << "  return thrift.PrependError(fmt.Sprintf(\"%T read error: \", p), err)"
       << endl;
-  out << indent() << "}" << endl;
+  out << indent() << "}" << endl << endl;
+
+  // Required variables does not have IsSet functions, so we need tmp vars to check them.
+  for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+    if ((*f_iter)->get_req() == t_field::T_REQUIRED) {
+      const string field_name(
+          publicize(variable_name_to_go_name(escape_string((*f_iter)->get_name()))));
+      indent(out) << "var isset" << field_name << " bool = false;" << endl;
+    }
+  }
+  out << endl;
+
   // Loop over reading in fields
   indent(out) << "for {" << endl;
   indent_up();
@@ -1297,6 +1308,14 @@
         << endl;
     out << indent() << "  return err" << endl;
     out << indent() << "}" << endl;
+
+    // Mark required field as read
+    if ((*f_iter)->get_req() == t_field::T_REQUIRED) {
+      const string field_name(
+          publicize(variable_name_to_go_name(escape_string((*f_iter)->get_name()))));
+      out << indent() << "isset" << field_name << " = true" << endl;
+    }
+
     indent_down();
   }
 
@@ -1327,6 +1346,19 @@
   out << indent() << "  return thrift.PrependError(fmt.Sprintf("
                      "\"%T read struct end error: \", p), err)" << endl;
   out << indent() << "}" << endl;
+
+  // Return error if any required fields are missing.
+  for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+    if ((*f_iter)->get_req() == t_field::T_REQUIRED) {
+      const string field_name(
+          publicize(variable_name_to_go_name(escape_string((*f_iter)->get_name()))));
+      out << indent() << "if !isset" << field_name << "{" << endl;
+      out << indent() << "  return thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, "
+                         "fmt.Errorf(\"Required field " << field_name << " is not set\"));" << endl;
+      out << indent() << "}" << endl;
+    }
+  }
+
   out << indent() << "return nil" << endl;
   indent_down();
   out << indent() << "}" << endl << endl;
diff --git a/lib/go/test/OptionalFieldsTest.thrift b/lib/go/test/OptionalFieldsTest.thrift
index 25b0ef6..2afc157 100644
--- a/lib/go/test/OptionalFieldsTest.thrift
+++ b/lib/go/test/OptionalFieldsTest.thrift
@@ -41,3 +41,10 @@
  1: required structA required_struct_thing
  2: optional structA optional_struct_thing
 }
+
+struct structC {
+ 1: string s,
+ 2: required i32 i,
+ 3: optional bool b,
+ 4: required string s2,
+}
diff --git a/lib/go/test/tests/required_fields_test.go b/lib/go/test/tests/required_fields_test.go
new file mode 100644
index 0000000..fcc7f25
--- /dev/null
+++ b/lib/go/test/tests/required_fields_test.go
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+	"code.google.com/p/gomock/gomock"
+	"optionalfieldstest"
+	"testing"
+	"thrift"
+)
+
+func TestStructReadRequiredFields(t *testing.T) {
+	mockCtrl := gomock.NewController(t)
+	protocol := NewMockTProtocol(mockCtrl)
+	testStruct := optionalfieldstest.NewStructC()
+
+	// None of required fields are set
+	gomock.InOrder(
+		protocol.EXPECT().ReadStructBegin().Return("StructC", nil),
+		protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), nil),
+		protocol.EXPECT().ReadStructEnd().Return(nil),
+	)
+
+	err := testStruct.Read(protocol)
+	mockCtrl.Finish()
+	if err == nil {
+		t.Fatal("Expected read to fail")
+	}
+	err2, ok := err.(thrift.TProtocolException)
+	if !ok {
+		t.Fatal("Expected a TProtocolException")
+	}
+	if err2.TypeId() != thrift.INVALID_DATA {
+		t.Fatal("Expected INVALID_DATA TProtocolException")
+	}
+
+	// One of the required fields is set
+	gomock.InOrder(
+		protocol.EXPECT().ReadStructBegin().Return("StructC", nil),
+		protocol.EXPECT().ReadFieldBegin().Return("I", thrift.TType(thrift.I32), int16(2), nil),
+		protocol.EXPECT().ReadI32().Return(int32(1), nil),
+		protocol.EXPECT().ReadFieldEnd().Return(nil),
+		protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), nil),
+		protocol.EXPECT().ReadStructEnd().Return(nil),
+	)
+
+	err = testStruct.Read(protocol)
+	mockCtrl.Finish()
+	if err == nil {
+		t.Fatal("Expected read to fail")
+	}
+	err2, ok = err.(thrift.TProtocolException)
+	if !ok {
+		t.Fatal("Expected a TProtocolException")
+	}
+	if err2.TypeId() != thrift.INVALID_DATA {
+		t.Fatal("Expected INVALID_DATA TProtocolException")
+	}
+
+	// Both of the required fields are set
+	gomock.InOrder(
+		protocol.EXPECT().ReadStructBegin().Return("StructC", nil),
+		protocol.EXPECT().ReadFieldBegin().Return("i", thrift.TType(thrift.I32), int16(2), nil),
+		protocol.EXPECT().ReadI32().Return(int32(1), nil),
+		protocol.EXPECT().ReadFieldEnd().Return(nil),
+		protocol.EXPECT().ReadFieldBegin().Return("s2", thrift.TType(thrift.STRING), int16(4), nil),
+		protocol.EXPECT().ReadString().Return("test", nil),
+		protocol.EXPECT().ReadFieldEnd().Return(nil),
+		protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), nil),
+		protocol.EXPECT().ReadStructEnd().Return(nil),
+	)
+
+	err = testStruct.Read(protocol)
+	mockCtrl.Finish()
+	if err != nil {
+		t.Fatal("Expected read to succeed")
+	}
+}