THRIFT-1719:SASL client support for Python
Client: py
Patch: Tyler Hobbs

Add SASL client transports that will work with the Java lib's TSaslTransport
diff --git a/lib/py/src/transport/TTwisted.py b/lib/py/src/transport/TTwisted.py
index 3ce3eb2..2b77414 100644
--- a/lib/py/src/transport/TTwisted.py
+++ b/lib/py/src/transport/TTwisted.py
@@ -17,14 +17,15 @@
 # under the License.
 #
 
+import struct
 from cStringIO import StringIO
 
 from zope.interface import implements, Interface, Attribute
-from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
+from twisted.internet.protocol import ServerFactory, ClientFactory, \
     connectionDone
 from twisted.internet import defer
+from twisted.internet.threads import deferToThread
 from twisted.protocols import basic
-from twisted.python import log
 from twisted.web import server, resource, http
 
 from thrift.transport import TTransport
@@ -101,6 +102,108 @@
         method(iprot, mtype, rseqid)
 
 
+class ThriftSASLClientProtocol(ThriftClientProtocol):
+
+    START = 1
+    OK = 2
+    BAD = 3
+    ERROR = 4
+    COMPLETE = 5
+
+    MAX_LENGTH = 2 ** 31 - 1
+
+    def __init__(self, client_class, iprot_factory, oprot_factory=None,
+            host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
+        """
+        host: the name of the server, from a SASL perspective
+        service: the name of the server's service, from a SASL perspective
+        mechanism: the name of the preferred mechanism to use
+
+        All other kwargs will be passed to the puresasl.client.SASLClient
+        constructor.
+        """
+
+        from puresasl.client import SASLClient
+        self.SASLCLient = SASLClient
+
+        ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
+
+        self._sasl_negotiation_deferred = None
+        self._sasl_negotiation_status = None
+        self.client = None
+
+        if host is not None:
+            self.createSASLClient(host, service, mechanism, **sasl_kwargs)
+
+    def createSASLClient(self, host, service, mechanism, **kwargs):
+        self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
+
+    def dispatch(self, msg):
+        encoded = self.sasl.wrap(msg)
+        len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
+        ThriftClientProtocol.dispatch(self, len_and_encoded)
+
+    @defer.inlineCallbacks
+    def connectionMade(self):
+        self._sendSASLMessage(self.START, self.sasl.mechanism)
+        initial_message = yield deferToThread(self.sasl.process)
+        self._sendSASLMessage(self.OK, initial_message)
+
+        while True:
+            status, challenge = yield self._receiveSASLMessage()
+            if status == self.OK:
+                response = yield deferToThread(self.sasl.process, challenge)
+                self._sendSASLMessage(self.OK, response)
+            elif status == self.COMPLETE:
+                if not self.sasl.complete:
+                    msg = "The server erroneously indicated that SASL " \
+                          "negotiation was complete"
+                    raise TTransport.TTransportException(msg, message=msg)
+                else:
+                    break
+            else:
+                msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
+                raise TTransport.TTransportException(msg, message=msg)
+
+        self._sasl_negotiation_deferred = None
+        ThriftClientProtocol.connectionMade(self)
+
+    def _sendSASLMessage(self, status, body):
+        if body is None:
+            body = ""
+        header = struct.pack(">BI", status, len(body))
+        self.transport.write(header + body)
+
+    def _receiveSASLMessage(self):
+        self._sasl_negotiation_deferred = defer.Deferred()
+        self._sasl_negotiation_status = None
+        return self._sasl_negotiation_deferred
+
+    def connectionLost(self, reason=connectionDone):
+        if self.client:
+            ThriftClientProtocol.connectionLost(self, reason)
+
+    def dataReceived(self, data):
+        if self._sasl_negotiation_deferred:
+            # we got a sasl challenge in the format (status, length, challenge)
+            # save the status, let IntNStringReceiver piece the challenge data together
+            self._sasl_negotiation_status, = struct.unpack("B", data[0])
+            ThriftClientProtocol.dataReceived(self, data[1:])
+        else:
+            # normal frame, let IntNStringReceiver piece it together
+            ThriftClientProtocol.dataReceived(self, data)
+
+    def stringReceived(self, frame):
+        if self._sasl_negotiation_deferred:
+            # the frame is just a SASL challenge
+            response = (self._sasl_negotiation_status, frame)
+            self._sasl_negotiation_deferred.callback(response)
+        else:
+            # there's a second 4 byte length prefix inside the frame
+            decoded_frame = self.sasl.unwrap(frame[4:])
+            ThriftClientProtocol.stringReceived(self, decoded_frame)
+
+
 class ThriftServerProtocol(basic.Int32StringReceiver):
 
     MAX_LENGTH = 2 ** 31 - 1