THRIFT-1719:SASL client support for Python
Client: py
Patch: Tyler Hobbs
Add SASL client transports that will work with the Java lib's TSaslTransport
diff --git a/lib/py/src/transport/TTwisted.py b/lib/py/src/transport/TTwisted.py
index 3ce3eb2..2b77414 100644
--- a/lib/py/src/transport/TTwisted.py
+++ b/lib/py/src/transport/TTwisted.py
@@ -17,14 +17,15 @@
# under the License.
#
+import struct
from cStringIO import StringIO
from zope.interface import implements, Interface, Attribute
-from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
+from twisted.internet.protocol import ServerFactory, ClientFactory, \
connectionDone
from twisted.internet import defer
+from twisted.internet.threads import deferToThread
from twisted.protocols import basic
-from twisted.python import log
from twisted.web import server, resource, http
from thrift.transport import TTransport
@@ -101,6 +102,108 @@
method(iprot, mtype, rseqid)
+class ThriftSASLClientProtocol(ThriftClientProtocol):
+
+ START = 1
+ OK = 2
+ BAD = 3
+ ERROR = 4
+ COMPLETE = 5
+
+ MAX_LENGTH = 2 ** 31 - 1
+
+ def __init__(self, client_class, iprot_factory, oprot_factory=None,
+ host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
+ """
+ host: the name of the server, from a SASL perspective
+ service: the name of the server's service, from a SASL perspective
+ mechanism: the name of the preferred mechanism to use
+
+ All other kwargs will be passed to the puresasl.client.SASLClient
+ constructor.
+ """
+
+ from puresasl.client import SASLClient
+ self.SASLCLient = SASLClient
+
+ ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
+
+ self._sasl_negotiation_deferred = None
+ self._sasl_negotiation_status = None
+ self.client = None
+
+ if host is not None:
+ self.createSASLClient(host, service, mechanism, **sasl_kwargs)
+
+ def createSASLClient(self, host, service, mechanism, **kwargs):
+ self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
+
+ def dispatch(self, msg):
+ encoded = self.sasl.wrap(msg)
+ len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
+ ThriftClientProtocol.dispatch(self, len_and_encoded)
+
+ @defer.inlineCallbacks
+ def connectionMade(self):
+ self._sendSASLMessage(self.START, self.sasl.mechanism)
+ initial_message = yield deferToThread(self.sasl.process)
+ self._sendSASLMessage(self.OK, initial_message)
+
+ while True:
+ status, challenge = yield self._receiveSASLMessage()
+ if status == self.OK:
+ response = yield deferToThread(self.sasl.process, challenge)
+ self._sendSASLMessage(self.OK, response)
+ elif status == self.COMPLETE:
+ if not self.sasl.complete:
+ msg = "The server erroneously indicated that SASL " \
+ "negotiation was complete"
+ raise TTransport.TTransportException(msg, message=msg)
+ else:
+ break
+ else:
+ msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
+ raise TTransport.TTransportException(msg, message=msg)
+
+ self._sasl_negotiation_deferred = None
+ ThriftClientProtocol.connectionMade(self)
+
+ def _sendSASLMessage(self, status, body):
+ if body is None:
+ body = ""
+ header = struct.pack(">BI", status, len(body))
+ self.transport.write(header + body)
+
+ def _receiveSASLMessage(self):
+ self._sasl_negotiation_deferred = defer.Deferred()
+ self._sasl_negotiation_status = None
+ return self._sasl_negotiation_deferred
+
+ def connectionLost(self, reason=connectionDone):
+ if self.client:
+ ThriftClientProtocol.connectionLost(self, reason)
+
+ def dataReceived(self, data):
+ if self._sasl_negotiation_deferred:
+ # we got a sasl challenge in the format (status, length, challenge)
+ # save the status, let IntNStringReceiver piece the challenge data together
+ self._sasl_negotiation_status, = struct.unpack("B", data[0])
+ ThriftClientProtocol.dataReceived(self, data[1:])
+ else:
+ # normal frame, let IntNStringReceiver piece it together
+ ThriftClientProtocol.dataReceived(self, data)
+
+ def stringReceived(self, frame):
+ if self._sasl_negotiation_deferred:
+ # the frame is just a SASL challenge
+ response = (self._sasl_negotiation_status, frame)
+ self._sasl_negotiation_deferred.callback(response)
+ else:
+ # there's a second 4 byte length prefix inside the frame
+ decoded_frame = self.sasl.unwrap(frame[4:])
+ ThriftClientProtocol.stringReceived(self, decoded_frame)
+
+
class ThriftServerProtocol(basic.Int32StringReceiver):
MAX_LENGTH = 2 ** 31 - 1