THRIFT-2572 Add string/collection length limit checks (from C++) to java protocol readers

Client: Java

This closes #138

Patch: Andrew Cox
diff --git a/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java b/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java
index 32a761f..65b3353 100644
--- a/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java
+++ b/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java
@@ -31,31 +31,53 @@
  */
 public class TBinaryProtocol extends TProtocol {
   private static final TStruct ANONYMOUS_STRUCT = new TStruct();
+  private static final long NO_LENGTH_LIMIT = -1;
 
   protected static final int VERSION_MASK = 0xffff0000;
   protected static final int VERSION_1 = 0x80010000;
 
-  protected boolean strictRead_ = false;
-  protected boolean strictWrite_ = true;
+  /**
+   * The maximum number of bytes to read from the transport for
+   * variable-length fields (such as strings or binary) or {@link #NO_LENGTH_LIMIT} for
+   * unlimited.
+   */
+  private final long stringLengthLimit_;
+
+  /**
+   * The maximum number of elements to read from the network for
+   * containers (maps, sets, lists), or {@link #NO_LENGTH_LIMIT} for unlimited.
+   */
+  private final long containerLengthLimit_;
+
+  protected boolean strictRead_;
+  protected boolean strictWrite_;
 
   /**
    * Factory
    */
   public static class Factory implements TProtocolFactory {
-    protected boolean strictRead_ = false;
-    protected boolean strictWrite_ = true;
+    protected long stringLengthLimit_;
+    protected long containerLengthLimit_;
+    protected boolean strictRead_;
+    protected boolean strictWrite_;
 
     public Factory() {
       this(false, true);
     }
 
     public Factory(boolean strictRead, boolean strictWrite) {
+      this(strictRead, strictWrite, NO_LENGTH_LIMIT, NO_LENGTH_LIMIT);
+    }
+
+    public Factory(boolean strictRead, boolean strictWrite, long stringLengthLimit, long containerLengthLimit) {
+      stringLengthLimit_ = stringLengthLimit;
+      containerLengthLimit_ = containerLengthLimit;
       strictRead_ = strictRead;
       strictWrite_ = strictWrite;
     }
 
     public TProtocol getProtocol(TTransport trans) {
-      return new TBinaryProtocol(trans, strictRead_, strictWrite_);
+      return new TBinaryProtocol(trans, stringLengthLimit_, containerLengthLimit_, strictRead_, strictWrite_);
     }
   }
 
@@ -67,7 +89,13 @@
   }
 
   public TBinaryProtocol(TTransport trans, boolean strictRead, boolean strictWrite) {
+    this(trans, NO_LENGTH_LIMIT, NO_LENGTH_LIMIT, strictRead, strictWrite);
+  }
+
+  public TBinaryProtocol(TTransport trans, long stringLengthLimit, long containerLengthLimit, boolean strictRead, boolean strictWrite) {
     super(trans);
+    stringLengthLimit_ = stringLengthLimit;
+    containerLengthLimit_ = containerLengthLimit;
     strictRead_ = strictRead;
     strictWrite_ = strictWrite;
   }
@@ -220,19 +248,25 @@
   public void readFieldEnd() {}
 
   public TMap readMapBegin() throws TException {
-    return new TMap(readByte(), readByte(), readI32());
+    TMap map = new TMap(readByte(), readByte(), readI32());
+    checkContainerReadLength(map.size);
+    return map;
   }
 
   public void readMapEnd() {}
 
   public TList readListBegin() throws TException {
-    return new TList(readByte(), readI32());
+    TList list = new TList(readByte(), readI32());
+    checkContainerReadLength(list.size);
+    return list;
   }
 
   public void readListEnd() {}
 
   public TSet readSetBegin() throws TException {
-    return new TSet(readByte(), readI32());
+    TSet set = new TSet(readByte(), readI32());
+    checkContainerReadLength(set.size);
+    return set;
   }
 
   public void readSetEnd() {}
