THRIFT-5527: Don't swallow idl exceptions in Process function
Client: go
This allows ProcessorMiddlewares to access such exceptions, unless
there's a network error writing the response (which takes priority).
While I'm here, also make the indentation of Process function more
consistent, and make it consistent on returning false and an error when
the reading/writing fails.
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc
index 7897b62..3b885f1 100644
--- a/compiler/cpp/src/thrift/generate/t_go_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc
@@ -959,8 +959,8 @@
// If not writing constants, and there are enums, need extra imports.
if (!consts && get_program()->get_enums().size() > 0) {
system_packages.push_back("database/sql/driver");
- system_packages.push_back("errors");
}
+ system_packages.push_back("errors");
system_packages.push_back("fmt");
system_packages.push_back("time");
// For the thrift import, always do rename import to make sure it's called thrift.
@@ -980,6 +980,7 @@
"// (needed to ensure safety because of naive import list construction.)\n"
"var _ = thrift.ZERO\n"
"var _ = fmt.Printf\n"
+ "var _ = errors.New\n"
"var _ = context.Background\n"
"var _ = time.Now\n"
"var _ = bytes.Equal\n\n");
@@ -2964,21 +2965,27 @@
<< ") Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err "
"thrift.TException) {" << endl;
indent_up();
+ string write_err;
+ if (!tfunction->is_oneway()) {
+ write_err = tmp("_write_err");
+ f_types_ << indent() << "var " << write_err << " error" << endl;
+ }
f_types_ << indent() << "args := " << argsname << "{}" << endl;
- f_types_ << indent() << "var err2 error" << endl;
- f_types_ << indent() << "if err2 = args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << endl;
- f_types_ << indent() << " iprot.ReadMessageEnd(ctx)" << endl;
+ f_types_ << indent() << "if err2 := args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl;
if (!tfunction->is_oneway()) {
f_types_ << indent()
- << " x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())"
+ << "x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())"
<< endl;
- f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
+ f_types_ << indent() << "oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
<< "\", thrift.EXCEPTION, seqId)" << endl;
- f_types_ << indent() << " x.Write(ctx, oprot)" << endl;
- f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl;
- f_types_ << indent() << " oprot.Flush(ctx)" << endl;
+ f_types_ << indent() << "x.Write(ctx, oprot)" << endl;
+ f_types_ << indent() << "oprot.WriteMessageEnd(ctx)" << endl;
+ f_types_ << indent() << "oprot.Flush(ctx)" << endl;
}
- f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl;
+ f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl << endl;
@@ -3037,9 +3044,6 @@
f_types_ << indent() << "result := " << resultname << "{}" << endl;
}
bool need_reference = type_need_reference(tfunction->get_returntype());
- if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) {
- f_types_ << indent() << "var retval " << type_to_go_type(tfunction->get_returntype()) << endl;
- }
f_types_ << indent() << "if ";
@@ -3053,7 +3057,7 @@
t_struct* arg_struct = tfunction->get_arglist();
const std::vector<t_field*>& fields = arg_struct->get_members();
vector<t_field*>::const_iterator f_iter;
- f_types_ << "err2 = p.handler." << publicize(tfunction->get_name()) << "(";
+ f_types_ << "err2 := p.handler." << publicize(tfunction->get_name()) << "(";
bool first = true;
f_types_ << "ctx";
@@ -3069,7 +3073,9 @@
}
f_types_ << "); err2 != nil {" << endl;
- f_types_ << indent() << " tickerCancel()" << endl;
+ indent_up();
+ f_types_ << indent() << "tickerCancel()" << endl;
+ f_types_ << indent() << "err = thrift.WrapTException(err2)" << endl;
t_struct* exceptions = tfunction->get_xceptions();
const vector<t_field*>& x_fields = exceptions->get_members();
@@ -3079,36 +3085,74 @@
vector<t_field*>::const_iterator xf_iter;
for (xf_iter = x_fields.begin(); xf_iter != x_fields.end(); ++xf_iter) {
- f_types_ << indent() << " case " << type_to_go_type(((*xf_iter)->get_type())) << ":"
+ f_types_ << indent() << "case " << type_to_go_type(((*xf_iter)->get_type())) << ":"
<< endl;
+ indent_up();
f_types_ << indent() << "result." << publicize((*xf_iter)->get_name()) << " = v" << endl;
+ indent_down();
}
- f_types_ << indent() << " default:" << endl;
+ f_types_ << indent() << "default:" << endl;
+ indent_up();
}
if (!tfunction->is_oneway()) {
// Avoid writing the error to the wire if it's ErrAbandonRequest
- f_types_ << indent() << " if err2 == thrift.ErrAbandonRequest {" << endl;
- f_types_ << indent() << " return false, thrift.WrapTException(err2)" << endl;
- f_types_ << indent() << " }" << endl;
+ f_types_ << indent() << "if errors.Is(err2, thrift.ErrAbandonRequest) {" << endl;
+ indent_up();
+ f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
- f_types_ << indent() << " x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "
+ string exc(tmp("_exc"));
+ f_types_ << indent() << exc << " := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "
"\"Internal error processing " << escape_string(tfunction->get_name())
<< ": \" + err2.Error())" << endl;
- f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
- << "\", thrift.EXCEPTION, seqId)" << endl;
- f_types_ << indent() << " x.Write(ctx, oprot)" << endl;
- f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl;
- f_types_ << indent() << " oprot.Flush(ctx)" << endl;
- }
- f_types_ << indent() << " return true, thrift.WrapTException(err2)" << endl;
+ f_types_ << indent() << "if err2 := oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name())
+ << "\", thrift.EXCEPTION, seqId); err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
+
+ f_types_ << indent() << "if err2 := " << exc << ".Write(ctx, oprot); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
+
+ f_types_ << indent() << "if err2 := oprot.WriteMessageEnd(ctx); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
+
+ f_types_ << indent() << "if err2 := oprot.Flush(ctx); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
+
+ f_types_ << indent() << "if " << write_err << " != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl;
+
+ // return success=true as long as writing to the wire was successful.
+ f_types_ << indent() << "return true, err" << endl;
+ }
if (!x_fields.empty()) {
- f_types_ << indent() << "}" << endl;
+ indent_down();
+ f_types_ << indent() << "}" << endl; // closes switch
}
+ indent_down();
f_types_ << indent() << "}"; // closes err2 != nil
if (!tfunction->is_oneway()) {
@@ -3126,29 +3170,47 @@
f_types_ << endl;
}
f_types_ << indent() << "tickerCancel()" << endl;
- f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(ctx, \""
+
+ f_types_ << indent() << "if err2 := oprot.WriteMessageBegin(ctx, \""
<< escape_string(tfunction->get_name()) << "\", thrift.REPLY, seqId); err2 != nil {"
<< endl;
- f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err2 = result." << write_method_name_ << "(ctx, oprot); err == nil && err2 != nil {" << endl;
- f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
+
+ f_types_ << indent() << "if err2 := result." << write_method_name_ << "(ctx, oprot); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {"
- << endl;
- f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
+
+ f_types_ << indent() << "if err2 := oprot.WriteMessageEnd(ctx); "
+ << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err2 = oprot.Flush(ctx); err == nil && err2 != nil {" << endl;
- f_types_ << indent() << " err = thrift.WrapTException(err2)" << endl;
+
+ f_types_ << indent() << "if err2 := oprot.Flush(ctx); " << write_err << " == nil && err2 != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << write_err << " = thrift.WrapTException(err2)" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err != nil {" << endl;
- f_types_ << indent() << " return" << endl;
+
+ f_types_ << indent() << "if " << write_err << " != nil {" << endl;
+ indent_up();
+ f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
+
+ // return success=true as long as writing to the wire was successful.
f_types_ << indent() << "return true, err" << endl;
} else {
f_types_ << endl;
f_types_ << indent() << "tickerCancel()" << endl;
- f_types_ << indent() << "return true, nil" << endl;
+ f_types_ << indent() << "return true, err" << endl;
}
indent_down();
f_types_ << indent() << "}" << endl << endl;
diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am
index 4b3ecda..2cca411 100644
--- a/lib/go/test/Makefile.am
+++ b/lib/go/test/Makefile.am
@@ -52,7 +52,8 @@
EqualsTest.thrift \
ConflictArgNamesTest.thrift \
ConstOptionalFieldImport.thrift \
- ConstOptionalField.thrift
+ ConstOptionalField.thrift \
+ ProcessorMiddlewareTest.thrift
mkdir -p gopath/src
grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift
$(THRIFT) $(THRIFTARGS) -r IncludesTest.thrift
@@ -84,6 +85,7 @@
$(THRIFT) $(THRIFTARGS) EqualsTest.thrift
$(THRIFT) $(THRIFTARGS) ConflictArgNamesTest.thrift
$(THRIFT) $(THRIFTARGS) -r ConstOptionalField.thrift
+ $(THRIFT) $(THRIFTARGS) ProcessorMiddlewareTest.thrift
ln -nfs ../../tests gopath/src/tests
cp -r ./dontexportrwtest gopath/src
touch gopath
@@ -106,7 +108,8 @@
./gopath/src/servicestest/container_test-remote \
./gopath/src/duplicateimportstest \
./gopath/src/equalstest \
- ./gopath/src/conflictargnamestest
+ ./gopath/src/conflictargnamestest \
+ ./gopath/src/processormiddlewaretest
$(GO) test -mod=mod github.com/apache/thrift/lib/go/thrift
$(GO) test -mod=mod ./gopath/src/tests ./gopath/src/dontexportrwtest
@@ -145,6 +148,7 @@
NamesTest.thrift \
OnewayTest.thrift \
OptionalFieldsTest.thrift \
+ ProcessorMiddlewareTest.thrift \
RefAnnotationFieldsTest.thrift \
RequiredFieldTest.thrift \
ServicesTest.thrift \
diff --git a/lib/go/test/ProcessorMiddlewareTest.thrift b/lib/go/test/ProcessorMiddlewareTest.thrift
new file mode 100644
index 0000000..2d4f5f4
--- /dev/null
+++ b/lib/go/test/ProcessorMiddlewareTest.thrift
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ * Contains some contributions under the Thrift Software License.
+ * Please see doc/old-thrift-license.txt in the Thrift distribution for
+ * details.
+ */
+
+exception Error {
+ 1: optional string foo,
+}
+
+service Service {
+ void ping() throws (
+ 1: Error error,
+ );
+}
diff --git a/lib/go/test/tests/processor_middleware_test.go b/lib/go/test/tests/processor_middleware_test.go
new file mode 100644
index 0000000..1bd911c
--- /dev/null
+++ b/lib/go/test/tests/processor_middleware_test.go
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/apache/thrift/lib/go/test/gopath/src/processormiddlewaretest"
+ "github.com/apache/thrift/lib/go/thrift"
+)
+
+const errorMessage = "foo error"
+
+type serviceImpl struct{}
+
+func (serviceImpl) Ping(_ context.Context) (err error) {
+ return &processormiddlewaretest.Error{
+ Foo: thrift.StringPtr(errorMessage),
+ }
+}
+
+func middleware(t *testing.T) thrift.ProcessorMiddleware {
+ return func(name string, next thrift.TProcessorFunction) thrift.TProcessorFunction {
+ return thrift.WrappedTProcessorFunction{
+ Wrapped: func(ctx context.Context, seqId int32, in, out thrift.TProtocol) (_ bool, err thrift.TException) {
+ defer func() {
+ checkError(t, err)
+ }()
+ return next.Process(ctx, seqId, in, out)
+ },
+ }
+ }
+}
+
+func checkError(tb testing.TB, err error) {
+ tb.Helper()
+
+ var idlErr *processormiddlewaretest.Error
+ if !errors.As(err, &idlErr) {
+ tb.Errorf("expected error to be of type *processormiddlewaretest.Error, actual %T, %#v", err, err)
+ return
+ }
+ if actual := idlErr.GetFoo(); actual != errorMessage {
+ tb.Errorf("expected error message to be %q, actual %q", errorMessage, actual)
+ }
+}
+
+func TestProcessorMiddleware(t *testing.T) {
+ const timeout = time.Second
+
+ processor := processormiddlewaretest.NewServiceProcessor(&serviceImpl{})
+ serverTransport, err := thrift.NewTServerSocket("127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("Could not find available server port: %v", err)
+ }
+ server := thrift.NewTSimpleServer4(
+ thrift.WrapProcessor(processor, middleware(t)),
+ serverTransport,
+ thrift.NewTHeaderTransportFactoryConf(nil, nil),
+ thrift.NewTHeaderProtocolFactoryConf(nil),
+ )
+ defer server.Stop()
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ server.Serve()
+ }()
+
+ time.Sleep(10 * time.Millisecond)
+
+ cfg := &thrift.TConfiguration{
+ ConnectTimeout: timeout,
+ SocketTimeout: timeout,
+ }
+ transport := thrift.NewTSocketFromAddrConf(serverTransport.Addr(), cfg)
+ if err := transport.Open(); err != nil {
+ t.Fatalf("Could not open client transport: %v", err)
+ }
+ defer transport.Close()
+ protocol := thrift.NewTHeaderProtocolConf(transport, nil)
+
+ client := processormiddlewaretest.NewServiceClient(thrift.NewTStandardClient(protocol, protocol))
+
+ err = client.Ping(context.Background())
+ checkError(t, err)
+}