THRIFT-3361 Improve C# library
Client: C#
Patch: Nobuaki Sukegawa <nsukeg@gmail.com>
This closes #630
diff --git a/lib/csharp/src/Transport/TBufferedTransport.cs b/lib/csharp/src/Transport/TBufferedTransport.cs
index 89b9ca7..e88800f 100644
--- a/lib/csharp/src/Transport/TBufferedTransport.cs
+++ b/lib/csharp/src/Transport/TBufferedTransport.cs
@@ -22,105 +22,144 @@
namespace Thrift.Transport
{
- public class TBufferedTransport : TTransport, IDisposable
+ public class TBufferedTransport : TTransport, IDisposable
{
- private BufferedStream inputBuffer;
- private BufferedStream outputBuffer;
- private int bufSize;
- private TStreamTransport transport;
+ private readonly int bufSize;
+ private readonly MemoryStream inputBuffer = new MemoryStream(0);
+ private readonly MemoryStream outputBuffer = new MemoryStream(0);
+ private readonly TTransport transport;
- public TBufferedTransport(TStreamTransport transport)
- :this(transport, 1024)
+ public TBufferedTransport(TTransport transport, int bufSize = 1024)
{
-
- }
-
- public TBufferedTransport(TStreamTransport transport, int bufSize)
- {
- this.bufSize = bufSize;
+ if (transport == null)
+ throw new ArgumentNullException("transport");
+ if (bufSize <= 0)
+ throw new ArgumentException("bufSize", "Buffer size must be a positive number.");
this.transport = transport;
- InitBuffers();
- }
-
- private void InitBuffers()
- {
- if (transport.InputStream != null)
- {
- inputBuffer = new BufferedStream(transport.InputStream, bufSize);
- }
- if (transport.OutputStream != null)
- {
- outputBuffer = new BufferedStream(transport.OutputStream, bufSize);
- }
- }
-
- private void CloseBuffers()
- {
- if (inputBuffer != null && inputBuffer.CanRead)
- {
- inputBuffer.Close();
- }
- if (outputBuffer != null && outputBuffer.CanWrite)
- {
- outputBuffer.Close();
- }
+ this.bufSize = bufSize;
}
public TTransport UnderlyingTransport
{
- get { return transport; }
+ get
+ {
+ CheckNotDisposed();
+ return transport;
+ }
}
public override bool IsOpen
{
- get { return transport.IsOpen; }
+ get
+ {
+ // We can legitimately throw here but be nice a bit.
+ // CheckNotDisposed();
+ return !_IsDisposed && transport.IsOpen;
+ }
}
public override void Open()
{
+ CheckNotDisposed();
transport.Open();
- InitBuffers();
}
public override void Close()
{
- CloseBuffers();
+ CheckNotDisposed();
transport.Close();
}
public override int Read(byte[] buf, int off, int len)
{
- return inputBuffer.Read(buf, off, len);
+ CheckNotDisposed();
+ ValidateBufferArgs(buf, off, len);
+ if (!IsOpen)
+ throw new InvalidOperationException("Transport is not open.");
+ if (inputBuffer.Capacity < bufSize)
+ inputBuffer.Capacity = bufSize;
+ int got = inputBuffer.Read(buf, off, len);
+ if (got > 0)
+ return got;
+
+ inputBuffer.Seek(0, SeekOrigin.Begin);
+ inputBuffer.SetLength(inputBuffer.Capacity);
+ int filled = transport.Read(inputBuffer.GetBuffer(), 0, (int)inputBuffer.Length);
+ inputBuffer.SetLength(filled);
+ if (filled == 0)
+ return 0;
+ return Read(buf, off, len);
}
public override void Write(byte[] buf, int off, int len)
{
- outputBuffer.Write(buf, off, len);
+ CheckNotDisposed();
+ ValidateBufferArgs(buf, off, len);
+ if (!IsOpen)
+ throw new InvalidOperationException("Transport is not open.");
+ // Relative offset from "off" argument
+ int offset = 0;
+ if (outputBuffer.Length > 0)
+ {
+ int capa = (int)(outputBuffer.Capacity - outputBuffer.Length);
+ int writeSize = capa <= len ? capa : len;
+ outputBuffer.Write(buf, off, writeSize);
+ offset += writeSize;
+ if (writeSize == capa)
+ {
+ transport.Write(outputBuffer.GetBuffer(), 0, (int)outputBuffer.Length);
+ outputBuffer.SetLength(0);
+ }
+ }
+ while (len - offset >= bufSize)
+ {
+ transport.Write(buf, off + offset, bufSize);
+ offset += bufSize;
+ }
+ int remain = len - offset;
+ if (remain > 0)
+ {
+ if (outputBuffer.Capacity < bufSize)
+ outputBuffer.Capacity = bufSize;
+ outputBuffer.Write(buf, off + offset, remain);
+ }
}
public override void Flush()
{
- outputBuffer.Flush();
+ CheckNotDisposed();
+ if (!IsOpen)
+ throw new InvalidOperationException("Transport is not open.");
+ if (outputBuffer.Length > 0)
+ {
+ transport.Write(outputBuffer.GetBuffer(), 0, (int)outputBuffer.Length);
+ outputBuffer.SetLength(0);
+ }
+ transport.Flush();
}
- #region " IDisposable Support "
- private bool _IsDisposed;
-
- // IDisposable
- protected override void Dispose(bool disposing)
- {
- if (!_IsDisposed)
- {
- if (disposing)
+ private void CheckNotDisposed()
{
- if (inputBuffer != null)
- inputBuffer.Dispose();
- if (outputBuffer != null)
- outputBuffer.Dispose();
+ if (_IsDisposed)
+ throw new ObjectDisposedException("TBufferedTransport");
}
- }
- _IsDisposed = true;
+
+ #region " IDisposable Support "
+ private bool _IsDisposed;
+
+ // IDisposable
+ protected override void Dispose(bool disposing)
+ {
+ if (!_IsDisposed)
+ {
+ if (disposing)
+ {
+ inputBuffer.Dispose();
+ outputBuffer.Dispose();
+ }
+ }
+ _IsDisposed = true;
+ }
+ #endregion
}
- #endregion
- }
}
diff --git a/lib/csharp/src/Transport/TFramedTransport.cs b/lib/csharp/src/Transport/TFramedTransport.cs
index 8af227f..9c6a794 100644
--- a/lib/csharp/src/Transport/TFramedTransport.cs
+++ b/lib/csharp/src/Transport/TFramedTransport.cs
@@ -21,14 +21,14 @@
namespace Thrift.Transport
{
- public class TFramedTransport : TTransport, IDisposable
+ public class TFramedTransport : TTransport, IDisposable
{
- protected TTransport transport = null;
- protected MemoryStream writeBuffer;
- protected MemoryStream readBuffer = null;
+ private readonly TTransport transport;
+ private readonly MemoryStream writeBuffer = new MemoryStream(1024);
+ private readonly MemoryStream readBuffer = new MemoryStream(1024);
- private const int header_size = 4;
- private static byte[] header_dummy = new byte[header_size]; // used as header placeholder while initilizing new write buffer
+ private const int HeaderSize = 4;
+ private readonly byte[] headerBuf = new byte[HeaderSize];
public class Factory : TTransportFactory
{
@@ -38,18 +38,17 @@
}
}
- protected TFramedTransport()
+ public TFramedTransport(TTransport transport)
{
- InitWriteBuffer();
- }
-
- public TFramedTransport(TTransport transport) : this()
- {
+ if (transport == null)
+ throw new ArgumentNullException("transport");
this.transport = transport;
+ InitWriteBuffer();
}
public override void Open()
{
+ CheckNotDisposed();
transport.Open();
}
@@ -57,24 +56,28 @@
{
get
{
- return transport.IsOpen;
+ // We can legitimately throw here but be nice a bit.
+ // CheckNotDisposed();
+ return !_IsDisposed && transport.IsOpen;
}
}
public override void Close()
{
+ CheckNotDisposed();
transport.Close();
}
public override int Read(byte[] buf, int off, int len)
{
- if (readBuffer != null)
+ CheckNotDisposed();
+ ValidateBufferArgs(buf, off, len);
+ if (!IsOpen)
+ throw new InvalidOperationException("Transport is not open.");
+ int got = readBuffer.Read(buf, off, len);
+ if (got > 0)
{
- int got = readBuffer.Read(buf, off, len);
- if (got > 0)
- {
- return got;
- }
+ return got;
}
// Read another frame of data
@@ -85,49 +88,56 @@
private void ReadFrame()
{
- byte[] i32rd = new byte[header_size];
- transport.ReadAll(i32rd, 0, header_size);
- int size = DecodeFrameSize(i32rd);
+ transport.ReadAll(headerBuf, 0, HeaderSize);
+ int size = DecodeFrameSize(headerBuf);
- byte[] buff = new byte[size];
+ readBuffer.SetLength(size);
+ readBuffer.Seek(0, SeekOrigin.Begin);
+ byte[] buff = readBuffer.GetBuffer();
transport.ReadAll(buff, 0, size);
- readBuffer = new MemoryStream(buff);
}
public override void Write(byte[] buf, int off, int len)
{
+ CheckNotDisposed();
+ ValidateBufferArgs(buf, off, len);
+ if (!IsOpen)
+ throw new InvalidOperationException("Transport is not open.");
+ if (writeBuffer.Length + (long)len > (long)int.MaxValue)
+ Flush();
writeBuffer.Write(buf, off, len);
}
public override void Flush()
{
+ CheckNotDisposed();
+ if (!IsOpen)
+ throw new InvalidOperationException("Transport is not open.");
byte[] buf = writeBuffer.GetBuffer();
int len = (int)writeBuffer.Length;
- int data_len = len - header_size;
+ int data_len = len - HeaderSize;
if ( data_len < 0 )
throw new System.InvalidOperationException (); // logic error actually
- InitWriteBuffer();
-
// Inject message header into the reserved buffer space
- EncodeFrameSize(data_len,ref buf);
+ EncodeFrameSize(data_len, buf);
// Send the entire message at once
transport.Write(buf, 0, len);
+ InitWriteBuffer();
+
transport.Flush();
}
private void InitWriteBuffer ()
{
- // Create new buffer instance
- writeBuffer = new MemoryStream(1024);
-
// Reserve space for message header to be put right before sending it out
- writeBuffer.Write ( header_dummy, 0, header_size );
+ writeBuffer.SetLength(HeaderSize);
+ writeBuffer.Seek(0, SeekOrigin.End);
}
- private static void EncodeFrameSize(int frameSize, ref byte[] buf)
+ private static void EncodeFrameSize(int frameSize, byte[] buf)
{
buf[0] = (byte)(0xff & (frameSize >> 24));
buf[1] = (byte)(0xff & (frameSize >> 16));
@@ -145,6 +155,12 @@
}
+ private void CheckNotDisposed()
+ {
+ if (_IsDisposed)
+ throw new ObjectDisposedException("TFramedTransport");
+ }
+
#region " IDisposable Support "
private bool _IsDisposed;
@@ -155,8 +171,8 @@
{
if (disposing)
{
- if (readBuffer != null)
- readBuffer.Dispose();
+ readBuffer.Dispose();
+ writeBuffer.Dispose();
}
}
_IsDisposed = true;
diff --git a/lib/csharp/src/Transport/TTLSSocket.cs b/lib/csharp/src/Transport/TTLSSocket.cs
index 5652556..833b792 100644
--- a/lib/csharp/src/Transport/TTLSSocket.cs
+++ b/lib/csharp/src/Transport/TTLSSocket.cs
@@ -33,43 +33,43 @@
/// <summary>
/// Internal TCP Client
/// </summary>
- private TcpClient client = null;
+ private TcpClient client;
/// <summary>
/// The host
/// </summary>
- private string host = null;
+ private string host;
/// <summary>
/// The port
/// </summary>
- private int port = 0;
+ private int port;
/// <summary>
/// The timeout for the connection
/// </summary>
- private int timeout = 0;
+ private int timeout;
/// <summary>
/// Internal SSL Stream for IO
/// </summary>
- private SslStream secureStream = null;
+ private SslStream secureStream;
/// <summary>
/// Defines wheter or not this socket is a server socket<br/>
/// This is used for the TLS-authentication
/// </summary>
- private bool isServer = false;
+ private bool isServer;
/// <summary>
/// The certificate
/// </summary>
- private X509Certificate certificate = null;
+ private X509Certificate certificate;
/// <summary>
/// User defined certificate validator.
/// </summary>
- private RemoteCertificateValidationCallback certValidator = null;
+ private RemoteCertificateValidationCallback certValidator;
/// <summary>
/// The function to determine which certificate to use.
@@ -96,6 +96,10 @@
this.certValidator = certValidator;
this.localCertificateSelectionCallback = localCertificateSelectionCallback;
this.isServer = isServer;
+ if (isServer && certificate == null)
+ {
+ throw new ArgumentException("TTLSSocket needs certificate to be used for server", "certificate");
+ }
if (IsOpen)
{
@@ -133,7 +137,7 @@
public TTLSSocket(
string host,
int port,
- X509Certificate certificate,
+ X509Certificate certificate = null,
RemoteCertificateValidationCallback certValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null)
: this(host, port, 0, certificate, certValidator, localCertificateSelectionCallback)
@@ -315,7 +319,8 @@
else
{
// Client authentication
- this.secureStream.AuthenticateAsClient(host, new X509CertificateCollection { certificate }, SslProtocols.Tls, true);
+ X509CertificateCollection certs = certificate != null ? new X509CertificateCollection { certificate } : new X509CertificateCollection();
+ this.secureStream.AuthenticateAsClient(host, certs, SslProtocols.Tls, true);
}
}
catch (Exception)
diff --git a/lib/csharp/src/Transport/TTransport.cs b/lib/csharp/src/Transport/TTransport.cs
index 2811399..a3639d2 100644
--- a/lib/csharp/src/Transport/TTransport.cs
+++ b/lib/csharp/src/Transport/TTransport.cs
@@ -34,7 +34,7 @@
}
private byte[] _peekBuffer = new byte[1];
- private bool _hasPeekByte = false;
+ private bool _hasPeekByte;
public bool Peek()
{
@@ -66,10 +66,23 @@
public abstract void Close();
+ protected static void ValidateBufferArgs(byte[] buf, int off, int len)
+ {
+ if (buf == null)
+ throw new ArgumentNullException("buf");
+ if (off < 0)
+ throw new ArgumentOutOfRangeException("Buffer offset is smaller than zero.");
+ if (len < 0)
+ throw new ArgumentOutOfRangeException("Buffer length is smaller than zero.");
+ if (off + len > buf.Length)
+ throw new ArgumentOutOfRangeException("Not enough data.");
+ }
+
public abstract int Read(byte[] buf, int off, int len);
public int ReadAll(byte[] buf, int off, int len)
{
+ ValidateBufferArgs(buf, off, len);
int got = 0;
//If we previously peeked a byte, we need to use that first.