@@ -321,6 +355,12 @@
   public String readString() throws TException {
     int size = readI32();
 
+    checkStringReadLength(size);
+    if (stringLengthLimit_ > 0 && size > stringLengthLimit_) {
+      throw new TProtocolException(TProtocolException.SIZE_LIMIT,
+                                   "String field exceeded string size limit");
+    }
+
     if (trans_.getBytesRemainingInBuffer() >= size) {
       try {
         String s = new String(trans_.getBuffer(), trans_.getBufferPosition(), size, "UTF-8");
@@ -347,6 +387,11 @@
   public ByteBuffer readBinary() throws TException {
     int size = readI32();
 
+    if (stringLengthLimit_ > 0 && size > stringLengthLimit_) {
+      throw new TProtocolException(TProtocolException.SIZE_LIMIT,
+                                   "Binary field exceeded string size limit");
+    }
+
     if (trans_.getBytesRemainingInBuffer() >= size) {
       ByteBuffer bb = ByteBuffer.wrap(trans_.getBuffer(), trans_.getBufferPosition(), size);
       trans_.consumeBuffer(size);
@@ -358,6 +403,28 @@
     return ByteBuffer.wrap(buf);
   }
 
+  private void checkStringReadLength(int length) throws TProtocolException {
+    if (length < 0) {
+      throw new TProtocolException(TProtocolException.NEGATIVE_SIZE,
+                                   "Negative length: " + length);
+    }
+    if (stringLengthLimit_ != NO_LENGTH_LIMIT && length > stringLengthLimit_) {
+      throw new TProtocolException(TProtocolException.SIZE_LIMIT,
+                                   "Length exceeded max allowed: " + length);
+    }
+  }
+
+  private void checkContainerReadLength(int length) throws TProtocolException {
+    if (length < 0) {
+      throw new TProtocolException(TProtocolException.NEGATIVE_SIZE,
+                                   "Negative length: " + length);
+    }
+    if (containerLengthLimit_ != NO_LENGTH_LIMIT && length > containerLengthLimit_) {
+      throw new TProtocolException(TProtocolException.SIZE_LIMIT,
+                                   "Length exceeded max allowed: " + length);
+    }
+  }
+
   private int readAll(byte[] buf, int off, int len) throws TException {
     return trans_.readAll(buf, off, len);
   }
diff --git a/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java
index 7b273c5..75300b8 100644
--- a/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java
+++ b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java
@@ -29,15 +29,17 @@
 
 /**
  * TCompactProtocol2 is the Java implementation of the compact protocol specified
- * in THRIFT-110. The fundamental approach to reducing the overhead of 
+ * in THRIFT-110. The fundamental approach to reducing the overhead of
  * structures is a) use variable-length integers all over the place and b) make
- * use of unused bits wherever possible. Your savings will obviously vary 
- * based on the specific makeup of your structs, but in general, the more 
+ * use of unused bits wherever possible. Your savings will obviously vary
+ * based on the specific makeup of your structs, but in general, the more
  * fields, nested structures, short strings and collections, and low-value i32
  * and i64 fields you have, the more benefit you'll see.
  */
 public class TCompactProtocol extends TProtocol {
 
+  private final static long NO_LENGTH_LIMIT = -1;
+
   private final static TStruct ANONYMOUS_STRUCT = new TStruct("");
   private final static TField TSTOP = new TField("", TType.STOP, (short)0);
 
@@ -62,18 +64,24 @@
    * TProtocolFactory that produces TCompactProtocols.
    */
   public static class Factory implements TProtocolFactory {
-    private final long maxNetworkBytes_;
+    private final long stringLengthLimit_;
+    private final long containerLengthLimit_;
 
     public Factory() {
-      maxNetworkBytes_ = -1;
+      this(NO_LENGTH_LIMIT, NO_LENGTH_LIMIT);
     }
 
-    public Factory(int maxNetworkBytes) {
-      maxNetworkBytes_ = maxNetworkBytes;
+    public Factory(long stringLengthLimit) {
+      this(stringLengthLimit, NO_LENGTH_LIMIT);
+    }
+
+    public Factory(long stringLengthLimit, long containerLengthLimit) {
+      this.containerLengthLimit_ = containerLengthLimit;
+      this.stringLengthLimit_ = stringLengthLimit;
     }
 
     public TProtocol getProtocol(TTransport trans) {
-      return new TCompactProtocol(trans, maxNetworkBytes_);
+      return new TCompactProtocol(trans, stringLengthLimit_, containerLengthLimit_);
     }
   }
 
@@ -101,7 +109,7 @@
     public static final byte STRUCT         = 0x0C;
   }
 
-  /** 
+  /**
    * Used to keep track of the last field for the current and previous structs,
    * so we can do the delta stuff.
    */
@@ -110,34 +118,56 @@
   private short lastFieldId_ = 0;
 
   /**
-   * If we encounter a boolean field begin, save the TField here so it can 
+   * If we encounter a boolean field begin, save the TField here so it can
    * have the value incorporated.
    */
   private TField booleanField_ = null;
 
   /**
-   * If we read a field header, and it's a boolean field, save the boolean 
+   * If we read a field header, and it's a boolean field, save the boolean
    * value here so that readBool can use it.
    */
   private Boolean boolValue_ = null;
 
   /**
-   * The maximum number of bytes to read from the network for
-   * variable-length fields (such as strings or binary) or -1 for
+   * The maximum number of bytes to read from the transport for
+   * variable-length fields (such as strings or binary) or {@link #NO_LENGTH_LIMIT} for
    * unlimited.
    */
-  private final long maxNetworkBytes_;
+  private final long stringLengthLimit_;
+
+  /**
+   * The maximum number of elements to read from the network for
+   * containers (maps, sets, lists), or {@link #NO_LENGTH_LIMIT} for unlimited.
+   */
+  private final long containerLengthLimit_;
 
   /**
    * Create a TCompactProtocol.
    *
    * @param transport the TTransport object to read from or write to.
-   * @param maxNetworkBytes the maximum number of bytes to read for
+   * @param stringLengthLimit the maximum number of bytes to read for
    *     variable-length fields.
+   * @param containerLengthLimit the maximum number of elements to read
+   *     for containers.
    */
-  public TCompactProtocol(TTransport transport, long maxNetworkBytes) {
+  public TCompactProtocol(TTransport transport, long stringLengthLimit, long containerLengthLimit) {
     super(transport);
-    maxNetworkBytes_ = maxNetworkBytes;
+    this.stringLengthLimit_ = stringLengthLimit;
+    this.containerLengthLimit_ = containerLengthLimit;
+  }
+
+  /**
+   * Create a TCompactProtocol.
+   *
+   * @param transport the TTransport object to read from or write to.
+   * @param stringLengthLimit the maximum number of bytes to read for
+   *     variable-length fields.
+   * @deprecated Use constructor specifying both string limit and container limit instead
+   */
+  @Deprecated
+  public TCompactProtocol(TTransport transport, long stringLengthLimit) {
+    this(transport, stringLengthLimit, NO_LENGTH_LIMIT);
   }
 
   /**
@@ -146,7 +176,7 @@
    * @param transport the TTransport object to read from or write to.
    */
   public TCompactProtocol(TTransport transport) {
-    this(transport, -1);
+    this(transport, NO_LENGTH_LIMIT, NO_LENGTH_LIMIT);
   }
 
   @Override
@@ -171,7 +201,7 @@
   }
 
   /**
-   * Write a struct begin. This doesn't actually put anything on the wire. We 
+   * Write a struct begin. This doesn't actually put anything on the wire. We
    * use it as an opportunity to put special placeholder markers on the field
    * stack so we can get the field id deltas correct.
    */
@@ -194,7 +224,7 @@
    * difference between the current field id and the last one is small (< 15),
    * then the field id will be encoded in the 4 MSB as a delta. Otherwise, the
    * field id will follow the type header as a zigzag varint.
-   */ 
+   */
   public void writeFieldBegin(TField field) throws TException {
     if (field.type == TType.BOOL) {
       // we want to possibly include the value, so we'll wait.
@@ -205,8 +235,8 @@
   }
 
   /**
-   * The workhorse of writeFieldBegin. It has the option of doing a 
-   * 'type override' of the type header. This is used specifically in the 
+   * The workhorse of writeFieldBegin. It has the option of doing a
+   * 'type override' of the type header. This is used specifically in the
    * boolean field case.
    */
   private void writeFieldBeginInternal(TField field, byte typeOverride) throws TException {
@@ -237,7 +267,7 @@
   }
 
   /**
-   * Write a map header. If the map is empty, omit the key and value type 
+   * Write a map header. If the map is empty, omit the key and value type
    * headers, as we don't need any additional information to skip it.
    */
   public void writeMapBegin(TMap map) throws TException {
@@ -248,8 +278,8 @@
       writeByteDirect(getCompactType(map.keyType) << 4 | getCompactType(map.valueType));
     }
   }
-  
-  /** 
+
+  /**
    * Write a list header.
    */
   public void writeListBegin(TList list) throws TException {
@@ -264,9 +294,9 @@
   }
 
   /**
-   * Write a boolean value. Potentially, this could be a boolean field, in 
+   * Write a boolean value. Potentially, this could be a boolean field, in
    * which case the field header info isn't written yet. If so, decide what the
-   * right type header is for the value and then write the field header. 
+   * right type header is for the value and then write the field header.
    * Otherwise, write a single byte.
    */
   public void writeBool(boolean b) throws TException {
@@ -280,7 +310,7 @@
     }
   }
 
-  /** 
+  /**
    * Write a byte. Nothing to see here!
    */
   public void writeByte(byte b) throws TException {
@@ -310,7 +340,7 @@
 
   /**
    * Write a double to the wire as 8 bytes.
-   */ 
+   */
   public void writeDouble(double dub) throws TException {
     byte[] data = new byte[]{0, 0, 0, 0, 0, 0, 0, 0};
     fixedLongToBytes(Double.doubleToLongBits(dub), data, 0);
@@ -330,7 +360,7 @@
   }
 
   /**
-   * Write a byte array, using a varint for the size. 
+   * Write a byte array, using a varint for the size.
    */
   public void writeBinary(ByteBuffer bin) throws TException {
     int length = bin.limit() - bin.position();
@@ -343,9 +373,9 @@
   }
 
   //
-  // These methods are called by structs, but don't actually have any wire 
+  // These methods are called by structs, but don't actually have any wire
   // output or purpose.
-  // 
+  //
 
   public void writeMessageEnd() throws TException {}
   public void writeMapEnd() throws TException {}
@@ -358,7 +388,7 @@
   //
 
   /**
-   * Abstract method for writing the start of lists and sets. List and sets on 
+   * Abstract method for writing the start of lists and sets. List and sets on
    * the wire differ only by the type indicator.
    */
   protected void writeCollectionBegin(byte elemType, int size) throws TException {
@@ -411,7 +441,7 @@
   }
 
   /**
-   * Convert l into a zigzag long. This allows negative numbers to be 
+   * Convert l into a zigzag long. This allows negative numbers to be
    * represented compactly as a varint.
    */
   private long longToZigzag(long l) {
@@ -419,7 +449,7 @@
   }
 
   /**
-   * Convert n into a zigzag int. This allows negative numbers to be 
+   * Convert n into a zigzag int. This allows negative numbers to be
    * represented compactly as a varint.
    */
   private int intToZigZag(int n) {
@@ -427,7 +457,7 @@
   }
 
   /**
-   * Convert a long into little-endian bytes in buf starting at off and going 
+   * Convert a long into little-endian bytes in buf starting at off and going
    * until off+7.
    */
   private void fixedLongToBytes(long n, byte[] buf, int off) {
@@ -441,8 +471,8 @@
     buf[off+7] = (byte)((n >> 56) & 0xff);
   }
 
-  /** 
-   * Writes a byte without any possibility of all that field header nonsense. 
+  /**
+   * Writes a byte without any possibility of all that field header nonsense.
    * Used internally by other writing methods that know they need to write a byte.
    */
   private byte[] byteDirectBuffer = new byte[1];
@@ -451,7 +481,7 @@
     trans_.write(byteDirectBuffer);
   }
 
-  /** 
+  /**
    * Writes a byte without any possibility of all that field header nonsense.
    */
   private void writeByteDirect(int n) throws TException {
@@ -459,12 +489,12 @@
   }
 
 
-  // 
+  //
   // Reading methods.
-  // 
+  //
 
   /**
-   * Read a message header. 
+   * Read a message header.
    */
   public TMessage readMessageBegin() throws TException {
     byte protocolId = readByte();
@@ -493,16 +523,16 @@
   }
 
   /**
-   * Doesn't actually consume any wire data, just removes the last field for 
+   * Doesn't actually consume any wire data, just removes the last field for
    * this struct from the field stack.
    */
   public void readStructEnd() throws TException {
     // consume the last field we read off the wire.
     lastFieldId_ = lastField_.pop();
   }
-  
+
   /**
-   * Read a field header off the wire. 
+   * Read a field header off the wire.
    */
   public TField readFieldBegin() throws TException {
     byte type = readByte();
@@ -530,26 +560,27 @@
     if (isBoolType(type)) {
       // save the boolean value in a special instance variable.
       boolValue_ = (byte)(type & 0x0f) == Types.BOOLEAN_TRUE ? Boolean.TRUE : Boolean.FALSE;
-    } 
+    }
 
     // push the new field onto the field stack so we can keep the deltas going.
     lastFieldId_ = field.id;
     return field;
   }
 
-  /** 
+  /**
    * Read a map header off the wire. If the size is zero, skip reading the key
    * and value type. This means that 0-length maps will yield TMaps without the
    * "correct" types.
    */
   public TMap readMapBegin() throws TException {
     int size = readVarint32();
+    checkContainerReadLength(size);
     byte keyAndValueType = size == 0 ? 0 : readByte();
     return new TMap(getTType((byte)(keyAndValueType >> 4)), getTType((byte)(keyAndValueType & 0xf)), size);
   }
 
   /**
-   * Read a list header off the wire. If the list size is 0-14, the size will 
+   * Read a list header off the wire. If the list size is 0-14, the size will
    * be packed into the element type header. If it's a longer list, the 4 MSB
    * of the element type header will be 0xF, and a varint will follow with the
    * true size.
@@ -560,12 +591,13 @@
     if (size == 15) {
       size = readVarint32();
     }
+    checkContainerReadLength(size);
     byte type = getTType(size_and_type);
     return new TList(type, size);
   }
 
   /**
-   * Read a set header off the wire. If the set size is 0-14, the size will 
+   * Read a set header off the wire. If the set size is 0-14, the size will
    * be packed into the element type header. If it's a longer set, the 4 MSB
    * of the element type header will be 0xF, and a varint will follow with the
    * true size.
@@ -639,7 +671,7 @@
    */
   public String readString() throws TException {
     int length = readVarint32();
-    checkReadLength(length);
+    checkStringReadLength(length);
 
     if (length == 0) {
       return "";
@@ -659,11 +691,11 @@
   }
 
   /**
-   * Read a byte[] from the wire. 
+   * Read a byte[] from the wire.
    */
   public ByteBuffer readBinary() throws TException {
     int length = readVarint32();
-    checkReadLength(length);
+    checkStringReadLength(length);
     if (length == 0) return ByteBuffer.wrap(new byte[0]);
 
     byte[] buf = new byte[length];
@@ -672,7 +704,7 @@
   }
 
   /**
-   * Read a byte[] of a known length from the wire. 
+   * Read a byte[] of a known length from the wire.
    */
   private byte[] readBinary(int length) throws TException {
     if (length == 0) return new byte[0];
@@ -682,17 +714,30 @@
     return buf;
   }
 
-  private void checkReadLength(int length) throws TProtocolException {
+  private void checkStringReadLength(int length) throws TProtocolException {
     if (length < 0) {
-      throw new TProtocolException("Negative length: " + length);
+      throw new TProtocolException(TProtocolException.NEGATIVE_SIZE,
+                                   "Negative length: " + length);
     }
-    if (maxNetworkBytes_ != -1 && length > maxNetworkBytes_) {
-      throw new TProtocolException("Length exceeded max allowed: " + length);
+    if (stringLengthLimit_ != NO_LENGTH_LIMIT && length > stringLengthLimit_) {
+      throw new TProtocolException(TProtocolException.SIZE_LIMIT,
+                                   "Length exceeded max allowed: " + length);
+    }
+  }
+
+  private void checkContainerReadLength(int length) throws TProtocolException {
+    if (length < 0) {
+      throw new TProtocolException(TProtocolException.NEGATIVE_SIZE,
+                                   "Negative length: " + length);
+    }
+    if (containerLengthLimit_ != NO_LENGTH_LIMIT && length > containerLengthLimit_) {
+      throw new TProtocolException(TProtocolException.SIZE_LIMIT,
+                                   "Length exceeded max allowed: " + length);
     }
   }
 
   //
-  // These methods are here for the struct to call, but don't have any wire 
+  // These methods are here for the struct to call, but don't have any wire
   // encoding.
   //
   public void readMessageEnd() throws TException {}
@@ -736,7 +781,7 @@
   }
 
   /**
-   * Read an i64 from the wire as a proper varint. The MSB of each byte is set 
+   * Read an i64 from the wire as a proper varint. The MSB of each byte is set
    * if there is another byte to follow. This can read up to 10 bytes.
    */
   private long readVarint64() throws TException {
@@ -776,7 +821,7 @@
     return (n >>> 1) ^ -(n & 1);
   }
 
-  /** 
+  /**
    * Convert from zigzag long to long.
    */
   private long zigzagToLong(long n) {
@@ -784,7 +829,7 @@
   }
 
   /**
-   * Note that it's important that the mask bytes are long literals, 
+   * Note that it's important that the mask bytes are long literals,
    * otherwise they'll default to ints, and when you shift an int left 56 bits,
    * you just get a messed up int.
    */
@@ -810,7 +855,7 @@
   }
 
   /**
-   * Given a TCompactProtocol.Types constant, convert it to its corresponding 
+   * Given a TCompactProtocol.Types constant, convert it to its corresponding
    * TType value.
    */
   private byte getTType(byte type) throws TProtocolException {