Fix python server bugs and go to new protocol wraps transport model

Reviewed By: ccheever


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664849 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index 0c4a458..1be84e2 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -2,5 +2,5 @@
 
   """Base class for procsessor, which works on two streams."""
 
-  def process(itrans, otrans):
+  def process(iprot, oprot):
     pass
diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py
index 25f3218..8275e34 100644
--- a/lib/py/src/protocol/TBinaryProtocol.py
+++ b/lib/py/src/protocol/TBinaryProtocol.py
@@ -5,164 +5,172 @@
 
   """Binary implementation of the Thrift protocol driver."""
 
-  def writeMessageBegin(self, otrans, name, type, seqid):
-    self.writeString(otrans, name)
-    self.writeByte(otrans, type)
-    self.writeI32(otrans, seqid)
+  def __init__(self, itrans, otrans=None):
+    TProtocolBase.__init__(self, itrans, otrans)
 
-  def writeMessageEnd(self, otrans):
+  def writeMessageBegin(self, name, type, seqid):
+    self.writeString(name)
+    self.writeByte(type)
+    self.writeI32(seqid)
+
+  def writeMessageEnd(self):
     pass
 
-  def writeStructBegin(self, otrans, name):
+  def writeStructBegin(self, name):
     pass
 
-  def writeStructEnd(self, otrans):
+  def writeStructEnd(self):
     pass
 
-  def writeFieldBegin(self, otrans, name, type, id):
-    self.writeByte(otrans, type)
-    self.writeI16(otrans, id)
+  def writeFieldBegin(self, name, type, id):
+    self.writeByte(type)
+    self.writeI16(id)
 
-  def writeFieldEnd(self, otrans):
+  def writeFieldEnd(self):
     pass
 
-  def writeFieldStop(self, otrans):
-    self.writeByte(otrans, TType.STOP);
+  def writeFieldStop(self):
+    self.writeByte(TType.STOP);
 
-  def writeMapBegin(self, otrans, ktype, vtype, size):
-    self.writeByte(otrans, ktype)
-    self.writeByte(otrans, vtype)
-    self.writeI32(otrans, size)
+  def writeMapBegin(self, ktype, vtype, size):
+    self.writeByte(ktype)
+    self.writeByte(vtype)
+    self.writeI32(size)
 
-  def writeMapEnd(self, otrans):
+  def writeMapEnd(self):
     pass
 
-  def writeListBegin(self, otrans, etype, size):
-    self.writeByte(otrans, etype)
-    self.writeI32(otrans, size)
+  def writeListBegin(self, etype, size):
+    self.writeByte(etype)
+    self.writeI32(size)
 
-  def writeListEnd(self, otrans):
+  def writeListEnd(self):
     pass
 
-  def writeSetBegin(self, otrans, etype, size):
-    self.writeByte(otrans, etype)
-    self.writeI32(otrans, size)
+  def writeSetBegin(self, etype, size):
+    self.writeByte(etype)
+    self.writeI32(size)
 
-  def writeSetEnd(self, otrans):
+  def writeSetEnd(self):
     pass
 
-  def writeBool(self, otrans, bool):
+  def writeBool(self, bool):
     if bool:
-      self.writeByte(otrans, 1)
+      self.writeByte(1)
     else:
-      self.writeByte(otrans, 0)
+      self.writeByte(0)
     
-  def writeByte(self, otrans, byte):
+  def writeByte(self, byte):
     buff = pack("!b", byte)
-    otrans.write(buff)
+    self.otrans.write(buff)
 
-  def writeI16(self, otrans, i16):
+  def writeI16(self, i16):
     buff = pack("!h", i16)
-    otrans.write(buff)
+    self.otrans.write(buff)
 
-  def writeI32(self, otrans, i32):
+  def writeI32(self, i32):
     buff = pack("!i", i32)
-    otrans.write(buff)
+    self.otrans.write(buff)
     
-  def writeI64(self, otrans, i64):
+  def writeI64(self, i64):
     buff = pack("!q", i64)
-    otrans.write(buff)
+    self.otrans.write(buff)
 
-  def writeDouble(self, otrans, dub):
+  def writeDouble(self, dub):
     buff = pack("!d", dub)
-    otrans.write(buff)
+    self.otrans.write(buff)
 
-  def writeString(self, otrans, str):
-    self.writeI32(otrans, len(str))
-    otrans.write(str)
+  def writeString(self, str):
+    self.writeI32(len(str))
+    self.otrans.write(str)
 
