THRIFT-1100. py: python TSSLSocket improvements, including certificate validation

This patch adds a number of features to TSSLSocket and TSSLServerSocket.

Patch: Will Pierce

git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1083880 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/transport/TSSLSocket.py
index 8ab91ca..5eff5e6 100644
--- a/lib/py/src/transport/TSSLSocket.py
+++ b/lib/py/src/transport/TSSLSocket.py
@@ -1,38 +1,166 @@
-import sys
-sys.path.append('/usr/lib/python2.6/site-packages/')
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import os
+import socket
+import ssl
 
 from thrift.transport import TSocket
-import socket, ssl
+from thrift.transport.TTransport import TTransportException
 
 class TSSLSocket(TSocket.TSocket):
-    def open(self):
+  """
+  SSL implementation of client-side TSocket
+
+  This class creates outbound sockets wrapped using the
+  python standard ssl module for encrypted connections.
+  
+  The protocol used is set using the class variable
+  SSL_VERSION, which must be one of ssl.PROTOCOL_* and
+  defaults to  ssl.PROTOCOL_TLSv1 for greatest security.
+  """
+  SSL_VERSION = ssl.PROTOCOL_TLSv1
+
+  def __init__(self, validate=True, ca_certs=None, *args, **kwargs):
+    """
+    @param validate: Set to False to disable SSL certificate validation entirely.
+    @type validate: bool
+    @param ca_certs: Filename to the Certificate Authority pem file, possibly a
+    file downloaded from: http://curl.haxx.se/ca/cacert.pem  This is passed to
+    the ssl_wrap function as the 'ca_certs' parameter.
+    @type ca_certs: str
+    
+    Raises an IOError exception if validate is True and the ca_certs file is
+    None, not present or unreadable.
+    """
+    self.validate = validate
+    self.is_valid = False
+    self.peercert = None
+    if not validate:
+      self.cert_reqs = ssl.CERT_NONE
+    else:
+      self.cert_reqs = ssl.CERT_REQUIRED
+    self.ca_certs = ca_certs
+    if validate and ca_certs is not None:
+      if not os.access(ca_certs, os.R_OK):
+        raise IOError('Certificate Authority ca_certs file "%s" is not readable, cannot validate SSL certificates.' % (ca_certs))
+    TSocket.TSocket.__init__(self, *args, **kwargs)
+
+  def open(self):
+    try:
+      res0 = self._resolveAddr()
+      for res in res0:
+        sock_family, sock_type= res[0:2]
+        ip_port = res[4]
+        plain_sock = socket.socket(sock_family, sock_type)
+        self.handle = ssl.wrap_socket(plain_sock, ssl_version=self.SSL_VERSION,
+            do_handshake_on_connect=True, ca_certs=self.ca_certs, cert_reqs=self.cert_reqs) 
+        self.handle.settimeout(self._timeout)
         try:
-          res0 = self._resolveAddr()
-          for res in res0:
-            plain_sock = socket.socket(res[0], res[1])
-            #TODO verify server cert
-            self.handle = ssl.wrap_socket(plain_sock, ssl_version=ssl.PROTOCOL_TLSv1) 
-            self.handle.settimeout(self._timeout)
-            try:
-              self.handle.connect(res[4])
-            except socket.error, e:
-              if res is not res0[-1]:
-                continue
-              else:
-                raise e
-            break
+          self.handle.connect(ip_port)
         except socket.error, e:
-          if self._unix_socket:
-            message = 'Could not connect to secure socket %s' % self._unix_socket
+          if res is not res0[-1]:
+            continue
           else:
