THRIFT-739. java: TCompactProtocol isn't suitable for reuse in partialDeserialize

This patch changes TProtocol to support a reset() method that should clear any internal state. Stateless protocols can ignore it; stateful ones should implement it. TDeserializer has been updated to take advantage of this method.

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@926460 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/java/src/org/apache/thrift/TDeserializer.java b/lib/java/src/org/apache/thrift/TDeserializer.java
index e766e5d..0797e7a 100644
--- a/lib/java/src/org/apache/thrift/TDeserializer.java
+++ b/lib/java/src/org/apache/thrift/TDeserializer.java
@@ -63,8 +63,12 @@
    * @param bytes The array to read from
    */
   public void deserialize(TBase base, byte[] bytes) throws TException {
-    trans_.reset(bytes);
-    base.read(protocol_);
+    try {
+      trans_.reset(bytes);
+      base.read(protocol_);
+    } finally {
+      protocol_.reset();
+    }
   }
 
   /**
@@ -80,6 +84,8 @@
       deserialize(base, data.getBytes(charset));
     } catch (UnsupportedEncodingException uex) {
       throw new TException("JVM DOES NOT SUPPORT ENCODING: " + charset);
+    } finally {
+      protocol_.reset();
     }
   }
 
@@ -92,46 +98,52 @@
    * @throws TException 
    */
   public void partialDeserialize(TBase tb, byte[] bytes, TFieldIdEnum ... fieldIdPath) throws TException {
-    // if there are no elements in the path, then the user is looking for the 
-    // regular deserialize method
-    // TODO: it might be nice not to have to do this check every time to save
-    // some performance.
-    if (fieldIdPath.length == 0) {
-      deserialize(tb, bytes);
-      return;
-    }
-
-    trans_.reset(bytes);
-
-    // index into field ID path being currently searched for
-    int curPathIndex = 0;
-
-    protocol_.readStructBegin();
-
-    while (curPathIndex < fieldIdPath.length) {
-      TField field = protocol_.readFieldBegin();
-      // we can stop searching if we either see a stop or we go past the field 
-      // id we're looking for (since fields should now be serialized in asc
-      // order).
-      if (field.type == TType.STOP || field.id > fieldIdPath[curPathIndex].getThriftFieldId()) { 
+    try {
+      // if there are no elements in the path, then the user is looking for the 
+      // regular deserialize method
+      // TODO: it might be nice not to have to do this check every time to save
+      // some performance.
+      if (fieldIdPath.length == 0) {
+        deserialize(tb, bytes);
         return;
       }
 
-      if (field.id != fieldIdPath[curPathIndex].getThriftFieldId()) {
-        // Not the field we're looking for. Skip field.
-        TProtocolUtil.skip(protocol_, field.type);
-        protocol_.readFieldEnd();
-      } else {
-        // This field is the next step in the path. Step into field.
-        curPathIndex++;
-        if (curPathIndex < fieldIdPath.length) {
-          protocol_.readStructBegin();
+      trans_.reset(bytes);
+
+      // index into field ID path being currently searched for
+      int curPathIndex = 0;
+
+      protocol_.readStructBegin();
+
+      while (curPathIndex < fieldIdPath.length) {
+        TField field = protocol_.readFieldBegin();
+        // we can stop searching if we either see a stop or we go past the field 
+        // id we're looking for (since fields should now be serialized in asc
+        // order).
+        if (field.type == TType.STOP || field.id > fieldIdPath[curPathIndex].getThriftFieldId()) { 
+          return;
+        }
+
+        if (field.id != fieldIdPath[curPathIndex].getThriftFieldId()) {
+          // Not the field we're looking for. Skip field.
+          TProtocolUtil.skip(protocol_, field.type);
+          protocol_.readFieldEnd();
+        } else {
+          // This field is the next step in the path. Step into field.
+          curPathIndex++;
+          if (curPathIndex < fieldIdPath.length) {
+            protocol_.readStructBegin();
+          }
         }
       }
-    }
 
-    // when this line is reached, iprot will be positioned at the start of tb.
-    tb.read(protocol_);
+      // when this line is reached, iprot will be positioned at the start of tb.
+      tb.read(protocol_);
+    } catch (Exception e) {
+      throw new TException(e);
+    } finally {
+      protocol_.reset();
+    }
   }
 
   /**
@@ -145,4 +157,3 @@
     deserialize(base, data.getBytes());
   }
 }
-
diff --git a/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java
index 8c29b3c..74bfc13 100755
--- a/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java
+++ b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java
@@ -41,7 +41,7 @@
   private final static TField TSTOP = new TField("", TType.STOP, (short)0);
 
   private final static byte[] ttypeToCompactType = new byte[16];
-  
+
   static {
     ttypeToCompactType[TType.STOP] = TType.STOP;
     ttypeToCompactType[TType.BOOL] = Types.BOOLEAN_TRUE;
@@ -56,24 +56,24 @@
     ttypeToCompactType[TType.MAP] = Types.MAP;
     ttypeToCompactType[TType.STRUCT] = Types.STRUCT;
   }
-  
+
   /**
    * TProtocolFactory that produces TCompactProtocols.
    */
   public static class Factory implements TProtocolFactory {
     public Factory() {}
-    
+
     public TProtocol getProtocol(TTransport trans) {
       return new TCompactProtocol(trans);
     }
   }
-  
+
   private static final byte PROTOCOL_ID = (byte)0x82;
   private static final byte VERSION = 1;
   private static final byte VERSION_MASK = 0x1f; // 0001 1111
   private static final byte TYPE_MASK = (byte)0xE0; // 1110 0000
   private static final int  TYPE_SHIFT_AMOUNT = 5;
-  
+
   /**
    * All of the on-wire type codes.
    */
@@ -91,27 +91,27 @@
     public static final byte MAP            = 0x0B;
     public static final byte STRUCT         = 0x0C;
   }
-  
+
   /** 
    * Used to keep track of the last field for the current and previous structs,
    * so we can do the delta stuff.
    */
   private Stack<Short> lastField_ = new Stack<Short>();
-  
+
   private short lastFieldId_ = 0;
-  
-  /** 
+
+  /**
    * If we encounter a boolean field begin, save the TField here so it can 
    * have the value incorporated.
    */
   private TField booleanField_ = null;
-  
+
   /**
    * If we read a field header, and it's a boolean field, save the boolean 
    * value here so that readBool can use it.
    */
   private Boolean boolValue_ = null;
-  
+
   /**
    * Create a TCompactProtocol.
    *
@@ -120,8 +120,13 @@
   public TCompactProtocol(TTransport transport) {
     super(transport);
   }
-  
-  
+
+  @Override
+  public void reset() {
+    lastField_.clear();
+    lastFieldId_ = 0;
+  }
+
   //
   // Public Writing methods.
   //
@@ -155,7 +160,7 @@
   public void writeStructEnd() throws TException {
     lastFieldId_ = lastField_.pop();
   }
-  
+
   /**
    * Write a field header containing the field id and field type. If the
    * difference between the current field id and the last one is small (< 15),
@@ -260,7 +265,7 @@
   public void writeI16(short i16) throws TException {
     writeVarint32(intToZigZag(i16));
   }
-  
+
   /**
    * Write an i32 as a zigzag varint.
    */
@@ -307,7 +312,7 @@
   // These methods are called by structs, but don't actually have any wire 
   // output or purpose.
   // 
-  
+
   public void writeMessageEnd() throws TException {}
   public void writeMapEnd() throws TException {}
   public void writeListEnd() throws TException {}
@@ -370,7 +375,7 @@
     }
     trans_.write(varint64out, 0, idx);
   }
