THRIFT-710. java: TBinaryProtocol should access buffers directly when possible

This patch makes TBinaryProtocol use direct buffer access in the relevant methods. Performance testing indicates as much as 2x speed boost, though your mileage may vary.

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@918147 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/java/build.xml b/lib/java/build.xml
index afe5afb..e8e5d78 100644
--- a/lib/java/build.xml
+++ b/lib/java/build.xml
@@ -168,6 +168,8 @@
       classpathref="test.classpath" failonerror="true" />
     <java classname="org.apache.thrift.test.TCompactProtocolTest"
       classpathref="test.classpath" failonerror="true" />
+    <java classname="org.apache.thrift.test.TBinaryProtocolTest"
+      classpathref="test.classpath" failonerror="true" />
     <java classname="org.apache.thrift.test.IdentityTest"
       classpathref="test.classpath" failonerror="true" />
     <java classname="org.apache.thrift.test.EqualityTest"
diff --git a/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java b/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java
index e9bd8b7..83c85e1 100644
--- a/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java
+++ b/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java
@@ -244,41 +244,75 @@
 
   private byte[] bin = new byte[1];
   public byte readByte() throws TException {
+    if (trans_.getBytesRemainingInBuffer() >= 1) {
+      byte b = trans_.getBuffer()[trans_.getBufferPosition()];
+      trans_.consumeBuffer(1);
+      return b;
+    }
     readAll(bin, 0, 1);
     return bin[0];
   }
 
   private byte[] i16rd = new byte[2];
   public short readI16() throws TException {
-    readAll(i16rd, 0, 2);
+    byte[] buf = i16rd;
+    int off = 0;
+
+    if (trans_.getBytesRemainingInBuffer() >= 2) {
+      buf = trans_.getBuffer();
+      off = trans_.getBufferPosition();
+      trans_.consumeBuffer(2);
+    } else {
+      readAll(i16rd, 0, 2);
+    }
+
     return
       (short)
-      (((i16rd[0] & 0xff) << 8) |
-       ((i16rd[1] & 0xff)));
+      (((buf[off] & 0xff) << 8) |
+       ((buf[off+1] & 0xff)));
   }
 
   private byte[] i32rd = new byte[4];
   public int readI32() throws TException {
-    readAll(i32rd, 0, 4);
+    byte[] buf = i32rd;
+    int off = 0;
+
+    if (trans_.getBytesRemainingInBuffer() >= 4) {
+      buf = trans_.getBuffer();
+      off = trans_.getBufferPosition();
+      trans_.consumeBuffer(4);
+    } else {
+      readAll(i32rd, 0, 4);
+    }
     return
-      ((i32rd[0] & 0xff) << 24) |
-      ((i32rd[1] & 0xff) << 16) |
-      ((i32rd[2] & 0xff) <<  8) |
-      ((i32rd[3] & 0xff));
+      ((buf[off] & 0xff) << 24) |
+      ((buf[off+1] & 0xff) << 16) |
+      ((buf[off+2] & 0xff) <<  8) |
+      ((buf[off+3] & 0xff));
   }
 
   private byte[] i64rd = new byte[8];
   public long readI64() throws TException {
-    readAll(i64rd, 0, 8);
+    byte[] buf = i64rd;
+    int off = 0;
+
+    if (trans_.getBytesRemainingInBuffer() >= 8) {
+      buf = trans_.getBuffer();
+      off = trans_.getBufferPosition();
+      trans_.consumeBuffer(8);
+    } else {
+      readAll(i64rd, 0, 8);
+    }
+
     return
-      ((long)(i64rd[0] & 0xff) << 56) |
-      ((long)(i64rd[1] & 0xff) << 48) |
-      ((long)(i64rd[2] & 0xff) << 40) |
-      ((long)(i64rd[3] & 0xff) << 32) |
-      ((long)(i64rd[4] & 0xff) << 24) |
-      ((long)(i64rd[5] & 0xff) << 16) |
-      ((long)(i64rd[6] & 0xff) <<  8) |
-      ((long)(i64rd[7] & 0xff));
+      ((long)(buf[off]   & 0xff) << 56) |
+      ((long)(buf[off+1] & 0xff) << 48) |
+      ((long)(buf[off+2] & 0xff) << 40) |
+      ((long)(buf[off+3] & 0xff) << 32) |
+      ((long)(buf[off+4] & 0xff) << 24) |
+      ((long)(buf[off+5] & 0xff) << 16) |
+      ((long)(buf[off+6] & 0xff) <<  8) |
+      ((long)(buf[off+7] & 0xff));
   }
 
   public double readDouble() throws TException {
@@ -287,6 +321,17 @@
 
   public String readString() throws TException {
     int size = readI32();
+
+    if (trans_.getBytesRemainingInBuffer() >= size) {
+      try {
+        String s = new String(trans_.getBuffer(), trans_.getBufferPosition(), size, "UTF-8");
+        trans_.consumeBuffer(size);
+        return s;
+      } catch (UnsupportedEncodingException e) {
+        throw new TException("JVM DOES NOT SUPPORT UTF-8");
+      }
+    }
+
     return readStringBody(size);
   }
 
diff --git a/lib/java/test/org/apache/thrift/test/ProtocolTestBase.java b/lib/java/test/org/apache/thrift/test/ProtocolTestBase.java
new file mode 100644
index 0000000..205f4fe
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/test/ProtocolTestBase.java
@@ -0,0 +1,416 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.thrift.test;
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.thrift.TBase;
+import org.apache.thrift.TDeserializer;
+import org.apache.thrift.TException;
+import org.apache.thrift.TSerializer;
+import org.apache.thrift.protocol.TBinaryProtocol;
+import org.apache.thrift.protocol.TField;
+import org.apache.thrift.protocol.TMessage;
+import org.apache.thrift.protocol.TMessageType;
+import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.protocol.TStruct;
+import org.apache.thrift.protocol.TType;
+import org.apache.thrift.transport.TMemoryBuffer;
+
+import thrift.test.CompactProtoTestStruct;
+import thrift.test.HolyMoley;
+import thrift.test.Nesting;
+import thrift.test.OneOfEach;
+import thrift.test.Srv;
+
+public abstract class ProtocolTestBase {
+  
+  protected abstract TProtocolFactory getFactory();
+
+  public void main() throws Exception {
+    testNakedByte();
+    for (int i = 0; i < 128; i++) {
+      testByteField((byte)i);
+      testByteField((byte)-i);
+    }
+    
+    for (int s : Arrays.asList(0, 1, 7, 150, 15000, 0x7fff, -1, -7, -150, -15000, -0x7fff)) {
+      testNakedI16((short)s);
+      testI16Field((short)s);
+    }
+
+    for (int i : Arrays.asList(0, 1, 7, 150, 15000, 31337, 0xffff, 0xffffff, -1, -7, -150, -15000, -0xffff, -0xffffff)) {
+      testNakedI32(i);
+      testI32Field(i);
+    }
+
+    testNakedI64(0);
+    testI64Field(0);
+    for (int i = 0; i < 62; i++) {
+      testNakedI64(1L << i);
+      testNakedI64(-(1L << i));
+      testI64Field(1L << i);
+      testI64Field(-(1L << i));
+    }
+
+    testDouble();
+
+    for (String s : Arrays.asList("", "short", "borderlinetiny", "a bit longer than the smallest possible")) {
+      testNakedString(s);
+      testStringField(s);
+    }
+
+    for (byte[] b : Arrays.asList(new byte[0], new byte[]{0,1,2,3,4,5,6,7,8,9,10}, new byte[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14}, new byte[128])) {
+      testNakedBinary(b);
+      testBinaryField(b);
+    }
+
+    testSerialization(OneOfEach.class, Fixtures.oneOfEach);
+    testSerialization(Nesting.class, Fixtures.nesting);
+    testSerialization(HolyMoley.class, Fixtures.holyMoley);
+    testSerialization(CompactProtoTestStruct.class, Fixtures.compactProtoTestStruct);
+
+    testMessage();
+
+    testServerRequest();
+
+    testTDeserializer();
+  }
+
+  public void testNakedByte() throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeByte((byte)123);
+    byte out = proto.readByte();
+    if (out != 123) {
+      throw new RuntimeException("Byte was supposed to be " + (byte)123 + " but was " + out);
+    }
+  }
+
+  public void testByteField(final byte b) throws Exception {
+    testStructField(new StructFieldTestCase(TType.BYTE, (short)15) {
+      public void writeMethod(TProtocol proto) throws TException {
+        proto.writeByte(b);
+      }
+      
+      public void readMethod(TProtocol proto) throws TException {
+        byte result = proto.readByte();
+        if (result != b) {
+          throw new RuntimeException("Byte was supposed to be " + (byte)b + " but was " + result);
+        }
+      }
+    });
+  }
+
+  public void testNakedI16(short n) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeI16(n);
+    // System.out.println(buf.inspect());
+    int out = proto.readI16();
+    if (out != n) {
+      throw new RuntimeException("I16 was supposed to be " + n + " but was " + out);
+    }
+  }
+
+  public void testI16Field(final short n) throws Exception {
+    testStructField(new StructFieldTestCase(TType.I16, (short)15) {
+      public void writeMethod(TProtocol proto) throws TException {
+        proto.writeI16(n);
+      }
+
+      public void readMethod(TProtocol proto) throws TException {
+        short result = proto.readI16();
+        if (result != n) {
+          throw new RuntimeException("I16 was supposed to be " + n + " but was " + result);
+        }
+      }
+    });
+  }
+
+  public void testNakedI32(int n) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeI32(n);
+    // System.out.println(buf.inspect());
+    int out = proto.readI32();
+    if (out != n) {
+      throw new RuntimeException("I32 was supposed to be " + n + " but was " + out);
+    }
+  }
+
+  public void testI32Field(final int n) throws Exception {
+    testStructField(new StructFieldTestCase(TType.I32, (short)15) {
+      public void writeMethod(TProtocol proto) throws TException {
+        proto.writeI32(n);
+      }
+
+      public void readMethod(TProtocol proto) throws TException {
+        int result = proto.readI32();
+        if (result != n) {
+          throw new RuntimeException("I32 was supposed to be " + n + " but was " + result);
+        }
+      }
+    });
+  }
+
+  public void testNakedI64(long n) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeI64(n);
+    // System.out.println(buf.inspect());
+    long out = proto.readI64();
+    if (out != n) {
+      throw new RuntimeException("I64 was supposed to be " + n + " but was " + out);
+    }
+  }
+
+  public void testI64Field(final long n) throws Exception {
+    testStructField(new StructFieldTestCase(TType.I64, (short)15) {
+      public void writeMethod(TProtocol proto) throws TException {
+        proto.writeI64(n);
+      }
+
+      public void readMethod(TProtocol proto) throws TException {
+        long result = proto.readI64();
+        if (result != n) {
+          throw new RuntimeException("I64 was supposed to be " + n + " but was " + result);
+        }
+      }
+    });
+  }
+
+  public void testDouble() throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(1000);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeDouble(123.456);
+    double out = proto.readDouble();
+    if (out != 123.456) {
+      throw new RuntimeException("Double was supposed to be " + 123.456 + " but was " + out);
+    }
+  }
+
+  public void testNakedString(String str) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeString(str);
+    // System.out.println(buf.inspect());
+    String out = proto.readString();
+    if (!str.equals(out)) {
+      throw new RuntimeException("String was supposed to be '" + str + "' but was '" + out + "'");
+    }
+  }
+  
+  public void testStringField(final String str) throws Exception {
+    testStructField(new StructFieldTestCase(TType.STRING, (short)15) {
+      public void writeMethod(TProtocol proto) throws TException {
+        proto.writeString(str);
+      }
+      
+      public void readMethod(TProtocol proto) throws TException {
+        String result = proto.readString();
+        if (!result.equals(str)) {
+          throw new RuntimeException("String was supposed to be " + str + " but was " + result);
+        }
+      }
+    });
+  }
+
+  public void testNakedBinary(byte[] data) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeBinary(data);
+    // System.out.println(buf.inspect());
+    byte[] out = proto.readBinary();
+    if (!Arrays.equals(data, out)) {
+      throw new RuntimeException("Binary was supposed to be '" + data + "' but was '" + out + "'");
+    }
+  }
+
+  public void testBinaryField(final byte[] data) throws Exception {
+    testStructField(new StructFieldTestCase(TType.STRING, (short)15) {
+      public void writeMethod(TProtocol proto) throws TException {
+        proto.writeBinary(data);
+      }
+      
+      public void readMethod(TProtocol proto) throws TException {
+        byte[] result = proto.readBinary();
+        if (!Arrays.equals(data, result)) {
+          throw new RuntimeException("Binary was supposed to be '" + bytesToString(data) + "' but was '" + bytesToString(result) + "'");
+        }
+      }
+    });
+    
+  }
+
+  public <T extends TBase> void testSerialization(Class<T> klass, T obj) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TBinaryProtocol binproto = new TBinaryProtocol(buf);
+
+    try {
+      obj.write(binproto);
+      // System.out.println("Size in binary protocol: " + buf.length());
+
+      buf = new TMemoryBuffer(0);
+      TProtocol proto = getFactory().getProtocol(buf);
+
+      obj.write(proto);
+      System.out.println("Size in " +  proto.getClass().getSimpleName() + ": " + buf.length());
+      // System.out.println(buf.inspect());
+
+      T objRead = klass.newInstance();
+      objRead.read(proto);
+      if (!obj.equals(objRead)) {
+        System.out.println("Expected: " + obj.toString());
+        System.out.println("Actual: " + objRead.toString());
+        // System.out.println(buf.inspect());
+        throw new RuntimeException("Objects didn't match!");
+      }
+    } catch (Exception e) {
+      System.out.println(buf.inspect());
+      throw e;
+    }
+  }
+
+  public void testMessage() throws Exception {
+    List<TMessage> msgs = Arrays.asList(new TMessage[]{
+      new TMessage("short message name", TMessageType.CALL, 0),
+      new TMessage("1", TMessageType.REPLY, 12345),
+      new TMessage("loooooooooooooooooooooooooooooooooong", TMessageType.EXCEPTION, 1 << 16),
+      new TMessage("Janky", TMessageType.CALL, 0),
+    });
+
+    for (TMessage msg : msgs) {
+      TMemoryBuffer buf = new TMemoryBuffer(0);
+      TProtocol proto = getFactory().getProtocol(buf);
+      TMessage output = null;
+
+      proto.writeMessageBegin(msg);
+      proto.writeMessageEnd();
+
+      output = proto.readMessageBegin();
+
+      if (!msg.equals(output)) {
+        throw new RuntimeException("Message was supposed to be " + msg + " but was " + output);
+      }
+    }
+  }
+
+  public void testServerRequest() throws Exception {
+    Srv.Iface handler = new Srv.Iface() {
+      public int Janky(int i32arg) throws TException {
+        return i32arg * 2;
+      }
+
+      public int primitiveMethod() throws TException {
+        return 0;
+      }
+
+      public CompactProtoTestStruct structMethod() throws TException {
+        return null;
+      }
+
+      public void voidMethod() throws TException {
+      }
+
+      public void methodWithDefaultArgs(int something) throws TException {
+      }
+    };
+
+    Srv.Processor testProcessor = new Srv.Processor(handler);
+
+    TMemoryBuffer clientOutTrans = new TMemoryBuffer(0);
+    TProtocol clientOutProto = getFactory().getProtocol(clientOutTrans);
+    TMemoryBuffer clientInTrans = new TMemoryBuffer(0);
+    TProtocol clientInProto = getFactory().getProtocol(clientInTrans);
+
+    Srv.Client testClient = new Srv.Client(clientInProto, clientOutProto);
+
+    testClient.send_Janky(1);
+    // System.out.println(clientOutTrans.inspect());
+    testProcessor.process(clientOutProto, clientInProto);
+    // System.out.println(clientInTrans.inspect());
+    int result = testClient.recv_Janky();
+    if (result != 2) {
+      throw new RuntimeException("Got an unexpected result: " + result);
+    }
+  }
+
+  private void testTDeserializer() throws TException {
+    TSerializer ser = new TSerializer(getFactory());
+    byte[] bytes = ser.serialize(Fixtures.compactProtoTestStruct);
+
+    TDeserializer deser = new TDeserializer(getFactory());
+    CompactProtoTestStruct cpts = new CompactProtoTestStruct();
+    deser.deserialize(cpts, bytes);
+
+    if (!Fixtures.compactProtoTestStruct.equals(cpts)) {
+      throw new RuntimeException(Fixtures.compactProtoTestStruct + " and " + cpts + " do not match!");
+    }
+  }
+
+  //
+  // Helper methods
+  //
+
+  private static String bytesToString(byte[] bytes) {
+    String s = "";
+    for (int i = 0; i < bytes.length; i++) {
+      s += Integer.toHexString((int)bytes[i]) + " ";
+    }
+    return s;
+  }
+
+  private void testStructField(StructFieldTestCase testCase) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+
+    TField field = new TField("test_field", testCase.type_, testCase.id_);
+    proto.writeStructBegin(new TStruct("test_struct"));
+    proto.writeFieldBegin(field);
+    testCase.writeMethod(proto);
+    proto.writeFieldEnd();
+    proto.writeStructEnd();
+
+    // System.out.println(buf.inspect());
+
+    proto.readStructBegin();
+    TField readField = proto.readFieldBegin();
+    // TODO: verify the field is as expected
+    if (!field.equals(readField)) {
+      throw new RuntimeException("Expected " + field + " but got " + readField);
+    }
+    testCase.readMethod(proto);
+    proto.readStructEnd();
+  }
+
+  public static abstract class StructFieldTestCase {
+    byte type_;
+    short id_;
+    public StructFieldTestCase(byte type, short id) {
+      type_ = type;
+      id_ = id;
+    }
+
+    public abstract void writeMethod(TProtocol proto) throws TException;
+    public abstract void readMethod(TProtocol proto) throws TException;
+  }
+}
diff --git a/lib/java/test/org/apache/thrift/test/SerializationBenchmark.java b/lib/java/test/org/apache/thrift/test/SerializationBenchmark.java
index 92c96d3..9ba7102 100644
--- a/lib/java/test/org/apache/thrift/test/SerializationBenchmark.java
+++ b/lib/java/test/org/apache/thrift/test/SerializationBenchmark.java
@@ -20,13 +20,16 @@
 
 package org.apache.thrift.test;
 
