THRIFT-5322: Implement TConfiguration in Go library
Client: go
Define TConfiguration following the spec, and also move the following
configurations scattered around different TTransport/TProtocol into it:
- connect and socket timeouts for TSocket and TSSLSocket
- tls config for TSSLSocket
- max frame size for TFramedTransport
- strict read and strict write for TBinaryProtocol
- proto id for THeaderTransport
Also add TConfiguration support for the following and their factories:
- THeaderTransport and THeaderProtocol
- TBinaryProtocol
- TCompactProtocol
- TFramedTransport
- TSocket
- TSSLSocket
Also define TConfigurationSetter interface for easier TConfiguration
propagation between wrapped TTransports/TProtocols , and add
implementations to the following for propagation
(they don't use anything from TConfiguration themselves):
- StreamTransport
- TBufferedTransport
- TDebugProtocol
- TJSONProtocol
- TSimpleJSONProtocol
- TZlibTransport
TConfigurationSetter are not implemented by the factories of the
"propagation only" TTransports/TProtocols, if they have a factory. For
those use cases, TTransportFactoryConf and TProtocolFactoryConf are
provided to wrap a factory with the ability to propagate TConfiguration.
Also add simple sanity check for TBinaryProtocol and TCompactProtocol's
ReadString and ReadBinary functions. Currently it only report error if
the header length is larger than MaxMessageSize configured in
TConfiguration, for simplicity.
diff --git a/CHANGES.md b/CHANGES.md
index 65ed07f..663c4c1 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -29,6 +29,7 @@
- [THRIFT-5164](https://issues.apache.org/jira/browse/THRIFT-5164) - Add ProcessorMiddleware function type and WrapProcessor function to support wrapping a TProcessor with middleware functions.
- [THRIFT-5233](https://issues.apache.org/jira/browse/THRIFT-5233) - Add context deadline check to ReadMessageBegin in TBinaryProtocol, TCompactProtocol, and THeaderProtocol.
- [THRIFT-5240](https://issues.apache.org/jira/browse/THRIFT-5240) - The context passed into server handler implementations will be canceled when we detected that the client closed the connection.
+- [THRIFT-5322](https://issues.apache.org/jira/browse/THRIFT-5322) - Add support to TConfiguration, and also fix a bug that could cause excessive memory usage when reading malformed messages from TCompactProtocol.
## 0.13.0
diff --git a/lib/go/thrift/binary_protocol.go b/lib/go/thrift/binary_protocol.go
index 58956f6..45c880d 100644
--- a/lib/go/thrift/binary_protocol.go
+++ b/lib/go/thrift/binary_protocol.go
@@ -32,22 +32,37 @@
type TBinaryProtocol struct {
trans TRichTransport
origTransport TTransport
- strictRead bool
- strictWrite bool
+ cfg *TConfiguration
buffer [64]byte
}
type TBinaryProtocolFactory struct {
- strictRead bool
- strictWrite bool
+ cfg *TConfiguration
}
+// Deprecated: Use NewTBinaryProtocolConf instead.
func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol {
- return NewTBinaryProtocol(t, false, true)
+ return NewTBinaryProtocolConf(t, &TConfiguration{
+ noPropagation: true,
+ })
}
+// Deprecated: Use NewTBinaryProtocolConf instead.
func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol {
- p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite}
+ return NewTBinaryProtocolConf(t, &TConfiguration{
+ TBinaryStrictRead: &strictRead,
+ TBinaryStrictWrite: &strictWrite,
+
+ noPropagation: true,
+ })
+}
+
+func NewTBinaryProtocolConf(t TTransport, conf *TConfiguration) *TBinaryProtocol {
+ PropagateTConfiguration(t, conf)
+ p := &TBinaryProtocol{
+ origTransport: t,
+ cfg: conf,
+ }
if et, ok := t.(TRichTransport); ok {
p.trans = et
} else {
@@ -56,16 +71,35 @@
return p
}
+// Deprecated: Use NewTBinaryProtocolFactoryConf instead.
func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory {
- return NewTBinaryProtocolFactory(false, true)
+ return NewTBinaryProtocolFactoryConf(&TConfiguration{
+ noPropagation: true,
+ })
}
+// Deprecated: Use NewTBinaryProtocolFactoryConf instead.
func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory {
- return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite}
+ return NewTBinaryProtocolFactoryConf(&TConfiguration{
+ TBinaryStrictRead: &strictRead,
+ TBinaryStrictWrite: &strictWrite,
+
+ noPropagation: true,
+ })
+}
+
+func NewTBinaryProtocolFactoryConf(conf *TConfiguration) *TBinaryProtocolFactory {
+ return &TBinaryProtocolFactory{
+ cfg: conf,
+ }
}
func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
- return NewTBinaryProtocol(t, p.strictRead, p.strictWrite)
+ return NewTBinaryProtocolConf(t, p.cfg)
+}
+
+func (p *TBinaryProtocolFactory) SetTConfiguration(conf *TConfiguration) {
+ p.cfg = conf
}
/**
@@ -73,7 +107,7 @@
*/
func (p *TBinaryProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error {
- if p.strictWrite {
+ if p.cfg.GetTBinaryStrictWrite() {
version := uint32(VERSION_1) | uint32(typeId)
e := p.WriteI32(ctx, int32(version))
if e != nil {
@@ -253,7 +287,7 @@
}
return name, typeId, seqId, nil
}
- if p.strictRead {
+ if p.cfg.GetTBinaryStrictRead() {
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin"))
}
name, e2 := p.readStringBody(size)
@@ -428,6 +462,10 @@
if e != nil {
return "", e
}
+ err = checkSizeForProtocol(size, p.cfg)
+ if err != nil {
+ return
+ }
if size < 0 {
err = invalidDataLength
return
@@ -450,8 +488,8 @@
if e != nil {
return nil, e
}
- if size < 0 {
- return nil, invalidDataLength
+ if err := checkSizeForProtocol(size, p.cfg); err != nil {
+ return nil, err
}
buf, err := safeReadBytes(size, p.trans)
@@ -491,6 +529,17 @@
return string(buf), NewTProtocolException(err)
}
+func (p *TBinaryProtocol) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(p.trans, conf)
+ PropagateTConfiguration(p.origTransport, conf)
+ p.cfg = conf
+}
+
+var (
+ _ TConfigurationSetter = (*TBinaryProtocolFactory)(nil)
+ _ TConfigurationSetter = (*TBinaryProtocol)(nil)
+)
+
// This function is shared between TBinaryProtocol and TCompactProtocol.
//
// It tries to read size bytes from trans, in a way that prevents large
diff --git a/lib/go/thrift/buffered_transport.go b/lib/go/thrift/buffered_transport.go
index 9670206..aa551b4 100644
--- a/lib/go/thrift/buffered_transport.go
+++ b/lib/go/thrift/buffered_transport.go
@@ -90,3 +90,10 @@
func (p *TBufferedTransport) RemainingBytes() (num_bytes uint64) {
return p.tp.RemainingBytes()
}
+
+// SetTConfiguration implements TConfigurationSetter for propagation.
+func (p *TBufferedTransport) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(p.tp, conf)
+}
+
+var _ TConfigurationSetter = (*TBufferedTransport)(nil)
diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go
index 424906d..25e6d0c 100644
--- a/lib/go/thrift/compact_protocol.go
+++ b/lib/go/thrift/compact_protocol.go
@@ -75,20 +75,37 @@
}
}
-type TCompactProtocolFactory struct{}
+type TCompactProtocolFactory struct {
+ cfg *TConfiguration
+}
+// Deprecated: Use NewTCompactProtocolFactoryConf instead.
func NewTCompactProtocolFactory() *TCompactProtocolFactory {
- return &TCompactProtocolFactory{}
+ return NewTCompactProtocolFactoryConf(&TConfiguration{
+ noPropagation: true,
+ })
+}
+
+func NewTCompactProtocolFactoryConf(conf *TConfiguration) *TCompactProtocolFactory {
+ return &TCompactProtocolFactory{
+ cfg: conf,
+ }
}
func (p *TCompactProtocolFactory) GetProtocol(trans TTransport) TProtocol {
- return NewTCompactProtocol(trans)
+ return NewTCompactProtocolConf(trans, p.cfg)
+}
+
+func (p *TCompactProtocolFactory) SetTConfiguration(conf *TConfiguration) {
+ p.cfg = conf
}
type TCompactProtocol struct {
trans TRichTransport
origTransport TTransport
+ cfg *TConfiguration
+
// Used to keep track of the last field for the current and previous structs,
// so we can do the delta stuff.
lastField []int
@@ -107,9 +124,19 @@
buffer [64]byte
}
-// Create a TCompactProtocol given a TTransport
+// Deprecated: Use NewTCompactProtocolConf instead.
func NewTCompactProtocol(trans TTransport) *TCompactProtocol {
- p := &TCompactProtocol{origTransport: trans, lastField: []int{}}
+ return NewTCompactProtocolConf(trans, &TConfiguration{
+ noPropagation: true,
+ })
+}
+
+func NewTCompactProtocolConf(trans TTransport, conf *TConfiguration) *TCompactProtocol {
+ PropagateTConfiguration(trans, conf)
+ p := &TCompactProtocol{
+ origTransport: trans,
+ cfg: conf,
+ }
if et, ok := trans.(TRichTransport); ok {
p.trans = et
} else {
@@ -117,7 +144,6 @@
}
return p
-
}
//
@@ -576,8 +602,9 @@
if e != nil {
return "", NewTProtocolException(e)
}
- if length < 0 {
- return "", invalidDataLength
+ err = checkSizeForProtocol(length, p.cfg)
+ if err != nil {
+ return
}
if length == 0 {
return "", nil
@@ -599,12 +626,13 @@
if e != nil {
return nil, NewTProtocolException(e)
}
+ err = checkSizeForProtocol(length, p.cfg)
+ if err != nil {
+ return
+ }
if length == 0 {
return []byte{}, nil
}
- if length < 0 {
- return nil, invalidDataLength
- }
buf, e := safeReadBytes(length, p.trans)
return buf, NewTProtocolException(e)
@@ -824,3 +852,14 @@
func (p *TCompactProtocol) getCompactType(t TType) tCompactType {
return ttypeToCompactType[t]
}
+
+func (p *TCompactProtocol) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(p.trans, conf)
+ PropagateTConfiguration(p.origTransport, conf)
+ p.cfg = conf
+}
+
+var (
+ _ TConfigurationSetter = (*TCompactProtocolFactory)(nil)
+ _ TConfigurationSetter = (*TCompactProtocol)(nil)
+)
diff --git a/lib/go/thrift/configuration.go b/lib/go/thrift/configuration.go
new file mode 100644
index 0000000..454d9f3
--- /dev/null
+++ b/lib/go/thrift/configuration.go
@@ -0,0 +1,378 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "crypto/tls"
+ "fmt"
+ "time"
+)
+
+// Default TConfiguration values.
+const (
+ DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024
+ DEFAULT_MAX_FRAME_SIZE = 16384000
+
+ DEFAULT_TBINARY_STRICT_READ = false
+ DEFAULT_TBINARY_STRICT_WRITE = true
+
+ DEFAULT_CONNECT_TIMEOUT = 0
+ DEFAULT_SOCKET_TIMEOUT = 0
+)
+
+// TConfiguration defines some configurations shared between TTransport,
+// TProtocol, TTransportFactory, TProtocolFactory, and other implementations.
+//
+// When constructing TConfiguration, you only need to specify the non-default
+// fields. All zero values have sane default values.
+//
+// Not all configurations defined are applicable to all implementations.
+// Implementations are free to ignore the configurations not applicable to them.
+//
+// All functions attached to this type are nil-safe.
+//
+// See [1] for spec.
+//
+// NOTE: When using TConfiguration, fill in all the configurations you want to
+// set across the stack, not only the ones you want to set in the immediate
+// TTransport/TProtocol.
+//
+// For example, say you want to migrate this old code into using TConfiguration:
+//
+// sccket := thrift.NewTSocketTimeout("host:port", time.Second)
+// transFactory := thrift.NewTFramedTransportFactoryMaxLength(
+// thrift.NewTTransportFactory(),
+// 1024 * 1024 * 256,
+// )
+// protoFactory := thrift.NewTBinaryProtocolFactory(true, true)
+//
+// This is the wrong way to do it because in the end the TConfiguration used by
+// socket and transFactory will be overwritten by the one used by protoFactory
+// because of TConfiguration propagation:
+//
+// // bad example, DO NOT USE
+// sccket := thrift.NewTSocketConf("host:port", &thrift.TConfiguration{
+// ConnectTimeout: time.Second,
+// SocketTimeout: time.Second,
+// })
+// transFactory := thrift.NewTFramedTransportFactoryConf(
+// thrift.NewTTransportFactory(),
+// &thrift.TConfiguration{
+// MaxFrameSize: 1024 * 1024 * 256,
+// },
+// )
+// protoFactory := thrift.NewTBinaryProtocolFactoryConf(&thrift.TConfiguration{
+// TBinaryStrictRead: thrift.BoolPtr(true),
+// TBinaryStrictWrite: thrift.BoolPtr(true),
+// })
+//
+// This is the correct way to do it:
+//
+// conf := &thrift.TConfiguration{
+// ConnectTimeout: time.Second,
+// SocketTimeout: time.Second,
+//
+// MaxFrameSize: 1024 * 1024 * 256,
+//
+// TBinaryStrictRead: thrift.BoolPtr(true),
+// TBinaryStrictWrite: thrift.BoolPtr(true),
+// }
+// sccket := thrift.NewTSocketConf("host:port", conf)
+// transFactory := thrift.NewTFramedTransportFactoryConf(thrift.NewTTransportFactory(), conf)
+// protoFactory := thrift.NewTBinaryProtocolFactoryConf(conf)
+//
+// [1]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-tconfiguration.md
+type TConfiguration struct {
+ // If <= 0, DEFAULT_MAX_MESSAGE_SIZE will be used instead.
+ MaxMessageSize int32
+
+ // If <= 0, DEFAULT_MAX_FRAME_SIZE will be used instead.
+ //
+ // Also if MaxMessageSize < MaxFrameSize,
+ // MaxMessageSize will be used instead.
+ MaxFrameSize int32
+
+ // Connect and socket timeouts to be used by TSocket and TSSLSocket.
+ //
+ // 0 means no timeout.
+ //
+ // If <0, DEFAULT_CONNECT_TIMEOUT and DEFAULT_SOCKET_TIMEOUT will be
+ // used.
+ ConnectTimeout time.Duration
+ SocketTimeout time.Duration
+
+ // TLS config to be used by TSSLSocket.
+ TLSConfig *tls.Config
+
+ // Strict read/write configurations for TBinaryProtocol.
+ //
+ // BoolPtr helper function is available to use literal values.
+ TBinaryStrictRead *bool
+ TBinaryStrictWrite *bool
+
+ // The wrapped protocol id to be used in THeader transport/protocol.
+ //
+ // THeaderProtocolIDPtr and THeaderProtocolIDPtrMust helper functions
+ // are provided to help filling this value.
+ THeaderProtocolID *THeaderProtocolID
+
+ // Used internally by deprecated constructors, to avoid overriding
+ // underlying TTransport/TProtocol's cfg by accidental propagations.
+ //
+ // For external users this is always false.
+ noPropagation bool
+}
+
+// GetMaxMessageSize returns the max message size an implementation should
+// follow.
+//
+// It's nil-safe. DEFAULT_MAX_MESSAGE_SIZE will be returned if tc is nil.
+func (tc *TConfiguration) GetMaxMessageSize() int32 {
+ if tc == nil || tc.MaxMessageSize <= 0 {
+ return DEFAULT_MAX_MESSAGE_SIZE
+ }
+ return tc.MaxMessageSize
+}
+
+// GetMaxFrameSize returns the max frame size an implementation should follow.
+//
+// It's nil-safe. DEFAULT_MAX_FRAME_SIZE will be returned if tc is nil.
+//
+// If the configured max message size is smaller than the configured max frame
+// size, the smaller one will be returned instead.
+func (tc *TConfiguration) GetMaxFrameSize() int32 {
+ if tc == nil {
+ return DEFAULT_MAX_FRAME_SIZE
+ }
+ maxFrameSize := tc.MaxFrameSize
+ if maxFrameSize <= 0 {
+ maxFrameSize = DEFAULT_MAX_FRAME_SIZE
+ }
+ if maxMessageSize := tc.GetMaxMessageSize(); maxMessageSize < maxFrameSize {
+ return maxMessageSize
+ }
+ return maxFrameSize
+}
+
+// GetConnectTimeout returns the connect timeout should be used by TSocket and
+// TSSLSocket.
+//
+// It's nil-safe. If tc is nil, DEFAULT_CONNECT_TIMEOUT will be returned instead.
+func (tc *TConfiguration) GetConnectTimeout() time.Duration {
+ if tc == nil || tc.ConnectTimeout < 0 {
+ return DEFAULT_CONNECT_TIMEOUT
+ }
+ return tc.ConnectTimeout
+}
+
+// GetSocketTimeout returns the socket timeout should be used by TSocket and
+// TSSLSocket.
+//
+// It's nil-safe. If tc is nil, DEFAULT_SOCKET_TIMEOUT will be returned instead.
+func (tc *TConfiguration) GetSocketTimeout() time.Duration {
+ if tc == nil || tc.SocketTimeout < 0 {
+ return DEFAULT_SOCKET_TIMEOUT
+ }
+ return tc.SocketTimeout
+}
+
+// GetTLSConfig returns the tls config should be used by TSSLSocket.
+//
+// It's nil-safe. If tc is nil, nil will be returned instead.
+func (tc *TConfiguration) GetTLSConfig() *tls.Config {
+ if tc == nil {
+ return nil
+ }
+ return tc.TLSConfig
+}
+
+// GetTBinaryStrictRead returns the strict read configuration TBinaryProtocol
+// should follow.
+//
+// It's nil-safe. DEFAULT_TBINARY_STRICT_READ will be returned if either tc or
+// tc.TBinaryStrictRead is nil.
+func (tc *TConfiguration) GetTBinaryStrictRead() bool {
+ if tc == nil || tc.TBinaryStrictRead == nil {
+ return DEFAULT_TBINARY_STRICT_READ
+ }
+ return *tc.TBinaryStrictRead
+}
+
+// GetTBinaryStrictWrite returns the strict read configuration TBinaryProtocol
+// should follow.
+//
+// It's nil-safe. DEFAULT_TBINARY_STRICT_WRITE will be returned if either tc or
+// tc.TBinaryStrictWrite is nil.
+func (tc *TConfiguration) GetTBinaryStrictWrite() bool {
+ if tc == nil || tc.TBinaryStrictWrite == nil {
+ return DEFAULT_TBINARY_STRICT_WRITE
+ }
+ return *tc.TBinaryStrictWrite
+}
+
+// GetTHeaderProtocolID returns the THeaderProtocolID should be used by
+// THeaderProtocol clients (for servers, they always use the same one as the
+// client instead).
+//
+// It's nil-safe. If either tc or tc.THeaderProtocolID is nil,
+// THeaderProtocolDefault will be returned instead.
+// THeaderProtocolDefault will also be returned if configured value is invalid.
+func (tc *TConfiguration) GetTHeaderProtocolID() THeaderProtocolID {
+ if tc == nil || tc.THeaderProtocolID == nil {
+ return THeaderProtocolDefault
+ }
+ protoID := *tc.THeaderProtocolID
+ if err := protoID.Validate(); err != nil {
+ return THeaderProtocolDefault
+ }
+ return protoID
+}
+
+// THeaderProtocolIDPtr validates and returns the pointer to id.
+//
+// If id is not a valid THeaderProtocolID, a pointer to THeaderProtocolDefault
+// and the validation error will be returned.
+func THeaderProtocolIDPtr(id THeaderProtocolID) (*THeaderProtocolID, error) {
+ err := id.Validate()
+ if err != nil {
+ id = THeaderProtocolDefault
+ }
+ return &id, err
+}
+
+// THeaderProtocolIDPtrMust validates and returns the pointer to id.
+//
+// It's similar to THeaderProtocolIDPtr, but it panics on validation errors
+// instead of returning them.
+func THeaderProtocolIDPtrMust(id THeaderProtocolID) *THeaderProtocolID {
+ ptr, err := THeaderProtocolIDPtr(id)
+ if err != nil {
+ panic(err)
+ }
+ return ptr
+}
+
+// TConfigurationSetter is an optional interface TProtocol, TTransport,
+// TProtocolFactory, TTransportFactory, and other implementations can implement.
+//
+// It's intended to be called during intializations.
+// The behavior of calling SetTConfiguration on a TTransport/TProtocol in the
+// middle of a message is undefined:
+// It may or may not change the behavior of the current processing message,
+// and it may even cause the current message to fail.
+//
+// Note for implementations: SetTConfiguration might be called multiple times
+// with the same value in quick successions due to the implementation of the
+// propagation. Implementations should make SetTConfiguration as simple as
+// possible (usually just overwrite the stored configuration and propagate it to
+// the wrapped TTransports/TProtocols).
+type TConfigurationSetter interface {
+ SetTConfiguration(*TConfiguration)
+}
+
+// PropagateTConfiguration propagates cfg to impl if impl implements
+// TConfigurationSetter and cfg is non-nil, otherwise it does nothing.
+//
+// NOTE: nil cfg is not propagated. If you want to propagate a TConfiguration
+// with everything being default value, use &TConfiguration{} explicitly instead.
+func PropagateTConfiguration(impl interface{}, cfg *TConfiguration) {
+ if cfg == nil || cfg.noPropagation {
+ return
+ }
+
+ if setter, ok := impl.(TConfigurationSetter); ok {
+ setter.SetTConfiguration(cfg)
+ }
+}
+
+func checkSizeForProtocol(size int32, cfg *TConfiguration) error {
+ if size < 0 {
+ return NewTProtocolExceptionWithType(
+ NEGATIVE_SIZE,
+ fmt.Errorf("negative size: %d", size),
+ )
+ }
+ if size > cfg.GetMaxMessageSize() {
+ return NewTProtocolExceptionWithType(
+ SIZE_LIMIT,
+ fmt.Errorf("size exceeded max allowed: %d", size),
+ )
+ }
+ return nil
+}
+
+type tTransportFactoryConf struct {
+ delegate TTransportFactory
+ cfg *TConfiguration
+}
+
+func (f *tTransportFactoryConf) GetTransport(orig TTransport) (TTransport, error) {
+ trans, err := f.delegate.GetTransport(orig)
+ if err == nil {
+ PropagateTConfiguration(orig, f.cfg)
+ PropagateTConfiguration(trans, f.cfg)
+ }
+ return trans, err
+}
+
+func (f *tTransportFactoryConf) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(f.delegate, f.cfg)
+ f.cfg = cfg
+}
+
+// TTransportFactoryConf wraps a TTransportFactory to propagate
+// TConfiguration on the factory's GetTransport calls.
+func TTransportFactoryConf(delegate TTransportFactory, conf *TConfiguration) TTransportFactory {
+ return &tTransportFactoryConf{
+ delegate: delegate,
+ cfg: conf,
+ }
+}
+
+type tProtocolFactoryConf struct {
+ delegate TProtocolFactory
+ cfg *TConfiguration
+}
+
+func (f *tProtocolFactoryConf) GetProtocol(trans TTransport) TProtocol {
+ proto := f.delegate.GetProtocol(trans)
+ PropagateTConfiguration(trans, f.cfg)
+ PropagateTConfiguration(proto, f.cfg)
+ return proto
+}
+
+func (f *tProtocolFactoryConf) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(f.delegate, f.cfg)
+ f.cfg = cfg
+}
+
+// TProtocolFactoryConf wraps a TProtocolFactory to propagate
+// TConfiguration on the factory's GetProtocol calls.
+func TProtocolFactoryConf(delegate TProtocolFactory, conf *TConfiguration) TProtocolFactory {
+ return &tProtocolFactoryConf{
+ delegate: delegate,
+ cfg: conf,
+ }
+}
+
+var (
+ _ TConfigurationSetter = (*tTransportFactoryConf)(nil)
+ _ TConfigurationSetter = (*tProtocolFactoryConf)(nil)
+)
diff --git a/lib/go/thrift/configuration_test.go b/lib/go/thrift/configuration_test.go
new file mode 100644
index 0000000..f747842
--- /dev/null
+++ b/lib/go/thrift/configuration_test.go
@@ -0,0 +1,338 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "crypto/tls"
+ "testing"
+ "time"
+)
+
+func TestTConfiguration(t *testing.T) {
+ invalidProtoID := THeaderProtocolID(-1)
+ if invalidProtoID.Validate() == nil {
+ t.Fatalf("Expected %v to be an invalid THeaderProtocolID, it passes the validation", invalidProtoID)
+ }
+
+ tlsConfig := &tls.Config{
+ Time: time.Now,
+ }
+
+ for _, c := range []struct {
+ label string
+ cfg *TConfiguration
+ expectedMessageSize int32
+ expectedFrameSize int32
+ expectedConnectTimeout time.Duration
+ expectedSocketTimeout time.Duration
+ expectedTLSConfig *tls.Config
+ expectedBinaryRead bool
+ expectedBinaryWrite bool
+ expectedProtoID THeaderProtocolID
+ }{
+ {
+ label: "nil",
+ cfg: nil,
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "empty",
+ cfg: &TConfiguration{},
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "normal",
+ cfg: &TConfiguration{
+ MaxMessageSize: 1024,
+ MaxFrameSize: 1024,
+ ConnectTimeout: time.Millisecond,
+ SocketTimeout: time.Millisecond * 2,
+ TLSConfig: tlsConfig,
+ TBinaryStrictRead: BoolPtr(true),
+ TBinaryStrictWrite: BoolPtr(false),
+ THeaderProtocolID: THeaderProtocolIDPtrMust(THeaderProtocolCompact),
+ },
+ expectedMessageSize: 1024,
+ expectedFrameSize: 1024,
+ expectedConnectTimeout: time.Millisecond,
+ expectedSocketTimeout: time.Millisecond * 2,
+ expectedTLSConfig: tlsConfig,
+ expectedBinaryRead: true,
+ expectedBinaryWrite: false,
+ expectedProtoID: THeaderProtocolCompact,
+ },
+ {
+ label: "message<frame",
+ cfg: &TConfiguration{
+ MaxMessageSize: 1024,
+ MaxFrameSize: 4096,
+ },
+ expectedMessageSize: 1024,
+ expectedFrameSize: 1024,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "frame<message",
+ cfg: &TConfiguration{
+ MaxMessageSize: 4096,
+ MaxFrameSize: 1024,
+ },
+ expectedMessageSize: 4096,
+ expectedFrameSize: 1024,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "negative-message-size",
+ cfg: &TConfiguration{
+ MaxMessageSize: -1,
+ },
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "negative-frame-size",
+ cfg: &TConfiguration{
+ MaxFrameSize: -1,
+ },
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "negative-connect-timeout",
+ cfg: &TConfiguration{
+ ConnectTimeout: -1,
+ SocketTimeout: time.Millisecond,
+ },
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: time.Millisecond,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "negative-socket-timeout",
+ cfg: &TConfiguration{
+ SocketTimeout: -1,
+ },
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ {
+ label: "invalid-proto-id",
+ cfg: &TConfiguration{
+ THeaderProtocolID: &invalidProtoID,
+ },
+ expectedMessageSize: DEFAULT_MAX_MESSAGE_SIZE,
+ expectedFrameSize: DEFAULT_MAX_FRAME_SIZE,
+ expectedConnectTimeout: DEFAULT_CONNECT_TIMEOUT,
+ expectedSocketTimeout: DEFAULT_SOCKET_TIMEOUT,
+ expectedTLSConfig: nil,
+ expectedBinaryRead: DEFAULT_TBINARY_STRICT_READ,
+ expectedBinaryWrite: DEFAULT_TBINARY_STRICT_WRITE,
+ expectedProtoID: THeaderProtocolDefault,
+ },
+ } {
+ t.Run(c.label, func(t *testing.T) {
+ t.Run("GetMaxMessageSize", func(t *testing.T) {
+ actual := c.cfg.GetMaxMessageSize()
+ if actual != c.expectedMessageSize {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedMessageSize,
+ actual,
+ )
+ }
+ })
+ t.Run("GetMaxFrameSize", func(t *testing.T) {
+ actual := c.cfg.GetMaxFrameSize()
+ if actual != c.expectedFrameSize {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedFrameSize,
+ actual,
+ )
+ }
+ })
+ t.Run("GetConnectTimeout", func(t *testing.T) {
+ actual := c.cfg.GetConnectTimeout()
+ if actual != c.expectedConnectTimeout {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedConnectTimeout,
+ actual,
+ )
+ }
+ })
+ t.Run("GetSocketTimeout", func(t *testing.T) {
+ actual := c.cfg.GetSocketTimeout()
+ if actual != c.expectedSocketTimeout {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedSocketTimeout,
+ actual,
+ )
+ }
+ })
+ t.Run("GetTLSConfig", func(t *testing.T) {
+ actual := c.cfg.GetTLSConfig()
+ if actual != c.expectedTLSConfig {
+ t.Errorf(
+ "Expected %p(%#v), got %p(%#v)",
+ c.expectedTLSConfig,
+ c.expectedTLSConfig,
+ actual,
+ actual,
+ )
+ }
+ })
+ t.Run("GetTBinaryStrictRead", func(t *testing.T) {
+ actual := c.cfg.GetTBinaryStrictRead()
+ if actual != c.expectedBinaryRead {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedBinaryRead,
+ actual,
+ )
+ }
+ })
+ t.Run("GetTBinaryStrictWrite", func(t *testing.T) {
+ actual := c.cfg.GetTBinaryStrictWrite()
+ if actual != c.expectedBinaryWrite {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedBinaryWrite,
+ actual,
+ )
+ }
+ })
+ t.Run("GetTHeaderProtocolID", func(t *testing.T) {
+ actual := c.cfg.GetTHeaderProtocolID()
+ if actual != c.expectedProtoID {
+ t.Errorf(
+ "Expected %v, got %v",
+ c.expectedProtoID,
+ actual,
+ )
+ }
+ })
+ })
+ }
+}
+
+func TestTHeaderProtocolIDPtr(t *testing.T) {
+ var invalidProtoID = THeaderProtocolID(-1)
+ if invalidProtoID.Validate() == nil {
+ t.Fatalf("Expected %v to be an invalid THeaderProtocolID, it passes the validation", invalidProtoID)
+ }
+
+ ptr, err := THeaderProtocolIDPtr(invalidProtoID)
+ if err == nil {
+ t.Error("Expected error on invalid proto id, got nil")
+ }
+ if ptr == nil {
+ t.Fatal("Expected non-nil pointer on invalid proto id, got nil")
+ }
+ if *ptr != THeaderProtocolDefault {
+ t.Errorf("Expected pointer to %v, got %v", THeaderProtocolDefault, *ptr)
+ }
+}
+
+func TestTHeaderProtocolIDPtrMust(t *testing.T) {
+ const expected = THeaderProtocolCompact
+ ptr := THeaderProtocolIDPtrMust(expected)
+ if *ptr != expected {
+ t.Errorf("Expected pointer to %v, got %v", expected, *ptr)
+ }
+}
+
+func TestTHeaderProtocolIDPtrMustPanic(t *testing.T) {
+ var invalidProtoID = THeaderProtocolID(-1)
+ if invalidProtoID.Validate() == nil {
+ t.Fatalf("Expected %v to be an invalid THeaderProtocolID, it passes the validation", invalidProtoID)
+ }
+
+ defer func() {
+ if recovered := recover(); recovered == nil {
+ t.Error("Expected panic on invalid proto id, did not happen.")
+ }
+ }()
+
+ THeaderProtocolIDPtrMust(invalidProtoID)
+}
+
+func TestPropagateTConfiguration(t *testing.T) {
+ cfg := &TConfiguration{}
+ // Just make sure it won't cause panics on some nil
+ // TProtocol/TTransport/TProtocolFactory/TTransportFactory values.
+ PropagateTConfiguration(nil, cfg)
+ var proto TProtocol
+ PropagateTConfiguration(proto, cfg)
+ var protoFactory TProtocolFactory
+ PropagateTConfiguration(protoFactory, cfg)
+ var trans TTransport
+ PropagateTConfiguration(trans, cfg)
+ var transFactory TTransportFactory
+ PropagateTConfiguration(transFactory, cfg)
+}
diff --git a/lib/go/thrift/debug_protocol.go b/lib/go/thrift/debug_protocol.go
index 875844b..fdf9bfe 100644
--- a/lib/go/thrift/debug_protocol.go
+++ b/lib/go/thrift/debug_protocol.go
@@ -437,3 +437,11 @@
func (tdp *TDebugProtocol) Transport() TTransport {
return tdp.Delegate.Transport()
}
+
+// SetTConfiguration implements TConfigurationSetter for propagation.
+func (tdp *TDebugProtocol) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(tdp.Delegate, conf)
+ PropagateTConfiguration(tdp.DuplicateTo, conf)
+}
+
+var _ TConfigurationSetter = (*TDebugProtocol)(nil)
diff --git a/lib/go/thrift/framed_transport.go b/lib/go/thrift/framed_transport.go
index f192075..f683e7f 100644
--- a/lib/go/thrift/framed_transport.go
+++ b/lib/go/thrift/framed_transport.go
@@ -28,11 +28,13 @@
"io"
)
+// Deprecated: Use DEFAULT_MAX_FRAME_SIZE instead.
const DEFAULT_MAX_LENGTH = 16384000
type TFramedTransport struct {
transport TTransport
- maxLength uint32
+
+ cfg *TConfiguration
writeBuf bytes.Buffer
@@ -43,32 +45,75 @@
}
type tFramedTransportFactory struct {
- factory TTransportFactory
- maxLength uint32
+ factory TTransportFactory
+ cfg *TConfiguration
}
+// Deprecated: Use NewTFramedTransportFactoryConf instead.
func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory {
- return &tFramedTransportFactory{factory: factory, maxLength: DEFAULT_MAX_LENGTH}
+ return NewTFramedTransportFactoryConf(factory, &TConfiguration{
+ MaxFrameSize: DEFAULT_MAX_LENGTH,
+
+ noPropagation: true,
+ })
}
+// Deprecated: Use NewTFramedTransportFactoryConf instead.
func NewTFramedTransportFactoryMaxLength(factory TTransportFactory, maxLength uint32) TTransportFactory {
- return &tFramedTransportFactory{factory: factory, maxLength: maxLength}
+ return NewTFramedTransportFactoryConf(factory, &TConfiguration{
+ MaxFrameSize: int32(maxLength),
+
+ noPropagation: true,
+ })
+}
+
+func NewTFramedTransportFactoryConf(factory TTransportFactory, conf *TConfiguration) TTransportFactory {
+ PropagateTConfiguration(factory, conf)
+ return &tFramedTransportFactory{
+ factory: factory,
+ cfg: conf,
+ }
}
func (p *tFramedTransportFactory) GetTransport(base TTransport) (TTransport, error) {
+ PropagateTConfiguration(base, p.cfg)
tt, err := p.factory.GetTransport(base)
if err != nil {
return nil, err
}
- return NewTFramedTransportMaxLength(tt, p.maxLength), nil
+ return NewTFramedTransportConf(tt, p.cfg), nil
}
+func (p *tFramedTransportFactory) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(p.factory, cfg)
+ p.cfg = cfg
+}
+
+// Deprecated: Use NewTFramedTransportConf instead.
func NewTFramedTransport(transport TTransport) *TFramedTransport {
- return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: DEFAULT_MAX_LENGTH}
+ return NewTFramedTransportConf(transport, &TConfiguration{
+ MaxFrameSize: DEFAULT_MAX_LENGTH,
+
+ noPropagation: true,
+ })
}
+// Deprecated: Use NewTFramedTransportConf instead.
func NewTFramedTransportMaxLength(transport TTransport, maxLength uint32) *TFramedTransport {
- return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: maxLength}
+ return NewTFramedTransportConf(transport, &TConfiguration{
+ MaxFrameSize: int32(maxLength),
+
+ noPropagation: true,
+ })
+}
+
+func NewTFramedTransportConf(transport TTransport, conf *TConfiguration) *TFramedTransport {
+ PropagateTConfiguration(transport, conf)
+ return &TFramedTransport{
+ transport: transport,
+ reader: bufio.NewReader(transport),
+ cfg: conf,
+ }
}
func (p *TFramedTransport) Open() error {
@@ -155,7 +200,7 @@
return err
}
size := binary.BigEndian.Uint32(buf)
- if size < 0 || size > p.maxLength {
+ if size < 0 || size > uint32(p.cfg.GetMaxFrameSize()) {
return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size))
}
_, err := io.CopyN(&p.readBuf, p.reader, int64(size))
@@ -165,3 +210,14 @@
func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) {
return uint64(p.readBuf.Len())
}
+
+// SetTConfiguration implements TConfigurationSetter.
+func (p *TFramedTransport) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(p.transport, cfg)
+ p.cfg = cfg
+}
+
+var (
+ _ TConfigurationSetter = (*tFramedTransportFactory)(nil)
+ _ TConfigurationSetter = (*TFramedTransport)(nil)
+)
diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go
index f86d558..5ad48e4 100644
--- a/lib/go/thrift/header_protocol.go
+++ b/lib/go/thrift/header_protocol.go
@@ -34,76 +34,65 @@
// Will be initialized on first read/write.
protocol TProtocol
+
+ cfg *TConfiguration
}
-// NewTHeaderProtocol creates a new THeaderProtocol from the underlying
-// transport with default protocol ID.
+// Deprecated: Use NewTHeaderProtocolConf instead.
+func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
+ return newTHeaderProtocolConf(trans, &TConfiguration{
+ noPropagation: true,
+ })
+}
+
+// NewTHeaderProtocolConf creates a new THeaderProtocol from the underlying
+// transport with given TConfiguration.
//
// The passed in transport will be wrapped with THeaderTransport.
//
// Note that THeaderTransport handles frame and zlib by itself,
// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket),
// instead of rich transports like TZlibTransport or TFramedTransport.
-func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
- p, err := newTHeaderProtocolWithProtocolID(trans, THeaderProtocolDefault)
- if err != nil {
- // Since we used THeaderProtocolDefault this should never happen,
- // but put a sanity check here just in case.
- panic(err)
- }
- return p
+func NewTHeaderProtocolConf(trans TTransport, conf *TConfiguration) *THeaderProtocol {
+ return newTHeaderProtocolConf(trans, conf)
}
-func newTHeaderProtocolWithProtocolID(trans TTransport, protoID THeaderProtocolID) (*THeaderProtocol, error) {
- t, err := NewTHeaderTransportWithProtocolID(trans, protoID)
- if err != nil {
- return nil, err
- }
- p, err := t.protocolID.GetProtocol(t)
- if err != nil {
- return nil, err
- }
+func newTHeaderProtocolConf(trans TTransport, cfg *TConfiguration) *THeaderProtocol {
+ t := NewTHeaderTransportConf(trans, cfg)
+ p, _ := t.cfg.GetTHeaderProtocolID().GetProtocol(t)
+ PropagateTConfiguration(p, cfg)
return &THeaderProtocol{
transport: t,
protocol: p,
- }, nil
+ cfg: cfg,
+ }
}
type tHeaderProtocolFactory struct {
- protoID THeaderProtocolID
+ cfg *TConfiguration
}
func (f tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
- p, err := newTHeaderProtocolWithProtocolID(trans, f.protoID)
- if err != nil {
- // Currently there's no way for external users to construct a
- // valid factory with invalid protoID, so this should never
- // happen. But put a sanity check here just in case in the
- // future a bug made that possible.
- panic(err)
- }
- return p
+ return newTHeaderProtocolConf(trans, f.cfg)
}
-// NewTHeaderProtocolFactory creates a factory for THeader with default protocol
-// ID.
-//
-// It's a wrapper for NewTHeaderProtocol
+func (f *tHeaderProtocolFactory) SetTConfiguration(cfg *TConfiguration) {
+ f.cfg = cfg
+}
+
+// Deprecated: Use NewTHeaderProtocolFactoryConf instead.
func NewTHeaderProtocolFactory() TProtocolFactory {
- return tHeaderProtocolFactory{
- protoID: THeaderProtocolDefault,
- }
+ return NewTHeaderProtocolFactoryConf(&TConfiguration{
+ noPropagation: true,
+ })
}
-// NewTHeaderProtocolFactoryWithProtocolID creates a factory for THeader with
-// given protocol ID.
-func NewTHeaderProtocolFactoryWithProtocolID(protoID THeaderProtocolID) (TProtocolFactory, error) {
- if err := protoID.Validate(); err != nil {
- return nil, err
- }
+// NewTHeaderProtocolFactoryConf creates a factory for THeader with given
+// TConfiguration.
+func NewTHeaderProtocolFactoryConf(conf *TConfiguration) TProtocolFactory {
return tHeaderProtocolFactory{
- protoID: protoID,
- }, nil
+ cfg: conf,
+ }
}
// Transport returns the underlying transport.
@@ -142,6 +131,7 @@
if err != nil {
return err
}
+ PropagateTConfiguration(newProto, p.cfg)
p.protocol = newProto
p.transport.SequenceID = seqID
return p.protocol.WriteMessageBegin(ctx, name, typeID, seqID)
@@ -261,6 +251,7 @@
}
return
}
+ PropagateTConfiguration(newProto, p.cfg)
p.protocol = newProto
return p.protocol.ReadMessageBegin(ctx)
@@ -346,6 +337,13 @@
return p.protocol.Skip(ctx, fieldType)
}
+// SetTConfiguration implements TConfigurationSetter.
+func (p *THeaderProtocol) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(p.transport, cfg)
+ PropagateTConfiguration(p.protocol, cfg)
+ p.cfg = cfg
+}
+
// GetResponseHeadersFromClient is a helper function to get the read THeaderMap
// from the last response received from the given client.
//
@@ -359,3 +357,8 @@
}
return nil
}
+
+var (
+ _ TConfigurationSetter = (*tHeaderProtocolFactory)(nil)
+ _ TConfigurationSetter = (*THeaderProtocol)(nil)
+)
diff --git a/lib/go/thrift/header_protocol_test.go b/lib/go/thrift/header_protocol_test.go
index f66ea64..48a69bf 100644
--- a/lib/go/thrift/header_protocol_test.go
+++ b/lib/go/thrift/header_protocol_test.go
@@ -34,11 +34,9 @@
t.Run(
"compact",
func(t *testing.T) {
- f, err := NewTHeaderProtocolFactoryWithProtocolID(THeaderProtocolCompact)
- if err != nil {
- t.Fatal(err)
- }
- ReadWriteProtocolTest(t, f)
+ ReadWriteProtocolTest(t, NewTHeaderProtocolFactoryConf(&TConfiguration{
+ THeaderProtocolID: THeaderProtocolIDPtrMust(THeaderProtocolCompact),
+ }))
},
)
}
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index 562d02f..1e8e302 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -264,7 +264,7 @@
writeTransforms []THeaderTransformID
clientType clientType
- protocolID THeaderProtocolID
+ cfg *TConfiguration
// buffer is used in the following scenarios to avoid repetitive
// allocations, while 4 is big enough for all those scenarios:
@@ -276,51 +276,35 @@
var _ TTransport = (*THeaderTransport)(nil)
-// NewTHeaderTransport creates THeaderTransport from the underlying transport.
-//
-// Please note that THeaderTransport handles framing and zlib by itself,
-// so the underlying transport should be the raw socket transports (TSocket or TSSLSocket),
-// instead of rich transports like TZlibTransport or TFramedTransport.
-//
-// If trans is already a *THeaderTransport, it will be returned as is.
+// Deprecated: Use NewTHeaderTransportConf instead.
func NewTHeaderTransport(trans TTransport) *THeaderTransport {
- if ht, ok := trans.(*THeaderTransport); ok {
- return ht
- }
- return &THeaderTransport{
- transport: trans,
- reader: bufio.NewReader(trans),
- writeHeaders: make(THeaderMap),
- protocolID: THeaderProtocolDefault,
- }
+ return NewTHeaderTransportConf(trans, &TConfiguration{
+ noPropagation: true,
+ })
}
-// NewTHeaderTransportWithProtocolID creates THeaderTransport from the
-// underlying transport, with given protocol ID set.
+// NewTHeaderTransportConf creates THeaderTransport from the
+// underlying transport, with given TConfiguration attached.
//
// If trans is already a *THeaderTransport, it will be returned as is,
-// but with protocol ID overridden by the value passed in.
+// but with TConfiguration overridden by the value passed in.
//
-// If the passed in protocol ID is an invalid/unsupported one,
-// this function returns error.
-//
-// The protocol ID overridden is only useful for client transports.
+// The protocol ID in TConfiguration is only useful for client transports.
// For servers,
// the protocol ID will be overridden again to the one set by the client,
// to ensure that servers always speak the same dialect as the client.
-func NewTHeaderTransportWithProtocolID(trans TTransport, protoID THeaderProtocolID) (*THeaderTransport, error) {
- if err := protoID.Validate(); err != nil {
- return nil, err
- }
+func NewTHeaderTransportConf(trans TTransport, conf *TConfiguration) *THeaderTransport {
if ht, ok := trans.(*THeaderTransport); ok {
- return ht, nil
+ ht.SetTConfiguration(conf)
+ return ht
}
+ PropagateTConfiguration(trans, conf)
return &THeaderTransport{
transport: trans,
reader: bufio.NewReader(trans),
writeHeaders: make(THeaderMap),
- protocolID: protoID,
- }, nil
+ cfg: conf,
+ }
}
// Open calls the underlying transport's Open function.
@@ -375,7 +359,7 @@
// At this point it should be a framed message,
// sanity check on frameSize then discard the peeked part.
- if frameSize > THeaderMaxFrameSize {
+ if frameSize > THeaderMaxFrameSize || frameSize > uint32(t.cfg.GetMaxFrameSize()) {
return NewTProtocolExceptionWithType(
SIZE_LIMIT,
errors.New("frame too large"),
@@ -451,6 +435,7 @@
return err
}
hp := NewTCompactProtocol(headerBuf)
+ hp.SetTConfiguration(t.cfg)
// At this point the header is already read into headerBuf,
// and t.frameBuffer starts from the actual payload.
@@ -458,7 +443,17 @@
if err != nil {
return err
}
- t.protocolID = THeaderProtocolID(protoID)
+ idPtr, err := THeaderProtocolIDPtr(THeaderProtocolID(protoID))
+ if err != nil {
+ return err
+ }
+ if t.cfg == nil {
+ t.cfg = &TConfiguration{
+ noPropagation: true,
+ }
+ }
+ t.cfg.THeaderProtocolID = idPtr
+
var transformCount int32
transformCount, err = hp.readVarint32()
if err != nil {
@@ -601,7 +596,8 @@
case clientHeaders:
headers := NewTMemoryBuffer()
hp := NewTCompactProtocol(headers)
- if _, err := hp.writeVarint32(int32(t.protocolID)); err != nil {
+ hp.SetTConfiguration(t.cfg)
+ if _, err := hp.writeVarint32(int32(t.cfg.GetTHeaderProtocolID())); err != nil {
return NewTTransportExceptionFromError(err)
}
if _, err := hp.writeVarint32(int32(len(t.writeTransforms))); err != nil {
@@ -746,7 +742,7 @@
func (t *THeaderTransport) Protocol() THeaderProtocolID {
switch t.clientType {
default:
- return t.protocolID
+ return t.cfg.GetTHeaderProtocolID()
case clientFramedBinary, clientUnframedBinary:
return THeaderProtocolBinary
case clientFramedCompact, clientUnframedCompact:
@@ -763,17 +759,37 @@
}
}
+// SetTConfiguration implements TConfigurationSetter.
+func (t *THeaderTransport) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(t.transport, cfg)
+ t.cfg = cfg
+}
+
// THeaderTransportFactory is a TTransportFactory implementation to create
// THeaderTransport.
+//
+// It also implements TConfigurationSetter.
type THeaderTransportFactory struct {
// The underlying factory, could be nil.
Factory TTransportFactory
+
+ cfg *TConfiguration
}
-// NewTHeaderTransportFactory creates a new *THeaderTransportFactory.
+// Deprecated: Use NewTHeaderTransportFactoryConf instead.
func NewTHeaderTransportFactory(factory TTransportFactory) TTransportFactory {
+ return NewTHeaderTransportFactoryConf(factory, &TConfiguration{
+ noPropagation: true,
+ })
+}
+
+// NewTHeaderTransportFactoryConf creates a new *THeaderTransportFactory with
+// the given *TConfiguration.
+func NewTHeaderTransportFactoryConf(factory TTransportFactory, conf *TConfiguration) TTransportFactory {
return &THeaderTransportFactory{
Factory: factory,
+
+ cfg: conf,
}
}
@@ -784,7 +800,18 @@
if err != nil {
return nil, err
}
- return NewTHeaderTransport(t), nil
+ return NewTHeaderTransportConf(t, f.cfg), nil
}
- return NewTHeaderTransport(trans), nil
+ return NewTHeaderTransportConf(trans, f.cfg), nil
}
+
+// SetTConfiguration implements TConfigurationSetter.
+func (f *THeaderTransportFactory) SetTConfiguration(cfg *TConfiguration) {
+ PropagateTConfiguration(f.Factory, f.cfg)
+ f.cfg = cfg
+}
+
+var (
+ _ TConfigurationSetter = (*THeaderTransportFactory)(nil)
+ _ TConfigurationSetter = (*THeaderTransport)(nil)
+)
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 5b47680..41efb18 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -21,6 +21,7 @@
import (
"context"
+ "fmt"
"io"
"io/ioutil"
"strings"
@@ -31,10 +32,9 @@
func testTHeaderHeadersReadWriteProtocolID(t *testing.T, protoID THeaderProtocolID) {
trans := NewTMemoryBuffer()
reader := NewTHeaderTransport(trans)
- writer, err := NewTHeaderTransportWithProtocolID(trans, protoID)
- if err != nil {
- t.Fatal(err)
- }
+ writer := NewTHeaderTransportConf(trans, &TConfiguration{
+ THeaderProtocolID: &protoID,
+ })
const key1 = "key1"
const value1 = "value1"
@@ -265,3 +265,19 @@
t.Error(err)
}
}
+
+func BenchmarkTHeaderProtocolIDValidate(b *testing.B) {
+ for _, c := range []THeaderProtocolID{
+ THeaderProtocolBinary,
+ THeaderProtocolCompact,
+ -1,
+ } {
+ b.Run(fmt.Sprintf("%2v", c), func(b *testing.B) {
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ c.Validate()
+ }
+ })
+ })
+ }
+}
diff --git a/lib/go/thrift/iostream_transport.go b/lib/go/thrift/iostream_transport.go
index 0b1775d..1c47799 100644
--- a/lib/go/thrift/iostream_transport.go
+++ b/lib/go/thrift/iostream_transport.go
@@ -212,3 +212,11 @@
const maxSize = ^uint64(0)
return maxSize // the truth is, we just don't know unless framed is used
}
+
+// SetTConfiguration implements TConfigurationSetter for propagation.
+func (p *StreamTransport) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(p.Reader, conf)
+ PropagateTConfiguration(p.Writer, conf)
+}
+
+var _ TConfigurationSetter = (*StreamTransport)(nil)
diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go
index edc49cc..8e59d16 100644
--- a/lib/go/thrift/json_protocol.go
+++ b/lib/go/thrift/json_protocol.go
@@ -587,3 +587,5 @@
e := fmt.Errorf("Unknown type identifier: %s", fieldType)
return TType(STOP), NewTProtocolExceptionWithType(INVALID_DATA, e)
}
+
+var _ TConfigurationSetter = (*TJSONProtocol)(nil)
diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go
index e94b44b..d1a8154 100644
--- a/lib/go/thrift/simple_json_protocol.go
+++ b/lib/go/thrift/simple_json_protocol.go
@@ -1364,3 +1364,10 @@
}
return n, err
}
+
+// SetTConfiguration implements TConfigurationSetter for propagation.
+func (p *TSimpleJSONProtocol) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(p.trans, conf)
+}
+
+var _ TConfigurationSetter = (*TSimpleJSONProtocol)(nil)
diff --git a/lib/go/thrift/socket.go b/lib/go/thrift/socket.go
index af75dd1..e911bf1 100644
--- a/lib/go/thrift/socket.go
+++ b/lib/go/thrift/socket.go
@@ -26,57 +26,116 @@
)
type TSocket struct {
- conn *socketConn
- addr net.Addr
+ conn *socketConn
+ addr net.Addr
+ cfg *TConfiguration
+
connectTimeout time.Duration
socketTimeout time.Duration
}
-// NewTSocket creates a net.Conn-backed TTransport, given a host and port
-//
-// Example:
-// trans, err := thrift.NewTSocket("localhost:9090")
+// Deprecated: Use NewTSocketConf instead.
func NewTSocket(hostPort string) (*TSocket, error) {
- return NewTSocketTimeout(hostPort, 0, 0)
+ return NewTSocketConf(hostPort, &TConfiguration{
+ noPropagation: true,
+ })
}
-// NewTSocketTimeout creates a net.Conn-backed TTransport, given a host and port
-// it also accepts a timeout as a time.Duration
-func NewTSocketTimeout(hostPort string, connTimeout time.Duration, soTimeout time.Duration) (*TSocket, error) {
- //conn, err := net.DialTimeout(network, address, timeout)
+// NewTSocketConf creates a net.Conn-backed TTransport, given a host and port.
+//
+// Example:
+//
+// trans, err := thrift.NewTSocketConf("localhost:9090", &TConfiguration{
+// ConnectTimeout: time.Second, // Use 0 for no timeout
+// SocketTimeout: time.Second, // Use 0 for no timeout
+// })
+func NewTSocketConf(hostPort string, conf *TConfiguration) (*TSocket, error) {
addr, err := net.ResolveTCPAddr("tcp", hostPort)
if err != nil {
return nil, err
}
- return NewTSocketFromAddrTimeout(addr, connTimeout, soTimeout), nil
+ return NewTSocketFromAddrConf(addr, conf), nil
}
-// Creates a TSocket from a net.Addr
+// Deprecated: Use NewTSocketConf instead.
+func NewTSocketTimeout(hostPort string, connTimeout time.Duration, soTimeout time.Duration) (*TSocket, error) {
+ return NewTSocketConf(hostPort, &TConfiguration{
+ ConnectTimeout: connTimeout,
+ SocketTimeout: soTimeout,
+
+ noPropagation: true,
+ })
+}
+
+// NewTSocketFromAddrConf creates a TSocket from a net.Addr
+func NewTSocketFromAddrConf(addr net.Addr, conf *TConfiguration) *TSocket {
+ return &TSocket{
+ addr: addr,
+ cfg: conf,
+ }
+}
+
+// Deprecated: Use NewTSocketFromAddrConf instead.
func NewTSocketFromAddrTimeout(addr net.Addr, connTimeout time.Duration, soTimeout time.Duration) *TSocket {
- return &TSocket{addr: addr, connectTimeout: connTimeout, socketTimeout: soTimeout}
+ return NewTSocketFromAddrConf(addr, &TConfiguration{
+ ConnectTimeout: connTimeout,
+ SocketTimeout: soTimeout,
+
+ noPropagation: true,
+ })
}
-// Creates a TSocket from an existing net.Conn
+// NewTSocketFromConnConf creates a TSocket from an existing net.Conn.
+func NewTSocketFromConnConf(conn net.Conn, conf *TConfiguration) *TSocket {
+ return &TSocket{
+ conn: wrapSocketConn(conn),
+ addr: conn.RemoteAddr(),
+ cfg: conf,
+ }
+}
+
+// Deprecated: Use NewTSocketFromConnConf instead.
func NewTSocketFromConnTimeout(conn net.Conn, socketTimeout time.Duration) *TSocket {
- return &TSocket{conn: wrapSocketConn(conn), addr: conn.RemoteAddr(), socketTimeout: socketTimeout}
+ return NewTSocketFromConnConf(conn, &TConfiguration{
+ SocketTimeout: socketTimeout,
+
+ noPropagation: true,
+ })
+}
+
+// SetTConfiguration implements TConfigurationSetter.
+//
+// It can be used to set connect and socket timeouts.
+func (p *TSocket) SetTConfiguration(conf *TConfiguration) {
+ p.cfg = conf
}
// Sets the connect timeout
func (p *TSocket) SetConnTimeout(timeout time.Duration) error {
- p.connectTimeout = timeout
+ if p.cfg == nil {
+ p.cfg = &TConfiguration{
+ noPropagation: true,
+ }
+ }
+ p.cfg.ConnectTimeout = timeout
return nil
}
// Sets the socket timeout
func (p *TSocket) SetSocketTimeout(timeout time.Duration) error {
- p.socketTimeout = timeout
+ if p.cfg == nil {
+ p.cfg = &TConfiguration{
+ noPropagation: true,
+ }
+ }
+ p.cfg.SocketTimeout = timeout
return nil
}
func (p *TSocket) pushDeadline(read, write bool) {
var t time.Time
- if p.socketTimeout > 0 {
- t = time.Now().Add(time.Duration(p.socketTimeout))
+ if timeout := p.cfg.GetSocketTimeout(); timeout > 0 {
+ t = time.Now().Add(time.Duration(timeout))
}
if read && write {
p.conn.SetDeadline(t)
@@ -105,7 +164,7 @@
if p.conn, err = createSocketConnFromReturn(net.DialTimeout(
p.addr.Network(),
p.addr.String(),
- p.connectTimeout,
+ p.cfg.GetConnectTimeout(),
)); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
@@ -175,3 +234,5 @@
const maxSize = ^uint64(0)
return maxSize // the truth is, we just don't know unless framed is used
}
+
+var _ TConfigurationSetter = (*TSocket)(nil)
diff --git a/lib/go/thrift/ssl_socket.go b/lib/go/thrift/ssl_socket.go
index 15ae96f..6359a74 100644
--- a/lib/go/thrift/ssl_socket.go
+++ b/lib/go/thrift/ssl_socket.go
@@ -34,70 +34,115 @@
// addr is nil when hostPort is not "", and is only used when the
// TSSLSocket is constructed from a net.Addr.
addr net.Addr
- cfg *tls.Config
- connectTimeout time.Duration
- socketTimeout time.Duration
+ cfg *TConfiguration
}
-// NewTSSLSocket creates a net.Conn-backed TTransport, given a host and port and tls Configuration
+// NewTSSLSocketConf creates a net.Conn-backed TTransport, given a host and port.
//
// Example:
-// trans, err := thrift.NewTSSLSocket("localhost:9090", nil)
-func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) {
- return NewTSSLSocketTimeout(hostPort, cfg, 0, 0)
-}
-
-// NewTSSLSocketTimeout creates a net.Conn-backed TTransport, given a host and port
-// it also accepts a tls Configuration and connect/socket timeouts as time.Duration
-func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, connectTimeout, socketTimeout time.Duration) (*TSSLSocket, error) {
- if cfg.MinVersion == 0 {
+//
+// trans, err := thrift.NewTSSLSocketConf("localhost:9090", nil, &TConfiguration{
+// ConnectTimeout: time.Second, // Use 0 for no timeout
+// SocketTimeout: time.Second, // Use 0 for no timeout
+// })
+func NewTSSLSocketConf(hostPort string, conf *TConfiguration) (*TSSLSocket, error) {
+ if cfg := conf.GetTLSConfig(); cfg != nil && cfg.MinVersion == 0 {
cfg.MinVersion = tls.VersionTLS10
}
return &TSSLSocket{
- hostPort: hostPort,
- cfg: cfg,
- connectTimeout: connectTimeout,
- socketTimeout: socketTimeout,
+ hostPort: hostPort,
+ cfg: conf,
}, nil
}
-// Creates a TSSLSocket from a net.Addr
-func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, connectTimeout, socketTimeout time.Duration) *TSSLSocket {
+// Deprecated: Use NewTSSLSocketConf instead.
+func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) {
+ return NewTSSLSocketConf(hostPort, &TConfiguration{
+ TLSConfig: cfg,
+
+ noPropagation: true,
+ })
+}
+
+// Deprecated: Use NewTSSLSocketConf instead.
+func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, connectTimeout, socketTimeout time.Duration) (*TSSLSocket, error) {
+ return NewTSSLSocketConf(hostPort, &TConfiguration{
+ ConnectTimeout: connectTimeout,
+ SocketTimeout: socketTimeout,
+ TLSConfig: cfg,
+
+ noPropagation: true,
+ })
+}
+
+// NewTSSLSocketFromAddrConf creates a TSSLSocket from a net.Addr.
+func NewTSSLSocketFromAddrConf(addr net.Addr, conf *TConfiguration) *TSSLSocket {
return &TSSLSocket{
- addr: addr,
- cfg: cfg,
- connectTimeout: connectTimeout,
- socketTimeout: socketTimeout,
+ addr: addr,
+ cfg: conf,
}
}
-// Creates a TSSLSocket from an existing net.Conn
-func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, socketTimeout time.Duration) *TSSLSocket {
+// Deprecated: Use NewTSSLSocketFromAddrConf instead.
+func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, connectTimeout, socketTimeout time.Duration) *TSSLSocket {
+ return NewTSSLSocketFromAddrConf(addr, &TConfiguration{
+ ConnectTimeout: connectTimeout,
+ SocketTimeout: socketTimeout,
+ TLSConfig: cfg,
+
+ noPropagation: true,
+ })
+}
+
+// NewTSSLSocketFromConnConf creates a TSSLSocket from an existing net.Conn.
+func NewTSSLSocketFromConnConf(conn net.Conn, conf *TConfiguration) *TSSLSocket {
return &TSSLSocket{
- conn: wrapSocketConn(conn),
- addr: conn.RemoteAddr(),
- cfg: cfg,
- socketTimeout: socketTimeout,
+ conn: wrapSocketConn(conn),
+ addr: conn.RemoteAddr(),
+ cfg: conf,
}
}
+// Deprecated: Use NewTSSLSocketFromConnConf instead.
+func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, socketTimeout time.Duration) *TSSLSocket {
+ return NewTSSLSocketFromConnConf(conn, &TConfiguration{
+ SocketTimeout: socketTimeout,
+ TLSConfig: cfg,
+
+ noPropagation: true,
+ })
+}
+
+// SetTConfiguration implements TConfigurationSetter.
+//
+// It can be used to change connect and socket timeouts.
+func (p *TSSLSocket) SetTConfiguration(conf *TConfiguration) {
+ p.cfg = conf
+}
+
// Sets the connect timeout
func (p *TSSLSocket) SetConnTimeout(timeout time.Duration) error {
- p.connectTimeout = timeout
+ if p.cfg == nil {
+ p.cfg = &TConfiguration{}
+ }
+ p.cfg.ConnectTimeout = timeout
return nil
}
// Sets the socket timeout
func (p *TSSLSocket) SetSocketTimeout(timeout time.Duration) error {
- p.socketTimeout = timeout
+ if p.cfg == nil {
+ p.cfg = &TConfiguration{}
+ }
+ p.cfg.SocketTimeout = timeout
return nil
}
func (p *TSSLSocket) pushDeadline(read, write bool) {
var t time.Time
- if p.socketTimeout > 0 {
- t = time.Now().Add(time.Duration(p.socketTimeout))
+ if timeout := p.cfg.GetSocketTimeout(); timeout > 0 {
+ t = time.Now().Add(time.Duration(timeout))
}
if read && write {
p.conn.SetDeadline(t)
@@ -116,11 +161,11 @@
if p.hostPort != "" {
if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
&net.Dialer{
- Timeout: p.connectTimeout,
+ Timeout: p.cfg.GetConnectTimeout(),
},
"tcp",
p.hostPort,
- p.cfg,
+ p.cfg.GetTLSConfig(),
)); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
@@ -139,11 +184,11 @@
}
if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
&net.Dialer{
- Timeout: p.connectTimeout,
+ Timeout: p.cfg.GetConnectTimeout(),
},
p.addr.Network(),
p.addr.String(),
- p.cfg,
+ p.cfg.GetTLSConfig(),
)); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
@@ -209,3 +254,5 @@
const maxSize = ^uint64(0)
return maxSize // the truth is, we just don't know unless framed is used
}
+
+var _ TConfigurationSetter = (*TSSLSocket)(nil)
diff --git a/lib/go/thrift/zlib_transport.go b/lib/go/thrift/zlib_transport.go
index e7efdfb..259943a 100644
--- a/lib/go/thrift/zlib_transport.go
+++ b/lib/go/thrift/zlib_transport.go
@@ -128,3 +128,10 @@
func (z *TZlibTransport) Write(p []byte) (int, error) {
return z.writer.Write(p)
}
+
+// SetTConfiguration implements TConfigurationSetter for propagation.
+func (z *TZlibTransport) SetTConfiguration(conf *TConfiguration) {
+ PropagateTConfiguration(z.transport, conf)
+}
+
+var _ TConfigurationSetter = (*TZlibTransport)(nil)