THRIFT-5164: Add middleware framework for Go clients
This commit adds a simple middleware framework for Go clients.
It provides:
* A `ClientMiddleware` function interface used to define the actual middleware
* `WrapClient`, the function that you use to wrap a `TClient` in a list of middleware
* A helper `WrappedTClient` struct to help with developing middleware
Client: go
diff --git a/lib/go/thrift/example_client_middleware_test.go b/lib/go/thrift/example_client_middleware_test.go
new file mode 100644
index 0000000..e2e11c3
--- /dev/null
+++ b/lib/go/thrift/example_client_middleware_test.go
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package thrift
+
+import (
+ "context"
+ "log"
+)
+
+// BEGIN THRIFT GENERATED CODE SECTION
+//
+// In real code this section should be from thrift generated code instead,
+// but for this example we just define some placeholders here.
+
+type MyEndpointRequest struct{}
+
+type MyEndpointResponse struct{}
+
+type MyService interface {
+ MyEndpoint(ctx context.Context, req *MyEndpointRequest) (*MyEndpointResponse, error)
+}
+
+func NewMyServiceClient(_ TClient) MyService {
+ // In real code this certainly won't return nil.
+ return nil
+}
+
+// END THRIFT GENERATED CODE SECTION
+
+func simpleClientLoggingMiddleware(next TClient) TClient {
+ return WrappedTClient{
+ Wrapped: func(ctx context.Context, method string, args, result TStruct) error {
+ log.Printf("Before: %q", method)
+ log.Printf("Args: %#v", args)
+ err := next.Call(ctx, method, args, result)
+ log.Printf("After: %q", method)
+ log.Printf("Result: %#v", result)
+ if err != nil {
+ log.Printf("Error: %v", err)
+ }
+ return err
+ },
+ }
+}
+
+func ExampleClientMiddleware() {
+ var (
+ trans TTransport
+ protoFactory TProtocolFactory
+ )
+ var client TClient
+ client = NewTStandardClient(
+ protoFactory.GetProtocol(trans),
+ protoFactory.GetProtocol(trans),
+ )
+ client = WrapClient(client, simpleClientLoggingMiddleware)
+ myServiceClient := NewMyServiceClient(client)
+ myServiceClient.MyEndpoint(context.Background(), &MyEndpointRequest{})
+}
diff --git a/lib/go/thrift/example_middleware_test.go b/lib/go/thrift/example_processor_middleware_test.go
similarity index 89%
rename from lib/go/thrift/example_middleware_test.go
rename to lib/go/thrift/example_processor_middleware_test.go
index 4706110..844358f 100644
--- a/lib/go/thrift/example_middleware_test.go
+++ b/lib/go/thrift/example_processor_middleware_test.go
@@ -24,7 +24,7 @@
"log"
)
-func simpleLoggingMiddleware(name string, next TProcessorFunction) TProcessorFunction {
+func simpleProcessorLoggingMiddleware(name string, next TProcessorFunction) TProcessorFunction {
return WrappedTProcessorFunction{
Wrapped: func(ctx context.Context, seqId int32, in, out TProtocol) (bool, TException) {
log.Printf("Before: %q", name)
@@ -46,7 +46,7 @@
transFactory TTransportFactory
protoFactory TProtocolFactory
)
- processor = WrapProcessor(processor, simpleLoggingMiddleware)
+ processor = WrapProcessor(processor, simpleProcessorLoggingMiddleware)
server := NewTSimpleServer4(processor, trans, transFactory, protoFactory)
log.Fatal(server.Serve())
}
diff --git a/lib/go/thrift/middleware.go b/lib/go/thrift/middleware.go
index 18f2b99..b575e16 100644
--- a/lib/go/thrift/middleware.go
+++ b/lib/go/thrift/middleware.go
@@ -68,3 +68,42 @@
_ TProcessorFunction = WrappedTProcessorFunction{}
_ TProcessorFunction = (*WrappedTProcessorFunction)(nil)
)
+
+// ClientMiddleware can be passed to WrapClient in order to wrap TClient calls
+// with custom middleware.
+type ClientMiddleware func(TClient) TClient
+
+// WrappedTClient is a convenience struct that implements the TClient interface
+// using inner Wrapped function.
+//
+// This is provided to aid in developing ClientMiddleware.
+type WrappedTClient struct {
+ Wrapped func(ctx context.Context, method string, args, result TStruct) error
+}
+
+// Call implements the TClient interface by calling and returning c.Wrapped.
+func (c WrappedTClient) Call(ctx context.Context, method string, args, result TStruct) error {
+ return c.Wrapped(ctx, method, args, result)
+}
+
+// verify that WrappedTClient implements TClient
+var (
+ _ TClient = WrappedTClient{}
+ _ TClient = (*WrappedTClient)(nil)
+)
+
+// WrapClient wraps the given TClient in the given middlewares.
+//
+// Middlewares will be called in the order that they are defined:
+//
+// 1. Middlewares[0]
+// 2. Middlewares[1]
+// ...
+// N. Middlewares[n]
+func WrapClient(client TClient, middlewares ...ClientMiddleware) TClient {
+ // Add middlewares in reverse so the first in the list is the outermost.
+ for i := len(middlewares) - 1; i >= 0; i-- {
+ client = middlewares[i](client)
+ }
+ return client
+}
diff --git a/lib/go/thrift/middleware_test.go b/lib/go/thrift/middleware_test.go
index 81cbc7b..2a4d1f9 100644
--- a/lib/go/thrift/middleware_test.go
+++ b/lib/go/thrift/middleware_test.go
@@ -32,7 +32,15 @@
c.count++
}
-func testMiddleware(c *counter) ProcessorMiddleware {
+func newCounter(t *testing.T) *counter {
+ c := counter{}
+ if c.count != 0 {
+ t.Fatal("Unexpected initial count.")
+ }
+ return &c
+}
+
+func testProcessorMiddleware(c *counter) ProcessorMiddleware {
return func(name string, next TProcessorFunction) TProcessorFunction {
return WrappedTProcessorFunction{
Wrapped: func(ctx context.Context, seqId int32, in, out TProtocol) (bool, TException) {
@@ -43,12 +51,15 @@
}
}
-func newCounter(t *testing.T) *counter {
- c := counter{}
- if c.count != 0 {
- t.Fatal("Unexpected initial count.")
+func testClientMiddleware(c *counter) ClientMiddleware {
+ return func(next TClient) TClient {
+ return WrappedTClient{
+ Wrapped: func(ctx context.Context, method string, args, result TStruct) error {
+ c.incr()
+ return next.Call(ctx, method, args, result)
+ },
+ }
}
- return &c
}
func TestWrapProcessor(t *testing.T) {
@@ -64,7 +75,7 @@
}
c := newCounter(t)
ctx := setMockWrappableProcessorName(context.Background(), name)
- wrapped := WrapProcessor(processor, testMiddleware(c))
+ wrapped := WrapProcessor(processor, testProcessorMiddleware(c))
wrapped.Process(ctx, nil, nil)
if c.count != 1 {
t.Fatalf("Unexpected count value %v", c.count)
@@ -94,7 +105,7 @@
},
},
})
- wrapped := WrapProcessor(processor, testMiddleware(c))
+ wrapped := WrapProcessor(processor, testProcessorMiddleware(c))
ctx := setMockWrappableProcessorName(context.Background(), name)
in := NewStoredMessageProtocol(nil, name, 1, 1)
wrapped.Process(ctx, in, nil)
@@ -108,3 +119,17 @@
t.Fatalf("Unexpected count value %v", c.count)
}
}
+
+func TestWrapClient(t *testing.T) {
+ client := WrappedTClient{
+ Wrapped: func(ctx context.Context, method string, args, result TStruct) error {
+ return nil
+ },
+ }
+ c := newCounter(t)
+ wrapped := WrapClient(client, testClientMiddleware(c))
+ wrapped.Call(context.Background(), "test", nil, nil)
+ if c.count != 1 {
+ t.Fatalf("Unexpected count value %v", c.count)
+ }
+}