THRIFT-5605: Client middleware to extract exceptions

Client: go

Provide ExtractIDLExceptionClientMiddleware client middleware
implementation and ExtractExceptionFromResult to extract exceptions
defined in thrift IDL into err return so they are accessible from other
client middlewares.
diff --git a/lib/go/test/ClientMiddlewareExceptionTest.thrift b/lib/go/test/ClientMiddlewareExceptionTest.thrift
new file mode 100644
index 0000000..647c433
--- /dev/null
+++ b/lib/go/test/ClientMiddlewareExceptionTest.thrift
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+exception Exception1 {
+}
+
+exception Exception2 {
+}
+
+// This is a special case, we want to make sure that the middleware don't
+// accidentally pull result as error.
+exception FooResponse {
+}
+
+service ClientMiddlewareExceptionTest {
+  FooResponse foo() throws(
+      1: Exception1 error1,
+      2: Exception2 error2,
+  )
+}
diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am
index 4392ebe..d938449 100644
--- a/lib/go/test/Makefile.am
+++ b/lib/go/test/Makefile.am
@@ -63,7 +63,8 @@
 				ConflictArgNamesTest.thrift \
 				ConstOptionalFieldImport.thrift \
 				ConstOptionalField.thrift \
-				ProcessorMiddlewareTest.thrift
+				ProcessorMiddlewareTest.thrift \
+				ClientMiddlewareExceptionTest.thrift
 	mkdir -p gopath/src
 	grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift
 	$(THRIFT) $(THRIFTARGS) -r IncludesTest.thrift
@@ -96,6 +97,7 @@
 	$(THRIFT) $(THRIFTARGS) ConflictArgNamesTest.thrift
 	$(THRIFT) $(THRIFTARGS) -r ConstOptionalField.thrift
 	$(THRIFT) $(THRIFTARGS_SKIP_REMOTE) ProcessorMiddlewareTest.thrift
+	$(THRIFT) $(THRIFTARGS) ClientMiddlewareExceptionTest.thrift
 	ln -nfs ../../tests gopath/src/tests
 	cp -r ./dontexportrwtest gopath/src
 	touch gopath
@@ -119,7 +121,8 @@
 				./gopath/src/duplicateimportstest \
 				./gopath/src/equalstest \
 				./gopath/src/conflictargnamestest \
-				./gopath/src/processormiddlewaretest
+				./gopath/src/processormiddlewaretest \
+				./gopath/src/clientmiddlewareexceptiontest
 	$(GO) test -mod=mod github.com/apache/thrift/lib/go/thrift
 	$(GO) test -mod=mod ./gopath/src/tests ./gopath/src/dontexportrwtest
 
@@ -134,6 +137,7 @@
 	tests \
 	common \
 	BinaryKeyTest.thrift \
+	ClientMiddlewareExceptionTest.thrift \
 	ConflictArgNamesTest.thrift \
 	ConflictNamespaceServiceTest.thrift \
 	ConflictNamespaceTestA.thrift \
