THRIFT-5326: Expand TException interface in go library
Client: go
Add TExceptionType enum type, and add
TExceptionType() TExceptionType
function to TException definition.
Also make TProtocolException unwrap-able.
diff --git a/CHANGES.md b/CHANGES.md
index 663c4c1..8e4d08e 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -17,6 +17,7 @@
- [THRIFT-5233](https://issues.apache.org/jira/browse/THRIFT-5233) - go: Now all Read*, Write* and Skip functions in TProtocol accept context arg
- [THRIFT-5152](https://issues.apache.org/jira/browse/THRIFT-5152) - go: TSocket and TSSLSocket now have separated connect timeout and socket timeout
- c++: dropped support for Windows XP
+- [THRIFT-5326](https://issues.apache.org/jira/browse/THRIFT-5326) - go: TException interface now has a new function: TExceptionType
### Java
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc
index 3bb2a5c..49d8bc1 100644
--- a/compiler/cpp/src/thrift/generate/t_go_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc
@@ -1493,8 +1493,15 @@
if (is_exception) {
out << indent() << "func (p *" << tstruct_name << ") Error() string {" << endl;
- out << indent() << " return p.String()" << endl;
+ out << indent() << indent() << "return p.String()" << endl;
out << indent() << "}" << endl << endl;
+
+ out << indent() << "func (" << tstruct_name << ") TExceptionType() thrift.TExceptionType {" << endl;
+ out << indent() << indent() << "return thrift.TExceptionTypeCompiled" << endl;
+ out << indent() << "}" << endl << endl;
+
+ out << indent() << "var _ thrift.TException = (*" << tstruct_name << ")(nil)"
+ << endl << endl;
}
}
@@ -2700,8 +2707,8 @@
f_types_ << indent() << "func (p *" << serviceName
<< "Processor) Process(ctx context.Context, iprot, oprot thrift.TProtocol) (success bool, err "
"thrift.TException) {" << endl;
- f_types_ << indent() << " name, _, seqId, err := iprot.ReadMessageBegin(ctx)" << endl;
- f_types_ << indent() << " if err != nil { return false, err }" << endl;
+ f_types_ << indent() << " name, _, seqId, err2 := iprot.ReadMessageBegin(ctx)" << endl;
+ f_types_ << indent() << " if err2 != nil { return false, thrift.WrapTException(err2) }" << endl;
f_types_ << indent() << " if processor, ok := p.GetProcessorFunction(name); ok {" << endl;
f_types_ << indent() << " return processor.Process(ctx, seqId, iprot, oprot)" << endl;
f_types_ << indent() << " }" << endl;
@@ -2767,11 +2774,12 @@
"thrift.TException) {" << endl;
indent_up();
f_types_ << indent() << "args := " << argsname << "{}" << endl;
- f_types_ << indent() << "if err = args." << read_method_name_ << "(ctx, iprot); err != nil {" << endl;
+ f_types_ << indent() << "var err2 error" << endl;
+ f_types_ << indent() << "if err2 = args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << endl;
f_types_ << indent() << " iprot.ReadMessageEnd(ctx)" << endl;
if (!tfunction->is_oneway()) {
f_types_ << indent()
- << " x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error())"
+ << " x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())"
<< endl;
f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
<< "\", thrift.EXCEPTION, seqId)" << endl;
@@ -2779,7 +2787,7 @@
f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl;
f_types_ << indent() << " oprot.Flush(ctx)" << endl;
}
- f_types_ << indent() << " return false, err" << endl;
+ f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl;
f_types_ << indent() << "}" << endl;
f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl << endl;
@@ -2842,7 +2850,6 @@
f_types_ << indent() << "var retval " << type_to_go_type(tfunction->get_returntype()) << endl;
}
- f_types_ << indent() << "var err2 error" << endl;
f_types_ << indent() << "if ";
if (!tfunction->is_oneway()) {
@@ -2892,7 +2899,7 @@
if (!tfunction->is_oneway()) {
// Avoid writing the error to the wire if it's ErrAbandonRequest
f_types_ << indent() << " if err2 == thrift.ErrAbandonRequest {" << endl;
- f_types_ << indent() << " return false, err2" << endl;
+ f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl;
f_types_ << indent() << " }" << endl;
f_types_ << indent() << " x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "
@@ -2905,7 +2912,7 @@
f_types_ << indent() << " oprot.Flush(ctx)" << endl;
}
- f_types_ << indent() << " return true, err2" << endl;
+ f_types_ << indent() << " return true, thrift.WrapTException(err2)" << endl;
if (!x_fields.empty()) {
f_types_ << indent() << "}" << endl;
@@ -2931,17 +2938,17 @@
f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(ctx, \""
<< escape_string(tfunction->get_name()) << "\", thrift.REPLY, seqId); err2 != nil {"
<< endl;
- f_types_ << indent() << " err = err2" << endl;
+ f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
f_types_ << indent() << "}" << endl;
f_types_ << indent() << "if err2 = result." << write_method_name_ << "(ctx, oprot); err == nil && err2 != nil {" << endl;
- f_types_ << indent() << " err = err2" << endl;
+ f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
f_types_ << indent() << "}" << endl;
f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {"
<< endl;
- f_types_ << indent() << " err = err2" << endl;
+ f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
f_types_ << indent() << "}" << endl;
f_types_ << indent() << "if err2 = oprot.Flush(ctx); err == nil && err2 != nil {" << endl;
- f_types_ << indent() << " err = err2" << endl;
+ f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
f_types_ << indent() << "}" << endl;
f_types_ << indent() << "if err != nil {" << endl;
f_types_ << indent() << " return" << endl;
diff --git a/lib/go/test/tests/thrifttest_handler.go b/lib/go/test/tests/thrifttest_handler.go
index 31b9ee2..7b115ec 100644
--- a/lib/go/test/tests/thrifttest_handler.go
+++ b/lib/go/test/tests/thrifttest_handler.go
@@ -179,7 +179,7 @@
x.Message = arg
return x
} else if arg == "TException" {
- return thrift.TException(errors.New(arg))
+ return thrift.WrapTException(errors.New(arg))
} else {
return nil
}
diff --git a/lib/go/thrift/application_exception.go b/lib/go/thrift/application_exception.go
index 6de37ee..32d5b01 100644
--- a/lib/go/thrift/application_exception.go
+++ b/lib/go/thrift/application_exception.go
@@ -64,6 +64,12 @@
type_ int32
}
+var _ TApplicationException = (*tApplicationException)(nil)
+
+func (tApplicationException) TExceptionType() TExceptionType {
+ return TExceptionTypeApplication
+}
+
func (e tApplicationException) Error() string {
if e.message != "" {
return e.message
diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go
index 25e6d0c..a49225d 100644
--- a/lib/go/thrift/compact_protocol.go
+++ b/lib/go/thrift/compact_protocol.go
@@ -845,7 +845,7 @@
case COMPACT_STRUCT:
return STRUCT, nil
}
- return STOP, TException(fmt.Errorf("don't know what type: %v", t&0x0f))
+ return STOP, NewTProtocolException(fmt.Errorf("don't know what type: %v", t&0x0f))
}
// Given a TType value, find the appropriate TCompactProtocol.Types constant.
diff --git a/lib/go/thrift/exception.go b/lib/go/thrift/exception.go
index ea8d6f6..b6885fa 100644
--- a/lib/go/thrift/exception.go
+++ b/lib/go/thrift/exception.go
@@ -26,19 +26,86 @@
// Generic Thrift exception
type TException interface {
error
+
+ TExceptionType() TExceptionType
}
// Prepends additional information to an error without losing the Thrift exception interface
func PrependError(prepend string, err error) error {
- if t, ok := err.(TTransportException); ok {
- return NewTTransportException(t.TypeId(), prepend+t.Error())
- }
- if t, ok := err.(TProtocolException); ok {
- return NewTProtocolExceptionWithType(t.TypeId(), errors.New(prepend+err.Error()))
- }
- if t, ok := err.(TApplicationException); ok {
- return NewTApplicationException(t.TypeId(), prepend+t.Error())
+ msg := prepend + err.Error()
+
+ if te, ok := err.(TException); ok {
+ switch te.TExceptionType() {
+ case TExceptionTypeTransport:
+ if t, ok := err.(TTransportException); ok {
+ return NewTTransportException(t.TypeId(), msg)
+ }
+ case TExceptionTypeProtocol:
+ if t, ok := err.(TProtocolException); ok {
+ return NewTProtocolExceptionWithType(t.TypeId(), errors.New(msg))
+ }
+ case TExceptionTypeApplication:
+ if t, ok := err.(TApplicationException); ok {
+ return NewTApplicationException(t.TypeId(), msg)
+ }
+ }
+
+ return wrappedTException{
+ err: errors.New(msg),
+ tExceptionType: te.TExceptionType(),
+ }
}
- return errors.New(prepend + err.Error())
+ return errors.New(msg)
}
+
+// TExceptionType is an enum type to categorize different "subclasses" of TExceptions.
+type TExceptionType byte
+
+// TExceptionType values
+const (
+ TExceptionTypeUnknown TExceptionType = iota
+ TExceptionTypeCompiled // TExceptions defined in thrift files and generated by thrift compiler
+ TExceptionTypeApplication // TApplicationExceptions
+ TExceptionTypeProtocol // TProtocolExceptions
+ TExceptionTypeTransport // TTransportExceptions
+)
+
+// WrapTException wraps an error into TException.
+//
+// If err is nil or already TException, it's returned as-is.
+// Otherwise it will be wraped into TException with TExceptionType() returning
+// TExceptionTypeUnknown, and Unwrap() returning the original error.
+func WrapTException(err error) TException {
+ if err == nil {
+ return nil
+ }
+
+ if te, ok := err.(TException); ok {
+ return te
+ }
+
+ return wrappedTException{
+ err: err,
+ tExceptionType: TExceptionTypeUnknown,
+ }
+}
+
+type wrappedTException struct {
+ err error
+ tExceptionType TExceptionType
+}
+
+func (w wrappedTException) Error() string {
+ return w.err.Error()
+}
+
+func (w wrappedTException) TExceptionType() TExceptionType {
+ return w.tExceptionType
+}
+
+func (w wrappedTException) Unwrap() error {
+ return w.err
+}
+
+var _ TException = wrappedTException{}
diff --git a/lib/go/thrift/multiplexed_protocol.go b/lib/go/thrift/multiplexed_protocol.go
index 2f7997e..cacbf6b 100644
--- a/lib/go/thrift/multiplexed_protocol.go
+++ b/lib/go/thrift/multiplexed_protocol.go
@@ -192,10 +192,10 @@
func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) {
name, typeId, seqid, err := in.ReadMessageBegin(ctx)
if err != nil {
- return false, err
+ return false, NewTProtocolException(err)
}
if typeId != CALL && typeId != ONEWAY {
- return false, fmt.Errorf("Unexpected message type %v", typeId)
+ return false, NewTProtocolException(fmt.Errorf("Unexpected message type %v", typeId))
}
//extract the service name
v := strings.SplitN(name, MULTIPLEXED_SEPARATOR, 2)
@@ -204,11 +204,17 @@
smb := NewStoredMessageProtocol(in, name, typeId, seqid)
return t.DefaultProcessor.Process(ctx, smb, out)
}
- return false, fmt.Errorf("Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?", name)
+ return false, NewTProtocolException(fmt.Errorf(
+ "Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?",
+ name,
+ ))
}
actualProcessor, ok := t.serviceProcessorMap[v[0]]
if !ok {
- return false, fmt.Errorf("Service name not found: %s. Did you forget to call registerProcessor()?", v[0])
+ return false, NewTProtocolException(fmt.Errorf(
+ "Service name not found: %s. Did you forget to call registerProcessor()?",
+ v[0],
+ ))
}
smb := NewStoredMessageProtocol(in, v[1], typeId, seqid)
return actualProcessor.Process(ctx, smb, out)
diff --git a/lib/go/thrift/protocol_exception.go b/lib/go/thrift/protocol_exception.go
index 29ab75d..b088caf 100644
--- a/lib/go/thrift/protocol_exception.go
+++ b/lib/go/thrift/protocol_exception.go
@@ -40,8 +40,14 @@
)
type tProtocolException struct {
- typeId int
- message string
+ typeId int
+ err error
+}
+
+var _ TProtocolException = (*tProtocolException)(nil)
+
+func (tProtocolException) TExceptionType() TExceptionType {
+ return TExceptionTypeProtocol
}
func (p *tProtocolException) TypeId() int {
@@ -49,11 +55,15 @@
}
func (p *tProtocolException) String() string {
- return p.message
+ return p.err.Error()
}
func (p *tProtocolException) Error() string {
- return p.message
+ return p.err.Error()
+}
+
+func (p *tProtocolException) Unwrap() error {
+ return p.err
}
func NewTProtocolException(err error) TProtocolException {
@@ -64,14 +74,23 @@
return e
}
if _, ok := err.(base64.CorruptInputError); ok {
- return &tProtocolException{INVALID_DATA, err.Error()}
+ return &tProtocolException{
+ typeId: INVALID_DATA,
+ err: err,
+ }
}
- return &tProtocolException{UNKNOWN_PROTOCOL_EXCEPTION, err.Error()}
+ return &tProtocolException{
+ typeId: UNKNOWN_PROTOCOL_EXCEPTION,
+ err: err,
+ }
}
func NewTProtocolExceptionWithType(errType int, err error) TProtocolException {
if err == nil {
return nil
}
- return &tProtocolException{errType, err.Error()}
+ return &tProtocolException{
+ typeId: errType,
+ err: err,
+ }
}
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index e9fea86..ca0e61d 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -315,7 +315,9 @@
}
ok, err := processor.Process(ctx, inputProtocol, outputProtocol)
- if err == ErrAbandonRequest {
+ // Once we dropped support for pre-go1.13 this can be replaced by:
+ // errors.Is(err, ErrAbandonRequest)
+ if unwrapError(err) == ErrAbandonRequest {
return client.Close()
}
if _, ok := err.(TTransportException); ok && err != nil {
@@ -330,3 +332,17 @@
}
return nil
}
+
+type unwrapper interface {
+ Unwrap() error
+}
+
+func unwrapError(err error) error {
+ for {
+ if u, ok := err.(unwrapper); ok {
+ err = u.Unwrap()
+ } else {
+ return err
+ }
+ }
+}
diff --git a/lib/go/thrift/transport_exception.go b/lib/go/thrift/transport_exception.go
index 16193ee..cf2cc00 100644
--- a/lib/go/thrift/transport_exception.go
+++ b/lib/go/thrift/transport_exception.go
@@ -48,6 +48,12 @@
err error
}
+var _ TTransportException = (*tTransportException)(nil)
+
+func (tTransportException) TExceptionType() TExceptionType {
+ return TExceptionTypeTransport
+}
+
func (p *tTransportException) TypeId() int {
return p.typeId
}
diff --git a/lib/go/thrift/transport_exception_test.go b/lib/go/thrift/transport_exception_test.go
index fb1dc26..57386cb 100644
--- a/lib/go/thrift/transport_exception_test.go
+++ b/lib/go/thrift/transport_exception_test.go
@@ -36,10 +36,6 @@
return fmt.Sprintf("Timeout: %v", t.timedout)
}
-type unwrapper interface {
- Unwrap() error
-}
-
func TestTExceptionTimeout(t *testing.T) {
timeout := &timeout{true}
exception := NewTTransportExceptionFromError(timeout)