-  def readMessageBegin(self, itrans):
-    name = self.readString(itrans)
-    type = self.readByte(itrans)
-    seqid = self.readI32(itrans)
+  def readMessageBegin(self):
+    name = self.readString()
+    type = self.readByte()
+    seqid = self.readI32()
     return (name, type, seqid)
 
-  def readMessageEnd(self, itrans):
+  def readMessageEnd(self):
     pass
 
-  def readStructBegin(self, itrans):
+  def readStructBegin(self):
     pass
 
-  def readStructEnd(self, itrans):
+  def readStructEnd(self):
     pass
 
-  def readFieldBegin(self, itrans):
-    type = self.readByte(itrans)
+  def readFieldBegin(self):
+    type = self.readByte()
     if type == TType.STOP:
       return (None, type, 0)
-    id = self.readI16(itrans)
+    id = self.readI16()
     return (None, type, id)
 
-  def readFieldEnd(self, itrans):
+  def readFieldEnd(self):
     pass
 
-  def readMapBegin(self, itrans):
-    ktype = self.readByte(itrans)
-    vtype = self.readByte(itrans)
-    size = self.readI32(itrans)
+  def readMapBegin(self):
+    ktype = self.readByte()
+    vtype = self.readByte()
+    size = self.readI32()
     return (ktype, vtype, size)
 
-  def readMapEnd(self, itrans):
+  def readMapEnd(self):
     pass
 
-  def readListBegin(self, itrans):
-    etype = self.readByte(itrans)
-    size = self.readI32(itrans)
+  def readListBegin(self):
+    etype = self.readByte()
+    size = self.readI32()
     return (etype, size)
 
-  def readListEnd(self, itrans):
+  def readListEnd(self):
     pass
 
-  def readSetBegin(self, itrans):
-    etype = self.readByte(itrans)
-    size = self.readI32(itrans)
+  def readSetBegin(self):
+    etype = self.readByte()
+    size = self.readI32()
     return (etype, size)
 
-  def readSetEnd(self, itrans):
+  def readSetEnd(self):
     pass
 
-  def readBool(self, itrans):
-    byte = self.readByte(itrans)
+  def readBool(self):
+    byte = self.readByte()
     if byte == 0:
       return False
     return True
 
-  def readByte(self, itrans):
-    buff = itrans.readAll(1)
+  def readByte(self):
+    buff = self.itrans.readAll(1)
     val, = unpack('!b', buff)
     return val
 
-  def readI16(self, itrans):
-    buff = itrans.readAll(2)
+  def readI16(self):
+    buff = self.itrans.readAll(2)
     val, = unpack('!h', buff)
     return val
 
-  def readI32(self, itrans):
-    buff = itrans.readAll(4)
+  def readI32(self):
+    buff = self.itrans.readAll(4)
     val, = unpack('!i', buff)
     return val
 
-  def readI64(self, itrans):
-    buff = itrans.readAll(8)
+  def readI64(self):
+    buff = self.itrans.readAll(8)
     val, = unpack('!q', buff)
     return val
 
-  def readDouble(self, itrans):
-    buff = itrans.readAll(8)
+  def readDouble(self):
+    buff = self.itrans.readAll(8)
     val, = unpack('!d', buff)
     return val
 
-  def readString(self, itrans):
-    len = self.readI32(itrans)
-    str = itrans.readAll(len)
+  def readString(self):
+    len = self.readI32()
+    str = self.itrans.readAll(len)
     return str
+
+class TBinaryProtocolFactory:
+  def getIOProtocols(self, itrans, otrans):
+    prot = TBinaryProtocol(itrans, otrans)
+    return (prot, prot)
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index 0b480d3..cc9517c 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -25,165 +25,171 @@
 
   """Base class for Thrift protocol driver."""
 
-  def writeMessageBegin(self, otrans, name, type, seqid):
+  def __init__(self, itrans, otrans=None):
+    self.itrans = self.otrans = itrans
+    if otrans != None:
+      self.otrans = otrans
+
+  def writeMessageBegin(self, name, type, seqid):
     pass
 
-  def writeMessageEnd(self, otrans):
+  def writeMessageEnd(self):
     pass
 
-  def writeStructBegin(self, otrans, name):
+  def writeStructBegin(self, name):
     pass
 
-  def writeStructEnd(self, otrans):
+  def writeStructEnd(self):
     pass
 
-  def writeFieldBegin(self, otrans, name, type, id):
+  def writeFieldBegin(self, name, type, id):
     pass
 
