-- Protocol and transport factories now wrap around a single protocol/transport

Summary:
- This is an analagous to the C++ change made in r31441

Reviewed By: cheever, mcslee


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664975 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py
index 8275e34..7fdfdda 100644
--- a/lib/py/src/protocol/TBinaryProtocol.py
+++ b/lib/py/src/protocol/TBinaryProtocol.py
@@ -5,8 +5,8 @@
 
   """Binary implementation of the Thrift protocol driver."""
 
-  def __init__(self, itrans, otrans=None):
-    TProtocolBase.__init__(self, itrans, otrans)
+  def __init__(self, trans):
+    TProtocolBase.__init__(self, trans)
 
   def writeMessageBegin(self, name, type, seqid):
     self.writeString(name)
@@ -62,27 +62,27 @@
     
   def writeByte(self, byte):
     buff = pack("!b", byte)
-    self.otrans.write(buff)
+    self.trans.write(buff)
 
   def writeI16(self, i16):
     buff = pack("!h", i16)
-    self.otrans.write(buff)
+    self.trans.write(buff)
 
   def writeI32(self, i32):
     buff = pack("!i", i32)
-    self.otrans.write(buff)
+    self.trans.write(buff)
     
   def writeI64(self, i64):
     buff = pack("!q", i64)
-    self.otrans.write(buff)
+    self.trans.write(buff)
 
   def writeDouble(self, dub):
     buff = pack("!d", dub)
-    self.otrans.write(buff)
+    self.trans.write(buff)
 
   def writeString(self, str):
     self.writeI32(len(str))
-    self.otrans.write(str)
+    self.trans.write(str)
 
   def readMessageBegin(self):
     name = self.readString()
@@ -141,36 +141,36 @@
     return True
 
   def readByte(self):
-    buff = self.itrans.readAll(1)
+    buff = self.trans.readAll(1)
     val, = unpack('!b', buff)
     return val
 
   def readI16(self):
-    buff = self.itrans.readAll(2)
+    buff = self.trans.readAll(2)
     val, = unpack('!h', buff)
     return val
 
   def readI32(self):
-    buff = self.itrans.readAll(4)
+    buff = self.trans.readAll(4)
     val, = unpack('!i', buff)
     return val
 
   def readI64(self):
-    buff = self.itrans.readAll(8)
+    buff = self.trans.readAll(8)
     val, = unpack('!q', buff)
     return val
 
   def readDouble(self):
-    buff = self.itrans.readAll(8)
+    buff = self.trans.readAll(8)
     val, = unpack('!d', buff)
     return val
 
   def readString(self):
     len = self.readI32()
-    str = self.itrans.readAll(len)
+    str = self.trans.readAll(len)
     return str
 
 class TBinaryProtocolFactory:
-  def getIOProtocols(self, itrans, otrans):
-    prot = TBinaryProtocol(itrans, otrans)
-    return (prot, prot)
+  def getProtocol(self, trans):
+    prot = TBinaryProtocol(trans)
+    return prot
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index cc9517c..15206b0 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -25,10 +25,8 @@
 
   """Base class for Thrift protocol driver."""
 
-  def __init__(self, itrans, otrans=None):
-    self.itrans = self.otrans = itrans
-    if otrans != None:
-      self.otrans = otrans
+  def __init__(self, trans):
+    self.trans = trans
 
   def writeMessageBegin(self, name, type, seqid):
     pass
@@ -191,5 +189,5 @@
       self.readListEnd()
 
 class TProtocolFactory:
-  def getIOProtocols(self, itrans, otrans):
+  def getProtocol(self, trans):
     pass
diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py
index 5b9e1fd..48a4fcb 100644
--- a/lib/py/src/server/TServer.py
+++ b/lib/py/src/server/TServer.py
@@ -11,17 +11,34 @@
 
   """Base interface for a server, which must have a serve method."""
 
-  def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
+  """ 3 constructors for all servers:
+  1) (processor, serverTransport)
+  2) (processor, serverTransport, transportFactory, protocolFactory)
+  3) (processor, serverTransport,
+      inputTransportFactory, outputTransportFactory,
+      inputProtocolFactory, outputProtocolFactory)"""
+  def __init__(self, *args):
+    print args
+    if (len(args) == 2):
+      self.__initArgs__(args[0], args[1],
+                        TTransport.TTransportFactoryBase(),
+                        TTransport.TTransportFactoryBase(),
+                        TBinaryProtocol.TBinaryProtocolFactory(),
+                        TBinaryProtocol.TBinaryProtocolFactory())
+    elif (len(args) == 4):
+      self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
+    elif (len(args) == 6):
+      self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
+
+  def __initArgs__(self, processor, serverTransport,
+                   inputTransportFactory, outputTransportFactory,
+                   inputProtocolFactory, outputProtocolFactory):
     self.processor = processor
     self.serverTransport = serverTransport
