THRIFT-2868 Enhance error handling in the Go client
Client: Go
Patch: Chi Vinh Le <cvl@chinet.info>
This closes #297
diff --git a/compiler/cpp/src/generate/t_go_generator.cc b/compiler/cpp/src/generate/t_go_generator.cc
index 691684a..e2efcd3 100644
--- a/compiler/cpp/src/generate/t_go_generator.cc
+++ b/compiler/cpp/src/generate/t_go_generator.cc
@@ -1765,10 +1765,22 @@
f_service_ << indent() << " iprot = p.ProtocolFactory.GetProtocol(p.Transport)" << endl;
f_service_ << indent() << " p.InputProtocol = iprot" << endl;
f_service_ << indent() << "}" << endl;
- f_service_ << indent() << "_, mTypeId, seqId, err := iprot.ReadMessageBegin()" << endl;
+ f_service_ << indent() << "method, mTypeId, seqId, err := iprot.ReadMessageBegin()" << endl;
f_service_ << indent() << "if err != nil {" << endl;
f_service_ << indent() << " return" << endl;
f_service_ << indent() << "}" << endl;
+ f_service_ << indent() << "if method != \"" << (*f_iter)->get_name() << "\" {" << endl;
+ f_service_ << indent() << " err = thrift.NewTApplicationException("
+ << "thrift.WRONG_METHOD_NAME, \"" << (*f_iter)->get_name()
+ << " failed: wrong method name\")" << endl;
+ f_service_ << indent() << " return" << endl;
+ f_service_ << indent() << "}" << endl;
+ f_service_ << indent() << "if p.SeqId != seqId {" << endl;
+ f_service_ << indent() << " err = thrift.NewTApplicationException("
+ << "thrift.BAD_SEQUENCE_ID, \"" << (*f_iter)->get_name()
+ << " failed: out of sequence response\")" << endl;
+ f_service_ << indent() << " return" << endl;
+ f_service_ << indent() << "}" << endl;
f_service_ << indent() << "if mTypeId == thrift.EXCEPTION {" << endl;
f_service_ << indent() << " " << error
<< " := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "
@@ -1784,9 +1796,10 @@
f_service_ << indent() << " err = " << error2 << endl;
f_service_ << indent() << " return" << endl;
f_service_ << indent() << "}" << endl;
- f_service_ << indent() << "if p.SeqId != seqId {" << endl;
- f_service_ << indent() << " err = thrift.NewTApplicationException(thrift.BAD_SEQUENCE_ID, \""
- << (*f_iter)->get_name() << " failed: out of sequence response\")" << endl;
+ f_service_ << indent() << "if mTypeId != thrift.REPLY {" << endl;
+ f_service_ << indent() << " err = thrift.NewTApplicationException("
+ << "thrift.INVALID_MESSAGE_TYPE_EXCEPTION, \"" << (*f_iter)->get_name()
+ << " failed: invalid message type\")" << endl;
f_service_ << indent() << " return" << endl;
f_service_ << indent() << "}" << endl;
f_service_ << indent() << "result := " << resultname << "{}" << endl;
diff --git a/lib/go/test/ErrorTest.thrift b/lib/go/test/ErrorTest.thrift
index eeaeb95..33b6644 100644
--- a/lib/go/test/ErrorTest.thrift
+++ b/lib/go/test/ErrorTest.thrift
@@ -31,4 +31,5 @@
service ErrorTest
{
TestStruct testStruct(1: TestStruct thing)
+ string testString(1: string s)
}
diff --git a/lib/go/test/tests/client_error_test.go b/lib/go/test/tests/client_error_test.go
index 11ea1c7..8fa81cc 100644
--- a/lib/go/test/tests/client_error_test.go
+++ b/lib/go/test/tests/client_error_test.go
@@ -27,14 +27,15 @@
"thrift"
)
+// TestCase: Comprehensive call and reply workflow in the client.
// Setup mock to fail at a certain position. Return true if position exists otherwise false.
-func prepareClientProtocolFailure(protocol *MockTProtocol, failAt int, failWith error) bool {
+func prepareClientCallReply(protocol *MockTProtocol, failAt int, failWith error) bool {
var err error = nil
if failAt == 0 {
err = failWith
}
- last := protocol.EXPECT().WriteMessageBegin("testStruct", thrift.TMessageType(1), int32(1)).Return(err)
+ last := protocol.EXPECT().WriteMessageBegin("testStruct", thrift.CALL, int32(1)).Return(err)
if failAt == 0 {
return true
}
@@ -217,7 +218,7 @@
if failAt == 26 {
err = failWith
}
- last = protocol.EXPECT().ReadMessageBegin().Return("testStruct", thrift.TMessageType(1), int32(1), err).After(last)
+ last = protocol.EXPECT().ReadMessageBegin().Return("testStruct", thrift.REPLY, int32(1), err).After(last)
if failAt == 26 {
return true
}
@@ -392,6 +393,8 @@
return false
}
+// TestCase: Comprehensive call and reply workflow in the client.
+// Expecting TTransportError on fail.
func TestClientReportTTransportErrors(t *testing.T) {
mockCtrl := gomock.NewController(t)
transport := thrift.NewTMemoryBuffer()
@@ -405,24 +408,25 @@
err := thrift.NewTTransportException(thrift.TIMED_OUT, "test")
for i := 0; ; i++ {
protocol := NewMockTProtocol(mockCtrl)
- if !prepareClientProtocolFailure(protocol, i, err) {
+ if !prepareClientCallReply(protocol, i, err) {
return
}
client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
_, retErr := client.TestStruct(thing)
+ mockCtrl.Finish()
err2, ok := retErr.(thrift.TTransportException)
if !ok {
t.Fatal("Expected a TTrasportException")
}
- if err2.TypeId() != err.TypeId() {
- t.Fatal("Expected a same error type id")
+ if err2.TypeId() != thrift.TIMED_OUT {
+ t.Fatal("Expected TIMED_OUT error")
}
-
- mockCtrl.Finish()
}
}
+// TestCase: Comprehensive call and reply workflow in the client.
+// Expecting TTProtocolErrors on fail.
func TestClientReportTProtocolErrors(t *testing.T) {
mockCtrl := gomock.NewController(t)
transport := thrift.NewTMemoryBuffer()
@@ -436,20 +440,241 @@
err := thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, errors.New("test"))
for i := 0; ; i++ {
protocol := NewMockTProtocol(mockCtrl)
- if !prepareClientProtocolFailure(protocol, i, err) {
+ if !prepareClientCallReply(protocol, i, err) {
return
}
client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
_, retErr := client.TestStruct(thing)
+ mockCtrl.Finish()
err2, ok := retErr.(thrift.TProtocolException)
if !ok {
t.Fatal("Expected a TProtocolException")
}
-
- if err2.TypeId() != err.TypeId() {
- t.Fatal("Expected a same error type id")
+ if err2.TypeId() != thrift.INVALID_DATA {
+ t.Fatal("Expected INVALID_DATA error")
}
+ }
+}
+// TestCase: call and reply with exception workflow in the client.
+// Setup mock to fail at a certain position. Return true if position exists otherwise false.
+func prepareClientCallException(protocol *MockTProtocol, failAt int, failWith error) bool {
+ var err error = nil
+
+ // No need to test failure in this block, because it is covered in other test cases
+ last := protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1))
+ last = protocol.EXPECT().WriteStructBegin("testString_args").After(last)
+ last = protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)).After(last)
+ last = protocol.EXPECT().WriteString("test").After(last)
+ last = protocol.EXPECT().WriteFieldEnd().After(last)
+ last = protocol.EXPECT().WriteFieldStop().After(last)
+ last = protocol.EXPECT().WriteStructEnd().After(last)
+ last = protocol.EXPECT().WriteMessageEnd().After(last)
+ last = protocol.EXPECT().Flush().After(last)
+
+ // Reading the exception, might fail.
+ if failAt == 0 {
+ err = failWith
+ }
+ last = protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.EXCEPTION, int32(1), err).After(last)
+ if failAt == 0 {
+ return true
+ }
+ if failAt == 1 {
+ err = failWith
+ }
+ last = protocol.EXPECT().ReadStructBegin().Return("TApplicationException", err).After(last)
+ if failAt == 1 {
+ return true
+ }
+ if failAt == 2 {
+ err = failWith
+ }
+ last = protocol.EXPECT().ReadFieldBegin().Return("message", thrift.TType(thrift.STRING), int16(1), err).After(last)
+ if failAt == 2 {
+ return true
+ }
+ if failAt == 3 {
+ err = failWith
+ }
+ last = protocol.EXPECT().ReadString().Return("test", err).After(last)
+ if failAt == 3 {
+ return true
+ }
+ if failAt == 4 {
+ err = failWith
+ }
+ last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
+ if failAt == 4 {
+ return true
+ }
+ if failAt == 5 {
+ err = failWith
+ }
+ last = protocol.EXPECT().ReadFieldBegin().Return("type", thrift.TType(thrift.I32), int16(2), err).After(last)
+ if failAt == 5 {
+ return true
+ }
+ if failAt == 6 {
+ err = failWith
+ }
+ last = protocol.EXPECT().ReadI32().Return(int32(thrift.PROTOCOL_ERROR), err).After(last)
+ if failAt == 6 {
+ return true
+ }
+ if failAt == 7 {
+ err = failWith
+ }
+ last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
+ if failAt == 7 {
+ return true
+ }
+ if failAt == 8 {
+ err = failWith
+ }
+ last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(2), err).After(last)
+ if failAt == 8 {
+ return true
+ }
+ if failAt == 9 {
+ err = failWith
+ }
+ last = protocol.EXPECT().ReadStructEnd().Return(err).After(last)
+ if failAt == 9 {
+ return true
+ }
+ if failAt == 10 {
+ err = failWith
+ }
+ last = protocol.EXPECT().ReadMessageEnd().Return(err).After(last)
+ if failAt == 10 {
+ return true
+ }
+
+ return false
+}
+
+// TestCase: call and reply with exception workflow in the client.
+func TestClientCallException(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ transport := thrift.NewTMemoryBuffer()
+
+ err := thrift.NewTTransportException(thrift.TIMED_OUT, "test")
+ for i := 0; ; i++ {
+ protocol := NewMockTProtocol(mockCtrl)
+ willComplete := !prepareClientCallException(protocol, i, err)
+
+ client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
+ _, retErr := client.TestString("test")
mockCtrl.Finish()
+
+ if !willComplete {
+ err2, ok := retErr.(thrift.TTransportException)
+ if !ok {
+ t.Fatal("Expected a TTransportException")
+ }
+ if err2.TypeId() != thrift.TIMED_OUT {
+ t.Fatal("Expected TIMED_OUT error")
+ }
+ } else {
+ err2, ok := retErr.(thrift.TApplicationException)
+ if !ok {
+ t.Fatal("Expected a TApplicationException")
+ }
+ if err2.TypeId() != thrift.PROTOCOL_ERROR {
+ t.Fatal("Expected PROTOCOL_ERROR error")
+ }
+ break
+ }
+ }
+}
+
+// TestCase: Mismatching sequence id has been received in the client.
+func TestClientSeqIdMismatch(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ transport := thrift.NewTMemoryBuffer()
+ protocol := NewMockTProtocol(mockCtrl)
+ gomock.InOrder(
+ protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
+ protocol.EXPECT().WriteStructBegin("testString_args"),
+ protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
+ protocol.EXPECT().WriteString("test"),
+ protocol.EXPECT().WriteFieldEnd(),
+ protocol.EXPECT().WriteFieldStop(),
+ protocol.EXPECT().WriteStructEnd(),
+ protocol.EXPECT().WriteMessageEnd(),
+ protocol.EXPECT().Flush(),
+ protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.REPLY, int32(2), nil),
+ )
+
+ client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
+ _, err := client.TestString("test")
+ mockCtrl.Finish()
+ appErr, ok := err.(thrift.TApplicationException)
+ if !ok {
+ t.Fatal("Expected TApplicationException")
+ }
+ if appErr.TypeId() != thrift.BAD_SEQUENCE_ID {
+ t.Fatal("Expected BAD_SEQUENCE_ID error")
+ }
+}
+
+// TestCase: Wrong method name has been received in the client.
+func TestClientWrongMethodName(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ transport := thrift.NewTMemoryBuffer()
+ protocol := NewMockTProtocol(mockCtrl)
+ gomock.InOrder(
+ protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
+ protocol.EXPECT().WriteStructBegin("testString_args"),
+ protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
+ protocol.EXPECT().WriteString("test"),
+ protocol.EXPECT().WriteFieldEnd(),
+ protocol.EXPECT().WriteFieldStop(),
+ protocol.EXPECT().WriteStructEnd(),
+ protocol.EXPECT().WriteMessageEnd(),
+ protocol.EXPECT().Flush(),
+ protocol.EXPECT().ReadMessageBegin().Return("unknown", thrift.REPLY, int32(1), nil),
+ )
+
+ client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
+ _, err := client.TestString("test")
+ mockCtrl.Finish()
+ appErr, ok := err.(thrift.TApplicationException)
+ if !ok {
+ t.Fatal("Expected TApplicationException")
+ }
+ if appErr.TypeId() != thrift.WRONG_METHOD_NAME {
+ t.Fatal("Expected WRONG_METHOD_NAME error")
+ }
+}
+
+// TestCase: Wrong message type has been received in the client.
+func TestClientWrongMessageType(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ transport := thrift.NewTMemoryBuffer()
+ protocol := NewMockTProtocol(mockCtrl)
+ gomock.InOrder(
+ protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
+ protocol.EXPECT().WriteStructBegin("testString_args"),
+ protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
+ protocol.EXPECT().WriteString("test"),
+ protocol.EXPECT().WriteFieldEnd(),
+ protocol.EXPECT().WriteFieldStop(),
+ protocol.EXPECT().WriteStructEnd(),
+ protocol.EXPECT().WriteMessageEnd(),
+ protocol.EXPECT().Flush(),
+ protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil),
+ )
+
+ client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
+ _, err := client.TestString("test")
+ mockCtrl.Finish()
+ appErr, ok := err.(thrift.TApplicationException)
+ if !ok {
+ t.Fatal("Expected TApplicationException")
+ }
+ if appErr.TypeId() != thrift.INVALID_MESSAGE_TYPE_EXCEPTION {
+ t.Fatal("Expected INVALID_MESSAGE_TYPE_EXCEPTION error")
}
}