Python basic threaded server

Reviewed By: ccheever-pillar


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664812 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py
index c9a701e..53260b8 100644
--- a/lib/py/src/server/TServer.py
+++ b/lib/py/src/server/TServer.py
@@ -34,9 +34,37 @@
       try:
         while True:
           self.processor.process(input, output)
+      except TTransport.TTransportException, tx:
+        pass
       except Exception, x:
         print '%s, %s, %s' % (type(x), x, traceback.format_exc())
-        print 'Client died.'
 
       input.close()
       output.close()
+
+class TThreadedServer(TServer):
+
+  """Threaded server that spawns a new thread per each connection."""
+
+  def __init__(self, processor, serverTransport, transportFactory=None):
+    TServer.__init__(self, processor, serverTransport, transportFactory)
+
+  def serve(self):
+    self.serverTransport.listen()
+    while True:
+      try:
+        client = self.serverTransport.accept()
+        t = threading.Thread(target = self.handle, args=(client,))
+        t.start()
+      except Exception, x:
+        print '%s, %s, %s,' % (type(x), x, traceback.format_exc())
+
+  def handle(self, client):
+    (input, output) = self.transportFactory.getIOTransports(client)
+    try:
+      while True:
+        self.processor.process(input, output)
+    except TTransport.TTransportException, tx:
+      pass
+    except Exception, x:
+      print '%s, %s, %s' % (type(x), x, traceback.format_exc())
diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py
index 2c7dd3e..dd4a166 100644
--- a/lib/py/src/transport/TSocket.py
+++ b/lib/py/src/transport/TSocket.py
@@ -10,7 +10,7 @@
     self.port = port
     self.handle = None
 
-  def set_handle(self, h):
+  def setHandle(self, h):
     self.handle = h
 
   def isOpen(self):
@@ -37,7 +37,7 @@
   def read(self, sz):
     buff = self.handle.recv(sz)
     if len(buff) == 0:
-      raise Exception('TSocket read 0 bytes')
+      raise TTransportException('TSocket read 0 bytes')
     return buff
 
   def write(self, buff):
@@ -46,7 +46,7 @@
     while sent < have:
       plus = self.handle.send(buff)
       if plus == 0:
-        raise Exception('sent 0 bytes')
+        raise TTransportException('sent 0 bytes')
       sent += plus
       buff = buff[plus:]
 
@@ -63,13 +63,16 @@
  
   def listen(self):
     self.handle = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+    if hasattr(self.handle, 'set_timeout'):
+      self.handle.set_timeout(None)
     self.handle.bind(('', self.port))
     self.handle.listen(128)
 
   def accept(self):
     (client, addr) = self.handle.accept()
     result = TSocket()
-    result.set_handle(client)
+    result.setHandle(client)
     return result
 
   def close(self):
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index a7eb3b0..7117aa0 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -1,3 +1,11 @@
+from cStringIO import StringIO
+
+class TTransportException(Exception):
+
+  """Custom Transport Exception class"""
+
+  pass
+
 class TTransportBase:
 
   """Base class for Thrift transport layer."""
@@ -58,7 +66,7 @@
 
   def __init__(self, trans):
     self.__trans = trans
-    self.__buf = ''
+    self.__buf = StringIO()
 
   def isOpen(self):
     return self.__trans.isOpen()
@@ -76,8 +84,8 @@
     return self.__trans.readAll(sz)
 
   def write(self, buf):
-    self.__buf += buf
+    self.__buf.write(buf)
 
   def flush(self):
-    self.__trans.write(self.__buf)
-    self.__buf = ''
+    self.__trans.write(self.__buf.getvalue())
+    self.__buf = StringIO()