THRIFT-5027 Implement remaining read bytes checks
Client: netstd
Patch: Jens Geyer
This closes #1946
diff --git a/lib/netstd/Thrift/Protocol/TBinaryProtocol.cs b/lib/netstd/Thrift/Protocol/TBinaryProtocol.cs
index f0772aa..a00c5c1 100644
--- a/lib/netstd/Thrift/Protocol/TBinaryProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TBinaryProtocol.cs
@@ -370,7 +370,7 @@
ValueType = (TType) await ReadByteAsync(cancellationToken),
Count = await ReadI32Async(cancellationToken)
};
-
+ CheckReadBytesAvailable(map);
return map;
}
@@ -394,7 +394,7 @@
ElementType = (TType) await ReadByteAsync(cancellationToken),
Count = await ReadI32Async(cancellationToken)
};
-
+ CheckReadBytesAvailable(list);
return list;
}
@@ -418,7 +418,7 @@
ElementType = (TType) await ReadByteAsync(cancellationToken),
Count = await ReadI32Async(cancellationToken)
};
-
+ CheckReadBytesAvailable(set);
return set;
}
@@ -507,6 +507,7 @@
}
var size = await ReadI32Async(cancellationToken);
+ Transport.CheckReadBytesAvailable(size);
var buf = new byte[size];
await Trans.ReadAllAsync(buf, 0, size, cancellationToken);
return buf;
@@ -536,11 +537,34 @@
return Encoding.UTF8.GetString(PreAllocatedBuffer, 0, size);
}
+ Transport.CheckReadBytesAvailable(size);
var buf = new byte[size];
await Trans.ReadAllAsync(buf, 0, size, cancellationToken);
return Encoding.UTF8.GetString(buf, 0, buf.Length);
}
+ // Return the minimum number of bytes a type will consume on the wire
+ public override int GetMinSerializedSize(TType type)
+ {
+ switch (type)
+ {
+ case TType.Stop: return 0;
+ case TType.Void: return 0;
+ case TType.Bool: return sizeof(byte);
+ case TType.Byte: return sizeof(byte);
+ case TType.Double: return sizeof(double);
+ case TType.I16: return sizeof(short);
+ case TType.I32: return sizeof(int);
+ case TType.I64: return sizeof(long);
+ case TType.String: return sizeof(int); // string length
+ case TType.Struct: return 0; // empty struct
+ case TType.Map: return sizeof(int); // element count
+ case TType.Set: return sizeof(int); // element count
+ case TType.List: return sizeof(int); // element count
+ default: throw new TTransportException(TTransportException.ExceptionType.Unknown, "unrecognized type code");
+ }
+ }
+
public class Factory : TProtocolFactory
{
protected bool StrictRead;
diff --git a/lib/netstd/Thrift/Protocol/TCompactProtocol.cs b/lib/netstd/Thrift/Protocol/TCompactProtocol.cs
index 921507c..a8a46f2 100644
--- a/lib/netstd/Thrift/Protocol/TCompactProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TCompactProtocol.cs
@@ -590,7 +590,9 @@
var size = (int) await ReadVarInt32Async(cancellationToken);
var keyAndValueType = size == 0 ? (byte) 0 : (byte) await ReadByteAsync(cancellationToken);
- return new TMap(GetTType((byte) (keyAndValueType >> 4)), GetTType((byte) (keyAndValueType & 0xf)), size);
+ var map = new TMap(GetTType((byte) (keyAndValueType >> 4)), GetTType((byte) (keyAndValueType & 0xf)), size);
+ CheckReadBytesAvailable(map);
+ return map;
}
public override async Task ReadMapEndAsync(CancellationToken cancellationToken)
@@ -703,6 +705,7 @@
return Encoding.UTF8.GetString(PreAllocatedBuffer, 0, length);
}
+ Transport.CheckReadBytesAvailable(length);
var buf = new byte[length];
await Trans.ReadAllAsync(buf, 0, length, cancellationToken);
return Encoding.UTF8.GetString(buf, 0, length);
@@ -718,6 +721,7 @@
}
// read data
+ Transport.CheckReadBytesAvailable(length);
var buf = new byte[length];
await Trans.ReadAllAsync(buf, 0, length, cancellationToken);
return buf;
@@ -745,7 +749,9 @@
}
var type = GetTType(sizeAndType);
- return new TList(type, size);
+ var list = new TList(type, size);
+ CheckReadBytesAvailable(list);
+ return list;
}
public override async Task ReadListEndAsync(CancellationToken cancellationToken)
@@ -856,6 +862,28 @@
return (uint) (n << 1) ^ (uint) (n >> 31);
}
+ // Return the minimum number of bytes a type will consume on the wire
+ public override int GetMinSerializedSize(TType type)
+ {
+ switch (type)
+ {
+ case TType.Stop: return 0;
+ case TType.Void: return 0;
+ case TType.Bool: return sizeof(byte);
+ case TType.Double: return 8; // uses fixedLongToBytes() which always writes 8 bytes
+ case TType.Byte: return sizeof(byte);
+ case TType.I16: return sizeof(byte); // zigzag
+ case TType.I32: return sizeof(byte); // zigzag
+ case TType.I64: return sizeof(byte); // zigzag
+ case TType.String: return sizeof(byte); // string length
+ case TType.Struct: return 0; // empty struct
+ case TType.Map: return sizeof(byte); // element count
+ case TType.Set: return sizeof(byte); // element count
+ case TType.List: return sizeof(byte); // element count
+ default: throw new TTransportException(TTransportException.ExceptionType.Unknown, "unrecognized type code");
+ }
+ }
+
public class Factory : TProtocolFactory
{
public override TProtocol GetProtocol(TTransport trans)
diff --git a/lib/netstd/Thrift/Protocol/TJSONProtocol.cs b/lib/netstd/Thrift/Protocol/TJSONProtocol.cs
index 464bd62..7bc7130 100644
--- a/lib/netstd/Thrift/Protocol/TJSONProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TJSONProtocol.cs
@@ -703,6 +703,7 @@
map.KeyType = TJSONProtocolHelper.GetTypeIdForTypeName(await ReadJsonStringAsync(false, cancellationToken));
map.ValueType = TJSONProtocolHelper.GetTypeIdForTypeName(await ReadJsonStringAsync(false, cancellationToken));
map.Count = (int) await ReadJsonIntegerAsync(cancellationToken);
+ CheckReadBytesAvailable(map);
await ReadJsonObjectStartAsync(cancellationToken);
return map;
}
@@ -719,6 +720,7 @@
await ReadJsonArrayStartAsync(cancellationToken);
list.ElementType = TJSONProtocolHelper.GetTypeIdForTypeName(await ReadJsonStringAsync(false, cancellationToken));
list.Count = (int) await ReadJsonIntegerAsync(cancellationToken);
+ CheckReadBytesAvailable(list);
return list;
}
@@ -733,6 +735,7 @@
await ReadJsonArrayStartAsync(cancellationToken);
set.ElementType = TJSONProtocolHelper.GetTypeIdForTypeName(await ReadJsonStringAsync(false, cancellationToken));
set.Count = (int) await ReadJsonIntegerAsync(cancellationToken);
+ CheckReadBytesAvailable(set);
return set;
}
@@ -782,6 +785,28 @@
return await ReadJsonBase64Async(cancellationToken);
}
+ // Return the minimum number of bytes a type will consume on the wire
+ public override int GetMinSerializedSize(TType type)
+ {
+ switch (type)
+ {
+ case TType.Stop: return 0;
+ case TType.Void: return 0;
+ case TType.Bool: return 1; // written as int
+ case TType.Byte: return 1;
+ case TType.Double: return 1;
+ case TType.I16: return 1;
+ case TType.I32: return 1;
+ case TType.I64: return 1;
+ case TType.String: return 2; // empty string
+ case TType.Struct: return 2; // empty struct
+ case TType.Map: return 2; // empty map
+ case TType.Set: return 2; // empty set
+ case TType.List: return 2; // empty list
+ default: throw new TTransportException(TTransportException.ExceptionType.Unknown, "unrecognized type code");
+ }
+ }
+
/// <summary>
/// Factory for JSON protocol objects
/// </summary>
diff --git a/lib/netstd/Thrift/Protocol/TProtocol.cs b/lib/netstd/Thrift/Protocol/TProtocol.cs
index dca3f9e..5275c9c 100644
--- a/lib/netstd/Thrift/Protocol/TProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TProtocol.cs
@@ -77,6 +77,27 @@
_isDisposed = true;
}
+
+ protected void CheckReadBytesAvailable(TSet set)
+ {
+ Transport.CheckReadBytesAvailable(set.Count * GetMinSerializedSize(set.ElementType));
+ }
+
+ protected void CheckReadBytesAvailable(TList list)
+ {
+ Transport.CheckReadBytesAvailable(list.Count * GetMinSerializedSize(list.ElementType));
+ }
+
+ protected void CheckReadBytesAvailable(TMap map)
+ {
+ var elmSize = GetMinSerializedSize(map.KeyType) + GetMinSerializedSize(map.ValueType);
+ Transport.CheckReadBytesAvailable(map.Count * elmSize);
+ }
+
+ // Returns the minimum amount of bytes needed to store the smallest possible instance of TType.
+ public abstract int GetMinSerializedSize(TType type);
+
+
public virtual async Task WriteMessageBeginAsync(TMessage message)
{
await WriteMessageBeginAsync(message, CancellationToken.None);
diff --git a/lib/netstd/Thrift/Protocol/TProtocolDecorator.cs b/lib/netstd/Thrift/Protocol/TProtocolDecorator.cs
index 845c827..b032e83 100644
--- a/lib/netstd/Thrift/Protocol/TProtocolDecorator.cs
+++ b/lib/netstd/Thrift/Protocol/TProtocolDecorator.cs
@@ -243,5 +243,13 @@
{
return await _wrappedProtocol.ReadBinaryAsync(cancellationToken);
}
+
+ // Returns the minimum amount of bytes needed to store the smallest possible instance of TType.
+ public override int GetMinSerializedSize(TType type)
+ {
+ return _wrappedProtocol.GetMinSerializedSize(type);
+ }
+
+
}
}
diff --git a/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs b/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs
index abf8f14..290e50c 100644
--- a/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs
@@ -41,6 +41,7 @@
{
Bytes = (byte[])buf.Clone();
_bytesUsed = Bytes.Length;
+ UpdateKnownMessageSize(_bytesUsed);
}
public int Position { get; set; }
@@ -121,7 +122,6 @@
public override ValueTask<int> ReadAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken)
{
- CheckReadBytesAvailable(length);
var count = Math.Min(Length - Position, length);
Buffer.BlockCopy(Bytes, Position, buffer, offset, count);
Position += count;
diff --git a/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs b/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs
index 10cec3c..dee52dd 100644
--- a/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs
+++ b/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs
@@ -172,6 +172,17 @@
await InnerTransport.FlushAsync(cancellationToken);
}
+ public override void CheckReadBytesAvailable(long numBytes)
+ {
+ var buffered = ReadBuffer.Length - ReadBuffer.Position;
+ if (buffered < numBytes)
+ {
+ numBytes -= buffered;
+ InnerTransport.CheckReadBytesAvailable(numBytes);
+ }
+ }
+
+
private void CheckNotDisposed()
{
if (IsDisposed)
diff --git a/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs b/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs
index 58b45f7..be1513f 100644
--- a/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs
+++ b/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs
@@ -155,6 +155,16 @@
WriteBuffer.Seek(0, SeekOrigin.End);
}
+ public override void CheckReadBytesAvailable(long numBytes)
+ {
+ var buffered = ReadBuffer.Length - ReadBuffer.Position;
+ if (buffered < numBytes)
+ {
+ numBytes -= buffered;
+ InnerTransport.CheckReadBytesAvailable(numBytes);
+ }
+ }
+
private void CheckNotDisposed()
{
if (IsDisposed)
diff --git a/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs b/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs
index 59d98ff..2137ae4 100644
--- a/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs
+++ b/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs
@@ -19,5 +19,10 @@
{
InnerTransport.UpdateKnownMessageSize(size);
}
+
+ public override void CheckReadBytesAvailable(long numBytes)
+ {
+ InnerTransport.CheckReadBytesAvailable(numBytes);
+ }
}
}
diff --git a/lib/netstd/Thrift/Transport/TEndpointTransport.cs b/lib/netstd/Thrift/Transport/TEndpointTransport.cs
index 810f3f4..fa2ac6b 100644
--- a/lib/netstd/Thrift/Transport/TEndpointTransport.cs
+++ b/lib/netstd/Thrift/Transport/TEndpointTransport.cs
@@ -9,6 +9,7 @@
abstract public class TEndpointTransport : TTransport
{
protected long MaxMessageSize { get => Configuration.MaxMessageSize; }
+ protected long KnownMessageSize { get; private set; }
protected long RemainingMessageSize { get; private set; }
private readonly TConfiguration _configuration;
@@ -25,22 +26,33 @@
/// <summary>
/// Resets RemainingMessageSize to the configured maximum
/// </summary>
- protected void ResetConsumedMessageSize(long knownSize = -1)
+ protected void ResetConsumedMessageSize(long newSize = -1)
{
- if(knownSize >= 0)
- RemainingMessageSize = Math.Min( MaxMessageSize, knownSize);
- else
+ // full reset
+ if (newSize < 0)
+ {
+ KnownMessageSize = MaxMessageSize;
RemainingMessageSize = MaxMessageSize;
+ return;
+ }
+
+ // update only: message size can shrink, but not grow
+ Debug.Assert(KnownMessageSize <= MaxMessageSize);
+ if (newSize > KnownMessageSize)
+ throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "MaxMessageSize reached");
+
+ KnownMessageSize = newSize;
+ RemainingMessageSize = newSize;
}
/// <summary>
/// Updates RemainingMessageSize to reflect then known real message size (e.g. framed transport).
- /// Will throw if we already consumed too many bytes.
+ /// Will throw if we already consumed too many bytes or if the new size is larger than allowed.
/// </summary>
/// <param name="size"></param>
public override void UpdateKnownMessageSize(long size)
{
- var consumed = MaxMessageSize - RemainingMessageSize;
+ var consumed = KnownMessageSize - RemainingMessageSize;
ResetConsumedMessageSize(size);
CountConsumedMessageBytes(consumed);
}
@@ -49,7 +61,7 @@
/// Throws if there are not enough bytes in the input stream to satisfy a read of numBytes bytes of data
/// </summary>
/// <param name="numBytes"></param>
- protected void CheckReadBytesAvailable(long numBytes)
+ public override void CheckReadBytesAvailable(long numBytes)
{
if (RemainingMessageSize < numBytes)
throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "MaxMessageSize reached");
diff --git a/lib/netstd/Thrift/Transport/TTransport.cs b/lib/netstd/Thrift/Transport/TTransport.cs
index 8f510dd..dedd51d 100644
--- a/lib/netstd/Thrift/Transport/TTransport.cs
+++ b/lib/netstd/Thrift/Transport/TTransport.cs
@@ -34,7 +34,7 @@
public abstract bool IsOpen { get; }
public abstract TConfiguration Configuration { get; }
public abstract void UpdateKnownMessageSize(long size);
-
+ public abstract void CheckReadBytesAvailable(long numBytes);
public void Dispose()
{
Dispose(true);