-import java.io.ByteArrayInputStream;
+import org.apache.thrift.TBase;
+import org.apache.thrift.protocol.TBinaryProtocol;
+import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.transport.TMemoryBuffer;
+import org.apache.thrift.transport.TMemoryInputTransport;
+import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
 
-import org.apache.thrift.*;
-import org.apache.thrift.protocol.*;
-import org.apache.thrift.transport.*;
-
-import thrift.test.*;
+import thrift.test.OneOfEach;
 
 public class SerializationBenchmark {
   private final static int HOW_MANY = 10000000;
@@ -55,7 +58,7 @@
     }
     long endTime = System.currentTimeMillis();
     
-    System.out.println("Test time: " + (endTime - startTime) + " ms");
+    System.out.println("Serialization test time: " + (endTime - startTime) + " ms");
   }
   
   public static <T extends TBase> void testDeserialization(TProtocolFactory factory, T object, Class<T> klass) throws Exception {
@@ -63,14 +66,14 @@
     object.write(factory.getProtocol(buf));
     byte[] serialized = new byte[100*1024];
     buf.read(serialized, 0, 100*1024);
-    
+
     long startTime = System.currentTimeMillis();
     for (int i = 0; i < HOW_MANY; i++) {
       T o2 = klass.newInstance();
-      o2.read(factory.getProtocol(new TIOStreamTransport(new ByteArrayInputStream(serialized))));
+      o2.read(factory.getProtocol(new TMemoryInputTransport(serialized)));
     }
     long endTime = System.currentTimeMillis();
-    
-    System.out.println("Test time: " + (endTime - startTime) + " ms");
+
+    System.out.println("Deserialization test time: " + (endTime - startTime) + " ms");
   }
 }