-  def writeFieldEnd(self, otrans):
+  def writeFieldEnd(self):
     pass
 
-  def writeFieldStop(self, otrans):
+  def writeFieldStop(self):
     pass
 
-  def writeMapBegin(self, otrans, ktype, vtype, size):
+  def writeMapBegin(self, ktype, vtype, size):
     pass
 
-  def writeMapEnd(self, otrans):
+  def writeMapEnd(self):
     pass
 
-  def writeListBegin(self, otrans, etype, size):
+  def writeListBegin(self, etype, size):
     pass
 
-  def writeListEnd(self, otrans):
+  def writeListEnd(self):
     pass
 
-  def writeSetBegin(self, otrans, etype, size):
+  def writeSetBegin(self, etype, size):
     pass
 
-  def writeSetEnd(self, otrans):
+  def writeSetEnd(self):
     pass
 
-  def writeBool(self, otrans, bool):
+  def writeBool(self, bool):
     pass
 
-  def writeByte(self, otrans, byte):
+  def writeByte(self, byte):
     pass
 
-  def writeI16(self, otrans, i16):
+  def writeI16(self, i16):
     pass
 
-  def writeI32(self, otrans, i32):
+  def writeI32(self, i32):
     pass
 
-  def writeI64(self, otrans, i64):
+  def writeI64(self, i64):
     pass
 
-  def writeDouble(self, otrans, dub):
+  def writeDouble(self, dub):
     pass
 
-  def writeString(self, otrans, str):
+  def writeString(self, str):
     pass
 
-  def readMessageBegin(self, itrans):
+  def readMessageBegin(self):
     pass
 
-  def readMessageEnd(self, itrans):
+  def readMessageEnd(self):
     pass
 
-  def readStructBegin(self, itrans):
+  def readStructBegin(self):
     pass
 
-  def readStructEnd(self, itrans):
+  def readStructEnd(self):
     pass
 
-  def readFieldBegin(self, itrans):
+  def readFieldBegin(self):
     pass
 
-  def readFieldEnd(self, itrans):
+  def readFieldEnd(self):
     pass
 
-  def readMapBegin(self, itrans):
+  def readMapBegin(self):
     pass
 
-  def readMapEnd(self, itrans):
+  def readMapEnd(self):
     pass
 
-  def readListBegin(self, itrans):
+  def readListBegin(self):
     pass
 
-  def readListEnd(self, itrans):
+  def readListEnd(self):
     pass
 
-  def readSetBegin(self, itrans):
+  def readSetBegin(self):
     pass
 
-  def readSetEnd(self, itrans):
+  def readSetEnd(self):
     pass
 
-  def readBool(self, itrans):
+  def readBool(self):
     pass
 
-  def readByte(self, itrans):
+  def readByte(self):
     pass
 
-  def readI16(self, itrans):
+  def readI16(self):
     pass
 
-  def readI32(self, itrans):
+  def readI32(self):
     pass
 
-  def readI64(self, itrans):
+  def readI64(self):
     pass
 
-  def readDouble(self, itrans):
+  def readDouble(self):
     pass
 
-  def readString(self, itrans):
+  def readString(self):
     pass
 
-  def skip(self, itrans, type):
+  def skip(self, type):
     if type == TType.STOP:
       return
     elif type == TType.BOOL:
-      self.readBool(itrans)
+      self.readBool()
     elif type == TType.BYTE:
-      self.readByte(itrans)
+      self.readByte()
     elif type == TType.I16:
-      self.readI16(itrans)
+      self.readI16()
     elif type == TType.I32:
-      self.readI32(itrans)
+      self.readI32()
     elif type == TType.I64:
-      self.readI64(itrans)
+      self.readI64()
     elif type == TType.DOUBLE:
-      self.readDouble(itrans)
+      self.readDouble()
     elif type == TType.STRING:
-      self.readString(itrans)
+      self.readString()
     elif type == TType.STRUCT:
-      name = self.readStructBegin(itrans)
+      name = self.readStructBegin()
       while True:
-        (name, type, id) = self.readFieldBegin(itrans)
+        (name, type, id) = self.readFieldBegin()
         if type == TType.STOP:
           break
-        self.skip(itrans, type)
-        self.readFieldEnd(itrans)
-      self.readStructEnd(itrans)
+        self.skip(type)
+        self.readFieldEnd()
+      self.readStructEnd()
     elif type == TType.MAP:
-      (ktype, vtype, size) = self.readMapBegin(itrans)
+      (ktype, vtype, size) = self.readMapBegin()
       for i in range(size):
