THRIFT-1643 Denial of Service attack in TBinaryProtocol.readString
Patch: Niraj Tolia
Fix: add TCompactProtocol maxNetworkBytes
git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1396186 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java
index 3b1d886..3db256d 100644
--- a/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java
+++ b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java
@@ -62,10 +62,18 @@
* TProtocolFactory that produces TCompactProtocols.
*/
public static class Factory implements TProtocolFactory {
- public Factory() {}
+ private final long maxNetworkBytes_;
+
+ public Factory() {
+ maxNetworkBytes_ = -1;
+ }
+
+ public Factory(int maxNetworkBytes) {
+ maxNetworkBytes_ = maxNetworkBytes;
+ }
public TProtocol getProtocol(TTransport trans) {
- return new TCompactProtocol(trans);
+ return new TCompactProtocol(trans, maxNetworkBytes_);
}
}
@@ -114,12 +122,31 @@
private Boolean boolValue_ = null;
/**
+ * The maximum number of bytes to read from the network for
+ * variable-length fields (such as strings or binary) or -1 for
+ * unlimited.
+ */
+ private final long maxNetworkBytes_;
+
+ /**
+ * Create a TCompactProtocol.
+ *
+ * @param transport the TTransport object to read from or write to.
+ * @param maxNetworkBytes the maximum number of bytes to read for
+ * variable-length fields.
+ */
+ public TCompactProtocol(TTransport transport, long maxNetworkBytes) {
+ super(transport);
+ maxNetworkBytes_ = maxNetworkBytes;
+ }
+
+ /**
* Create a TCompactProtocol.
*
* @param transport the TTransport object to read from or write to.
*/
public TCompactProtocol(TTransport transport) {
- super(transport);
+ this(transport, -1);
}
@Override
@@ -617,6 +644,10 @@
return "";
}
+ if (maxNetworkBytes_ != -1 && length > maxNetworkBytes_) {
+ throw new TException("Read size greater than max allowed.");
+ }
+
try {
if (trans_.getBytesRemainingInBuffer() >= length) {
String str = new String(trans_.getBuffer(), trans_.getBufferPosition(), length, "UTF-8");
@@ -637,6 +668,10 @@
int length = readVarint32();
if (length == 0) return ByteBuffer.wrap(new byte[0]);
+ if (maxNetworkBytes_ != -1 && length > maxNetworkBytes_) {
+ throw new TException("Read size greater than max allowed.");
+ }
+
byte[] buf = new byte[length];
trans_.readAll(buf, 0, length);
return ByteBuffer.wrap(buf);