THRIFT-1719:SASL client support for Python
Client: py
Patch: Tyler Hobbs
Add SASL client transports that will work with the Java lib's TSaslTransport
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index 4481371..5ab2fde 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -328,3 +328,114 @@
def flush(self):
self.fileobj.flush()
+
+
+class TSaslClientTransport(TTransportBase, CReadableTransport):
+ """
+ SASL transport
+ """
+
+ START = 1
+ OK = 2
+ BAD = 3
+ ERROR = 4
+ COMPLETE = 5
+
+ def __init__(self, transport, host, service, mechanism='GSSAPI',
+ **sasl_kwargs):
+ """
+ transport: an underlying transport to use, typically just a TSocket
+ host: the name of the server, from a SASL perspective
+ service: the name of the server's service, from a SASL perspective
+ mechanism: the name of the preferred mechanism to use
+
+ All other kwargs will be passed to the puresasl.client.SASLClient
+ constructor.
+ """
+
+ from puresasl.client import SASLClient
+
+ self.transport = transport
+ self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
+
+ self.__wbuf = StringIO()
+ self.__rbuf = StringIO()
+
+ def open(self):
+ if not self.transport.isOpen():
+ self.transport.open()
+
+ self.send_sasl_msg(self.START, self.sasl.mechanism)
+ self.send_sasl_msg(self.OK, self.sasl.process())
+
+ while True:
+ status, challenge = self.recv_sasl_msg()
+ if status == self.OK:
+ self.send_sasl_msg(self.OK, self.sasl.process(challenge))
+ elif status == self.COMPLETE:
+ if not self.sasl.complete:
+ raise TTransportException("The server erroneously indicated "
+ "that SASL negotiation was complete")
+ else:
+ break
+ else:
+ raise TTransportException("Bad SASL negotiation status: %d (%s)"
+ % (status, challenge))
+
+ def send_sasl_msg(self, status, body):
+ header = pack(">BI", status, len(body))
+ self.transport.write(header + body)
+ self.transport.flush()
+
+ def recv_sasl_msg(self):
+ header = self.transport.readAll(5)
+ status, length = unpack(">BI", header)
+ if length > 0:
+ payload = self.transport.readAll(length)
+ else:
+ payload = ""
+ return status, payload
+
+ def write(self, data):
+ self.__wbuf.write(data)
+
+ def flush(self):
+ data = self.__wbuf.getvalue()
+ encoded = self.sasl.wrap(data)
+ self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
+ self.transport.flush()
+ self.__wbuf = StringIO()
+
+ def read(self, sz):
+ ret = self.__rbuf.read(sz)
+ if len(ret) != 0:
+ return ret
+
+ self._read_frame()
+ return self.__rbuf.read(sz)
+
+ def _read_frame(self):
+ header = self.transport.readAll(4)
+ length, = unpack('!i', header)
+ encoded = self.transport.readAll(length)
+ self.__rbuf = StringIO(self.sasl.unwrap(encoded))
+
+ def close(self):
+ self.sasl.dispose()
+ self.transport.close()
+
+ # based on TFramedTransport
+ @property
+ def cstringio_buf(self):
+ return self.__rbuf
+
+ def cstringio_refill(self, prefix, reqlen):
+ # self.__rbuf will already be empty here because fastbinary doesn't
+ # ask for a refill until the previous buffer is empty. Therefore,
+ # we can start reading new frames immediately.
+ while len(prefix) < reqlen:
+ self._read_frame()
+ prefix += self.__rbuf.getvalue()
+ self.__rbuf = StringIO(prefix)
+ return self.__rbuf
+
diff --git a/lib/py/src/transport/TTwisted.py b/lib/py/src/transport/TTwisted.py
index 3ce3eb2..2b77414 100644
--- a/lib/py/src/transport/TTwisted.py
+++ b/lib/py/src/transport/TTwisted.py
@@ -17,14 +17,15 @@
# under the License.
#
+import struct
from cStringIO import StringIO
from zope.interface import implements, Interface, Attribute
-from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
+from twisted.internet.protocol import ServerFactory, ClientFactory, \
connectionDone
from twisted.internet import defer
+from twisted.internet.threads import deferToThread
from twisted.protocols import basic
-from twisted.python import log
from twisted.web import server, resource, http
from thrift.transport import TTransport
@@ -101,6 +102,108 @@
method(iprot, mtype, rseqid)
+class ThriftSASLClientProtocol(ThriftClientProtocol):
+
+ START = 1
+ OK = 2
+ BAD = 3
+ ERROR = 4
+ COMPLETE = 5
+
+ MAX_LENGTH = 2 ** 31 - 1
+
+ def __init__(self, client_class, iprot_factory, oprot_factory=None,
+ host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
+ """
+ host: the name of the server, from a SASL perspective
+ service: the name of the server's service, from a SASL perspective
+ mechanism: the name of the preferred mechanism to use
+
+ All other kwargs will be passed to the puresasl.client.SASLClient
+ constructor.
+ """
+
+ from puresasl.client import SASLClient
+ self.SASLCLient = SASLClient
+
+ ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
+
+ self._sasl_negotiation_deferred = None
+ self._sasl_negotiation_status = None
+ self.client = None
+
+ if host is not None:
+ self.createSASLClient(host, service, mechanism, **sasl_kwargs)
+
+ def createSASLClient(self, host, service, mechanism, **kwargs):
+ self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
+
+ def dispatch(self, msg):
+ encoded = self.sasl.wrap(msg)
+ len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
+ ThriftClientProtocol.dispatch(self, len_and_encoded)
+
+ @defer.inlineCallbacks
+ def connectionMade(self):
+ self._sendSASLMessage(self.START, self.sasl.mechanism)
+ initial_message = yield deferToThread(self.sasl.process)
+ self._sendSASLMessage(self.OK, initial_message)
+
+ while True:
+ status, challenge = yield self._receiveSASLMessage()
+ if status == self.OK:
+ response = yield deferToThread(self.sasl.process, challenge)
+ self._sendSASLMessage(self.OK, response)
+ elif status == self.COMPLETE:
+ if not self.sasl.complete:
+ msg = "The server erroneously indicated that SASL " \
+ "negotiation was complete"
+ raise TTransport.TTransportException(msg, message=msg)
+ else:
+ break
+ else:
+ msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
+ raise TTransport.TTransportException(msg, message=msg)
+
+ self._sasl_negotiation_deferred = None
+ ThriftClientProtocol.connectionMade(self)
+
+ def _sendSASLMessage(self, status, body):
+ if body is None:
+ body = ""
+ header = struct.pack(">BI", status, len(body))
+ self.transport.write(header + body)
+
+ def _receiveSASLMessage(self):
+ self._sasl_negotiation_deferred = defer.Deferred()
+ self._sasl_negotiation_status = None
+ return self._sasl_negotiation_deferred
+
+ def connectionLost(self, reason=connectionDone):
+ if self.client:
+ ThriftClientProtocol.connectionLost(self, reason)
+
+ def dataReceived(self, data):
+ if self._sasl_negotiation_deferred:
+ # we got a sasl challenge in the format (status, length, challenge)
+ # save the status, let IntNStringReceiver piece the challenge data together
+ self._sasl_negotiation_status, = struct.unpack("B", data[0])
+ ThriftClientProtocol.dataReceived(self, data[1:])
+ else:
+ # normal frame, let IntNStringReceiver piece it together
+ ThriftClientProtocol.dataReceived(self, data)
+
+ def stringReceived(self, frame):
+ if self._sasl_negotiation_deferred:
+ # the frame is just a SASL challenge
+ response = (self._sasl_negotiation_status, frame)
+ self._sasl_negotiation_deferred.callback(response)
+ else:
+ # there's a second 4 byte length prefix inside the frame
+ decoded_frame = self.sasl.unwrap(frame[4:])
+ ThriftClientProtocol.stringReceived(self, decoded_frame)
+
+
class ThriftServerProtocol(basic.Int32StringReceiver):
MAX_LENGTH = 2 ** 31 - 1