THRIFT-3070 Add ability to set the LocalCertificateSelectionCallback
Client: C#
Patch: Hans-Peter Klett <hansk@spectralogic.com>
This closes #415
Added an optional LocalCertificateSelectionCallback. Also cleans up the connection when a secure authentication fails on the server.
diff --git a/lib/csharp/src/Transport/TTLSServerSocket.cs b/lib/csharp/src/Transport/TTLSServerSocket.cs
index 2e2d299..631a593 100644
--- a/lib/csharp/src/Transport/TTLSServerSocket.cs
+++ b/lib/csharp/src/Transport/TTLSServerSocket.cs
@@ -60,6 +60,11 @@
private RemoteCertificateValidationCallback clientCertValidator;
/// <summary>
+ /// The function to determine which certificate to use.
+ /// </summary>
+ private LocalCertificateSelectionCallback localCertificateSelectionCallback;
+
+ /// <summary>
/// Initializes a new instance of the <see cref="TTLSServerSocket" /> class.
/// </summary>
/// <param name="port">The port where the server runs.</param>
@@ -88,7 +93,14 @@
/// <param name="useBufferedSockets">If set to <c>true</c> [use buffered sockets].</param>
/// <param name="certificate">The certificate object.</param>
/// <param name="clientCertValidator">The certificate validator.</param>
- public TTLSServerSocket(int port, int clientTimeout, bool useBufferedSockets, X509Certificate2 certificate, RemoteCertificateValidationCallback clientCertValidator = null)
+ /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param>
+ public TTLSServerSocket(
+ int port,
+ int clientTimeout,
+ bool useBufferedSockets,
+ X509Certificate2 certificate,
+ RemoteCertificateValidationCallback clientCertValidator = null,
+ LocalCertificateSelectionCallback localCertificateSelectionCallback = null)
{
if (!certificate.HasPrivateKey)
{
@@ -99,6 +111,7 @@
this.serverCertificate = certificate;
this.useBufferedSockets = useBufferedSockets;
this.clientCertValidator = clientCertValidator;
+ this.localCertificateSelectionCallback = localCertificateSelectionCallback;
try
{
// Create server socket
@@ -150,7 +163,13 @@
client.SendTimeout = client.ReceiveTimeout = this.clientTimeout;
//wrap the client in an SSL Socket passing in the SSL cert
- TTLSSocket socket = new TTLSSocket(client, this.serverCertificate, true, this.clientCertValidator);
+ TTLSSocket socket = new TTLSSocket(
+ client,
+ this.serverCertificate,
+ true,
+ this.clientCertValidator,
+ this.localCertificateSelectionCallback
+ );
socket.setupTLS();
diff --git a/lib/csharp/src/Transport/TTLSSocket.cs b/lib/csharp/src/Transport/TTLSSocket.cs
index ca8ee41..5652556 100644
--- a/lib/csharp/src/Transport/TTLSSocket.cs
+++ b/lib/csharp/src/Transport/TTLSSocket.cs
@@ -72,17 +72,29 @@
private RemoteCertificateValidationCallback certValidator = null;
/// <summary>
+ /// The function to determine which certificate to use.
+ /// </summary>
+ private LocalCertificateSelectionCallback localCertificateSelectionCallback;
+
+ /// <summary>
/// Initializes a new instance of the <see cref="TTLSSocket"/> class.
/// </summary>
/// <param name="client">An already created TCP-client</param>
/// <param name="certificate">The certificate.</param>
/// <param name="isServer">if set to <c>true</c> [is server].</param>
/// <param name="certValidator">User defined cert validator.</param>
- public TTLSSocket(TcpClient client, X509Certificate certificate, bool isServer = false, RemoteCertificateValidationCallback certValidator = null)
+ /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param>
+ public TTLSSocket(
+ TcpClient client,
+ X509Certificate certificate,
+ bool isServer = false,
+ RemoteCertificateValidationCallback certValidator = null,
+ LocalCertificateSelectionCallback localCertificateSelectionCallback = null)
{
this.client = client;
this.certificate = certificate;
this.certValidator = certValidator;
+ this.localCertificateSelectionCallback = localCertificateSelectionCallback;
this.isServer = isServer;
if (IsOpen)
@@ -99,8 +111,14 @@
/// <param name="port">The port.</param>
/// <param name="certificatePath">The certificate path.</param>
/// <param name="certValidator">User defined cert validator.</param>
- public TTLSSocket(string host, int port, string certificatePath, RemoteCertificateValidationCallback certValidator = null)
- : this(host, port, 0, X509Certificate.CreateFromCertFile(certificatePath), certValidator)
+ /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param>
+ public TTLSSocket(
+ string host,
+ int port,
+ string certificatePath,
+ RemoteCertificateValidationCallback certValidator = null,
+ LocalCertificateSelectionCallback localCertificateSelectionCallback = null)
+ : this(host, port, 0, X509Certificate.CreateFromCertFile(certificatePath), certValidator, localCertificateSelectionCallback)
{
}
@@ -111,8 +129,14 @@
/// <param name="port">The port.</param>
/// <param name="certificate">The certificate.</param>
/// <param name="certValidator">User defined cert validator.</param>
- public TTLSSocket(string host, int port, X509Certificate certificate, RemoteCertificateValidationCallback certValidator = null)
- : this(host, port, 0, certificate, certValidator)
+ /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param>
+ public TTLSSocket(
+ string host,
+ int port,
+ X509Certificate certificate,
+ RemoteCertificateValidationCallback certValidator = null,
+ LocalCertificateSelectionCallback localCertificateSelectionCallback = null)
+ : this(host, port, 0, certificate, certValidator, localCertificateSelectionCallback)
{
}
@@ -124,13 +148,21 @@
/// <param name="timeout">The timeout.</param>
/// <param name="certificate">The certificate.</param>
/// <param name="certValidator">User defined cert validator.</param>
- public TTLSSocket(string host, int port, int timeout, X509Certificate certificate, RemoteCertificateValidationCallback certValidator = null)
+ /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param>
+ public TTLSSocket(
+ string host,
+ int port,
+ int timeout,
+ X509Certificate certificate,
+ RemoteCertificateValidationCallback certValidator = null,
+ LocalCertificateSelectionCallback localCertificateSelectionCallback = null)
{
this.host = host;
this.port = port;
this.timeout = timeout;
this.certificate = certificate;
this.certValidator = certValidator;
+ this.localCertificateSelectionCallback = localCertificateSelectionCallback;
InitSocket();
}
@@ -213,7 +245,7 @@
/// <param name="chain">The certificate chain.</param>
/// <param name="sslPolicyErrors">An enum, which lists all the errors from the .NET certificate check.</param>
/// <returns></returns>
- private bool CertificateValidator(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslValidationErrors)
+ private bool DefaultCertificateValidator(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslValidationErrors)
{
return (sslValidationErrors == SslPolicyErrors.None);
}
@@ -253,16 +285,43 @@
/// </summary>
public void setupTLS()
{
- this.secureStream = new SslStream(this.client.GetStream(), false, this.certValidator ?? CertificateValidator);
- if (isServer)
+ RemoteCertificateValidationCallback validator = this.certValidator ?? DefaultCertificateValidator;
+
+ if( this.localCertificateSelectionCallback != null)
{
- // Server authentication
- this.secureStream.AuthenticateAsServer(this.certificate, this.certValidator != null, SslProtocols.Tls, true);
+ this.secureStream = new SslStream(
+ this.client.GetStream(),
+ false,
+ validator,
+ this.localCertificateSelectionCallback
+ );
}
else
{
- // Client authentication
- this.secureStream.AuthenticateAsClient(host, new X509CertificateCollection { certificate }, SslProtocols.Tls, true);
+ this.secureStream = new SslStream(
+ this.client.GetStream(),
+ false,
+ validator
+ );
+ }
+
+ try
+ {
+ if (isServer)
+ {
+ // Server authentication
+ this.secureStream.AuthenticateAsServer(this.certificate, this.certValidator != null, SslProtocols.Tls, true);
+ }
+ else
+ {
+ // Client authentication
+ this.secureStream.AuthenticateAsClient(host, new X509CertificateCollection { certificate }, SslProtocols.Tls, true);
+ }
+ }
+ catch (Exception)
+ {
+ this.Close();
+ throw;
}
inputStream = this.secureStream;