THRIFT-2868 Enhance error handling in the Go client
Client: Go
Patch: Chi Vinh Le <cvl@chinet.info>

This closes #297
diff --git a/compiler/cpp/src/generate/t_go_generator.cc b/compiler/cpp/src/generate/t_go_generator.cc
index 691684a..e2efcd3 100644
--- a/compiler/cpp/src/generate/t_go_generator.cc
+++ b/compiler/cpp/src/generate/t_go_generator.cc
@@ -1765,10 +1765,22 @@
       f_service_ << indent() << "  iprot = p.ProtocolFactory.GetProtocol(p.Transport)" << endl;
       f_service_ << indent() << "  p.InputProtocol = iprot" << endl;
       f_service_ << indent() << "}" << endl;
-      f_service_ << indent() << "_, mTypeId, seqId, err := iprot.ReadMessageBegin()" << endl;
+      f_service_ << indent() << "method, mTypeId, seqId, err := iprot.ReadMessageBegin()" << endl;
       f_service_ << indent() << "if err != nil {" << endl;
       f_service_ << indent() << "  return" << endl;
       f_service_ << indent() << "}" << endl;
+      f_service_ << indent() << "if method != \"" << (*f_iter)->get_name() << "\" {" << endl;
+      f_service_ << indent() << "  err = thrift.NewTApplicationException("
+                 << "thrift.WRONG_METHOD_NAME, \"" << (*f_iter)->get_name()
+                 << " failed: wrong method name\")" << endl;
+      f_service_ << indent() << "  return" << endl;
+      f_service_ << indent() << "}" << endl;
+      f_service_ << indent() << "if p.SeqId != seqId {" << endl;
+      f_service_ << indent() << "  err = thrift.NewTApplicationException("
+                 << "thrift.BAD_SEQUENCE_ID, \"" << (*f_iter)->get_name()
+                 << " failed: out of sequence response\")" << endl;
+      f_service_ << indent() << "  return" << endl;
+      f_service_ << indent() << "}" << endl;
       f_service_ << indent() << "if mTypeId == thrift.EXCEPTION {" << endl;
       f_service_ << indent() << "  " << error
                  << " := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "
@@ -1784,9 +1796,10 @@
       f_service_ << indent() << "  err = " << error2 << endl;
       f_service_ << indent() << "  return" << endl;
       f_service_ << indent() << "}" << endl;
-      f_service_ << indent() << "if p.SeqId != seqId {" << endl;
-      f_service_ << indent() << "  err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, \""
-                 << (*f_iter)->get_name() << " failed: out of sequence response\")" << endl;
+      f_service_ << indent() << "if mTypeId != thrift.REPLY {" << endl;
+      f_service_ << indent() << "  err = thrift.NewTApplicationException("
+                 << "thrift.INVALID_MESSAGE_TYPE_EXCEPTION, \"" << (*f_iter)->get_name()
+                 << " failed: invalid message type\")" << endl;
       f_service_ << indent() << "  return" << endl;
       f_service_ << indent() << "}" << endl;
       f_service_ << indent() << "result := " << resultname << "{}" << endl;
diff --git a/lib/go/test/ErrorTest.thrift b/lib/go/test/ErrorTest.thrift
index eeaeb95..33b6644 100644
--- a/lib/go/test/ErrorTest.thrift
+++ b/lib/go/test/ErrorTest.thrift
@@ -31,4 +31,5 @@
 service ErrorTest 
 {
   TestStruct         testStruct(1: TestStruct thing)
+  string             testString(1: string s)
 }
diff --git a/lib/go/test/tests/client_error_test.go b/lib/go/test/tests/client_error_test.go
index 11ea1c7..8fa81cc 100644
--- a/lib/go/test/tests/client_error_test.go
+++ b/lib/go/test/tests/client_error_test.go
@@ -27,14 +27,15 @@
 	"thrift"
 )
 
+// TestCase: Comprehensive call and reply workflow in the client.
 // Setup mock to fail at a certain position. Return true if position exists otherwise false.
