THRIFT-3292 Integrate new Zlib transport for Go into test suite
Client: Go
Patch: Paul Magrath <paul@swiftkey.com>
This closes #580
diff --git a/lib/go/thrift/zlib_transport.go b/lib/go/thrift/zlib_transport.go
new file mode 100644
index 0000000..e47455f
--- /dev/null
+++ b/lib/go/thrift/zlib_transport.go
@@ -0,0 +1,117 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+ */
+
+package thrift
+
+import (
+ "compress/zlib"
+ "io"
+ "log"
+)
+
+// TZlibTransportFactory is a factory for TZlibTransport instances
+type TZlibTransportFactory struct {
+ level int
+}
+
+// TZlibTransport is a TTransport implementation that makes use of zlib compression.
+type TZlibTransport struct {
+ reader io.ReadCloser
+ transport TTransport
+ writer *zlib.Writer
+}
+
+// GetTransport constructs a new instance of NewTZlibTransport
+func (p *TZlibTransportFactory) GetTransport(trans TTransport) TTransport {
+ t, _ := NewTZlibTransport(trans, p.level)
+ return t
+}
+
+// NewTZlibTransportFactory constructs a new instance of NewTZlibTransportFactory
+func NewTZlibTransportFactory(level int) *TZlibTransportFactory {
+ return &TZlibTransportFactory{level: level}
+}
+
+// NewTZlibTransport constructs a new instance of TZlibTransport
+func NewTZlibTransport(trans TTransport, level int) (*TZlibTransport, error) {
+ w, err := zlib.NewWriterLevel(trans, level)
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+
+ return &TZlibTransport{
+ writer: w,
+ transport: trans,
+ }, nil
+}
+
+// Close closes the reader and writer (flushing any unwritten data) and closes
+// the underlying transport.
+func (z *TZlibTransport) Close() error {
+ if z.reader != nil {
+ if err := z.reader.Close(); err != nil {
+ return err
+ }
+ }
+ if err := z.writer.Close(); err != nil {
+ return err
+ }
+ return z.transport.Close()
+}
+
+// Flush flushes the writer and its underlying transport.
+func (z *TZlibTransport) Flush() error {
+ if err := z.writer.Flush(); err != nil {
+ return err
+ }
+ return z.transport.Flush()
+}
+
+// IsOpen returns true if the transport is open
+func (z *TZlibTransport) IsOpen() bool {
+ return z.transport.IsOpen()
+}
+
+// Open opens the transport for communication
+func (z *TZlibTransport) Open() error {
+ return z.transport.Open()
+}
+
+func (z *TZlibTransport) Read(p []byte) (int, error) {
+ if z.reader == nil {
+ r, err := zlib.NewReader(z.transport)
+ if err != nil {
+ return 0, NewTTransportExceptionFromError(err)
+ }
+ z.reader = r
+ }
+
+ return z.reader.Read(p)
+}
+
+// RemainingBytes returns the size in bytes of the data that is still to be
+// read.
+func (z *TZlibTransport) RemainingBytes() uint64 {
+ return z.transport.RemainingBytes()
+}
+
+func (z *TZlibTransport) Write(p []byte) (int, error) {
+ return z.writer.Write(p)
+}
diff --git a/lib/go/thrift/zlib_transport_test.go b/lib/go/thrift/zlib_transport_test.go
new file mode 100644
index 0000000..f57610c
--- /dev/null
+++ b/lib/go/thrift/zlib_transport_test.go
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "compress/zlib"
+ "testing"
+)
+
+func TestZlibTransport(t *testing.T) {
+ trans, err := NewTZlibTransport(NewTMemoryBuffer(), zlib.BestCompression)
+ if err != nil {
+ t.Fatal(err)
+ }
+ TransportTest(t, trans, trans)
+}
diff --git a/test/go/src/bin/testclient/main.go b/test/go/src/bin/testclient/main.go
index c48df0e..94b5c61 100644
--- a/test/go/src/bin/testclient/main.go
+++ b/test/go/src/bin/testclient/main.go
@@ -31,7 +31,7 @@
var host = flag.String("host", "localhost", "Host to connect")
var port = flag.Int64("port", 9090, "Port number to connect")
var domain_socket = flag.String("domain-socket", "", "Domain Socket (e.g. /tmp/thrifttest.thrift), instead of host and port")
-var transport = flag.String("transport", "buffered", "Transport: buffered, framed, http")
+var transport = flag.String("transport", "buffered", "Transport: buffered, framed, http, zlib")
var protocol = flag.String("protocol", "binary", "Protocol: binary, compact, json")
var ssl = flag.Bool("ssl", false, "Encrypted Transport using SSL")
var testloops = flag.Int("testloops", 1, "Number of Tests")
diff --git a/test/go/src/bin/testserver/main.go b/test/go/src/bin/testserver/main.go
index ebcd8e5..291dff5 100644
--- a/test/go/src/bin/testserver/main.go
+++ b/test/go/src/bin/testserver/main.go
@@ -28,7 +28,7 @@
var host = flag.String("host", "localhost", "Host to connect")
var port = flag.Int64("port", 9090, "Port number to connect")
var domain_socket = flag.String("domain-socket", "", "Domain Socket (e.g. /tmp/ThriftTest.thrift), instead of host and port")
-var transport = flag.String("transport", "buffered", "Transport: buffered, framed, http")
+var transport = flag.String("transport", "buffered", "Transport: buffered, framed, http, zlib")
var protocol = flag.String("protocol", "binary", "Protocol: binary, compact, json")
var ssl = flag.Bool("ssl", false, "Encrypted Transport using SSL")
var certPath = flag.String("certPath", "keys", "Directory that contains SSL certificates")
diff --git a/test/go/src/common/client.go b/test/go/src/common/client.go
index 267273e..e55dc6d 100644
--- a/test/go/src/common/client.go
+++ b/test/go/src/common/client.go
@@ -20,6 +20,7 @@
package common
import (
+ "compress/zlib"
"crypto/tls"
"flag"
"fmt"
@@ -82,6 +83,11 @@
trans = thrift.NewTFramedTransport(trans)
case "buffered":
trans = thrift.NewTBufferedTransport(trans, 8192)
+ case "zlib":
+ trans, err = thrift.NewTZlibTransport(trans, zlib.BestCompression)
+ if err != nil {
+ return nil, err
+ }
case "":
trans = trans
default:
diff --git a/test/go/src/common/server.go b/test/go/src/common/server.go
index d354b32..dc380b2 100644
--- a/test/go/src/common/server.go
+++ b/test/go/src/common/server.go
@@ -20,6 +20,7 @@
package common
import (
+ "compress/zlib"
"crypto/tls"
"flag"
"fmt"
@@ -99,6 +100,8 @@
transportFactory = thrift.NewTFramedTransportFactory(transportFactory)
case "buffered":
transportFactory = thrift.NewTBufferedTransportFactory(8192)
+ case "zlib":
+ transportFactory = thrift.NewTZlibTransportFactory(zlib.BestCompression)
case "":
transportFactory = thrift.NewTTransportFactory()
default: