THRIFT-711. java: TFramedTransport should support direct buffer access

This patch adds direct buffer read access to TFramedTransport as well as a simple test for reading, direct buffer reading, and writing.

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@918142 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/java/build.xml b/lib/java/build.xml
index 9750963..afe5afb 100644
--- a/lib/java/build.xml
+++ b/lib/java/build.xml
@@ -186,6 +186,8 @@
       classpathref="test.classpath" failonerror="true" />
     <java classname="org.apache.thrift.test.PartialDeserializeTest"
       classpathref="test.classpath" failonerror="true" />
+    <java classname="org.apache.thrift.test.transport.TFramedTransportTest"
+      classpathref="test.classpath" failonerror="true" />
   </target>
 
   <target name="testclient" description="Run a test client">
diff --git a/lib/java/src/org/apache/thrift/transport/TFramedTransport.java b/lib/java/src/org/apache/thrift/transport/TFramedTransport.java
index f266cc1..74e2c97 100644
--- a/lib/java/src/org/apache/thrift/transport/TFramedTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TFramedTransport.java
@@ -19,8 +19,6 @@
 
 package org.apache.thrift.transport;
 
-import java.io.ByteArrayInputStream;
-
 import org.apache.thrift.TByteArrayOutputStream;
 
 /**
@@ -43,7 +41,7 @@
   /**
    * Buffer for input
    */
-  private ByteArrayInputStream readBuffer_ = null;
+  private TMemoryInputTransport readBuffer_ = new TMemoryInputTransport(new byte[0]);
 
   public static class Factory extends TTransportFactory {
     public Factory() {
@@ -87,8 +85,24 @@
     return readBuffer_.read(buf, off, len);
   }
 
+  public byte[] getBuffer() {
+    return readBuffer_.getBuffer();
+  }
+
+  public int getBufferPosition() {
+    return readBuffer_.getBufferPosition();
+  }
+
+  public int getBytesRemainingInBuffer() {
+    return readBuffer_.getBytesRemainingInBuffer();
+  }
+
+  public void consumeBuffer(int len) {
+    readBuffer_.consumeBuffer(len);
+  }
+
+  private final byte[] i32rd = new byte[4];
   private void readFrame() throws TTransportException {
-    byte[] i32rd = new byte[4];
     transport_.readAll(i32rd, 0, 4);
     int size =
       ((i32rd[0] & 0xff) << 24) |
@@ -99,10 +113,10 @@
     if (size < 0) {
       throw new TTransportException("Read a negative frame size (" + size + ")!");
     }
-    
+
     byte[] buff = new byte[size];
     transport_.readAll(buff, 0, size);
-    readBuffer_ = new ByteArrayInputStream(buff);
+    readBuffer_.reset(buff);
   }
 
   public void write(byte[] buf, int off, int len) throws TTransportException {
diff --git a/lib/java/test/org/apache/thrift/test/transport/TFramedTransportTest.java b/lib/java/test/org/apache/thrift/test/transport/TFramedTransportTest.java
new file mode 100644
index 0000000..b2169de
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/test/transport/TFramedTransportTest.java
@@ -0,0 +1,164 @@
+package org.apache.thrift.test.transport;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Arrays;
+
+import org.apache.thrift.transport.TFramedTransport;
+import org.apache.thrift.transport.TIOStreamTransport;
+import org.apache.thrift.transport.TMemoryBuffer;
+import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
+
+public class TFramedTransportTest {
+  public static class WriteCountingTransport extends TTransport {
+    private int writeCount = 0;
+    private final TTransport trans;
+
+    public WriteCountingTransport(TTransport underlying) {
+      trans = underlying;
+    }
+
+    @Override
+    public void close() {}
+
+    @Override
+    public boolean isOpen() {return true;}
+
+    @Override
+    public void open() throws TTransportException {}
+
+    @Override
+    public int read(byte[] buf, int off, int len) throws TTransportException {
+      return 0;
+    }
+
+    @Override
+    public void write(byte[] buf, int off, int len) throws TTransportException {
+      writeCount ++;
+      trans.write(buf, off, len);
+    }
+  }
+
+  public static class ReadCountingTransport extends TTransport {
+    public int readCount = 0;
+    private TTransport trans;
+
+    public ReadCountingTransport(TTransport underlying) {
+      trans = underlying;
+    }
+
+    @Override
+    public void close() {}
+
+    @Override
+    public boolean isOpen() {return true;}
+
+    @Override
+    public void open() throws TTransportException {}
+
+    @Override
+    public int read(byte[] buf, int off, int len) throws TTransportException {
+      readCount++;
+      return trans.read(buf, off, len);
+    }
+
+    @Override
+    public void write(byte[] buf, int off, int len) throws TTransportException {}
+  }
+
+  public static void main(String[] args) throws TTransportException, IOException {
+    testWrite();
+    testRead();
+    testDirectRead();
+  }
+
+  private static void testWrite() throws TTransportException, IOException {
+    ByteArrayOutputStream baos = new ByteArrayOutputStream();
+    WriteCountingTransport countingTrans = new WriteCountingTransport(new TIOStreamTransport(baos));
+    TTransport trans = new TFramedTransport(countingTrans);
+
+    trans.write(byteSequence(0,100));
+    failUnless(countingTrans.writeCount == 0);
+    trans.write(byteSequence(101,200));
+    trans.write(byteSequence(201,255));
+    failUnless(countingTrans.writeCount == 0);
+
+    trans.flush();
+    failUnless(countingTrans.writeCount == 2);
+
+    DataInputStream din = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+    failUnless(din.readInt() == 256);
+
+    byte[] buf = new byte[256];
+    din.read(buf, 0, 256);
+    failUnless(Arrays.equals(byteSequence(0,255), buf));
+  }
+
+  private static void testRead() throws IOException, TTransportException {
+    ByteArrayOutputStream baos = new ByteArrayOutputStream();
+    DataOutputStream dos = new DataOutputStream(baos);
+    dos.writeInt(50);
+    dos.write(byteSequence(0, 49));
+
+    TMemoryBuffer membuf = new TMemoryBuffer(0);
+    membuf.write(baos.toByteArray());
+
+    ReadCountingTransport countTrans = new ReadCountingTransport(membuf);
+    TFramedTransport trans = new TFramedTransport(countTrans);
+
+    byte[] readBuf = new byte[10];
+    trans.read(readBuf, 0, 10);
+    failUnless(Arrays.equals(readBuf, byteSequence(0,9)));
+
+    trans.read(readBuf, 0, 10);
+    failUnless(Arrays.equals(readBuf, byteSequence(10,19)));
+
+    failUnless(countTrans.readCount == 2);
+  }
+
+  private static void testDirectRead() throws IOException, TTransportException {
+    ByteArrayOutputStream baos = new ByteArrayOutputStream();
+    DataOutputStream dos = new DataOutputStream(baos);
+    dos.writeInt(50);
+    dos.write(byteSequence(0, 49));
+
+    TMemoryBuffer membuf = new TMemoryBuffer(0);
+    membuf.write(baos.toByteArray());
+
+    ReadCountingTransport countTrans = new ReadCountingTransport(membuf);
+    TFramedTransport trans = new TFramedTransport(countTrans);
+
+    failUnless(trans.getBytesRemainingInBuffer() == 0);
+
+    byte[] readBuf = new byte[10];
+    trans.read(readBuf, 0, 10);
+    failUnless(Arrays.equals(readBuf, byteSequence(0,9)));
+
+    failUnless(trans.getBytesRemainingInBuffer() == 40);
+    failUnless(trans.getBufferPosition() == 10);
+
+    trans.consumeBuffer(5);
+    failUnless(trans.getBytesRemainingInBuffer() == 35);
+    failUnless(trans.getBufferPosition() == 15);
+
+    failUnless(countTrans.readCount == 2);
+  }
+
+  private static void failUnless(boolean b) {
+    if (!b) {
+      throw new RuntimeException();
+    }
+  }
+
+  private static byte[] byteSequence(int start, int end) {
+    byte[] result = new byte[end-start+1];
+    for (int i = 0; i <= (end-start); i++) {
+      result[i] = (byte)(start+i);
+    }
+    return result;
+  }
+}