-    if transportFactory == None:
-      self.transportFactory = TTransport.TTransportFactoryBase()
-    else:
-      self.transportFactory = transportFactory
-    if protocolFactory == None:
-      self.protocolFactory = TBinaryProtocol.TBinaryProtocolFactory()
-    else:
-      self.protocolFactory = protocolFactory
+    self.inputTransportFactory = inputTransportFactory
+    self.outputTransportFactory = outputTransportFactory
+    self.inputProtocolFactory = inputProtocolFactory
+    self.outputProtocolFactory = outputProtocolFactory
 
   def serve(self):
     pass
@@ -30,15 +47,17 @@
 
   """Simple single-threaded server that just pumps around one transport."""
 
-  def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
-    TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory)
+  def __init__(self, *args):
+    TServer.__init__(self, *args)
 
   def serve(self):
     self.serverTransport.listen()
     while True:
       client = self.serverTransport.accept()
-      (itrans, otrans) = self.transportFactory.getIOTransports(client)
-      (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans)
+      itrans = self.inputTransportFactory.getTransport(client)
+      otrans = self.outputTransportFactory.getTransport(client)
+      iprot = self.inputProtocolFactory.getProtocol(itrans)
+      oprot = self.oututProtocolFactory.getProtocol(otrans)
       try:
         while True:
           self.processor.process(iprot, oprot)
@@ -54,8 +73,8 @@
 
   """Threaded server that spawns a new thread per each connection."""
 
-  def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
-    TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory)
+  def __init__(self, *args):
+    TServer.__init__(self, *args)
 
   def serve(self):
     self.serverTransport.listen()
@@ -68,8 +87,10 @@
         print '%s, %s, %s,' % (type(x), x, traceback.format_exc())
 
   def handle(self, client):
-    (itrans, otrans) = self.transportFactory.getIOTransports(client)
-    (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans)
+    itrans = self.inputTransportFactory.getTransport(client)
+    otrans = self.outputTransportFactory.getTransport(client)
+    iprot = self.inputProtocolFactory.getProtocol(itrans)
+    oprot = self.oututProtocolFactory.getProtocol(otrans)
     try:
       while True:
         self.processor.process(iprot, oprot)
@@ -85,8 +106,8 @@
 
   """Server with a fixed size pool of threads which service requests."""
 
-  def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
-    TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory)
+  def __init__(self, *args):
+    TServer.__init__(self, *args)
     self.clients = Queue.Queue()
     self.threads = 10
 
@@ -105,8 +126,10 @@
       
   def serveClient(self, client):
     """Process input/output from a client for as long as possible"""
-    (itrans, otrans) = self.transportFactory.getIOTransports(client)
-    (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans)
+    itrans = self.inputTransportFactory.getTransport(client)
+    otrans = self.outputTransportFactory.getTransport(client)
+    iprot = self.inputProtocolFactory.getProtocol(itrans)
+    oprot = self.oututProtocolFactory.getProtocol(otrans)
     try:
       while True:
         self.processor.process(iprot, oprot)
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index e6e202b..502b327 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -55,16 +55,16 @@
 
   """Base class for a Transport Factory"""
 
-  def getIOTransports(self, trans):
-    return (trans, trans)
+  def getTransport(self, trans):
+    return trans
 
 class TBufferedTransportFactory:
 
   """Factory transport that builds buffered transports"""
 
-  def getIOTransports(self, trans):
+  def getTransport(self, trans):
     buffered = TBufferedTransport(trans)
-    return (buffered, buffered)
+    return buffered
 
 
 class TBufferedTransport(TTransportBase):
@@ -99,9 +99,9 @@
 
   """Factory transport that builds framed transports"""
 
-  def getIOTransports(self, trans):
+  def getTransport(self, trans):
     framed = TFramedTransport(trans)
-    return (framed, framed)
+    return framed
 
 
 class TFramedTransport(TTransportBase):