THRIFT-4346: Allow go ZlibTransportFactory to wrap other factories
Client: go

This closes #1375
diff --git a/lib/go/thrift/zlib_transport.go b/lib/go/thrift/zlib_transport.go
index 6f477ca..f2f0732 100644
--- a/lib/go/thrift/zlib_transport.go
+++ b/lib/go/thrift/zlib_transport.go
@@ -27,7 +27,8 @@
 
 // TZlibTransportFactory is a factory for TZlibTransport instances
 type TZlibTransportFactory struct {
-	level int
+	level   int
+	factory TTransportFactory
 }
 
 // TZlibTransport is a TTransport implementation that makes use of zlib compression.
@@ -39,12 +40,26 @@
 
 // GetTransport constructs a new instance of NewTZlibTransport
 func (p *TZlibTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
+	if p.factory != nil {
+		// wrap other factory
+		var err error
+		trans, err = p.factory.GetTransport(trans)
+		if err != nil {
+			return nil, err
+		}
+	}
 	return NewTZlibTransport(trans, p.level)
 }
 
 // NewTZlibTransportFactory constructs a new instance of NewTZlibTransportFactory
 func NewTZlibTransportFactory(level int) *TZlibTransportFactory {
-	return &TZlibTransportFactory{level: level}
+	return &TZlibTransportFactory{level: level, factory: nil}
+}
+
+// NewTZlibTransportFactory constructs a new instance of TZlibTransportFactory
+// as a wrapper over existing transport factory
+func NewTZlibTransportFactoryWithFactory(level int, factory TTransportFactory) *TZlibTransportFactory {
+	return &TZlibTransportFactory{level: level, factory: factory}
 }
 
 // NewTZlibTransport constructs a new instance of TZlibTransport
diff --git a/lib/go/thrift/zlib_transport_test.go b/lib/go/thrift/zlib_transport_test.go
index f57610c..3c6f11e 100644
--- a/lib/go/thrift/zlib_transport_test.go
+++ b/lib/go/thrift/zlib_transport_test.go
@@ -31,3 +31,32 @@
 	}
 	TransportTest(t, trans, trans)
 }
+
+type DummyTransportFactory struct{}
+
+func (p *DummyTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
+	return NewTMemoryBuffer(), nil
+}
+
+func TestZlibFactoryTransportWithFactory(t *testing.T) {
+	factory := NewTZlibTransportFactoryWithFactory(
+		zlib.BestCompression,
+		&DummyTransportFactory{},
+	)
+	buffer := NewTMemoryBuffer()
+	trans, err := factory.GetTransport(buffer)
+	if err != nil {
+		t.Fatal(err)
+	}
+	TransportTest(t, trans, trans)
+}
+
+func TestZlibFactoryTransportWithoutFactory(t *testing.T) {
+	factory := NewTZlibTransportFactoryWithFactory(zlib.BestCompression, nil)
+	buffer := NewTMemoryBuffer()
+	trans, err := factory.GetTransport(buffer)
+	if err != nil {
+		t.Fatal(err)
+	}
+	TransportTest(t, trans, trans)
+}