THRIFT-5278: Allow set protoID in go THeader transport/protocol
Client: go
In Go library code, allow setting the underlying protoID to a
non-default (TCompactProtocol) one for THeaderTransport/THeaderProtocol.
diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go
index 428b261..f86d558 100644
--- a/lib/go/thrift/header_protocol.go
+++ b/lib/go/thrift/header_protocol.go
@@ -37,31 +37,73 @@
}
// NewTHeaderProtocol creates a new THeaderProtocol from the underlying
-// transport. The passed in transport will be wrapped with THeaderTransport.
+// transport with default protocol ID.
+//
+// The passed in transport will be wrapped with THeaderTransport.
//
// Note that THeaderTransport handles frame and zlib by itself,
// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket),
// instead of rich transports like TZlibTransport or TFramedTransport.
func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
- t := NewTHeaderTransport(trans)
- p, _ := THeaderProtocolDefault.GetProtocol(t)
+ p, err := newTHeaderProtocolWithProtocolID(trans, THeaderProtocolDefault)
+ if err != nil {
+ // Since we used THeaderProtocolDefault this should never happen,
+ // but put a sanity check here just in case.
+ panic(err)
+ }
+ return p
+}
+
+func newTHeaderProtocolWithProtocolID(trans TTransport, protoID THeaderProtocolID) (*THeaderProtocol, error) {
+ t, err := NewTHeaderTransportWithProtocolID(trans, protoID)
+ if err != nil {
+ return nil, err
+ }
+ p, err := t.protocolID.GetProtocol(t)
+ if err != nil {
+ return nil, err
+ }
return &THeaderProtocol{
transport: t,
protocol: p,
+ }, nil
+}
+
+type tHeaderProtocolFactory struct {
+ protoID THeaderProtocolID
+}
+
+func (f tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
+ p, err := newTHeaderProtocolWithProtocolID(trans, f.protoID)
+ if err != nil {
+ // Currently there's no way for external users to construct a
+ // valid factory with invalid protoID, so this should never
+ // happen. But put a sanity check here just in case in the
+ // future a bug made that possible.
+ panic(err)
}
+ return p
}
-type tHeaderProtocolFactory struct{}
-
-func (tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
- return NewTHeaderProtocol(trans)
-}
-
-// NewTHeaderProtocolFactory creates a factory for THeader.
+// NewTHeaderProtocolFactory creates a factory for THeader with default protocol
+// ID.
//
// It's a wrapper for NewTHeaderProtocol
func NewTHeaderProtocolFactory() TProtocolFactory {
- return tHeaderProtocolFactory{}
+ return tHeaderProtocolFactory{
+ protoID: THeaderProtocolDefault,
+ }
+}
+
+// NewTHeaderProtocolFactoryWithProtocolID creates a factory for THeader with
+// given protocol ID.
+func NewTHeaderProtocolFactoryWithProtocolID(protoID THeaderProtocolID) (TProtocolFactory, error) {
+ if err := protoID.Validate(); err != nil {
+ return nil, err
+ }
+ return tHeaderProtocolFactory{
+ protoID: protoID,
+ }, nil
}
// Transport returns the underlying transport.
diff --git a/lib/go/thrift/header_protocol_test.go b/lib/go/thrift/header_protocol_test.go
index 9b6019b..f66ea64 100644
--- a/lib/go/thrift/header_protocol_test.go
+++ b/lib/go/thrift/header_protocol_test.go
@@ -24,5 +24,21 @@
)
func TestReadWriteHeaderProtocol(t *testing.T) {
- ReadWriteProtocolTest(t, NewTHeaderProtocolFactory())
+ t.Run(
+ "default",
+ func(t *testing.T) {
+ ReadWriteProtocolTest(t, NewTHeaderProtocolFactory())
+ },
+ )
+
+ t.Run(
+ "compact",
+ func(t *testing.T) {
+ f, err := NewTHeaderProtocolFactoryWithProtocolID(THeaderProtocolCompact)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ReadWriteProtocolTest(t, f)
+ },
+ )
}
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index e208034..562d02f 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -75,6 +75,15 @@
THeaderProtocolDefault = THeaderProtocolBinary
)
+// Declared globally to avoid repetitive allocations, not really used.
+var globalMemoryBuffer = NewTMemoryBuffer()
+
+// Validate checks whether the THeaderProtocolID is a valid/supported one.
+func (id THeaderProtocolID) Validate() error {
+ _, err := id.GetProtocol(globalMemoryBuffer)
+ return err
+}
+
// GetProtocol gets the corresponding TProtocol from the wrapped protocol id.
func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
switch id {
@@ -84,7 +93,7 @@
fmt.Sprintf("THeader protocol id %d not supported", id),
)
case THeaderProtocolBinary:
- return NewTBinaryProtocolFactoryDefault().GetProtocol(trans), nil
+ return NewTBinaryProtocolTransport(trans), nil
case THeaderProtocolCompact:
return NewTCompactProtocol(trans), nil
}
@@ -93,11 +102,12 @@
// THeaderTransformID defines the numeric id of the transform used.
type THeaderTransformID int32
-// THeaderTransformID values
+// THeaderTransformID values.
+//
+// Values not defined here are not currently supported, namely HMAC and Snappy.
const (
TransformNone THeaderTransformID = iota // 0, no special handling
TransformZlib // 1, zlib
- // Rest of the values are not currently supported, namely HMAC and Snappy.
)
var supportedTransformIDs = map[THeaderTransformID]bool{
@@ -285,6 +295,34 @@
}
}
+// NewTHeaderTransportWithProtocolID creates THeaderTransport from the
+// underlying transport, with given protocol ID set.
+//
+// If trans is already a *THeaderTransport, it will be returned as is,
+// but with protocol ID overridden by the value passed in.
+//
+// If the passed in protocol ID is an invalid/unsupported one,
+// this function returns error.
+//
+// The protocol ID overridden is only useful for client transports.
+// For servers,
+// the protocol ID will be overridden again to the one set by the client,
+// to ensure that servers always speak the same dialect as the client.
+func NewTHeaderTransportWithProtocolID(trans TTransport, protoID THeaderProtocolID) (*THeaderTransport, error) {
+ if err := protoID.Validate(); err != nil {
+ return nil, err
+ }
+ if ht, ok := trans.(*THeaderTransport); ok {
+ return ht, nil
+ }
+ return &THeaderTransport{
+ transport: trans,
+ reader: bufio.NewReader(trans),
+ writeHeaders: make(THeaderMap),
+ protocolID: protoID,
+ }, nil
+}
+
// Open calls the underlying transport's Open function.
func (t *THeaderTransport) Open() error {
return t.transport.Open()
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 320fb2a..5b47680 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -28,10 +28,13 @@
"testing/quick"
)
-func TestTHeaderHeadersReadWrite(t *testing.T) {
+func testTHeaderHeadersReadWriteProtocolID(t *testing.T, protoID THeaderProtocolID) {
trans := NewTMemoryBuffer()
reader := NewTHeaderTransport(trans)
- writer := NewTHeaderTransport(trans)
+ writer, err := NewTHeaderTransportWithProtocolID(trans, protoID)
+ if err != nil {
+ t.Fatal(err)
+ }
const key1 = "key1"
const value1 = "value1"
@@ -98,10 +101,10 @@
read,
)
}
- if prot := reader.Protocol(); prot != THeaderProtocolBinary {
+ if prot := reader.Protocol(); prot != protoID {
t.Errorf(
"reader.Protocol() expected %d, got %d",
- THeaderProtocolBinary,
+ protoID,
prot,
)
}
@@ -121,6 +124,18 @@
}
}
+func TestTHeaderHeadersReadWrite(t *testing.T) {
+ for label, id := range map[string]THeaderProtocolID{
+ "default": THeaderProtocolDefault,
+ "binary": THeaderProtocolBinary,
+ "compact": THeaderProtocolCompact,
+ } {
+ t.Run(label, func(t *testing.T) {
+ testTHeaderHeadersReadWriteProtocolID(t, id)
+ })
+ }
+}
+
func TestTHeaderTransportNoDoubleWrapping(t *testing.T) {
trans := NewTMemoryBuffer()
orig := NewTHeaderTransport(trans)