\ No newline at end of file
diff --git a/lib/java/test/org/apache/thrift/test/TBinaryProtocolTest.java b/lib/java/test/org/apache/thrift/test/TBinaryProtocolTest.java
new file mode 100644
index 0000000..71839fe
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/test/TBinaryProtocolTest.java
@@ -0,0 +1,17 @@
+package org.apache.thrift.test;
+
+import org.apache.thrift.protocol.TBinaryProtocol;
+import org.apache.thrift.protocol.TProtocolFactory;
+
+public class TBinaryProtocolTest extends ProtocolTestBase {
+
+  public static void main(String[] args) throws Exception {
+    new TBinaryProtocolTest().main();
+  }
+  
+  @Override
+  protected TProtocolFactory getFactory() {
+    return new TBinaryProtocol.Factory();
+  }
+
+}
diff --git a/lib/java/test/org/apache/thrift/test/TCompactProtocolTest.java b/lib/java/test/org/apache/thrift/test/TCompactProtocolTest.java
index 86ea57c..1642c42 100755
--- a/lib/java/test/org/apache/thrift/test/TCompactProtocolTest.java
+++ b/lib/java/test/org/apache/thrift/test/TCompactProtocolTest.java
@@ -20,447 +20,18 @@
 
 package org.apache.thrift.test;
 
