THRIFT-4914: Add THeader to context for server reads
Client: go
This is the first part of THRIFT-4914, which handles the server reading
part in the requests (client -> server direction).
In TSimpleServer, when the protocol is THeaderProtocol automatically
add all present headers into the context object before passing
it to processor, so the processor code can access headers from the
context directly by using the new helper functions added in
header_context.go.
This closes #1840.
diff --git a/lib/go/thrift/header_context.go b/lib/go/thrift/header_context.go
new file mode 100644
index 0000000..5d9104b
--- /dev/null
+++ b/lib/go/thrift/header_context.go
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "context"
+)
+
+// See https://godoc.org/context#WithValue on why do we need the unexported typedefs.
+type (
+ headerKey string
+ headerKeyList int
+)
+
+// Values for headerKeyList.
+const (
+ headerKeyListRead headerKeyList = iota
+)
+
+// SetHeader sets a header in the context.
+func SetHeader(ctx context.Context, key, value string) context.Context {
+ return context.WithValue(
+ ctx,
+ headerKey(key),
+ value,
+ )
+}
+
+// GetHeader returns a value of the given header from the context.
+func GetHeader(ctx context.Context, key string) (value string, ok bool) {
+ if v := ctx.Value(headerKey(key)); v != nil {
+ value, ok = v.(string)
+ }
+ return
+}
+
+// SetReadHeaderList sets the key list of read THeaders in the context.
+func SetReadHeaderList(ctx context.Context, keys []string) context.Context {
+ return context.WithValue(
+ ctx,
+ headerKeyListRead,
+ keys,
+ )
+}
+
+// GetReadHeaderList returns the key list of read THeaders from the context.
+func GetReadHeaderList(ctx context.Context) []string {
+ if v := ctx.Value(headerKeyListRead); v != nil {
+ if value, ok := v.([]string); ok {
+ return value
+ }
+ }
+ return nil
+}
+
+// AddReadTHeaderToContext adds the whole THeader headers into context.
+func AddReadTHeaderToContext(ctx context.Context, headers THeaderMap) context.Context {
+ keys := make([]string, 0, len(headers))
+ for key, value := range headers {
+ ctx = SetHeader(ctx, key, value)
+ keys = append(keys, key)
+ }
+ return SetReadHeaderList(ctx, keys)
+}
diff --git a/lib/go/thrift/header_context_test.go b/lib/go/thrift/header_context_test.go
new file mode 100644
index 0000000..33ac4ec
--- /dev/null
+++ b/lib/go/thrift/header_context_test.go
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "context"
+ "reflect"
+ "testing"
+)
+
+func TestSetGetHeader(t *testing.T) {
+ const (
+ key = "foo"
+ value = "bar"
+ )
+ ctx := context.Background()
+
+ ctx = SetHeader(ctx, key, value)
+
+ checkGet := func(t *testing.T, ctx context.Context) {
+ t.Helper()
+ got, ok := GetHeader(ctx, key)
+ if !ok {
+ t.Fatalf("Cannot get header %q back after setting it.", key)
+ }
+ if got != value {
+ t.Fatalf("Header value expected %q, got %q instead", value, got)
+ }
+ }
+
+ checkGet(t, ctx)
+
+ t.Run(
+ "NoConflicts",
+ func(t *testing.T) {
+ type otherType string
+ const otherValue = "bar2"
+
+ ctx = context.WithValue(ctx, otherType(key), otherValue)
+ checkGet(t, ctx)
+ },
+ )
+
+ t.Run(
+ "GetHeaderOnNonExistKey",
+ func(t *testing.T) {
+ const otherKey = "foo2"
+
+ if _, ok := GetHeader(ctx, otherKey); ok {
+ t.Errorf("GetHeader returned ok on non-existing key %q", otherKey)
+ }
+ },
+ )
+}
+
+func TestKeyList(t *testing.T) {
+ headers := THeaderMap{
+ "key1": "value1",
+ "key2": "value2",
+ }
+ ctx := context.Background()
+
+ ctx = AddReadTHeaderToContext(ctx, headers)
+
+ got := make(THeaderMap)
+ keys := GetReadHeaderList(ctx)
+ t.Logf("keys: %+v", keys)
+ for _, key := range keys {
+ value, ok := GetHeader(ctx, key)
+ if ok {
+ got[key] = value
+ } else {
+ t.Errorf("Cannot get key %q from context", key)
+ }
+ }
+
+ if !reflect.DeepEqual(headers, got) {
+ t.Errorf("Expected header map %+v, got %+v", headers, got)
+ }
+}
diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go
index 0cf48f7..46205b2 100644
--- a/lib/go/thrift/header_protocol.go
+++ b/lib/go/thrift/header_protocol.go
@@ -188,6 +188,11 @@
return p.protocol.WriteBinary(value)
}
+// ReadFrame calls underlying THeaderTransport's ReadFrame function.
+func (p *THeaderProtocol) ReadFrame() error {
+ return p.transport.ReadFrame()
+}
+
func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID TMessageType, seqID int32, err error) {
if err = p.transport.ReadFrame(); err != nil {
return
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 7462dd5..e304768 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -21,6 +21,7 @@
import (
"context"
+ "io"
"io/ioutil"
"testing"
)
@@ -73,10 +74,21 @@
}
// Read
+
+ // Make sure multiple calls to ReadFrame is fine.
+ if err := reader.ReadFrame(); err != nil {
+ t.Errorf("reader.ReadFrame returned error: %v", err)
+ }
+ if err := reader.ReadFrame(); err != nil {
+ t.Errorf("reader.ReadFrame returned error: %v", err)
+ }
read, err := ioutil.ReadAll(reader)
if err != nil {
t.Errorf("Read returned error: %v", err)
}
+ if err := reader.ReadFrame(); err != nil && err != io.EOF {
+ t.Errorf("reader.ReadFrame returned error: %v", err)
+ }
if string(read) != payload1+payload2 {
t.Errorf(
"Read content expected %q, got %q",
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index 7db36c2..9155cfb 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -194,7 +194,8 @@
// for THeaderProtocol, we must use the same protocol instance for
// input and output so that the response is in the same dialect that
// the server detected the request was in.
- if _, ok := inputProtocol.(*THeaderProtocol); ok {
+ headerProtocol, ok := inputProtocol.(*THeaderProtocol)
+ if ok {
outputProtocol = inputProtocol
} else {
oTrans, err := p.outputTransportFactory.GetTransport(client)
@@ -222,7 +223,21 @@
return nil
}
- ok, err := processor.Process(defaultCtx, inputProtocol, outputProtocol)
+ ctx := defaultCtx
+ if headerProtocol != nil {
+ // We need to call ReadFrame here, otherwise we won't
+ // get any headers on the AddReadTHeaderToContext call.
+ //
+ // ReadFrame is safe to be called multiple times so it
+ // won't break when it's called again later when we
+ // actually start to read the message.
+ if err := headerProtocol.ReadFrame(); err != nil {
+ return err
+ }
+ ctx = AddReadTHeaderToContext(defaultCtx, headerProtocol.GetReadHeaders())
+ }
+
+ ok, err := processor.Process(ctx, inputProtocol, outputProtocol)
if err, ok := err.(TTransportException); ok && err.TypeId() == END_OF_FILE {
return nil
} else if err != nil {