THRIFT-5021 Implement MAX_MESSAGE_SIZE and centralize limits into a TConfiguration class
Client: netstd
Patch: Jens Geyer
This closes #1943
diff --git a/lib/netstd/Tests/Thrift.IntegrationTests/Protocols/ProtocolsOperationsTests.cs b/lib/netstd/Tests/Thrift.IntegrationTests/Protocols/ProtocolsOperationsTests.cs
index b1f8418..b8df515 100644
--- a/lib/netstd/Tests/Thrift.IntegrationTests/Protocols/ProtocolsOperationsTests.cs
+++ b/lib/netstd/Tests/Thrift.IntegrationTests/Protocols/ProtocolsOperationsTests.cs
@@ -31,6 +31,7 @@
public class ProtocolsOperationsTests
{
private readonly CompareLogic _compareLogic = new CompareLogic();
+ private static readonly TConfiguration Configuration = null; // or new TConfiguration() if needed
[DataTestMethod]
[DataRow(typeof(TBinaryProtocol), TMessageType.Call)]
@@ -494,7 +495,7 @@
private static Tuple<Stream, TProtocol> GetProtocolInstance(Type protocolType)
{
var memoryStream = new MemoryStream();
- var streamClientTransport = new TStreamTransport(memoryStream, memoryStream);
+ var streamClientTransport = new TStreamTransport(memoryStream, memoryStream,Configuration);
var protocol = (TProtocol) Activator.CreateInstance(protocolType, streamClientTransport);
return new Tuple<Stream, TProtocol>(memoryStream, protocol);
}
diff --git a/lib/netstd/Tests/Thrift.Tests/Protocols/TJsonProtocolTests.cs b/lib/netstd/Tests/Thrift.Tests/Protocols/TJsonProtocolTests.cs
index 970ce7e..4054a29 100644
--- a/lib/netstd/Tests/Thrift.Tests/Protocols/TJsonProtocolTests.cs
+++ b/lib/netstd/Tests/Thrift.Tests/Protocols/TJsonProtocolTests.cs
@@ -21,7 +21,6 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
-using NSubstitute;
using Thrift.Protocol;
using Thrift.Protocol.Entities;
using Thrift.Transport;
@@ -36,7 +35,7 @@
[TestMethod]
public void TJSONProtocol_Can_Create_Instance_Test()
{
- var httpClientTransport = Substitute.For<THttpTransport>(new Uri("http://localhost"), null, null);
+ var httpClientTransport = new THttpTransport( new Uri("http://localhost"), null, null, null);
var result = new TJSONProtocolWrapper(httpClientTransport);
@@ -45,7 +44,7 @@
Assert.IsNotNull(result.WrappedReader);
Assert.IsNotNull(result.Transport);
Assert.IsTrue(result.WrappedRecursionDepth == 0);
- Assert.IsTrue(result.WrappedRecursionLimit == TProtocol.DefaultRecursionDepth);
+ Assert.IsTrue(result.WrappedRecursionLimit == TConfiguration.DEFAULT_RECURSION_DEPTH);
Assert.IsTrue(result.Transport.Equals(httpClientTransport));
Assert.IsTrue(result.WrappedContext.GetType().Name.Equals("JSONBaseContext", StringComparison.OrdinalIgnoreCase));
diff --git a/lib/netstd/Thrift/Protocol/TProtocol.cs b/lib/netstd/Thrift/Protocol/TProtocol.cs
index 75edb11..dca3f9e 100644
--- a/lib/netstd/Thrift/Protocol/TProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TProtocol.cs
@@ -27,7 +27,6 @@
// ReSharper disable once InconsistentNaming
public abstract class TProtocol : IDisposable
{
- public const int DefaultRecursionDepth = 64;
private bool _isDisposed;
protected int RecursionDepth;
@@ -36,7 +35,7 @@
protected TProtocol(TTransport trans)
{
Trans = trans;
- RecursionLimit = DefaultRecursionDepth;
+ RecursionLimit = trans.Configuration.RecursionLimit;
RecursionDepth = 0;
}
diff --git a/lib/netstd/Thrift/TConfiguration.cs b/lib/netstd/Thrift/TConfiguration.cs
new file mode 100644
index 0000000..c8dde10
--- /dev/null
+++ b/lib/netstd/Thrift/TConfiguration.cs
@@ -0,0 +1,19 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Thrift
+{
+ public class TConfiguration
+ {
+ public const int DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024;
+ public const int DEFAULT_MAX_FRAME_SIZE = 16384000; // this value is used consistently across all Thrift libraries
+ public const int DEFAULT_RECURSION_DEPTH = 64;
+
+ public int MaxMessageSize { get; set; } = DEFAULT_MAX_MESSAGE_SIZE;
+ public int MaxFrameSize { get; set; } = DEFAULT_MAX_FRAME_SIZE;
+ public int RecursionLimit { get; set; } = DEFAULT_RECURSION_DEPTH;
+
+ // TODO(JensG): add connection and i/o timeouts
+ }
+}
diff --git a/lib/netstd/Thrift/Transport/Client/THttpTransport.cs b/lib/netstd/Thrift/Transport/Client/THttpTransport.cs
index 4f8454c..bbd94fa 100644
--- a/lib/netstd/Thrift/Transport/Client/THttpTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/THttpTransport.cs
@@ -28,7 +28,7 @@
namespace Thrift.Transport.Client
{
// ReSharper disable once InconsistentNaming
- public class THttpTransport : TTransport
+ public class THttpTransport : TEndpointTransport
{
private readonly X509Certificate[] _certificates;
private readonly Uri _uri;
@@ -39,13 +39,14 @@
private MemoryStream _outputStream = new MemoryStream();
private bool _isDisposed;
- public THttpTransport(Uri uri, IDictionary<string, string> customRequestHeaders = null, string userAgent = null)
- : this(uri, Enumerable.Empty<X509Certificate>(), customRequestHeaders, userAgent)
+ public THttpTransport(Uri uri, TConfiguration config, IDictionary<string, string> customRequestHeaders = null, string userAgent = null)
+ : this(uri, config, Enumerable.Empty<X509Certificate>(), customRequestHeaders, userAgent)
{
}
- public THttpTransport(Uri uri, IEnumerable<X509Certificate> certificates,
+ public THttpTransport(Uri uri, TConfiguration config, IEnumerable<X509Certificate> certificates,
IDictionary<string, string> customRequestHeaders, string userAgent = null)
+ : base(config)
{
_uri = uri;
_certificates = (certificates ?? Enumerable.Empty<X509Certificate>()).ToArray();
@@ -104,6 +105,8 @@
if (_inputStream == null)
throw new TTransportException(TTransportException.ExceptionType.NotOpen, "No request has been sent");
+ CheckReadBytesAvailable(length);
+
try
{
var ret = await _inputStream.ReadAsync(buffer, offset, length, cancellationToken);
@@ -112,6 +115,7 @@
throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "No more data available");
}
+ CountConsumedMessageBytes(ret);
return ret;
}
catch (IOException iox)
@@ -196,9 +200,11 @@
finally
{
_outputStream = new MemoryStream();
+ ResetConsumedMessageSize();
}
}
+
// IDisposable
protected override void Dispose(bool disposing)
{
diff --git a/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs b/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs
index cdbbc0d..abf8f14 100644
--- a/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs
@@ -24,18 +24,20 @@
namespace Thrift.Transport.Client
{
// ReSharper disable once InconsistentNaming
- public class TMemoryBufferTransport : TTransport
+ public class TMemoryBufferTransport : TEndpointTransport
{
private bool IsDisposed;
private byte[] Bytes;
private int _bytesUsed;
- public TMemoryBufferTransport(int initialCapacity = 2048)
+ public TMemoryBufferTransport(TConfiguration config, int initialCapacity = 2048)
+ : base(config)
{
Bytes = new byte[initialCapacity];
}
- public TMemoryBufferTransport(byte[] buf)
+ public TMemoryBufferTransport(byte[] buf, TConfiguration config)
+ :base(config)
{
Bytes = (byte[])buf.Clone();
_bytesUsed = Bytes.Length;
@@ -112,13 +114,18 @@
if ((0 > newPos) || (newPos > _bytesUsed))
throw new ArgumentException(nameof(origin));
Position = newPos;
+
+ ResetConsumedMessageSize();
+ CountConsumedMessageBytes(Position);
}
public override ValueTask<int> ReadAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken)
{
+ CheckReadBytesAvailable(length);
var count = Math.Min(Length - Position, length);
Buffer.BlockCopy(Bytes, Position, buffer, offset, count);
Position += count;
+ CountConsumedMessageBytes(count);
return new ValueTask<int>(count);
}
@@ -142,6 +149,7 @@
{
await Task.FromCanceled(cancellationToken);
}
+ ResetConsumedMessageSize();
}
public byte[] GetBuffer()
@@ -157,7 +165,6 @@
return true;
}
-
// IDisposable
protected override void Dispose(bool disposing)
{
diff --git a/lib/netstd/Thrift/Transport/Client/TNamedPipeTransport.cs b/lib/netstd/Thrift/Transport/Client/TNamedPipeTransport.cs
index 1ae6074..f7f10b7 100644
--- a/lib/netstd/Thrift/Transport/Client/TNamedPipeTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/TNamedPipeTransport.cs
@@ -23,17 +23,18 @@
namespace Thrift.Transport.Client
{
// ReSharper disable once InconsistentNaming
- public class TNamedPipeTransport : TTransport
+ public class TNamedPipeTransport : TEndpointTransport
{
private NamedPipeClientStream PipeStream;
private readonly int ConnectTimeout;
- public TNamedPipeTransport(string pipe, int timeout = Timeout.Infinite)
- : this(".", pipe, timeout)
+ public TNamedPipeTransport(string pipe, TConfiguration config, int timeout = Timeout.Infinite)
+ : this(".", pipe, config, timeout)
{
}
- public TNamedPipeTransport(string server, string pipe, int timeout = Timeout.Infinite)
+ public TNamedPipeTransport(string server, string pipe, TConfiguration config, int timeout = Timeout.Infinite)
+ : base(config)
{
var serverName = string.IsNullOrWhiteSpace(server) ? server : ".";
ConnectTimeout = (timeout > 0) ? timeout : Timeout.Infinite;
@@ -51,6 +52,7 @@
}
await PipeStream.ConnectAsync( ConnectTimeout, cancellationToken);
+ ResetConsumedMessageSize();
}
public override void Close()
@@ -69,7 +71,10 @@
throw new TTransportException(TTransportException.ExceptionType.NotOpen);
}
- return await PipeStream.ReadAsync(buffer, offset, length, cancellationToken);
+ CheckReadBytesAvailable(length);
+ var numRead = await PipeStream.ReadAsync(buffer, offset, length, cancellationToken);
+ CountConsumedMessageBytes(numRead);
+ return numRead;
}
public override async Task WriteAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken)
@@ -98,8 +103,10 @@
{
await Task.FromCanceled(cancellationToken);
}
+ ResetConsumedMessageSize();
}
+
protected override void Dispose(bool disposing)
{
if(disposing)
diff --git a/lib/netstd/Thrift/Transport/Client/TSocketTransport.cs b/lib/netstd/Thrift/Transport/Client/TSocketTransport.cs
index dd506bc..d559154 100644
--- a/lib/netstd/Thrift/Transport/Client/TSocketTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/TSocketTransport.cs
@@ -30,13 +30,15 @@
private bool _isDisposed;
- public TSocketTransport(TcpClient client)
+ public TSocketTransport(TcpClient client, TConfiguration config)
+ : base(config)
{
TcpClient = client ?? throw new ArgumentNullException(nameof(client));
SetInputOutputStream();
}
- public TSocketTransport(IPAddress host, int port, int timeout = 0)
+ public TSocketTransport(IPAddress host, int port, TConfiguration config, int timeout = 0)
+ : base(config)
{
Host = host;
Port = port;
@@ -47,7 +49,8 @@
SetInputOutputStream();
}
- public TSocketTransport(string host, int port, int timeout = 0)
+ public TSocketTransport(string host, int port, TConfiguration config, int timeout = 0)
+ : base(config)
{
try
{
diff --git a/lib/netstd/Thrift/Transport/Client/TStreamTransport.cs b/lib/netstd/Thrift/Transport/Client/TStreamTransport.cs
index d8574d6..e04b3b3 100644
--- a/lib/netstd/Thrift/Transport/Client/TStreamTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/TStreamTransport.cs
@@ -22,15 +22,17 @@
namespace Thrift.Transport.Client
{
// ReSharper disable once InconsistentNaming
- public class TStreamTransport : TTransport
+ public class TStreamTransport : TEndpointTransport
{
private bool _isDisposed;
- protected TStreamTransport()
+ protected TStreamTransport(TConfiguration config)
+ :base(config)
{
}
- public TStreamTransport(Stream inputStream, Stream outputStream)
+ public TStreamTransport(Stream inputStream, Stream outputStream, TConfiguration config)
+ : base(config)
{
InputStream = inputStream;
OutputStream = outputStream;
@@ -38,7 +40,14 @@
protected Stream OutputStream { get; set; }
- protected Stream InputStream { get; set; }
+ private Stream _InputStream = null;
+ protected Stream InputStream {
+ get => _InputStream;
+ set {
+ _InputStream = value;
+ ResetConsumedMessageSize();
+ }
+ }
public override bool IsOpen => true;
@@ -90,8 +99,10 @@
public override async Task FlushAsync(CancellationToken cancellationToken)
{
await OutputStream.FlushAsync(cancellationToken);
+ ResetConsumedMessageSize();
}
+
// IDisposable
protected override void Dispose(bool disposing)
{
diff --git a/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs b/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs
index a926a38..0980526 100644
--- a/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs
@@ -42,11 +42,12 @@
private SslStream _secureStream;
private int _timeout;
- public TTlsSocketTransport(TcpClient client,
+ public TTlsSocketTransport(TcpClient client, TConfiguration config,
X509Certificate2 certificate, bool isServer = false,
RemoteCertificateValidationCallback certValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
+ : base(config)
{
_client = client;
_certificate = certificate;
@@ -68,12 +69,12 @@
}
}
- public TTlsSocketTransport(IPAddress host, int port,
+ public TTlsSocketTransport(IPAddress host, int port, TConfiguration config,
string certificatePath,
RemoteCertificateValidationCallback certValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
- : this(host, port, 0,
+ : this(host, port, config, 0,
new X509Certificate2(certificatePath),
certValidator,
localCertificateSelectionCallback,
@@ -81,12 +82,12 @@
{
}
- public TTlsSocketTransport(IPAddress host, int port,
+ public TTlsSocketTransport(IPAddress host, int port, TConfiguration config,
X509Certificate2 certificate = null,
RemoteCertificateValidationCallback certValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
- : this(host, port, 0,
+ : this(host, port, config, 0,
certificate,
certValidator,
localCertificateSelectionCallback,
@@ -94,11 +95,12 @@
{
}
- public TTlsSocketTransport(IPAddress host, int port, int timeout,
+ public TTlsSocketTransport(IPAddress host, int port, TConfiguration config, int timeout,
X509Certificate2 certificate,
RemoteCertificateValidationCallback certValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
+ : base(config)
{
_host = host;
_port = port;
@@ -111,11 +113,12 @@
InitSocket();
}
- public TTlsSocketTransport(string host, int port, int timeout,
+ public TTlsSocketTransport(string host, int port, TConfiguration config, int timeout,
X509Certificate2 certificate,
RemoteCertificateValidationCallback certValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
+ : base(config)
{
try
{
diff --git a/lib/netstd/Thrift/Transport/TBufferedTransport.cs b/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs
similarity index 92%
rename from lib/netstd/Thrift/Transport/TBufferedTransport.cs
rename to lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs
index e4fdd3a..10cec3c 100644
--- a/lib/netstd/Thrift/Transport/TBufferedTransport.cs
+++ b/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs
@@ -24,12 +24,11 @@
namespace Thrift.Transport
{
// ReSharper disable once InconsistentNaming
- public class TBufferedTransport : TTransport
+ public class TBufferedTransport : TLayeredTransport
{
private readonly int DesiredBufferSize;
- private readonly Client.TMemoryBufferTransport ReadBuffer = new Client.TMemoryBufferTransport(1024);
- private readonly Client.TMemoryBufferTransport WriteBuffer = new Client.TMemoryBufferTransport(1024);
- private readonly TTransport InnerTransport;
+ private readonly Client.TMemoryBufferTransport ReadBuffer;
+ private readonly Client.TMemoryBufferTransport WriteBuffer;
private bool IsDisposed;
public class Factory : TTransportFactory
@@ -42,19 +41,20 @@
//TODO: should support only specified input transport?
public TBufferedTransport(TTransport transport, int bufSize = 1024)
+ : base(transport)
{
if (bufSize <= 0)
{
throw new ArgumentOutOfRangeException(nameof(bufSize), "Buffer size must be a positive number.");
}
- InnerTransport = transport ?? throw new ArgumentNullException(nameof(transport));
DesiredBufferSize = bufSize;
- if (DesiredBufferSize != ReadBuffer.Capacity)
- ReadBuffer.Capacity = DesiredBufferSize;
- if (DesiredBufferSize != WriteBuffer.Capacity)
- WriteBuffer.Capacity = DesiredBufferSize;
+ WriteBuffer = new Client.TMemoryBufferTransport(InnerTransport.Configuration, bufSize);
+ ReadBuffer = new Client.TMemoryBufferTransport(InnerTransport.Configuration, bufSize);
+
+ Debug.Assert(DesiredBufferSize == ReadBuffer.Capacity);
+ Debug.Assert(DesiredBufferSize == WriteBuffer.Capacity);
}
public TTransport UnderlyingTransport
diff --git a/lib/netstd/Thrift/Transport/TFramedTransport.cs b/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs
similarity index 90%
rename from lib/netstd/Thrift/Transport/TFramedTransport.cs
rename to lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs
index de6df72..c842a16 100644
--- a/lib/netstd/Thrift/Transport/TFramedTransport.cs
+++ b/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs
@@ -23,13 +23,12 @@
namespace Thrift.Transport
{
// ReSharper disable once InconsistentNaming
- public class TFramedTransport : TTransport
+ public class TFramedTransport : TLayeredTransport
{
private const int HeaderSize = 4;
private readonly byte[] HeaderBuf = new byte[HeaderSize];
- private readonly Client.TMemoryBufferTransport ReadBuffer = new Client.TMemoryBufferTransport();
- private readonly Client.TMemoryBufferTransport WriteBuffer = new Client.TMemoryBufferTransport();
- private readonly TTransport InnerTransport;
+ private readonly Client.TMemoryBufferTransport ReadBuffer;
+ private readonly Client.TMemoryBufferTransport WriteBuffer;
private bool IsDisposed;
@@ -42,9 +41,10 @@
}
public TFramedTransport(TTransport transport)
+ : base(transport)
{
- InnerTransport = transport ?? throw new ArgumentNullException(nameof(transport));
-
+ ReadBuffer = new Client.TMemoryBufferTransport(Configuration);
+ WriteBuffer = new Client.TMemoryBufferTransport(Configuration);
InitWriteBuffer();
}
@@ -86,7 +86,11 @@
private async ValueTask ReadFrameAsync(CancellationToken cancellationToken)
{
await InnerTransport.ReadAllAsync(HeaderBuf, 0, HeaderSize, cancellationToken);
- var size = DecodeFrameSize(HeaderBuf);
+ int size = DecodeFrameSize(HeaderBuf);
+
+ if ((0 > size) || (size > Configuration.MaxFrameSize)) // size must be in the range 0 to allowed max
+ throw new TTransportException(TTransportException.ExceptionType.Unknown, $"Maximum frame size exceeded ({size} bytes)");
+ UpdateKnownMessageSize(size + HeaderSize);
ReadBuffer.SetLength(size);
ReadBuffer.Seek(0, SeekOrigin.Begin);
diff --git a/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs b/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs
new file mode 100644
index 0000000..59d98ff
--- /dev/null
+++ b/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs
@@ -0,0 +1,23 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Thrift.Transport
+{
+ public abstract class TLayeredTransport : TTransport
+ {
+ public readonly TTransport InnerTransport;
+
+ public override TConfiguration Configuration { get => InnerTransport.Configuration; }
+
+ public TLayeredTransport(TTransport transport)
+ {
+ InnerTransport = transport ?? throw new ArgumentNullException(nameof(transport));
+ }
+
+ public override void UpdateKnownMessageSize(long size)
+ {
+ InnerTransport.UpdateKnownMessageSize(size);
+ }
+ }
+}
diff --git a/lib/netstd/Thrift/Transport/Server/THttpServerTransport.cs b/lib/netstd/Thrift/Transport/Server/THttpServerTransport.cs
index 2a40db3..7271f50 100644
--- a/lib/netstd/Thrift/Transport/Server/THttpServerTransport.cs
+++ b/lib/netstd/Thrift/Transport/Server/THttpServerTransport.cs
@@ -42,27 +42,31 @@
protected TTransportFactory OutputTransportFactory;
protected ITAsyncProcessor Processor;
+ protected TConfiguration Configuration;
public THttpServerTransport(
ITAsyncProcessor processor,
+ TConfiguration config,
RequestDelegate next = null,
ILoggerFactory loggerFactory = null)
- : this(processor, new TBinaryProtocol.Factory(), null, next, loggerFactory)
+ : this(processor, config, new TBinaryProtocol.Factory(), null, next, loggerFactory)
{
}
public THttpServerTransport(
ITAsyncProcessor processor,
+ TConfiguration config,
TProtocolFactory protocolFactory,
TTransportFactory transFactory = null,
RequestDelegate next = null,
ILoggerFactory loggerFactory = null)
- : this(processor, protocolFactory, protocolFactory, transFactory, transFactory, next, loggerFactory)
+ : this(processor, config, protocolFactory, protocolFactory, transFactory, transFactory, next, loggerFactory)
{
}
public THttpServerTransport(
ITAsyncProcessor processor,
+ TConfiguration config,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory,
TTransportFactory inputTransFactory = null,
@@ -73,6 +77,8 @@
// loggerFactory == null is not illegal anymore
Processor = processor ?? throw new ArgumentNullException(nameof(processor));
+ Configuration = config; // may be null
+
InputProtocolFactory = inputProtocolFactory ?? throw new ArgumentNullException(nameof(inputProtocolFactory));
OutputProtocolFactory = outputProtocolFactory ?? throw new ArgumentNullException(nameof(outputProtocolFactory));
@@ -91,7 +97,7 @@
public async Task ProcessRequestAsync(HttpContext context, CancellationToken cancellationToken)
{
- var transport = new TStreamTransport(context.Request.Body, context.Response.Body);
+ var transport = new TStreamTransport(context.Request.Body, context.Response.Body, Configuration);
try
{
diff --git a/lib/netstd/Thrift/Transport/Server/TNamedPipeServerTransport.cs b/lib/netstd/Thrift/Transport/Server/TNamedPipeServerTransport.cs
index b2f29b4..a8b64c4 100644
--- a/lib/netstd/Thrift/Transport/Server/TNamedPipeServerTransport.cs
+++ b/lib/netstd/Thrift/Transport/Server/TNamedPipeServerTransport.cs
@@ -38,7 +38,8 @@
private volatile bool _isPending = true;
private NamedPipeServerStream _stream = null;
- public TNamedPipeServerTransport(string pipeAddress)
+ public TNamedPipeServerTransport(string pipeAddress, TConfiguration config)
+ : base(config)
{
_pipeAddress = pipeAddress;
}
@@ -224,7 +225,7 @@
await _stream.WaitForConnectionAsync(cancellationToken);
- var trans = new ServerTransport(_stream);
+ var trans = new ServerTransport(_stream, Configuration);
_stream = null; // pass ownership to ServerTransport
//_isPending = false;
@@ -243,11 +244,12 @@
}
}
- private class ServerTransport : TTransport
+ private class ServerTransport : TEndpointTransport
{
private readonly NamedPipeServerStream PipeStream;
- public ServerTransport(NamedPipeServerStream stream)
+ public ServerTransport(NamedPipeServerStream stream, TConfiguration config)
+ : base(config)
{
PipeStream = stream;
}
@@ -274,7 +276,10 @@
throw new TTransportException(TTransportException.ExceptionType.NotOpen);
}
- return await PipeStream.ReadAsync(buffer, offset, length, cancellationToken);
+ CheckReadBytesAvailable(length);
+ var numBytes = await PipeStream.ReadAsync(buffer, offset, length, cancellationToken);
+ CountConsumedMessageBytes(numBytes);
+ return numBytes;
}
public override async Task WriteAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken)
@@ -303,6 +308,8 @@
{
await Task.FromCanceled(cancellationToken);
}
+
+ ResetConsumedMessageSize();
}
protected override void Dispose(bool disposing)
diff --git a/lib/netstd/Thrift/Transport/Server/TServerSocketTransport.cs b/lib/netstd/Thrift/Transport/Server/TServerSocketTransport.cs
index 86d82e3..6656b64 100644
--- a/lib/netstd/Thrift/Transport/Server/TServerSocketTransport.cs
+++ b/lib/netstd/Thrift/Transport/Server/TServerSocketTransport.cs
@@ -31,14 +31,15 @@
private readonly int _clientTimeout;
private TcpListener _server;
- public TServerSocketTransport(TcpListener listener, int clientTimeout = 0)
+ public TServerSocketTransport(TcpListener listener, TConfiguration config, int clientTimeout = 0)
+ : base(config)
{
_server = listener;
_clientTimeout = clientTimeout;
}
- public TServerSocketTransport(int port, int clientTimeout = 0)
- : this(null, clientTimeout)
+ public TServerSocketTransport(int port, TConfiguration config, int clientTimeout = 0)
+ : this(null, config, clientTimeout)
{
try
{
@@ -93,7 +94,7 @@
try
{
- tSocketTransport = new TSocketTransport(tcpClient)
+ tSocketTransport = new TSocketTransport(tcpClient,Configuration)
{
Timeout = _clientTimeout
};
diff --git a/lib/netstd/Thrift/Transport/Server/TServerTransport.cs b/lib/netstd/Thrift/Transport/Server/TServerTransport.cs
index dd60f6a..31f578d 100644
--- a/lib/netstd/Thrift/Transport/Server/TServerTransport.cs
+++ b/lib/netstd/Thrift/Transport/Server/TServerTransport.cs
@@ -23,6 +23,13 @@
// ReSharper disable once InconsistentNaming
public abstract class TServerTransport
{
+ public readonly TConfiguration Configuration;
+
+ public TServerTransport(TConfiguration config)
+ {
+ Configuration = config ?? new TConfiguration();
+ }
+
public abstract void Listen();
public abstract void Close();
public abstract bool IsClientPending();
diff --git a/lib/netstd/Thrift/Transport/Server/TTlsServerSocketTransport.cs b/lib/netstd/Thrift/Transport/Server/TTlsServerSocketTransport.cs
index 231b83f..9f74562 100644
--- a/lib/netstd/Thrift/Transport/Server/TTlsServerSocketTransport.cs
+++ b/lib/netstd/Thrift/Transport/Server/TTlsServerSocketTransport.cs
@@ -39,10 +39,12 @@
public TTlsServerSocketTransport(
TcpListener listener,
+ TConfiguration config,
X509Certificate2 certificate,
RemoteCertificateValidationCallback clientCertValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
+ : base(config)
{
if (!certificate.HasPrivateKey)
{
@@ -59,11 +61,12 @@
public TTlsServerSocketTransport(
int port,
+ TConfiguration config,
X509Certificate2 certificate,
RemoteCertificateValidationCallback clientCertValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
- : this(null, certificate, clientCertValidator, localCertificateSelectionCallback, sslProtocols)
+ : this(null, config, certificate, clientCertValidator, localCertificateSelectionCallback, sslProtocols)
{
try
{
@@ -117,8 +120,8 @@
client.SendTimeout = client.ReceiveTimeout = _clientTimeout;
//wrap the client in an SSL Socket passing in the SSL cert
- var tTlsSocket = new TTlsSocketTransport(
- client,
+ var tTlsSocket = new TTlsSocketTransport(
+ client, Configuration,
_serverCertificate, true, _clientCertValidator,
_localCertificateSelectionCallback, _sslProtocols);
diff --git a/lib/netstd/Thrift/Transport/TEndpointTransport.cs b/lib/netstd/Thrift/Transport/TEndpointTransport.cs
new file mode 100644
index 0000000..810f3f4
--- /dev/null
+++ b/lib/netstd/Thrift/Transport/TEndpointTransport.cs
@@ -0,0 +1,75 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Text;
+
+namespace Thrift.Transport
+{
+
+ abstract public class TEndpointTransport : TTransport
+ {
+ protected long MaxMessageSize { get => Configuration.MaxMessageSize; }
+ protected long RemainingMessageSize { get; private set; }
+
+ private readonly TConfiguration _configuration;
+ public override TConfiguration Configuration { get => _configuration; }
+
+ public TEndpointTransport( TConfiguration config)
+ {
+ _configuration = config ?? new TConfiguration();
+ Debug.Assert(Configuration != null);
+
+ ResetConsumedMessageSize();
+ }
+
+ /// <summary>
+ /// Resets RemainingMessageSize to the configured maximum
+ /// </summary>
+ protected void ResetConsumedMessageSize(long knownSize = -1)
+ {
+ if(knownSize >= 0)
+ RemainingMessageSize = Math.Min( MaxMessageSize, knownSize);
+ else
+ RemainingMessageSize = MaxMessageSize;
+ }
+
+ /// <summary>
+ /// Updates RemainingMessageSize to reflect then known real message size (e.g. framed transport).
+ /// Will throw if we already consumed too many bytes.
+ /// </summary>
+ /// <param name="size"></param>
+ public override void UpdateKnownMessageSize(long size)
+ {
+ var consumed = MaxMessageSize - RemainingMessageSize;
+ ResetConsumedMessageSize(size);
+ CountConsumedMessageBytes(consumed);
+ }
+
+ /// <summary>
+ /// Throws if there are not enough bytes in the input stream to satisfy a read of numBytes bytes of data
+ /// </summary>
+ /// <param name="numBytes"></param>
+ protected void CheckReadBytesAvailable(long numBytes)
+ {
+ if (RemainingMessageSize < numBytes)
+ throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "MaxMessageSize reached");
+ }
+
+ /// <summary>
+ /// Consumes numBytes from the RemainingMessageSize.
+ /// </summary>
+ /// <param name="numBytes"></param>
+ protected void CountConsumedMessageBytes(long numBytes)
+ {
+ if (RemainingMessageSize >= numBytes)
+ {
+ RemainingMessageSize -= numBytes;
+ }
+ else
+ {
+ RemainingMessageSize = 0;
+ throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "MaxMessageSize reached");
+ }
+ }
+ }
+}
diff --git a/lib/netstd/Thrift/Transport/TTransport.cs b/lib/netstd/Thrift/Transport/TTransport.cs
index 7998012..8f510dd 100644
--- a/lib/netstd/Thrift/Transport/TTransport.cs
+++ b/lib/netstd/Thrift/Transport/TTransport.cs
@@ -30,7 +30,10 @@
//TODO: think how to avoid peek byte
private readonly byte[] _peekBuffer = new byte[1];
private bool _hasPeekByte;
+
public abstract bool IsOpen { get; }
+ public abstract TConfiguration Configuration { get; }
+ public abstract void UpdateKnownMessageSize(long size);
public void Dispose()
{