diff --git a/lib/go/test/tests/client_middleware_exception_test.go b/lib/go/test/tests/client_middleware_exception_test.go
new file mode 100644
index 0000000..5cb42ab
--- /dev/null
+++ b/lib/go/test/tests/client_middleware_exception_test.go
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+	"context"
+	"errors"
+	"testing"
+
+	"github.com/apache/thrift/lib/go/test/gopath/src/clientmiddlewareexceptiontest"
+	"github.com/apache/thrift/lib/go/thrift"
+)
+
+type fakeClientMiddlewareExceptionTestHandler func(ctx context.Context) (*clientmiddlewareexceptiontest.FooResponse, error)
+
+func (f fakeClientMiddlewareExceptionTestHandler) Foo(ctx context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) {
+	return f(ctx)
+}
+
+type clientMiddlewareErrorChecker func(err error) error
+
+var clientMiddlewareExceptionCases = []struct {
+	label   string
+	handler fakeClientMiddlewareExceptionTestHandler
+	checker clientMiddlewareErrorChecker
+}{
+	{
+		label: "no-error",
+		handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) {
+			return new(clientmiddlewareexceptiontest.FooResponse), nil
+		},
+		checker: func(err error) error {
+			if err != nil {
+				return errors.New("expected err to be nil")
+			}
+			return nil
+		},
+	},
+	{
+		label: "exception-1",
+		handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) {
+			return nil, new(clientmiddlewareexceptiontest.Exception1)
+		},
+		checker: func(err error) error {
+			if !errors.As(err, new(*clientmiddlewareexceptiontest.Exception1)) {
+				return errors.New("expected err to be of type *clientmiddlewareexceptiontest.Exception1")
+			}
+			return nil
+		},
+	},
+	{
+		label: "no-error",
+		handler: func(_ context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) {
+			return nil, new(clientmiddlewareexceptiontest.Exception2)
+		},
+		checker: func(err error) error {
+			if !errors.As(err, new(*clientmiddlewareexceptiontest.Exception2)) {
+				return errors.New("expected err to be of type *clientmiddlewareexceptiontest.Exception2")
+			}
+			return nil
+		},
+	},
+}
+
+func TestClientMiddlewareException(t *testing.T) {
+	for _, c := range clientMiddlewareExceptionCases {
+		t.Run(c.label, func(t *testing.T) {
+			serverSocket, err := thrift.NewTServerSocket(":0")
+			if err != nil {
+				t.Fatalf("failed to create server socket: %v", err)
+			}
+			processor := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestProcessor(c.handler)
+			server := thrift.NewTSimpleServer2(processor, serverSocket)
+			if err := server.Listen(); err != nil {
+				t.Fatalf("failed to listen server: %v", err)
+			}
+			addr := serverSocket.Addr().String()
+			go server.Serve()
+			t.Cleanup(func() {
+				server.Stop()
+			})
+
+			var cfg *thrift.TConfiguration
+			socket := thrift.NewTSocketConf(addr, cfg)
+			if err := socket.Open(); err != nil {
+				t.Fatalf("failed to create client connection: %v", err)
+			}
+			t.Cleanup(func() {
+				socket.Close()
+			})
+			inProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+			outProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+			middleware := func(next thrift.TClient) thrift.TClient {
+				return thrift.WrappedTClient{
+					Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (_ thrift.ResponseMeta, err error) {
+						defer func() {
+							if checkErr := c.checker(err); checkErr != nil {
+								t.Errorf("middleware result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err)
+							}
+						}()
+						return next.Call(ctx, method, args, result)
+					},
+				}
+			}
+			client := thrift.WrapClient(
+				thrift.NewTStandardClient(inProtocol, outProtocol),
+				middleware,
+				thrift.ExtractIDLExceptionClientMiddleware,
+			)
+			result, err := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestClient(client).Foo(context.Background())
+			if checkErr := c.checker(err); checkErr != nil {
+				t.Errorf("final result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err)
+			}
+		})
+	}
+}
+
+func TestExtractExceptionFromResult(t *testing.T) {
+
+	for _, c := range clientMiddlewareExceptionCases {
+		t.Run(c.label, func(t *testing.T) {
+			serverSocket, err := thrift.NewTServerSocket(":0")
+			if err != nil {
+				t.Fatalf("failed to create server socket: %v", err)
+			}
+			processor := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestProcessor(c.handler)
+			server := thrift.NewTSimpleServer2(processor, serverSocket)
+			if err := server.Listen(); err != nil {
+				t.Fatalf("failed to listen server: %v", err)
+			}
+			addr := serverSocket.Addr().String()
+			go server.Serve()
+			t.Cleanup(func() {
+				server.Stop()
+			})
+
+			var cfg *thrift.TConfiguration
+			socket := thrift.NewTSocketConf(addr, cfg)
+			if err := socket.Open(); err != nil {
+				t.Fatalf("failed to create client connection: %v", err)
+			}
+			t.Cleanup(func() {
+				socket.Close()
+			})
+			inProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+			outProtocol := thrift.NewTBinaryProtocolConf(socket, cfg)
+			middleware := func(next thrift.TClient) thrift.TClient {
+				return thrift.WrappedTClient{
+					Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (_ thrift.ResponseMeta, err error) {
+						defer func() {
+							if err == nil {
+								err = thrift.ExtractExceptionFromResult(result)
+							}
+							if checkErr := c.checker(err); checkErr != nil {
+								t.Errorf("middleware result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err)
+							}
+						}()
+						return next.Call(ctx, method, args, result)
+					},
+				}
+			}
+			client := thrift.WrapClient(
+				thrift.NewTStandardClient(inProtocol, outProtocol),
+				middleware,
+			)
+			result, err := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestClient(client).Foo(context.Background())
+			if checkErr := c.checker(err); checkErr != nil {
+				t.Errorf("final result unexpected: %v (result=%#v, err=%#v)", checkErr, result, err)
+			}
+		})
+	}
+}
diff --git a/lib/go/thrift/exception.go b/lib/go/thrift/exception.go
index 53bf862..e2f1728 100644
--- a/lib/go/thrift/exception.go
+++ b/lib/go/thrift/exception.go
@@ -21,6 +21,7 @@
 
 import (
 	"errors"
+	"reflect"
 )
 
 // Generic Thrift exception
@@ -114,3 +115,47 @@
 }
 
 var _ TException = wrappedTException{}
+
+// ExtractExceptionFromResult extracts exceptions defined in thrift IDL from
+// result TStruct used in TClient.Call.
+//
+// For a endpoint defined in thrift IDL like this:
+//
+//     service MyService {
+//       FooResponse foo(1: FooRequest request) throws (
+//         1: Exception1 error1,
+//         2: Exception2 error2,
+//       )
+//     }
+//
+// The thrift compiler generated go code for the result TStruct would be like:
+//
+//     type MyServiceFooResult struct {
+//       Success *FooResponse `thrift:"success,0" db:"success" json:"success,omitempty"`
+//       Error1 *Exception1 `thrift:"error1,1" db:"error1" json:"error1,omitempty"`
+//       Error2 *Exception2 `thrift:"error2,2" db:"error2" json:"error2,omitempty"`
+//     }
+//
+// And this function extracts the first non-nil exception out of
+// *MyServiceFooResult.
+func ExtractExceptionFromResult(result TStruct) error {
+	v := reflect.Indirect(reflect.ValueOf(result))
+	if v.Kind() != reflect.Struct {
+		return nil
+	}
+	typ := v.Type()
+	for i := 0; i < v.NumField(); i++ {
+		if typ.Field(i).Name == "Success" {
+			continue
+		}
+		field := v.Field(i)
+		if field.IsZero() {
+			continue
+		}
+		tExc, ok := field.Interface().(TException)
+		if ok && tExc != nil && tExc.TExceptionType() == TExceptionTypeCompiled {
+			return tExc
+		}
+	}
+	return nil
+}
diff --git a/lib/go/thrift/middleware.go b/lib/go/thrift/middleware.go
index 8a788df..85c7e06 100644
--- a/lib/go/thrift/middleware.go
+++ b/lib/go/thrift/middleware.go
@@ -19,7 +19,9 @@
 
 package thrift
 
-import "context"
+import (
+	"context"
+)
 
 // ProcessorMiddleware is a function that can be passed to WrapProcessor to wrap the
 // TProcessorFunctions for that TProcessor.
@@ -107,3 +109,40 @@
 	}
 	return client
 }
