THRIFT-67. python: Add TNonblockingServer
This TNonblockingServer is very similar to the C++ implementation.
It assumes the framed transport, but it uses select instead of libevent.
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@712306 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/server/TNonblockingServer.py b/lib/py/src/server/TNonblockingServer.py
new file mode 100644
index 0000000..a588fe3
--- /dev/null
+++ b/lib/py/src/server/TNonblockingServer.py
@@ -0,0 +1,291 @@
+"""Implementation of non-blocking server.
+
+The main idea of the server is reciving and sending requests
+only from main thread.
+
+It also makes thread pool server in tasks terms, not connections.
+"""
+import threading
+import socket
+import Queue
+import select
+import struct
+import logging
+
+from thrift.transport import TTransport
+from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
+
+__all__ = ['TNonblockingServer']
+
+class Worker(threading.Thread):
+ """Worker is a small helper to process incoming connection."""
+ def __init__(self, queue):
+ threading.Thread.__init__(self)
+ self.queue = queue
+
+ def run(self):
+ """Process queries from task queue, stop if processor is None."""
+ while True:
+ try:
+ processor, iprot, oprot, otrans, callback = self.queue.get()
+ if processor is None:
+ break
+ processor.process(iprot, oprot)
+ callback(True, otrans.getvalue())
+ except Exception:
+ logging.exception("Exception while processing request")
+ callback(False, '')
+
+WAIT_LEN = 0
+WAIT_MESSAGE = 1
+WAIT_PROCESS = 2
+SEND_ANSWER = 3
+CLOSED = 4
+
+def locked(func):
+ "Decorator which locks self.lock."
+ def nested(self, *args, **kwargs):
+ self.lock.acquire()
+ try:
+ return func(self, *args, **kwargs)
+ finally:
+ self.lock.release()
+ return nested
+
+def socket_exception(func):
+ "Decorator close object on socket.error."
+ def read(self, *args, **kwargs):
+ try:
+ return func(self, *args, **kwargs)
+ except socket.error:
+ self.close()
+ return read
+
+class Connection:
+ """Basic class is represented connection.
+
+ It can be in state:
+ WAIT_LEN --- connection is reading request len.
+ WAIT_MESSAGE --- connection is reading request.
+ WAIT_PROCESS --- connection has just read whole request and
+ waits for call ready routine.
+ SEND_ANSWER --- connection is sending answer string (including length
+ of answer).
+ CLOSED --- socket was closed and connection should be deleted.
+ """
+ def __init__(self, new_socket, wake_up):
+ self.socket = new_socket
+ self.socket.setblocking(False)
+ self.status = WAIT_LEN
+ self.len = 0
+ self.message = ''
+ self.lock = threading.Lock()
+ self.wake_up = wake_up
+
+ def _read_len(self):
+ """Reads length of request.
+
+ It's really paranoic routine and it may be replaced by
+ self.socket.recv(4)."""
+ read = self.socket.recv(4 - len(self.message))
+ if len(read) == 0:
+ # if we read 0 bytes and self.message is empty, it means client close
+ # connection
+ if len(self.message) != 0:
+ logging.error("can't read frame size from socket")
+ self.close()
+ return
+ self.message += read
+ if len(self.message) == 4:
+ self.len, = struct.unpack('!i', self.message)
+ if self.len < 0:
+ logging.error("negative frame size, it seems client"\
+ " doesn't use FramedTransport")
+ self.close()
+ elif self.len == 0:
+ logging.error("empty frame, it's really strange")
+ self.close()
+ else:
+ self.message = ''
+ self.status = WAIT_MESSAGE
+
+ @socket_exception
+ def read(self):
+ """Reads data from stream and switch state."""
+ assert self.status in (WAIT_LEN, WAIT_MESSAGE)
+ if self.status == WAIT_LEN:
+ self._read_len()
+ # go back to the main loop here for simplicity instead of
+ # falling through, even though there is a good chance that
+ # the message is already available
+ elif self.status == WAIT_MESSAGE:
+ read = self.socket.recv(self.len - len(self.message))
+ if len(read) == 0:
+ logging.error("can't read frame from socket (get %d of %d bytes)" %
+ (len(self.message), self.len))
+ self.close()
+ return
+ self.message += read
+ if len(self.message) == self.len:
+ self.status = WAIT_PROCESS
+
+ @socket_exception
+ def write(self):
+ """Writes data from socket and switch state."""
+ assert self.status == SEND_ANSWER
+ sent = self.socket.send(self.message)
+ if sent == len(self.message):
+ self.status = WAIT_LEN
+ self.message = ''
+ self.len = 0
+ else:
+ self.message = self.message[sent:]
+
+ @locked
+ def ready(self, all_ok, message):
+ """Callback function for switching state and waking up main thread.
+
+ This function is the only function witch can be called asynchronous.
+
+ The ready can switch Connection to three states:
+ WAIT_LEN if request was async.
+ SEND_ANSWER if request was processed in normal way.
+ CLOSED if request throws unexpected exception.
+
+ The one wakes up main thread.
+ """
+ assert self.status == WAIT_PROCESS
+ if not all_ok:
+ self.close()
+ self.wake_up()
+ return
+ self.len = ''
+ self.message = struct.pack('!i', len(message)) + message
+ if len(message) == 0:
+ # it was async request, do not write answer
+ self.status = WAIT_LEN
+ else:
+ self.status = SEND_ANSWER
+ self.wake_up()
+
+ @locked
+ def is_writeable(self):
+ "Returns True if connection should be added to write list of select."
+ return self.status == SEND_ANSWER
+
+ # it's not necessary, but...
+ @locked
+ def is_readable(self):
+ "Returns True if connection should be added to read list of select."
+ return self.status in (WAIT_LEN, WAIT_MESSAGE)
+
+ @locked
+ def is_closed(self):
+ "Returns True if connection is closed."
+ return self.status == CLOSED
+
+ def fileno(self):
+ "Returns the file descriptor of the associated socket."
+ return self.socket.fileno()
+
+ def close(self):
+ "Closes connection"
+ self.status = CLOSED
+ self.socket.close()
+
+class TNonblockingServer:
+ """Non-blocking server."""
+ def __init__(self, processor, lsocket, inputProtocolFactory=None,
+ outputProtocolFactory=None, threads=10):
+ self.processor = processor
+ self.socket = lsocket
+ self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory()
+ self.out_protocol = outputProtocolFactory or self.in_protocol
+ self.threads = int(threads)
+ self.clients = {}
+ self.tasks = Queue.Queue()
+ self._read, self._write = socket.socketpair()
+ self.prepared = False
+
+ def setNumThreads(self, num):
+ """Set the number of worker threads that should be created."""
+ # implement ThreadPool interface
+ assert not self.prepared, "You can't change number of threads for working server"
+ self.threads = num
+
+ def prepare(self):
+ """Prepares server for serve requests."""
+ self.socket.listen()
+ for _ in xrange(self.threads):
+ thread = Worker(self.tasks)
+ thread.setDaemon(True)
+ thread.start()
+ self.prepared = True
+
+ def wake_up(self):
+ """Wake up main thread.
+
+ The server usualy waits in select call in we should terminate one.
+ The simplest way is using socketpair.
+
+ Select always wait to read from the first socket of socketpair.
+
+ In this case, we can just write anything to the second socket from
+ socketpair."""
+ self._write.send('1')
+
+ def _select(self):
+ """Does select on open connections."""
+ readable = [self.socket.handle.fileno(), self._read.fileno()]
+ writable = []
+ for i, connection in self.clients.items():
+ if connection.is_readable():
+ readable.append(connection.fileno())
+ if connection.is_writeable():
+ writable.append(connection.fileno())
+ if connection.is_closed():
+ del self.clients[i]
+ return select.select(readable, writable, readable)
+
+ def handle(self):
+ """Handle requests.
+
+ WARNING! You must call prepare BEFORE calling handle.
+ """
+ assert self.prepared, "You have to call prepare before handle"
+ rset, wset, xset = self._select()
+ for readable in rset:
+ if readable == self._read.fileno():
+ # don't care i just need to clean readable flag
+ self._read.recv(1024)
+ elif readable == self.socket.handle.fileno():
+ client = self.socket.accept().handle
+ self.clients[client.fileno()] = Connection(client, self.wake_up)
+ else:
+ connection = self.clients[readable]
+ connection.read()
+ if connection.status == WAIT_PROCESS:
+ itransport = TTransport.TMemoryBuffer(connection.message)
+ otransport = TTransport.TMemoryBuffer()
+ iprot = self.in_protocol.getProtocol(itransport)
+ oprot = self.out_protocol.getProtocol(otransport)
+ self.tasks.put([self.processor, iprot, oprot,
+ otransport, connection.ready])
+ for writeable in wset:
+ self.clients[writeable].write()
+ for oob in xset:
+ self.clients[oob].close()
+ del self.clients[oob]
+
+ def close(self):
+ """Closes the server."""
+ for _ in xrange(self.threads):
+ self.tasks.put([None, None, None, None, None])
+ self.socket.close()
+ self.prepared = False
+
+ def serve(self):
+ """Serve forever."""
+ self.prepare()
+ while True:
+ self.handle()
diff --git a/lib/py/src/server/__init__.py b/lib/py/src/server/__init__.py
index b4b46a1..f017abd 100644
--- a/lib/py/src/server/__init__.py
+++ b/lib/py/src/server/__init__.py
@@ -4,4 +4,4 @@
# See accompanying file LICENSE or visit the Thrift site at:
# http://developers.facebook.com/thrift/
-__all__ = ['TServer']
+__all__ = ['TServer', 'TNonblockingServer']
diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py
index cbff372..48eadb6 100755
--- a/test/py/RunClientServer.py
+++ b/test/py/RunClientServer.py
@@ -9,12 +9,16 @@
def relfile(fname):
return os.path.join(os.path.dirname(__file__), fname)
+FRAMED = ["TNonblockingServer"]
+
def runTest(server_class):
print "Testing ", server_class
serverproc = subprocess.Popen([sys.executable, relfile("TestServer.py"), server_class])
try:
-
- ret = subprocess.call([sys.executable, relfile("TestClient.py")])
+ argv = [sys.executable, relfile("TestClient.py")]
+ if server_class in FRAMED:
+ argv.append('--framed')
+ ret = subprocess.call(argv)
if ret != 0:
raise Exception("subprocess failed")
finally:
@@ -25,4 +29,4 @@
time.sleep(5)
map(runTest, ["TForkingServer", "TThreadPoolServer",
- "TThreadedServer", "TSimpleServer"])
+ "TThreadedServer", "TSimpleServer", "TNonblockingServer"])
diff --git a/test/py/TestClient.py b/test/py/TestClient.py
index fb0133a..78dc80a 100755
--- a/test/py/TestClient.py
+++ b/test/py/TestClient.py
@@ -15,24 +15,29 @@
parser = OptionParser()
+parser.set_defaults(framed=False, verbose=1, host='localhost', port=9090)
+parser.add_option("--port", type="int", dest="port",
+ help="connect to server at port")
+parser.add_option("--host", type="string", dest="host",
+ help="connect to server")
+parser.add_option("--framed", action="store_true", dest="framed",
+ help="use framed transport")
+parser.add_option('-v', '--verbose', action="store_const",
+ dest="verbose", const=2,
+ help="verbose output")
+parser.add_option('-q', '--quiet', action="store_const",
+ dest="verbose", const=0,
+ help="minimal output")
-parser.add_option("--port", type="int", dest="port", default=9090)
-parser.add_option("--host", type="string", dest="host", default='localhost')
-parser.add_option("--framed-input", action="store_true", dest="framed_input")
-parser.add_option("--framed-output", action="store_false", dest="framed_output")
-
-(options, args) = parser.parse_args()
+options, args = parser.parse_args()
class AbstractTest(unittest.TestCase):
-
def setUp(self):
- global options
-
socket = TSocket.TSocket(options.host, options.port)
# Frame or buffer depending upon args
- if options.framed_input or options.framed_output:
- self.transport = TTransport.TFramedTransport(socket, options.framed_input, options.framed_output)
+ if options.framed:
+ self.transport = TTransport.TFramedTransport(socket)
else:
self.transport = TTransport.TBufferedTransport(socket)
@@ -113,5 +118,13 @@
suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
return suite
+class OwnArgsTestProgram(unittest.TestProgram):
+ def parseArgs(self, argv):
+ if args:
+ self.testNames = args
+ else:
+ self.testNames = (self.defaultTest,)
+ self.createTests()
+
if __name__ == "__main__":
- unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
+ OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
diff --git a/test/py/TestServer.py b/test/py/TestServer.py
index 0247bc2..a7bf6d0 100755
--- a/test/py/TestServer.py
+++ b/test/py/TestServer.py
@@ -9,7 +9,7 @@
from thrift.transport import TTransport
from thrift.transport import TSocket
from thrift.protocol import TBinaryProtocol
-from thrift.server import TServer
+from thrift.server import TServer, TNonblockingServer
class TestHandler:
@@ -59,13 +59,33 @@
time.sleep(seconds)
print 'done sleeping'
+ def testNest(self, thing):
+ return thing
+
+ def testMap(self, thing):
+ return thing
+
+ def testSet(self, thing):
+ return thing
+
+ def testList(self, thing):
+ return thing
+
+ def testEnum(self, thing):
+ return thing
+
+ def testTypedef(self, thing):
+ return thing
+
handler = TestHandler()
processor = ThriftTest.Processor(handler)
transport = TSocket.TServerSocket(9090)
tfactory = TTransport.TBufferedTransportFactory()
pfactory = TBinaryProtocol.TBinaryProtocolFactory()
-ServerClass = getattr(TServer, sys.argv[1])
-
-server = ServerClass(processor, transport, tfactory, pfactory)
+if sys.argv[1] == "TNonblockingServer":
+ server = TNonblockingServer.TNonblockingServer(processor, transport)
+else:
+ ServerClass = getattr(TServer, sys.argv[1])
+ server = ServerClass(processor, transport, tfactory, pfactory)
server.serve()