-  
+
   /**
    * Convert l into a zigzag long. This allows negative numbers to be 
    * represented compactly as a varint.
@@ -378,7 +383,7 @@
   private long longToZigzag(long l) {
     return (l << 1) ^ (l >> 63);
   }
-  
+
   /**
    * Convert n into a zigzag int. This allows negative numbers to be 
    * represented compactly as a varint.
@@ -386,7 +391,7 @@
   private int intToZigZag(int n) {
     return (n << 1) ^ (n >> 31);
   }
-  
+
   /**
    * Convert a long into little-endian bytes in buf starting at off and going 
    * until off+7.
@@ -650,11 +655,11 @@
   public void readMapEnd() throws TException {}
   public void readListEnd() throws TException {}
   public void readSetEnd() throws TException {}
-  
+
   //
   // Internal reading methods
   //
-  
+
   /**
    * Read an i32 from the wire as a varint. The MSB of each byte is set
    * if there is another byte to follow. This can read up to 5 bytes.
@@ -684,14 +689,14 @@
   //
   // encoding helpers
   //
-  
+
   /**
    * Convert from zigzag int to int.
    */
   private int zigzagToInt(int n) {
     return (n >>> 1) ^ -(n & 1);
   }
-  
+
   /** 
    * Convert from zigzag long to long.
    */
