THRIFT-5605: Client middleware to extract exceptions
Client: go
Provide ExtractIDLExceptionClientMiddleware client middleware
implementation and ExtractExceptionFromResult to extract exceptions
defined in thrift IDL into err return so they are accessible from other
client middlewares.
diff --git a/lib/go/test/ClientMiddlewareExceptionTest.thrift b/lib/go/test/ClientMiddlewareExceptionTest.thrift
new file mode 100644
index 0000000..647c433
--- /dev/null
+++ b/lib/go/test/ClientMiddlewareExceptionTest.thrift
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+exception Exception1 {
+}
+
+exception Exception2 {
+}
+
+// This is a special case, we want to make sure that the middleware don't
+// accidentally pull result as error.
+exception FooResponse {
+}
+
+service ClientMiddlewareExceptionTest {
+ FooResponse foo() throws(
+ 1: Exception1 error1,
+ 2: Exception2 error2,
+ )
+}
diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am
index 4392ebe..d938449 100644
--- a/lib/go/test/Makefile.am
+++ b/lib/go/test/Makefile.am
@@ -63,7 +63,8 @@
ConflictArgNamesTest.thrift \
ConstOptionalFieldImport.thrift \
ConstOptionalField.thrift \
- ProcessorMiddlewareTest.thrift
+ ProcessorMiddlewareTest.thrift \
+ ClientMiddlewareExceptionTest.thrift
mkdir -p gopath/src
grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift
$(THRIFT) $(THRIFTARGS) -r IncludesTest.thrift
@@ -96,6 +97,7 @@
$(THRIFT) $(THRIFTARGS) ConflictArgNamesTest.thrift
$(THRIFT) $(THRIFTARGS) -r ConstOptionalField.thrift
$(THRIFT) $(THRIFTARGS_SKIP_REMOTE) ProcessorMiddlewareTest.thrift
+ $(THRIFT) $(THRIFTARGS) ClientMiddlewareExceptionTest.thrift
ln -nfs ../../tests gopath/src/tests
cp -r ./dontexportrwtest gopath/src
touch gopath
@@ -119,7 +121,8 @@
./gopath/src/duplicateimportstest \
./gopath/src/equalstest \
./gopath/src/conflictargnamestest \
- ./gopath/src/processormiddlewaretest
+ ./gopath/src/processormiddlewaretest \
+ ./gopath/src/clientmiddlewareexceptiontest
$(GO) test -mod=mod github.com/apache/thrift/lib/go/thrift
$(GO) test -mod=mod ./gopath/src/tests ./gopath/src/dontexportrwtest
@@ -134,6 +137,7 @@
tests \
common \
BinaryKeyTest.thrift \
+ ClientMiddlewareExceptionTest.thrift \
ConflictArgNamesTest.thrift \
ConflictNamespaceServiceTest.thrift \
ConflictNamespaceTestA.thrift \
diff --git a/lib/go/test/tests/client_middleware_exception_test.go b/lib/go/test/tests/client_middleware_exception_test.go
new file mode 100644
index 0000000..5cb42ab
--- /dev/null
+++ b/lib/go/test/tests/client_middleware_exception_test.go
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/apache/thrift/lib/go/test/gopath/src/clientmiddlewareexceptiontest"
+ "github.com/apache/thrift/lib/go/thrift"
+)
+
+type fakeClientMiddlewareExceptionTestHandler func(ctx context.Context) (*clientmiddlewareexceptiontest.FooResponse, error)
+
+func (f fakeClientMiddlewareExceptionTestHandler) Foo(ctx context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) {
+ return f(ctx)
+}
+
+type clientMiddlewareErrorChecker func(err error) error
+
+var clientMiddlewareExceptionCases = []struct {
+ label string
+ handler fakeClientMiddlewareExceptionTestHandler
+ checker clientMiddlewareErrorChecker
+}{
+ {
+ label: "no-error",
+ handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) {
+ return new(clientmiddlewareexceptiontest.FooResponse), nil
+ },
+ checker: func(err error) error {
+ if err != nil {
+ return errors.New("expected err to be nil")
+ }
+ return nil
+ },
+ },
+ {
+ label: "exception-1",
+ handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) {
+ return nil, new(clientmiddlewareexceptiontest.Exception1)
+ },
+ checker: func(err error) error {
+ if !errors.As(err, new(*clientmiddlewareexceptiontest.Exception1)) {
+ return errors.New("expected err to be of type *clientmiddlewareexceptiontest.Exception1")
+ }
+ return nil
+ },
+ },
+ {
+ label: "no-error",
+ handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) {
+ return nil, new(clientmiddlewareexceptiontest.Exception2)
+ },
+ checker: func(err error) error {
+ if !errors.As(err, new(*clientmiddlewareexceptiontest.Exception2)) {
+ return errors.New("expected err to be of type *clientmiddlewareexceptiontest.Exception2")
+ }
+ return nil
+ },
+ },
+}
+
+func TestClientMiddlewareException(t *testing.T) {
+ for _, c := range clientMiddlewareExceptionCases {
+ t.Run(c.label, func(t *testing.T) {
+ serverSocket, err := thrift.NewTServerSocket(":0")
+ if err != nil {
+ t.Fatalf("failed to create server socket: %v", err)
+ }
+ processor := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestProcessor(c.handler)
+ server := thrift.NewTSimpleServer2(processor, serverSocket)
+ if err := server.Listen(); err != nil {
+ t.Fatalf("failed to listen server: %v", err)
+ }
+ addr := serverSocket.Addr().String()
+ go server.Serve()
+ t.Cleanup(func() {
+ server.Stop()
+ })
+
+ var cfg *thrift.TConfiguration
+ socket := thrift.NewTSocketConf(addr, cfg)
+ if err := socket.Open(); err != nil {
+ t.Fatalf("failed to create client connection: %v", err)
+ }
+ t.Cleanup(func() {
+ socket.Close()
+ })
+ inProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+ outProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+ middleware := func(next thrift.TClient) thrift.TClient {
+ return thrift.WrappedTClient{
+ Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (_ thrift.ResponseMeta, err error) {
+ defer func() {
+ if checkErr := c.checker(err); checkErr != nil {
+ t.Errorf("middleware result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err)
+ }
+ }()
+ return next.Call(ctx, method, args, result)
+ },
+ }
+ }
+ client := thrift.WrapClient(
+ thrift.NewTStandardClient(inProtocol, outProtocol),
+ middleware,
+ thrift.ExtractIDLExceptionClientMiddleware,
+ )
+ result, err := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestClient(client).Foo(context.Background())
+ if checkErr := c.checker(err); checkErr != nil {
+ t.Errorf("final result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err)
+ }
+ })
+ }
+}
+
+func TestExtractExceptionFromResult(t *testing.T) {
+
+ for _, c := range clientMiddlewareExceptionCases {
+ t.Run(c.label, func(t *testing.T) {
+ serverSocket, err := thrift.NewTServerSocket(":0")
+ if err != nil {
+ t.Fatalf("failed to create server socket: %v", err)
+ }
+ processor := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestProcessor(c.handler)
+ server := thrift.NewTSimpleServer2(processor, serverSocket)
+ if err := server.Listen(); err != nil {
+ t.Fatalf("failed to listen server: %v", err)
+ }
+ addr := serverSocket.Addr().String()
+ go server.Serve()
+ t.Cleanup(func() {
+ server.Stop()
+ })
+
+ var cfg *thrift.TConfiguration
+ socket := thrift.NewTSocketConf(addr, cfg)
+ if err := socket.Open(); err != nil {
+ t.Fatalf("failed to create client connection: %v", err)
+ }
+ t.Cleanup(func() {
+ socket.Close()
+ })
+ inProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+ outProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+ middleware := func(next thrift.TClient) thrift.TClient {
+ return thrift.WrappedTClient{
+ Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (_ thrift.ResponseMeta, err error) {
+ defer func() {
+ if err == nil {
+ err = thrift.ExtractExceptionFromResult(result)
+ }
+ if checkErr := c.checker(err); checkErr != nil {
+ t.Errorf("middleware result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err)
+ }
+ }()
+ return next.Call(ctx, method, args, result)
+ },
+ }
+ }
+ client := thrift.WrapClient(
+ thrift.NewTStandardClient(inProtocol, outProtocol),
+ middleware,
+ )
+ result, err := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestClient(client).Foo(context.Background())
+ if checkErr := c.checker(err); checkErr != nil {
+ t.Errorf("final result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err)
+ }
+ })
+ }
+}