THRIFT-4346: Allow go ZlibTransportFactory to wrap other factories
Client: go
This closes #1375
diff --git a/lib/go/thrift/zlib_transport.go b/lib/go/thrift/zlib_transport.go
index 6f477ca..f2f0732 100644
--- a/lib/go/thrift/zlib_transport.go
+++ b/lib/go/thrift/zlib_transport.go
@@ -27,7 +27,8 @@
// TZlibTransportFactory is a factory for TZlibTransport instances
type TZlibTransportFactory struct {
- level int
+ level int
+ factory TTransportFactory
}
// TZlibTransport is a TTransport implementation that makes use of zlib compression.
@@ -39,12 +40,26 @@
// GetTransport constructs a new instance of NewTZlibTransport
func (p *TZlibTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
+ if p.factory != nil {
+ // wrap other factory
+ var err error
+ trans, err = p.factory.GetTransport(trans)
+ if err != nil {
+ return nil, err
+ }
+ }
return NewTZlibTransport(trans, p.level)
}
// NewTZlibTransportFactory constructs a new instance of NewTZlibTransportFactory
func NewTZlibTransportFactory(level int) *TZlibTransportFactory {
- return &TZlibTransportFactory{level: level}
+ return &TZlibTransportFactory{level: level, factory: nil}
+}
+
+// NewTZlibTransportFactory constructs a new instance of TZlibTransportFactory
+// as a wrapper over existing transport factory
+func NewTZlibTransportFactoryWithFactory(level int, factory TTransportFactory) *TZlibTransportFactory {
+ return &TZlibTransportFactory{level: level, factory: factory}
}
// NewTZlibTransport constructs a new instance of TZlibTransport
diff --git a/lib/go/thrift/zlib_transport_test.go b/lib/go/thrift/zlib_transport_test.go
index f57610c..3c6f11e 100644
--- a/lib/go/thrift/zlib_transport_test.go
+++ b/lib/go/thrift/zlib_transport_test.go
@@ -31,3 +31,32 @@
}
TransportTest(t, trans, trans)
}
+
+type DummyTransportFactory struct{}
+
+func (p *DummyTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
+ return NewTMemoryBuffer(), nil
+}
+
+func TestZlibFactoryTransportWithFactory(t *testing.T) {
+ factory := NewTZlibTransportFactoryWithFactory(
+ zlib.BestCompression,
+ &DummyTransportFactory{},
+ )
+ buffer := NewTMemoryBuffer()
+ trans, err := factory.GetTransport(buffer)
+ if err != nil {
+ t.Fatal(err)
+ }
+ TransportTest(t, trans, trans)
+}
+
+func TestZlibFactoryTransportWithoutFactory(t *testing.T) {
+ factory := NewTZlibTransportFactoryWithFactory(zlib.BestCompression, nil)
+ buffer := NewTMemoryBuffer()
+ trans, err := factory.GetTransport(buffer)
+ if err != nil {
+ t.Fatal(err)
+ }
+ TransportTest(t, trans, trans)
+}