THRIFT-2660 Validate the bytes received in TSaslTransport
Client: Java
Patch: Harsh J
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
index 947ee13..fc6d9b8 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
@@ -178,13 +178,22 @@
underlyingTransport.readAll(messageHeader, 0, messageHeader.length);
byte statusByte = messageHeader[0];
- byte[] payload = new byte[EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES)];
- underlyingTransport.readAll(payload, 0, payload.length);
NegotiationStatus status = NegotiationStatus.byValue(statusByte);
if (status == null) {
throw sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte);
- } else if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) {
+ }
+
+ int payloadBytes = EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES);
+ if (payloadBytes <= 0 || payloadBytes > 104857600 /* 100 MB */) {
+ throw sendAndThrowMessage(
+ NegotiationStatus.ERROR, "Invalid payload header length: " + payloadBytes);
+ }
+
+ byte[] payload = new byte[payloadBytes];
+ underlyingTransport.readAll(payload, 0, payload.length);
+
+ if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) {
try {
String remoteMessage = new String(payload, "UTF-8");
throw new TTransportException("Peer indicated failure: " + remoteMessage);
diff --git a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
index 80e53b9..b627ccf 100644
--- a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
+++ b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
@@ -412,4 +412,64 @@
put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
}
}
+
+ private static class MockTTransport extends TTransport {
+
+ byte[] badHeader = null;
+ private TMemoryInputTransport readBuffer = new TMemoryInputTransport();
+
+ public MockTTransport(int mode) {
+ if (mode==1) {
+ // Invalid status byte
+ badHeader = new byte[] { (byte)0xFF, (byte)0x00, (byte)0x00, (byte)0x00, (byte)0x05 };
+ } else if (mode == 2) {
+ // Valid status byte, negative payload length
+ badHeader = new byte[] { (byte)0x01, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF };
+ } else if (mode == 3) {
+ // Valid status byte, excessively large, bogus payload length
+ badHeader = new byte[] { (byte)0x01, (byte)0x64, (byte)0x00, (byte)0x00, (byte)0x00 };
+ }
+ readBuffer.reset(badHeader);
+ }
+
+ @Override
+ public boolean isOpen() {
+ return true;
+ }
+
+ @Override
+ public void open() throws TTransportException {}
+
+ @Override
+ public void close() {}
+
+ @Override
+ public int read(byte[] buf, int off, int len) throws TTransportException {
+ return readBuffer.read(buf, off, len);
+ }
+
+ @Override
+ public void write(byte[] buf, int off, int len) throws TTransportException {}
+ }
+
+ public void testBadHeader() {
+ TSaslTransport saslTransport = new TSaslServerTransport(new MockTTransport(1));
+ try {
+ saslTransport.receiveSaslMessage();
+ fail("Should have gotten an error due to incorrect status byte value.");
+ } catch (TTransportException e) {
+ }
+ saslTransport = new TSaslServerTransport(new MockTTransport(2));
+ try {
+ saslTransport.receiveSaslMessage();
+ fail("Should have gotten an error due to negative payload length.");
+ } catch (TTransportException e) {
+ }
+ saslTransport = new TSaslServerTransport(new MockTTransport(3));
+ try {
+ saslTransport.receiveSaslMessage();
+ fail("Should have gotten an error due to bogus (large) payload length.");
+ } catch (TTransportException e) {
+ }
+ }
}