THRIFT-3300:Reimplement TZlibTransport in Java using streams
Client: Java Lib
Patch: Paul Magrath
This closes #590
commit c01aff7038adb9fa2098c02d0092757834fd4df4
Author: Paul Magrath <paul@swiftkey.com>
Date: 2015-08-17T17:25:24Z
THRIFT-3300 Reimplement TZlibTransport in Java using streams
diff --git a/lib/java/src/org/apache/thrift/transport/TZlibTransport.java b/lib/java/src/org/apache/thrift/transport/TZlibTransport.java
index 06965c5..df4de13 100644
--- a/lib/java/src/org/apache/thrift/transport/TZlibTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TZlibTransport.java
@@ -18,31 +18,20 @@
*/
package org.apache.thrift.transport;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
import java.util.zip.Deflater;
+import java.util.zip.DeflaterOutputStream;
import java.util.zip.Inflater;
-import org.apache.thrift.TByteArrayOutputStream;
-import org.apache.thrift.transport.TMemoryInputTransport;
-import org.apache.thrift.transport.TTransport;
-import org.apache.thrift.transport.TTransportException;
-import org.apache.thrift.transport.TTransportFactory;
+import java.util.zip.InflaterInputStream;
/**
* TZlibTransport deflates on write and inflates on read.
*/
-public class TZlibTransport extends TTransport {
- //Class constants
- public static final int INFLATE_BUF_SIZE = 1024;
- public static final int READ_BUF_SIZE = 1024;
- public static final int INIT_WRITE_BUF_SIZE = 1024;
- //Client rw buffers and underlying transport
- private TByteArrayOutputStream writeBuffer_ = new TByteArrayOutputStream(INIT_WRITE_BUF_SIZE);
- private TMemoryInputTransport readBuffer_ = new TMemoryInputTransport(new byte[0]);
+public class TZlibTransport extends TIOStreamTransport {
+
private TTransport transport_ = null;
- //Zip objects and buffers
- private byte[] inflateBuf = new byte[INFLATE_BUF_SIZE];
- private byte[] readBuf = new byte[READ_BUF_SIZE];
- private Inflater decompresser = new Inflater(false);
- private Deflater compresser = new Deflater(Deflater.BEST_COMPRESSION, false);
public static class Factory extends TTransportFactory {
public Factory() {
@@ -59,7 +48,7 @@
* @param transport the underlying transport to read from and write to
*/
public TZlibTransport(TTransport transport) {
- transport_ = transport;
+ this(transport, Deflater.BEST_COMPRESSION);
}
/**
@@ -69,12 +58,8 @@
*/
public TZlibTransport(TTransport transport, int compressionLevel) {
transport_ = transport;
- compresser = new Deflater(compressionLevel, false);
- }
-
- @Override
- public void open() throws TTransportException {
- transport_.open();
+ inputStream_ = new InflaterInputStream(new TTransportInputStream(transport_), new Inflater());
+ outputStream_ = new DeflaterOutputStream(new TTransportOutputStream(transport_), new Deflater(compressionLevel, false), true);
}
@Override
@@ -83,99 +68,80 @@
}
@Override
+ public void open() throws TTransportException {
+ transport_.open();
+ }
+
+ @Override
public void close() {
- readBuffer_.reset(new byte[0]);
- writeBuffer_.reset();
- compresser.reset();
- decompresser.reset();
- transport_.close();
+ if (transport_.isOpen()) {
+ transport_.close();
+ }
+ }
+
+}
+
+class TTransportInputStream extends InputStream {
+
+ private TTransport transport = null;
+
+ public TTransportInputStream(TTransport transport) {
+ this.transport = transport;
}
@Override
- public int read(byte[] buf, int off, int len) throws TTransportException {
- int bytesRead = readBuffer_.read(buf, off, len);
- if (bytesRead > 0) {
- return bytesRead;
- }
-
- while (true) {
- if (readComp() > 0) {
- break;
- }
- }
-
- return readBuffer_.read(buf, off, len);
- }
-
- private int readComp() throws TTransportException {
- //If low level read buffer is exhausted, read more bytes from underlying transport
- if (decompresser.needsInput()) {
- int bytesRead = transport_.read(readBuf, 0, READ_BUF_SIZE);
- decompresser.setInput(readBuf, 0, bytesRead);
- }
- //Decompress bytes into high level client read buffer
+ public int read() throws IOException {
try {
- int InflatedBytes = decompresser.inflate(inflateBuf);
- if (InflatedBytes <= 0) {
- return 0;
- }
-
- byte[] old = new byte[readBuffer_.getBytesRemainingInBuffer()];
- readBuffer_.read(old, 0, readBuffer_.getBytesRemainingInBuffer());
- byte[] all = new byte[old.length + InflatedBytes];
- System.arraycopy(old, 0, all, 0, old.length);
- System.arraycopy(inflateBuf, 0, all, old.length, InflatedBytes);
-
- readBuffer_.reset(all);
- return all.length;
- } catch (java.util.zip.DataFormatException ex) {
- throw new TTransportException(ex);
+ byte[] buf = new byte[1];
+ transport.read(buf, 0, 1);
+ return buf[0];
+ } catch (TTransportException e) {
+ throw new IOException(e);
}
}
@Override
- public byte[] getBuffer() {
- return readBuffer_.getBuffer();
- }
-
- @Override
- public int getBufferPosition() {
- return readBuffer_.getBufferPosition();
- }
-
- @Override
- public int getBytesRemainingInBuffer() {
- return readBuffer_.getBytesRemainingInBuffer();
- }
-
- @Override
- public void consumeBuffer(int len) {
- readBuffer_.consumeBuffer(len);
- }
-
- @Override
- public void write(byte[] buf, int off, int len) throws TTransportException {
- writeBuffer_.write(buf, off, len);
- }
-
- /**
- * Compress write buffer and send it to underlying transport.
- */
- @Override
- public void flush() throws TTransportException {
- byte[] buf = writeBuffer_.get();
- int bufLength = writeBuffer_.len();
- writeBuffer_.reset();
- compresser.setInput(buf, 0, bufLength);
-
- byte[] compBuf = new byte[buf.length * 2];
- int compressedDataLength = compresser.deflate(compBuf, 0, compBuf.length, Deflater.SYNC_FLUSH);
- if (compressedDataLength >= compBuf.length) {
- throw new TTransportException("Compression error, compressed output exceeds buffer size");
+ public int read(byte b[], int off, int len) throws IOException {
+ try {
+ return transport.read(b, off, len);
+ } catch (TTransportException e) {
+ throw new IOException(e);
}
- if (compressedDataLength > 0) {
- transport_.write(compBuf, 0, compressedDataLength);
- transport_.flush();
+ }
+}
+
+class TTransportOutputStream extends OutputStream {
+
+ private TTransport transport = null;
+
+ public TTransportOutputStream(TTransport transport) {
+ this.transport = transport;
+ }
+
+ @Override
+ public void write(final int b) throws IOException {
+ try {
+ transport.write(new byte[]{(byte) b});
+ } catch (TTransportException e) {
+ throw new IOException(e);
+ }
+ }
+
+ @Override
+ public void write(byte b[], int off, int len) throws IOException {
+ try {
+ transport.write(b, off, len);
+ } catch (TTransportException e) {
+ throw new IOException(e);
+ }
+ }
+
+ @Override
+ public void flush() throws IOException {
+ try {
+ transport.flush();
+ } catch (TTransportException e) {
+ throw new IOException(e);
}
}
}
diff --git a/lib/java/test/org/apache/thrift/transport/TestTZlibTransport.java b/lib/java/test/org/apache/thrift/transport/TestTZlibTransport.java
index 74817b1..fe8dd51 100644
--- a/lib/java/test/org/apache/thrift/transport/TestTZlibTransport.java
+++ b/lib/java/test/org/apache/thrift/transport/TestTZlibTransport.java
@@ -18,20 +18,14 @@
*/
package org.apache.thrift.transport;
-import java.io.BufferedOutputStream;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import java.io.IOException;
+import junit.framework.TestCase;
+
+import java.io.*;
import java.util.Arrays;
import java.util.zip.DataFormatException;
import java.util.zip.DeflaterOutputStream;
-import java.util.zip.Inflater;
import java.util.zip.InflaterInputStream;
-import junit.framework.TestCase;
-
public class TestTZlibTransport extends TestCase {
protected TTransport getTransport(TTransport underlying) {
@@ -46,6 +40,16 @@
return result;
}
+ public void testClose() throws TTransportException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ WriteCountingTransport countingTrans = new WriteCountingTransport(new TIOStreamTransport(new BufferedOutputStream
+ (baos)));
+ TTransport trans = getTransport(countingTrans);
+ trans.write(byteSequence(0, 245));
+ countingTrans.close();
+ trans.close();
+ }
+
public void testRead() throws IOException, TTransportException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(baos);
@@ -85,17 +89,17 @@
TTransport trans = getTransport(countingTrans);
trans.write(byteSequence(0, 100));
- assertEquals(0, countingTrans.writeCount);
+ assertEquals(1, countingTrans.writeCount);
trans.write(byteSequence(101, 200));
trans.write(byteSequence(201, 255));
- assertEquals(0, countingTrans.writeCount);
+ assertEquals(1, countingTrans.writeCount);
trans.flush();
- assertEquals(1, countingTrans.writeCount);
+ assertEquals(2, countingTrans.writeCount);
trans.write(byteSequence(0, 245));
trans.flush();
- assertEquals(2, countingTrans.writeCount);
+ assertEquals(3, countingTrans.writeCount);
DataInputStream din = new DataInputStream(new InflaterInputStream(new ByteArrayInputStream(baos.toByteArray())));
byte[] buf = new byte[256];