-        self.skip(itrans, ktype)
-        self.skip(itrans, vtype)
-      self.readMapEnd(itrans)
+        self.skip(ktype)
+        self.skip(vtype)
+      self.readMapEnd()
     elif type == TType.SET:
-      (etype, size) = self.readSetBegin(itrans)
+      (etype, size) = self.readSetBegin()
       for i in range(size):
-        self.skip(itrans, etype)
-      self.readSetEnd(itrans)
+        self.skip(etype)
+      self.readSetEnd()
     elif type == TType.LIST:
-      (etype, size) = self.readListBegin(itrans)
+      (etype, size) = self.readListBegin()
       for i in range(size):
-        self.skip(itrans, etype)
-      self.readListEnd(itrans)
+        self.skip(etype)
+      self.readListEnd()
 
-
-    
+class TProtocolFactory:
+  def getIOProtocols(self, itrans, otrans):
+    pass
diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py
index 56ee9c0..7514264 100644
--- a/lib/py/src/server/TServer.py
+++ b/lib/py/src/server/TServer.py
@@ -5,18 +5,23 @@
 
 from thrift.Thrift import TProcessor
 from thrift.transport import TTransport
+from thrift.protocol import TBinaryProtocol
 
 class TServer:
 
   """Base interface for a server, which must have a serve method."""
 
-  def __init__(self, processor, serverTransport, transportFactory=None):
+  def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
     self.processor = processor
     self.serverTransport = serverTransport
     if transportFactory == None:
       self.transportFactory = TTransport.TTransportFactoryBase()
     else:
       self.transportFactory = transportFactory
+    if protocolFactory == None:
+      self.protocolFactory = TBinaryProtocol.TBinaryProtocolFactory()
+    else:
+      self.protocolFactory = protocolFactory
 
   def serve(self):
     pass
@@ -25,31 +30,32 @@
 
   """Simple single-threaded server that just pumps around one transport."""
 
-  def __init__(self, processor, serverTransport, transportFactory=None):
-    TServer.__init__(self, processor, serverTransport, transportFactory)
+  def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
+    TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory)
 
   def serve(self):
     self.serverTransport.listen()
     while True:
       client = self.serverTransport.accept()
-      (input, output) = self.transportFactory.getIOTransports(client)
+      (itrans, otrans) = self.transportFactory.getIOTransports(client)
+      (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans)
       try:
         while True:
-          self.processor.process(input, output)
+          self.processor.process(iprot, oprot)
       except TTransport.TTransportException, tx:
         pass
       except Exception, x:
         print '%s, %s, %s' % (type(x), x, traceback.format_exc())
 
-      input.close()
-      output.close()
+      itrans.close()
+      otrans.close()
 
 class TThreadedServer(TServer):
 
   """Threaded server that spawns a new thread per each connection."""
 
-  def __init__(self, processor, serverTransport, transportFactory=None):
-    TServer.__init__(self, processor, serverTransport, transportFactory)
+  def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
+    TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory)
 
   def serve(self):
     self.serverTransport.listen()
@@ -62,15 +68,19 @@
         print '%s, %s, %s,' % (type(x), x, traceback.format_exc())
 
   def handle(self, client):
-    (input, output) = self.transportFactory.getIOTransports(client)
+    (itrans, otrans) = self.transportFactory.getIOTransports(client)
+    (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans)
     try:
       while True:
-        self.processor.process(input, output)
+        self.processor.process(iprot, oprot)
     except TTransport.TTransportException, tx:
       pass
     except Exception, x:
       print '%s, %s, %s' % (type(x), x, traceback.format_exc())
 
+    itrans.close()
+    otrans.close()
+
 class TThreadPoolServer(TServer):
 
   """Server with a fixed size pool of threads which service requests."""
@@ -95,15 +105,19 @@
       
   def serveClient(self, client):
     """Process input/output from a client for as long as possible"""
-    (input, output) = self.transportFactory.getIOTransports(client)
+    (itrans, otrans) = self.transportFactory.getIOTransports(client)
+    (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans)
     try:
       while True:
-        self.processor.process(input, output)
+        self.processor.process(iprot, oprot)
     except TTransport.TTransportException, tx:
       pass
     except Exception, x:
       print '%s, %s, %s' % (type(x), x, traceback.format_exc())
 
+    itrans.close()
+    otrans.close()
+
   def serve(self):
     """Start a fixed number of worker threads and put client into a queue"""
     for i in range(self.threads):