THRIFT-2660 Validate the bytes received in TSaslTransport

Client: Java
Patch: Harsh J
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
index 947ee13..fc6d9b8 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
@@ -178,13 +178,22 @@
     underlyingTransport.readAll(messageHeader, 0, messageHeader.length);
 
     byte statusByte = messageHeader[0];
-    byte[] payload = new byte[EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES)];
-    underlyingTransport.readAll(payload, 0, payload.length);
 
     NegotiationStatus status = NegotiationStatus.byValue(statusByte);
     if (status == null) {
       throw sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte);
-    } else if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) {
+    }
+
+    int payloadBytes = EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES);
+    if (payloadBytes <= 0 || payloadBytes > 104857600 /* 100 MB */) {
+      throw sendAndThrowMessage(
+        NegotiationStatus.ERROR, "Invalid payload header length: " + payloadBytes);
+    }
+
+    byte[] payload = new byte[payloadBytes];
+    underlyingTransport.readAll(payload, 0, payload.length);
+
+    if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) {
       try {
         String remoteMessage = new String(payload, "UTF-8");
         throw new TTransportException("Peer indicated failure: " + remoteMessage);
diff --git a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
index 80e53b9..b627ccf 100644
--- a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
+++ b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
@@ -412,4 +412,64 @@
       put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
     }
   }
+
+  private static class MockTTransport extends TTransport {
+
+    byte[] badHeader = null;
+    private TMemoryInputTransport readBuffer = new TMemoryInputTransport();
+
+    public MockTTransport(int mode) {
+      if (mode==1) {
+        // Invalid status byte
+        badHeader = new byte[] { (byte)0xFF, (byte)0x00, (byte)0x00, (byte)0x00, (byte)0x05 };
+      } else if (mode == 2) {
+        // Valid status byte, negative payload length
+        badHeader = new byte[] { (byte)0x01, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF };
+      } else if (mode == 3) {
+        // Valid status byte, excessively large, bogus payload length
+        badHeader = new byte[] { (byte)0x01, (byte)0x64, (byte)0x00, (byte)0x00, (byte)0x00 };
+      }
+      readBuffer.reset(badHeader);
+    }
+
+    @Override
+    public boolean isOpen() {
+      return true;
+    }
+
+    @Override
+    public void open() throws TTransportException {}
+
+    @Override
+    public void close() {}
+
+    @Override
+    public int read(byte[] buf, int off, int len) throws TTransportException {
+      return readBuffer.read(buf, off, len);
+    }
+
+    @Override
+    public void write(byte[] buf, int off, int len) throws TTransportException {}
+  }
+
+  public void testBadHeader() {
+    TSaslTransport saslTransport = new TSaslServerTransport(new MockTTransport(1));
+    try {
+      saslTransport.receiveSaslMessage();
+      fail("Should have gotten an error due to incorrect status byte value.");
+    } catch (TTransportException e) {
+    }
+    saslTransport = new TSaslServerTransport(new MockTTransport(2));
+    try {
+      saslTransport.receiveSaslMessage();
+      fail("Should have gotten an error due to negative payload length.");
+    } catch (TTransportException e) {
+    }
+    saslTransport = new TSaslServerTransport(new MockTTransport(3));
+    try {
+      saslTransport.receiveSaslMessage();
+      fail("Should have gotten an error due to bogus (large) payload length.");
+    } catch (TTransportException e) {
+    }
+  }
 }