+
+// ExtractIDLExceptionClientMiddleware is a ClientMiddleware implementation that
+// extracts exceptions defined in thrift IDL into the error return of
+// TClient.Call. It uses ExtractExceptionFromResult under the hood.
+//
+// By default if a client call gets an exception defined in the thrift IDL, for
+// example:
+//
+//     service MyService {
+//       FooResponse foo(1: FooRequest request) throws (
+//         1: Exception1 error1,
+//         2: Exception2 error2,
+//       )
+//     }
+//
+// Exception1 or Exception2 will not be in the err return of TClient.Call,
+// but in the result TStruct instead, and there's no easy access to them.
+// If you have a ClientMiddleware that would need to access them,
+// you can add this middleware into your client middleware chain,
+// *after* your other middlewares need them,
+// then your other middlewares will have access to those exceptions from the err
+// return.
+//
+// Alternatively you can also just use ExtractExceptionFromResult in your client
+// middleware directly to access those exceptions.
+func ExtractIDLExceptionClientMiddleware(next TClient) TClient {
+	return WrappedTClient{
+		Wrapped: func(ctx context.Context, method string, args, result TStruct) (_ ResponseMeta, err error) {
+			defer func() {
+				if err == nil {
+					err = ExtractExceptionFromResult(result)
+				}
+			}()
+			return next.Call(ctx, method, args, result)
+		},
+	}
+}