@@ -767,5 +772,4 @@
   private byte getCompactType(byte ttype) {
     return ttypeToCompactType[ttype];
   }
-  
 }
diff --git a/lib/java/src/org/apache/thrift/protocol/TJSONProtocol.java b/lib/java/src/org/apache/thrift/protocol/TJSONProtocol.java
index 631c6a5..89ba9b4 100644
--- a/lib/java/src/org/apache/thrift/protocol/TJSONProtocol.java
+++ b/lib/java/src/org/apache/thrift/protocol/TJSONProtocol.java
@@ -301,6 +301,13 @@
     super(trans);
   }
 
+  @Override
+  public void reset() {
+    contextStack_.clear();
+    context_ = new JSONBaseContext();
+    reader_ = new LookaheadReader();
+  }
+
   // Temporary buffer used by several methods
   private byte[] tmpbuf_ = new byte[4];
 
diff --git a/lib/java/src/org/apache/thrift/protocol/TProtocol.java b/lib/java/src/org/apache/thrift/protocol/TProtocol.java
index 65b6f4b..78b07f9 100644
--- a/lib/java/src/org/apache/thrift/protocol/TProtocol.java
+++ b/lib/java/src/org/apache/thrift/protocol/TProtocol.java
@@ -142,4 +142,10 @@
   public abstract String readString() throws TException;
 
   public abstract byte[] readBinary() throws TException;
+
+  /**
+   * Reset any internal state back to a blank slate. This method only needs to
+   * be implemented for stateful protocols.
+   */
+  public void reset() {}
 }
diff --git a/lib/java/test/org/apache/thrift/test/PartialDeserializeTest.java b/lib/java/test/org/apache/thrift/test/PartialDeserializeTest.java
index a7fa59b..831a419 100644
--- a/lib/java/test/org/apache/thrift/test/PartialDeserializeTest.java
+++ b/lib/java/test/org/apache/thrift/test/PartialDeserializeTest.java
@@ -75,9 +75,14 @@
 
   public static void testPartialDeserialize(TProtocolFactory protocolFactory, TBase input, TBase output, TBase expected, TFieldIdEnum ... fieldIdPath) throws TException {
     byte[] record = new TSerializer(protocolFactory).serialize(input);
-    new TDeserializer(protocolFactory).partialDeserialize(output, record, fieldIdPath);
-    if(!output.equals(expected))
-      throw new RuntimeException("with " + protocolFactory.toString() + ", expected " + expected + " but got " + output);
+    TDeserializer deserializer = new TDeserializer(protocolFactory);
+    for (int i = 0; i < 2; i++) {
+      TBase outputCopy = output.deepCopy();
+      deserializer.partialDeserialize(outputCopy, record, fieldIdPath);
+      if(!outputCopy.equals(expected)) {
+        throw new RuntimeException("on attempt " + i + ", with " + protocolFactory.toString() + ", expected " + expected + " but got " + outputCopy);
+      }
+    }
   }
 }