THRIFT-3300:Reimplement TZlibTransport in Java using streams
Client: Java Lib
Patch: Paul Magrath

This closes #590
commit c01aff7038adb9fa2098c02d0092757834fd4df4
Author: Paul Magrath <paul@swiftkey.com>
Date: 2015-08-17T17:25:24Z
THRIFT-3300 Reimplement TZlibTransport in Java using streams
diff --git a/lib/java/src/org/apache/thrift/transport/TZlibTransport.java b/lib/java/src/org/apache/thrift/transport/TZlibTransport.java
index 06965c5..df4de13 100644
--- a/lib/java/src/org/apache/thrift/transport/TZlibTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TZlibTransport.java
@@ -18,31 +18,20 @@
  */
 package org.apache.thrift.transport;
 
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
 import java.util.zip.Deflater;
+import java.util.zip.DeflaterOutputStream;
 import java.util.zip.Inflater;
-import org.apache.thrift.TByteArrayOutputStream;
-import org.apache.thrift.transport.TMemoryInputTransport;
-import org.apache.thrift.transport.TTransport;
-import org.apache.thrift.transport.TTransportException;
-import org.apache.thrift.transport.TTransportFactory;
+import java.util.zip.InflaterInputStream;
 
 /**
  * TZlibTransport deflates on write and inflates on read.
  */
