THRIFT-5322: THeaderTransport protocol id fix
Client: go
This fixes a bug introduced in
https://github.com/apache/thrift/pull/2296, that we mixed the preferred
proto id and the detected proto id, which was a bad idea.
This change separates them, so when we propagate TConfiguration, we only
change the preferred one, which will only be used for new connections,
and leave the detected one from existing connections untouched.
Also add a test for it.
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index f1dc99c..f5736df 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -264,6 +264,7 @@
writeTransforms []THeaderTransformID
clientType clientType
+ protocolID THeaderProtocolID
cfg *TConfiguration
// buffer is used in the following scenarios to avoid repetitive
@@ -303,6 +304,7 @@
transport: trans,
reader: bufio.NewReader(trans),
writeHeaders: make(THeaderMap),
+ protocolID: conf.GetTHeaderProtocolID(),
cfg: conf,
}
}
@@ -443,16 +445,7 @@
if err != nil {
return err
}
- idPtr, err := THeaderProtocolIDPtr(THeaderProtocolID(protoID))
- if err != nil {
- return err
- }
- if t.cfg == nil {
- t.cfg = &TConfiguration{
- noPropagation: true,
- }
- }
- t.cfg.THeaderProtocolID = idPtr
+ t.protocolID = THeaderProtocolID(protoID)
var transformCount int32
transformCount, err = hp.readVarint32()
@@ -597,7 +590,7 @@
headers := NewTMemoryBuffer()
hp := NewTCompactProtocol(headers)
hp.SetTConfiguration(t.cfg)
- if _, err := hp.writeVarint32(int32(t.cfg.GetTHeaderProtocolID())); err != nil {
+ if _, err := hp.writeVarint32(int32(t.protocolID)); err != nil {
return NewTTransportExceptionFromError(err)
}
if _, err := hp.writeVarint32(int32(len(t.writeTransforms))); err != nil {
@@ -742,7 +735,7 @@
func (t *THeaderTransport) Protocol() THeaderProtocolID {
switch t.clientType {
default:
- return t.cfg.GetTHeaderProtocolID()
+ return t.protocolID
case clientFramedBinary, clientUnframedBinary:
return THeaderProtocolBinary
case clientFramedCompact, clientUnframedCompact:
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 41efb18..65e69ee 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -281,3 +281,27 @@
})
}
}
+
+func TestSetTHeaderTransportProtocolID(t *testing.T) {
+ const expected = THeaderProtocolCompact
+ factory := NewTHeaderTransportFactoryConf(nil, &TConfiguration{
+ THeaderProtocolID: THeaderProtocolIDPtrMust(expected),
+ })
+ buf := NewTMemoryBuffer()
+ trans, err := factory.GetTransport(buf)
+ if err != nil {
+ t.Fatalf("Failed to get transport from factory: %v", err)
+ }
+ ht, ok := trans.(*THeaderTransport)
+ if !ok {
+ t.Fatalf("Transport is not *THeaderTransport: %#v", trans)
+ }
+ if actual := ht.Protocol(); actual != expected {
+ t.Errorf("Expected protocol id %v, got %v", expected, actual)
+ }
+
+ ht.SetTConfiguration(&TConfiguration{})
+ if actual := ht.Protocol(); actual != expected {
+ t.Errorf("Expected protocol id %v, got %v", expected, actual)
+ }
+}