-import java.util.Arrays;
-import java.util.List;
-
-import org.apache.thrift.TBase;
-import org.apache.thrift.TDeserializer;
-import org.apache.thrift.TException;
-import org.apache.thrift.TSerializer;
-import org.apache.thrift.protocol.TBinaryProtocol;
 import org.apache.thrift.protocol.TCompactProtocol;
-import org.apache.thrift.protocol.TField;
-import org.apache.thrift.protocol.TMessage;
-import org.apache.thrift.protocol.TMessageType;
-import org.apache.thrift.protocol.TProtocol;
 import org.apache.thrift.protocol.TProtocolFactory;
-import org.apache.thrift.protocol.TStruct;
-import org.apache.thrift.protocol.TType;
-import org.apache.thrift.transport.TMemoryBuffer;
 
-import thrift.test.CompactProtoTestStruct;
-import thrift.test.HolyMoley;
-import thrift.test.Nesting;
-import thrift.test.OneOfEach;
-import thrift.test.Srv;
-
-public class TCompactProtocolTest {
-
-  static TProtocolFactory factory = new TCompactProtocol.Factory();
-
+public class TCompactProtocolTest extends ProtocolTestBase {
+  
   public static void main(String[] args) throws Exception {
-    testNakedByte();
-    for (int i = 0; i < 128; i++) {
-      testByteField((byte)i);
-      testByteField((byte)-i);
-    }
-    
-    testNakedI16((short)0);
-    testNakedI16((short)1);
-    testNakedI16((short)15000);
-    testNakedI16((short)0x7fff);
-    testNakedI16((short)-1);
-    testNakedI16((short)-15000);
-    testNakedI16((short)-0x7fff);
-    
-    testI16Field((short)0);
-    testI16Field((short)1);
-    testI16Field((short)7);
-    testI16Field((short)150);
-    testI16Field((short)15000);
-    testI16Field((short)0x7fff);
-    testI16Field((short)-1);
-    testI16Field((short)-7);
-    testI16Field((short)-150);
-    testI16Field((short)-15000);
-    testI16Field((short)-0x7fff);
-    
-    testNakedI32(0);
-    testNakedI32(1);
-    testNakedI32(15000);
-    testNakedI32(0xffff);
-    testNakedI32(-1);
-    testNakedI32(-15000);
-    testNakedI32(-0xffff);
-    
-    testI32Field(0);
-    testI32Field(1);
-    testI32Field(7);
-    testI32Field(150);
-    testI32Field(15000);
-    testI32Field(31337);
-    testI32Field(0xffff);
-    testI32Field(0xffffff);
-    testI32Field(-1);
-    testI32Field(-7);
-    testI32Field(-150);
-    testI32Field(-15000);
-    testI32Field(-0xffff);
-    testI32Field(-0xffffff);
-    
-    testNakedI64(0);
-    for (int i = 0; i < 62; i++) {
-      testNakedI64(1L << i);
-      testNakedI64(-(1L << i));
-    }
-
-    testI64Field(0);
-    for (int i = 0; i < 62; i++) {
-      testI64Field(1L << i);
-      testI64Field(-(1L << i));
-    }
-
-    testDouble();
-    
-    testNakedString("");
-    testNakedString("short");
-    testNakedString("borderlinetiny");
-    testNakedString("a bit longer than the smallest possible");
-    
-    testStringField("");
-    testStringField("short");
-    testStringField("borderlinetiny");
-    testStringField("a bit longer than the smallest possible");
-    
-    testNakedBinary(new byte[]{});
-    testNakedBinary(new byte[]{0,1,2,3,4,5,6,7,8,9,10});
-    testNakedBinary(new byte[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14});
-    testNakedBinary(new byte[128]);
-    
-    testBinaryField(new byte[]{});
-    testBinaryField(new byte[]{0,1,2,3,4,5,6,7,8,9,10});
-    testBinaryField(new byte[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14});
-    testBinaryField(new byte[128]);
-    
-    testSerialization(OneOfEach.class, Fixtures.oneOfEach);
-    testSerialization(Nesting.class, Fixtures.nesting);
-    testSerialization(HolyMoley.class, Fixtures.holyMoley);
-    testSerialization(CompactProtoTestStruct.class, Fixtures.compactProtoTestStruct);
-    
-    testMessage();
-    
-    testServerRequest();
-    
-    testTDeserializer();
+    new TCompactProtocolTest().main();
   }
+
   
-  public static void testNakedByte() throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeByte((byte)123);
-    byte out = proto.readByte();
-    if (out != 123) {
-      throw new RuntimeException("Byte was supposed to be " + (byte)123 + " but was " + out);
-    }
-  }
-  
-  public static void testByteField(final byte b) throws Exception {
-    testStructField(new StructFieldTestCase(TType.BYTE, (short)15) {
-      public void writeMethod(TProtocol proto) throws TException {
-        proto.writeByte(b);
-      }
-      
-      public void readMethod(TProtocol proto) throws TException {
-        byte result = proto.readByte();
-        if (result != b) {
-          throw new RuntimeException("Byte was supposed to be " + (byte)b + " but was " + result);
-        }
-      }
-    });
-  }
-
-  public static void testNakedI16(short n) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeI16(n);
-    // System.out.println(buf.inspect());
-    int out = proto.readI16();
-    if (out != n) {
-      throw new RuntimeException("I16 was supposed to be " + n + " but was " + out);
-    }
-  }
-
-  public static void testI16Field(final short n) throws Exception {
-    testStructField(new StructFieldTestCase(TType.I16, (short)15) {
-      public void writeMethod(TProtocol proto) throws TException {
-        proto.writeI16(n);
-      }
-      
-      public void readMethod(TProtocol proto) throws TException {
-        short result = proto.readI16();
-        if (result != n) {
-          throw new RuntimeException("I16 was supposed to be " + n + " but was " + result);
-        }
-      }
-    });
-  }
-  
-  public static void testNakedI32(int n) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeI32(n);
-    // System.out.println(buf.inspect());
-    int out = proto.readI32();
-    if (out != n) {
-      throw new RuntimeException("I32 was supposed to be " + n + " but was " + out);
-    }
-  }
-  
-  public static void testI32Field(final int n) throws Exception {
-    testStructField(new StructFieldTestCase(TType.I32, (short)15) {
-      public void writeMethod(TProtocol proto) throws TException {
-        proto.writeI32(n);
-      }
-      
-      public void readMethod(TProtocol proto) throws TException {
-        int result = proto.readI32();
-        if (result != n) {
-          throw new RuntimeException("I32 was supposed to be " + n + " but was " + result);
-        }
-      }
-    });
-    
-  }
-
-  public static void testNakedI64(long n) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeI64(n);
-    // System.out.println(buf.inspect());
-    long out = proto.readI64();
-    if (out != n) {
-      throw new RuntimeException("I64 was supposed to be " + n + " but was " + out);
-    }
-  }
-  
-  public static void testI64Field(final long n) throws Exception {
-    testStructField(new StructFieldTestCase(TType.I64, (short)15) {
-      public void writeMethod(TProtocol proto) throws TException {
-        proto.writeI64(n);
-      }
-      
-      public void readMethod(TProtocol proto) throws TException {
-        long result = proto.readI64();
-        if (result != n) {
-          throw new RuntimeException("I64 was supposed to be " + n + " but was " + result);
-        }
-      }
-    });
-  }
-    
-  public static void testDouble() throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(1000);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeDouble(123.456);
-    double out = proto.readDouble();
-    if (out != 123.456) {
-      throw new RuntimeException("Double was supposed to be " + 123.456 + " but was " + out);
-    }
-  }
-    
-  public static void testNakedString(String str) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeString(str);
-    // System.out.println(buf.inspect());
-    String out = proto.readString();
-    if (!str.equals(out)) {
-      throw new RuntimeException("String was supposed to be '" + str + "' but was '" + out + "'");
-    }
-  }
-  
-  public static void testStringField(final String str) throws Exception {
-    testStructField(new StructFieldTestCase(TType.STRING, (short)15) {
-      public void writeMethod(TProtocol proto) throws TException {
-        proto.writeString(str);
-      }
-      
-      public void readMethod(TProtocol proto) throws TException {
-        String result = proto.readString();
-        if (!result.equals(str)) {
-          throw new RuntimeException("String was supposed to be " + str + " but was " + result);
-        }
-      }
-    });
-  }
-
-  public static void testNakedBinary(byte[] data) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeBinary(data);
-    // System.out.println(buf.inspect());
-    byte[] out = proto.readBinary();
-    if (!Arrays.equals(data, out)) {
-      throw new RuntimeException("Binary was supposed to be '" + data + "' but was '" + out + "'");
-    }
-  }
-
-  public static void testBinaryField(final byte[] data) throws Exception {
-    testStructField(new StructFieldTestCase(TType.STRING, (short)15) {
-      public void writeMethod(TProtocol proto) throws TException {
-        proto.writeBinary(data);
-      }
-      
-      public void readMethod(TProtocol proto) throws TException {
-        byte[] result = proto.readBinary();
-        if (!Arrays.equals(data, result)) {
-          throw new RuntimeException("Binary was supposed to be '" + bytesToString(data) + "' but was '" + bytesToString(result) + "'");
-        }
-      }
-    });
-    
-  }
-
-  public static <T extends TBase> void testSerialization(Class<T> klass, T obj) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TBinaryProtocol binproto = new TBinaryProtocol(buf);
-    
-    try {
-      obj.write(binproto);
-      // System.out.println("Size in binary protocol: " + buf.length());
-    
-      buf = new TMemoryBuffer(0);
-      TProtocol proto = factory.getProtocol(buf);
-    
-      obj.write(proto);
-      System.out.println("Size in compact protocol: " + buf.length());
-      // System.out.println(buf.inspect());
-    
-      T objRead = klass.newInstance();
-      objRead.read(proto);
-      if (!obj.equals(objRead)) {
-        System.out.println("Expected: " + obj.toString());
-        System.out.println("Actual: " + objRead.toString());
-        // System.out.println(buf.inspect());
-        throw new RuntimeException("Objects didn't match!");
-      }
-    } catch (Exception e) {
-      System.out.println(buf.inspect());
-      throw e;
-    }
-  }
-
-  public static void testMessage() throws Exception {
-    List<TMessage> msgs = Arrays.asList(new TMessage[]{
-      new TMessage("short message name", TMessageType.CALL, 0),
-      new TMessage("1", TMessageType.REPLY, 12345),
-      new TMessage("loooooooooooooooooooooooooooooooooong", TMessageType.EXCEPTION, 1 << 16),
-      new TMessage("Janky", TMessageType.CALL, 0),
-    });
-    
-    for (TMessage msg : msgs) {
-      TMemoryBuffer buf = new TMemoryBuffer(0);
-      TProtocol proto = factory.getProtocol(buf);
-      TMessage output = null;
-      
-      proto.writeMessageBegin(msg);
-      proto.writeMessageEnd();
-
-      output = proto.readMessageBegin();
-
-      if (!msg.equals(output)) {
-        throw new RuntimeException("Message was supposed to be " + msg + " but was " + output);
-      }
-    }
-  }
-
-  public static void testServerRequest() throws Exception {
-    Srv.Iface handler = new Srv.Iface() {
-      public int Janky(int i32arg) throws TException {
-        return i32arg * 2;
-      }
-
-      public int primitiveMethod() throws TException {
-        return 0;
-      }
-
-      public CompactProtoTestStruct structMethod() throws TException {
-        return null;
-      }
-
-      public void voidMethod() throws TException {
-      }
-
-      public void methodWithDefaultArgs(int something) throws TException {
-      }
-    };
-    
-    Srv.Processor testProcessor = new Srv.Processor(handler);
-
-    TMemoryBuffer clientOutTrans = new TMemoryBuffer(0);
-    TProtocol clientOutProto = factory.getProtocol(clientOutTrans);
-    TMemoryBuffer clientInTrans = new TMemoryBuffer(0);
-    TProtocol clientInProto = factory.getProtocol(clientInTrans);
-    
-    Srv.Client testClient = new Srv.Client(clientInProto, clientOutProto);
-    
-    testClient.send_Janky(1);
-    // System.out.println(clientOutTrans.inspect());
-    testProcessor.process(clientOutProto, clientInProto);
-    // System.out.println(clientInTrans.inspect());
-    int result = testClient.recv_Janky();
-    if (result != 2) {
-      throw new RuntimeException("Got an unexpected result: " + result);
-    }
-  }
-
-  //
-  // Helper methods
-  //
-  
-  private static String bytesToString(byte[] bytes) {
-    String s = "";
-    for (int i = 0; i < bytes.length; i++) {
-      s += Integer.toHexString((int)bytes[i]) + " ";
-    }
-    return s;
-  }
-
-  private static void testStructField(StructFieldTestCase testCase) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    
-    TField field = new TField("test_field", testCase.type_, testCase.id_);
-    proto.writeStructBegin(new TStruct("test_struct"));
-    proto.writeFieldBegin(field);
-    testCase.writeMethod(proto);
-    proto.writeFieldEnd();
-    proto.writeStructEnd();
-    
-    // System.out.println(buf.inspect());
-
-    proto.readStructBegin();
-    TField readField = proto.readFieldBegin();
-    // TODO: verify the field is as expected
-    if (!field.equals(readField)) {
-      throw new RuntimeException("Expected " + field + " but got " + readField);
-    }
-    testCase.readMethod(proto);
-    proto.readStructEnd();
-  }
-  
-  public static abstract class StructFieldTestCase {
-    byte type_;
-    short id_;
-    public StructFieldTestCase(byte type, short id) {
-      type_ = type;
-      id_ = id;
-    }
-    
-    public abstract void writeMethod(TProtocol proto) throws TException;
-    public abstract void readMethod(TProtocol proto) throws TException;
-  }
-  
-  private static void testTDeserializer() throws TException {
-    TSerializer ser = new TSerializer(new TCompactProtocol.Factory());
-    byte[] bytes = ser.serialize(Fixtures.compactProtoTestStruct);
-    
-    TDeserializer deser = new TDeserializer(new TCompactProtocol.Factory());
-    CompactProtoTestStruct cpts = new CompactProtoTestStruct();
-    deser.deserialize(cpts, bytes);
-    
-    if (!Fixtures.compactProtoTestStruct.equals(cpts)) {
-      throw new RuntimeException(Fixtures.compactProtoTestStruct + " and " + cpts + " do not match!");
-    }
+  @Override
+  protected TProtocolFactory getFactory() {
+    return new TCompactProtocol.Factory();
   }
 }
\ No newline at end of file