-public class TZlibTransport extends TTransport {
-    //Class constants
-    public static final int INFLATE_BUF_SIZE = 1024;
-    public static final int READ_BUF_SIZE = 1024;
-    public static final int INIT_WRITE_BUF_SIZE = 1024;
-    //Client rw buffers and underlying transport
-    private TByteArrayOutputStream writeBuffer_ = new TByteArrayOutputStream(INIT_WRITE_BUF_SIZE);
-    private TMemoryInputTransport readBuffer_ = new TMemoryInputTransport(new byte[0]);
+public class TZlibTransport extends TIOStreamTransport {
+
     private TTransport transport_ = null;
-    //Zip objects and buffers
-    private byte[] inflateBuf = new byte[INFLATE_BUF_SIZE];
-    private byte[] readBuf = new byte[READ_BUF_SIZE];
-    private Inflater decompresser = new Inflater(false);
-    private Deflater compresser = new Deflater(Deflater.BEST_COMPRESSION, false);
 
     public static class Factory extends TTransportFactory {
         public Factory() {
@@ -59,7 +48,7 @@
      * @param  transport the underlying transport to read from and write to
      */
     public TZlibTransport(TTransport transport) {
-        transport_ = transport;
+        this(transport, Deflater.BEST_COMPRESSION);
     }
 
     /**
@@ -69,12 +58,8 @@
      */
     public TZlibTransport(TTransport transport, int compressionLevel) {
         transport_ = transport;
-        compresser = new Deflater(compressionLevel, false);
-    }
-
-    @Override
-    public void open() throws TTransportException {
-        transport_.open();
+        inputStream_ = new InflaterInputStream(new TTransportInputStream(transport_), new Inflater());
+        outputStream_ = new DeflaterOutputStream(new TTransportOutputStream(transport_), new Deflater(compressionLevel, false), true);
     }
 
     @Override
@@ -83,99 +68,80 @@
     }
 
     @Override
+    public void open() throws TTransportException {
+        transport_.open();
+    }
+
+    @Override
     public void close() {
-        readBuffer_.reset(new byte[0]);
-        writeBuffer_.reset();
-        compresser.reset();
-        decompresser.reset();
-        transport_.close();
+        if (transport_.isOpen()) {
+            transport_.close();
+        }
+    }
+
+}
+
+class TTransportInputStream extends InputStream {
+
+    private TTransport transport = null;
+
+    public TTransportInputStream(TTransport transport) {
+        this.transport = transport;
     }
 
     @Override
-    public int read(byte[] buf, int off, int len) throws TTransportException {
-        int bytesRead = readBuffer_.read(buf, off, len);
-        if (bytesRead > 0) {
-            return bytesRead;
-        }
-
-        while (true) {
-            if (readComp() > 0) {
-                break;
-            }
-        }
-
-        return readBuffer_.read(buf, off, len);
-    }
-
-    private int readComp() throws TTransportException {
-        //If low level read buffer is exhausted, read more bytes from underlying transport
-        if (decompresser.needsInput()) {
-            int bytesRead = transport_.read(readBuf, 0, READ_BUF_SIZE);
-            decompresser.setInput(readBuf, 0, bytesRead);
-        }
-        //Decompress bytes into high level client read buffer
+    public int read() throws IOException {
         try {
-            int InflatedBytes = decompresser.inflate(inflateBuf);
-            if (InflatedBytes <= 0) {
-                return 0;
-            }
-
-            byte[] old = new byte[readBuffer_.getBytesRemainingInBuffer()];
-            readBuffer_.read(old, 0, readBuffer_.getBytesRemainingInBuffer());
-            byte[] all = new byte[old.length + InflatedBytes];
-            System.arraycopy(old, 0, all, 0, old.length);
-            System.arraycopy(inflateBuf, 0, all, old.length, InflatedBytes);
-
-            readBuffer_.reset(all);
-            return all.length;
-        } catch (java.util.zip.DataFormatException ex) {
-            throw new TTransportException(ex);
+            byte[] buf = new byte[1];
+            transport.read(buf, 0, 1);
+            return buf[0];
+        } catch (TTransportException e) {
+            throw new IOException(e);
         }
     }
 
     @Override
-    public byte[] getBuffer() {
-        return readBuffer_.getBuffer();
-    }
-
-    @Override
-    public int getBufferPosition() {
-        return readBuffer_.getBufferPosition();
-    }
-
-    @Override
-    public int getBytesRemainingInBuffer() {
-        return readBuffer_.getBytesRemainingInBuffer();
-    }
-
-    @Override
-    public void consumeBuffer(int len) {
-        readBuffer_.consumeBuffer(len);
-    }
-
-    @Override
-    public void write(byte[] buf, int off, int len) throws TTransportException {
-        writeBuffer_.write(buf, off, len);
-    }
-
-    /**
-     * Compress write buffer and send it to underlying transport.
-     */
-    @Override
-    public void flush() throws TTransportException {
-        byte[] buf = writeBuffer_.get();
-        int bufLength = writeBuffer_.len();
-        writeBuffer_.reset();
-        compresser.setInput(buf, 0, bufLength);
-
-        byte[] compBuf = new byte[buf.length * 2];
-        int compressedDataLength = compresser.deflate(compBuf, 0, compBuf.length, Deflater.SYNC_FLUSH);
-        if (compressedDataLength >= compBuf.length) {
-            throw new TTransportException("Compression error, compressed output exceeds buffer size");
+    public int read(byte b[], int off, int len) throws IOException {
+        try {
+            return transport.read(b, off, len);
+        } catch (TTransportException e) {
+            throw new IOException(e);
         }
-        if (compressedDataLength > 0) {
-            transport_.write(compBuf, 0, compressedDataLength);
-            transport_.flush();
+    }
+}
+
+class TTransportOutputStream extends OutputStream {
+
+    private TTransport transport = null;
+
+    public TTransportOutputStream(TTransport transport) {
+        this.transport = transport;
+    }
+
+    @Override
+    public void write(final int b) throws IOException {
+        try {
+            transport.write(new byte[]{(byte) b});
+        } catch (TTransportException e) {
+            throw new IOException(e);
+        }
+    }
+
+    @Override
+    public void write(byte b[], int off, int len) throws IOException {
+        try {
+            transport.write(b, off, len);
+        } catch (TTransportException e) {
+            throw new IOException(e);
+        }
+    }
+
+    @Override
+    public void flush() throws IOException {
+        try {
+            transport.flush();
+        } catch (TTransportException e) {
+            throw new IOException(e);
         }
     }
 }
diff --git a/lib/java/test/org/apache/thrift/transport/TestTZlibTransport.java b/lib/java/test/org/apache/thrift/transport/TestTZlibTransport.java
index 74817b1..fe8dd51 100644
--- a/lib/java/test/org/apache/thrift/transport/TestTZlibTransport.java
+++ b/lib/java/test/org/apache/thrift/transport/TestTZlibTransport.java
@@ -18,20 +18,14 @@
  */
 package org.apache.thrift.transport;
 
-import java.io.BufferedOutputStream;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import java.io.IOException;
+import junit.framework.TestCase;
+
+import java.io.*;
 import java.util.Arrays;
 import java.util.zip.DataFormatException;
 import java.util.zip.DeflaterOutputStream;
-import java.util.zip.Inflater;
 import java.util.zip.InflaterInputStream;
 
-import junit.framework.TestCase;
-
 public class TestTZlibTransport extends TestCase {
 
   protected TTransport getTransport(TTransport underlying) {
@@ -46,6 +40,16 @@
     return result;
   }
 
+  public void testClose() throws TTransportException {
+    ByteArrayOutputStream baos = new ByteArrayOutputStream();
+    WriteCountingTransport countingTrans = new WriteCountingTransport(new TIOStreamTransport(new BufferedOutputStream
+        (baos)));
+    TTransport trans = getTransport(countingTrans);
+    trans.write(byteSequence(0, 245));
+    countingTrans.close();
+    trans.close();
+  }
+
   public void testRead() throws IOException, TTransportException {
     ByteArrayOutputStream baos = new ByteArrayOutputStream();
     DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(baos);
@@ -85,17 +89,17 @@
     TTransport trans = getTransport(countingTrans);
 
     trans.write(byteSequence(0, 100));
-    assertEquals(0, countingTrans.writeCount);
+    assertEquals(1, countingTrans.writeCount);
     trans.write(byteSequence(101, 200));
     trans.write(byteSequence(201, 255));
-    assertEquals(0, countingTrans.writeCount);
+    assertEquals(1, countingTrans.writeCount);
 
     trans.flush();
-    assertEquals(1, countingTrans.writeCount);
+    assertEquals(2, countingTrans.writeCount);
 
     trans.write(byteSequence(0, 245));
     trans.flush();
-    assertEquals(2, countingTrans.writeCount);
+    assertEquals(3, countingTrans.writeCount);
 
     DataInputStream din = new DataInputStream(new InflaterInputStream(new ByteArrayInputStream(baos.toByteArray())));
     byte[] buf = new byte[256];