-            message = 'Could not connect to %s:%d' % (self.host, self.port)
-          raise TTransportException(type=TTransportException.NOT_OPEN, message=message)
+            raise e
+        break
+    except socket.error, e:
+      if self._unix_socket:
+        message = 'Could not connect to secure socket %s' % self._unix_socket
+      else:
+        message = 'Could not connect to %s:%d' % (self.host, self.port)
+      raise TTransportException(type=TTransportException.NOT_OPEN, message=message)
+    if self.validate:
+      self._validate_cert()
+
+  def _validate_cert(self):
+    """internal method to validate the peer's SSL certificate, and to check the
+    commonName of the certificate to ensure it matches the hostname we
+    used to make this connection.  Does not support subjectAltName records
+    in certificates.
+    
+    raises TTransportException if the certificate fails validation."""
+    cert = self.handle.getpeercert()
+    self.peercert = cert
+    if 'subject' not in cert:
+      raise TTransportException(type=TTransportException.NOT_OPEN,
+                      message='No SSL certificate found from %s:%s' % (self.host, self.port))
+    fields = cert['subject']
+    for field in fields:
+      # ensure structure we get back is what we expect
+      if not isinstance(field, tuple):
+        continue
+      cert_pair = field[0]
+      if len(cert_pair) < 2:
+        continue
+      cert_key, cert_value = cert_pair[0:2]
+      if cert_key != 'commonName':
+        continue
+      certhost = cert_value
+      if certhost == self.host:
+        # success, cert commonName matches desired hostname
+        self.is_valid = True
+        return 
+      else:
+        raise TTransportException(type=TTransportException.UNKNOWN,
+                          message='Host name we connected to "%s" doesn\'t match certificate provided commonName "%s"' % (self.host, certhost))
+    raise TTransportException(type=TTransportException.UNKNOWN,
+                      message='Could not validate SSL certificate from host "%s".  Cert=%s' % (self.host, cert))
 
 class TSSLServerSocket(TSocket.TServerSocket):
-    def accept(self):
-        plain_client, addr = self.handle.accept()
-        result = TSocket.TSocket()
-        #TODO take certfile/keyfile as a parameter at setup
-        client = ssl.wrap_socket(plain_client, certfile='cert.pem', server_side=True)
-        result.setHandle(client)
-        return result
+  """
+  SSL implementation of TServerSocket
+
+  This uses the ssl module's wrap_socket() method to provide SSL
+  negotiated encryption.
+  """
+  SSL_VERSION = ssl.PROTOCOL_TLSv1
+
+  def __init__(self, certfile='cert.pem', *args, **kwargs):
+    """Initialize a TSSLServerSocket
+    
+    @param certfile: The filename of the server certificate file, defaults to cert.pem
+    @type certfile: str
+    @param host: The hostname or IP to bind the listen socket to, i.e. 'localhost' for only allowing
+    local network connections. Pass None to bind to all interfaces.
+    @type host: str
+    @param port: The port to listen on for inbound connections.
+    @type port: int
+    """
+    self.setCertfile(certfile)
+    TSocket.TServerSocket.__init__(self, *args, **kwargs)
+
+  def setCertfile(self, certfile):
+    """Set or change the server certificate file used to wrap new connections.
+    
+    @param certfile: The filename of the server certificate, i.e. '/etc/certs/server.pem'
+    @type certfile: str
+    
+    Raises an IOError exception if the certfile is not present or unreadable.
+    """
+    if not os.access(certfile, os.R_OK):
+      raise IOError('No such certfile found: %s' % (certfile))
+    self.certfile = certfile
+
+  def accept(self):
+    plain_client, addr = self.handle.accept()
+    result = TSocket.TSocket()
+    client = ssl.wrap_socket(plain_client, certfile=self.certfile,
+                    server_side=True, ssl_version=self.SSL_VERSION)
+    result.setHandle(client)
+    return result
diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py
index 085a5ee..be61678 100644
--- a/lib/py/src/transport/TSocket.py
+++ b/lib/py/src/transport/TSocket.py
@@ -57,7 +57,7 @@
     self.handle = h
 
   def isOpen(self):
-    return self.handle != None
+    return self.handle is not None
 
   def setTimeout(self, ms):
     if ms is None:
@@ -65,7 +65,7 @@
     else:
       self._timeout = ms/1000.0
 
-    if (self.handle != None):
+    if self.handle is not None:
       self.handle.settimeout(self._timeout)
 
   def open(self):