-func prepareClientProtocolFailure(protocol *MockTProtocol, failAt int, failWith error) bool {
+func prepareClientCallReply(protocol *MockTProtocol, failAt int, failWith error) bool {
 	var err error = nil
 
 	if failAt == 0 {
 		err = failWith
 	}
-	last := protocol.EXPECT().WriteMessageBegin("testStruct", thrift.TMessageType(1), int32(1)).Return(err)
+	last := protocol.EXPECT().WriteMessageBegin("testStruct", thrift.CALL, int32(1)).Return(err)
 	if failAt == 0 {
 		return true
 	}
@@ -217,7 +218,7 @@
 	if failAt == 26 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadMessageBegin().Return("testStruct", thrift.TMessageType(1), int32(1), err).After(last)
+	last = protocol.EXPECT().ReadMessageBegin().Return("testStruct", thrift.REPLY, int32(1), err).After(last)
 	if failAt == 26 {
 		return true
 	}
@@ -392,6 +393,8 @@
 	return false
 }
 
+// TestCase: Comprehensive call and reply workflow in the client.
+// Expecting TTransportError on fail.
 func TestClientReportTTransportErrors(t *testing.T) {
 	mockCtrl := gomock.NewController(t)
 	transport := thrift.NewTMemoryBuffer()
@@ -405,24 +408,25 @@
 	err := thrift.NewTTransportException(thrift.TIMED_OUT, "test")
 	for i := 0; ; i++ {
 		protocol := NewMockTProtocol(mockCtrl)
-		if !prepareClientProtocolFailure(protocol, i, err) {
+		if !prepareClientCallReply(protocol, i, err) {
 			return
 		}
 		client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
 		_, retErr := client.TestStruct(thing)
+		mockCtrl.Finish()
 		err2, ok := retErr.(thrift.TTransportException)
 		if !ok {
 			t.Fatal("Expected a TTrasportException")
 		}
 
-		if err2.TypeId() != err.TypeId() {
-			t.Fatal("Expected a same error type id")
+		if err2.TypeId() != thrift.TIMED_OUT {
+			t.Fatal("Expected TIMED_OUT error")
 		}
-
-		mockCtrl.Finish()
 	}
 }
 
+// TestCase: Comprehensive call and reply workflow in the client.
+// Expecting TTProtocolErrors on fail.
 func TestClientReportTProtocolErrors(t *testing.T) {
 	mockCtrl := gomock.NewController(t)
 	transport := thrift.NewTMemoryBuffer()
@@ -436,20 +440,241 @@
 	err := thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, errors.New("test"))
 	for i := 0; ; i++ {
 		protocol := NewMockTProtocol(mockCtrl)
-		if !prepareClientProtocolFailure(protocol, i, err) {
+		if !prepareClientCallReply(protocol, i, err) {
 			return
 		}
 		client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
 		_, retErr := client.TestStruct(thing)
+		mockCtrl.Finish()
 		err2, ok := retErr.(thrift.TProtocolException)
 		if !ok {
 			t.Fatal("Expected a TProtocolException")
 		}
-
-		if err2.TypeId() != err.TypeId() {
-			t.Fatal("Expected a same error type id")
+		if err2.TypeId() != thrift.INVALID_DATA {
+			t.Fatal("Expected INVALID_DATA error")
 		}
+	}
+}
 
+// TestCase: call and reply with exception workflow in the client.
+// Setup mock to fail at a certain position. Return true if position exists otherwise false.
+func prepareClientCallException(protocol *MockTProtocol, failAt int, failWith error) bool {
+	var err error = nil
+
+	// No need to test failure in this block, because it is covered in other test cases
+	last := protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1))
+	last = protocol.EXPECT().WriteStructBegin("testString_args").After(last)
+	last = protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)).After(last)
+	last = protocol.EXPECT().WriteString("test").After(last)
+	last = protocol.EXPECT().WriteFieldEnd().After(last)
+	last = protocol.EXPECT().WriteFieldStop().After(last)
+	last = protocol.EXPECT().WriteStructEnd().After(last)
+	last = protocol.EXPECT().WriteMessageEnd().After(last)
+	last = protocol.EXPECT().Flush().After(last)
+
+	// Reading the exception, might fail.
+	if failAt == 0 {
+		err = failWith
+	}
+	last = protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.EXCEPTION, int32(1), err).After(last)
+	if failAt == 0 {
+		return true
+	}
+	if failAt == 1 {
+		err = failWith
+	}
+	last = protocol.EXPECT().ReadStructBegin().Return("TApplicationException", err).After(last)
+	if failAt == 1 {
+		return true
+	}
+	if failAt == 2 {
+		err = failWith
+	}
+	last = protocol.EXPECT().ReadFieldBegin().Return("message", thrift.TType(thrift.STRING), int16(1), err).After(last)
+	if failAt == 2 {
+		return true
+	}
+	if failAt == 3 {
+		err = failWith
+	}
+	last = protocol.EXPECT().ReadString().Return("test", err).After(last)
+	if failAt == 3 {
+		return true
+	}
+	if failAt == 4 {
+		err = failWith
+	}
+	last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
+	if failAt == 4 {
+		return true
+	}
+	if failAt == 5 {
+		err = failWith
+	}
+	last = protocol.EXPECT().ReadFieldBegin().Return("type", thrift.TType(thrift.I32), int16(2), err).After(last)
+	if failAt == 5 {
+		return true
+	}
+	if failAt == 6 {
+		err = failWith
+	}
+	last = protocol.EXPECT().ReadI32().Return(int32(thrift.PROTOCOL_ERROR), err).After(last)
+	if failAt == 6 {
+		return true
+	}
+	if failAt == 7 {
+		err = failWith
+	}
+	last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
+	if failAt == 7 {
+		return true
+	}
+	if failAt == 8 {
+		err = failWith
+	}
+	last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(2), err).After(last)
+	if failAt == 8 {
+		return true
+	}
+	if failAt == 9 {
+		err = failWith
+	}
+	last = protocol.EXPECT().ReadStructEnd().Return(err).After(last)
+	if failAt == 9 {
+		return true
+	}
+	if failAt == 10 {
+		err = failWith
+	}
+	last = protocol.EXPECT().ReadMessageEnd().Return(err).After(last)
+	if failAt == 10 {
+		return true
+	}
+
+	return false
+}
+
+// TestCase: call and reply with exception workflow in the client.
+func TestClientCallException(t *testing.T) {
+	mockCtrl := gomock.NewController(t)
+	transport := thrift.NewTMemoryBuffer()
+
+	err := thrift.NewTTransportException(thrift.TIMED_OUT, "test")
+	for i := 0; ; i++ {
+		protocol := NewMockTProtocol(mockCtrl)
+		willComplete := !prepareClientCallException(protocol, i, err)
+
+		client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
+		_, retErr := client.TestString("test")
 		mockCtrl.Finish()
+
+		if !willComplete {
+			err2, ok := retErr.(thrift.TTransportException)
+			if !ok {
+				t.Fatal("Expected a TTransportException")
+			}
+			if err2.TypeId() != thrift.TIMED_OUT {
+				t.Fatal("Expected TIMED_OUT error")
+			}
+		} else {
+			err2, ok := retErr.(thrift.TApplicationException)
+			if !ok {
+				t.Fatal("Expected a TApplicationException")
+			}
+			if err2.TypeId() != thrift.PROTOCOL_ERROR {
+				t.Fatal("Expected PROTOCOL_ERROR error")
+			}
+			break
+		}
+	}
+}
+
+// TestCase: Mismatching sequence id has been received in the client.
+func TestClientSeqIdMismatch(t *testing.T) {
+	mockCtrl := gomock.NewController(t)
+	transport := thrift.NewTMemoryBuffer()
+	protocol := NewMockTProtocol(mockCtrl)
+	gomock.InOrder(
+		protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
+		protocol.EXPECT().WriteStructBegin("testString_args"),
+		protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
+		protocol.EXPECT().WriteString("test"),
+		protocol.EXPECT().WriteFieldEnd(),
+		protocol.EXPECT().WriteFieldStop(),
+		protocol.EXPECT().WriteStructEnd(),
+		protocol.EXPECT().WriteMessageEnd(),
+		protocol.EXPECT().Flush(),
+		protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.REPLY, int32(2), nil),
+	)
+
+	client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
+	_, err := client.TestString("test")
+	mockCtrl.Finish()
+	appErr, ok := err.(thrift.TApplicationException)
+	if !ok {
+		t.Fatal("Expected TApplicationException")
+	}
+	if appErr.TypeId() != thrift.BAD_SEQUENCE_ID {
+		t.Fatal("Expected BAD_SEQUENCE_ID error")
+	}
+}
+
+// TestCase: Wrong method name has been received in the client.
+func TestClientWrongMethodName(t *testing.T) {
+	mockCtrl := gomock.NewController(t)
+	transport := thrift.NewTMemoryBuffer()
+	protocol := NewMockTProtocol(mockCtrl)
+	gomock.InOrder(
+		protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
+		protocol.EXPECT().WriteStructBegin("testString_args"),
+		protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
+		protocol.EXPECT().WriteString("test"),
+		protocol.EXPECT().WriteFieldEnd(),
+		protocol.EXPECT().WriteFieldStop(),
+		protocol.EXPECT().WriteStructEnd(),
+		protocol.EXPECT().WriteMessageEnd(),
+		protocol.EXPECT().Flush(),
+		protocol.EXPECT().ReadMessageBegin().Return("unknown", thrift.REPLY, int32(1), nil),
+	)
+
+	client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
+	_, err := client.TestString("test")
+	mockCtrl.Finish()
+	appErr, ok := err.(thrift.TApplicationException)
+	if !ok {
+		t.Fatal("Expected TApplicationException")
+	}
+	if appErr.TypeId() != thrift.WRONG_METHOD_NAME {
+		t.Fatal("Expected WRONG_METHOD_NAME error")
+	}
+}
+
+// TestCase: Wrong message type has been received in the client.
+func TestClientWrongMessageType(t *testing.T) {
+	mockCtrl := gomock.NewController(t)
+	transport := thrift.NewTMemoryBuffer()
+	protocol := NewMockTProtocol(mockCtrl)
+	gomock.InOrder(
+		protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
+		protocol.EXPECT().WriteStructBegin("testString_args"),
+		protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
+		protocol.EXPECT().WriteString("test"),
+		protocol.EXPECT().WriteFieldEnd(),
+		protocol.EXPECT().WriteFieldStop(),
+		protocol.EXPECT().WriteStructEnd(),
+		protocol.EXPECT().WriteMessageEnd(),
+		protocol.EXPECT().Flush(),
+		protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil),
+	)
+
+	client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
+	_, err := client.TestString("test")
+	mockCtrl.Finish()
+	appErr, ok := err.(thrift.TApplicationException)
+	if !ok {
+		t.Fatal("Expected TApplicationException")
+	}
+	if appErr.TypeId() != thrift.INVALID_MESSAGE_TYPE_EXCEPTION {
+		t.Fatal("Expected INVALID_MESSAGE_TYPE_EXCEPTION error")
 	}
 }