-- Protocol and transport factories now wrap around a single protocol/transport
Summary:
- This is an analagous to the C++ change made in r31441
Reviewed By: cheever, mcslee
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664975 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py
index 8275e34..7fdfdda 100644
--- a/lib/py/src/protocol/TBinaryProtocol.py
+++ b/lib/py/src/protocol/TBinaryProtocol.py
@@ -5,8 +5,8 @@
"""Binary implementation of the Thrift protocol driver."""
- def __init__(self, itrans, otrans=None):
- TProtocolBase.__init__(self, itrans, otrans)
+ def __init__(self, trans):
+ TProtocolBase.__init__(self, trans)
def writeMessageBegin(self, name, type, seqid):
self.writeString(name)
@@ -62,27 +62,27 @@
def writeByte(self, byte):
buff = pack("!b", byte)
- self.otrans.write(buff)
+ self.trans.write(buff)
def writeI16(self, i16):
buff = pack("!h", i16)
- self.otrans.write(buff)
+ self.trans.write(buff)
def writeI32(self, i32):
buff = pack("!i", i32)
- self.otrans.write(buff)
+ self.trans.write(buff)
def writeI64(self, i64):
buff = pack("!q", i64)
- self.otrans.write(buff)
+ self.trans.write(buff)
def writeDouble(self, dub):
buff = pack("!d", dub)
- self.otrans.write(buff)
+ self.trans.write(buff)
def writeString(self, str):
self.writeI32(len(str))
- self.otrans.write(str)
+ self.trans.write(str)
def readMessageBegin(self):
name = self.readString()
@@ -141,36 +141,36 @@
return True
def readByte(self):
- buff = self.itrans.readAll(1)
+ buff = self.trans.readAll(1)
val, = unpack('!b', buff)
return val
def readI16(self):
- buff = self.itrans.readAll(2)
+ buff = self.trans.readAll(2)
val, = unpack('!h', buff)
return val
def readI32(self):
- buff = self.itrans.readAll(4)
+ buff = self.trans.readAll(4)
val, = unpack('!i', buff)
return val
def readI64(self):
- buff = self.itrans.readAll(8)
+ buff = self.trans.readAll(8)
val, = unpack('!q', buff)
return val
def readDouble(self):
- buff = self.itrans.readAll(8)
+ buff = self.trans.readAll(8)
val, = unpack('!d', buff)
return val
def readString(self):
len = self.readI32()
- str = self.itrans.readAll(len)
+ str = self.trans.readAll(len)
return str
class TBinaryProtocolFactory:
- def getIOProtocols(self, itrans, otrans):
- prot = TBinaryProtocol(itrans, otrans)
- return (prot, prot)
+ def getProtocol(self, trans):
+ prot = TBinaryProtocol(trans)
+ return prot
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index cc9517c..15206b0 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -25,10 +25,8 @@
"""Base class for Thrift protocol driver."""
- def __init__(self, itrans, otrans=None):
- self.itrans = self.otrans = itrans
- if otrans != None:
- self.otrans = otrans
+ def __init__(self, trans):
+ self.trans = trans
def writeMessageBegin(self, name, type, seqid):
pass
@@ -191,5 +189,5 @@
self.readListEnd()
class TProtocolFactory:
- def getIOProtocols(self, itrans, otrans):
+ def getProtocol(self, trans):
pass
diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py
index 5b9e1fd..48a4fcb 100644
--- a/lib/py/src/server/TServer.py
+++ b/lib/py/src/server/TServer.py
@@ -11,17 +11,34 @@
"""Base interface for a server, which must have a serve method."""
- def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
+ """ 3 constructors for all servers:
+ 1) (processor, serverTransport)
+ 2) (processor, serverTransport, transportFactory, protocolFactory)
+ 3) (processor, serverTransport,
+ inputTransportFactory, outputTransportFactory,
+ inputProtocolFactory, outputProtocolFactory)"""
+ def __init__(self, *args):
+ print args
+ if (len(args) == 2):
+ self.__initArgs__(args[0], args[1],
+ TTransport.TTransportFactoryBase(),
+ TTransport.TTransportFactoryBase(),
+ TBinaryProtocol.TBinaryProtocolFactory(),
+ TBinaryProtocol.TBinaryProtocolFactory())
+ elif (len(args) == 4):
+ self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
+ elif (len(args) == 6):
+ self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
+
+ def __initArgs__(self, processor, serverTransport,
+ inputTransportFactory, outputTransportFactory,
+ inputProtocolFactory, outputProtocolFactory):
self.processor = processor
self.serverTransport = serverTransport
- if transportFactory == None:
- self.transportFactory = TTransport.TTransportFactoryBase()
- else:
- self.transportFactory = transportFactory
- if protocolFactory == None:
- self.protocolFactory = TBinaryProtocol.TBinaryProtocolFactory()
- else:
- self.protocolFactory = protocolFactory
+ self.inputTransportFactory = inputTransportFactory
+ self.outputTransportFactory = outputTransportFactory
+ self.inputProtocolFactory = inputProtocolFactory
+ self.outputProtocolFactory = outputProtocolFactory
def serve(self):
pass
@@ -30,15 +47,17 @@
"""Simple single-threaded server that just pumps around one transport."""
- def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
- TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory)
+ def __init__(self, *args):
+ TServer.__init__(self, *args)
def serve(self):
self.serverTransport.listen()
while True:
client = self.serverTransport.accept()
- (itrans, otrans) = self.transportFactory.getIOTransports(client)
- (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans)
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+ oprot = self.oututProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
@@ -54,8 +73,8 @@
"""Threaded server that spawns a new thread per each connection."""
- def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
- TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory)
+ def __init__(self, *args):
+ TServer.__init__(self, *args)
def serve(self):
self.serverTransport.listen()
@@ -68,8 +87,10 @@
print '%s, %s, %s,' % (type(x), x, traceback.format_exc())
def handle(self, client):
- (itrans, otrans) = self.transportFactory.getIOTransports(client)
- (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans)
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+ oprot = self.oututProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
@@ -85,8 +106,8 @@
"""Server with a fixed size pool of threads which service requests."""
- def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
- TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory)
+ def __init__(self, *args):
+ TServer.__init__(self, *args)
self.clients = Queue.Queue()
self.threads = 10
@@ -105,8 +126,10 @@
def serveClient(self, client):
"""Process input/output from a client for as long as possible"""
- (itrans, otrans) = self.transportFactory.getIOTransports(client)
- (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans)
+ itrans = self.inputTransportFactory.getTransport(client)
+ otrans = self.outputTransportFactory.getTransport(client)
+ iprot = self.inputProtocolFactory.getProtocol(itrans)
+ oprot = self.oututProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index e6e202b..502b327 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -55,16 +55,16 @@
"""Base class for a Transport Factory"""
- def getIOTransports(self, trans):
- return (trans, trans)
+ def getTransport(self, trans):
+ return trans
class TBufferedTransportFactory:
"""Factory transport that builds buffered transports"""
- def getIOTransports(self, trans):
+ def getTransport(self, trans):
buffered = TBufferedTransport(trans)
- return (buffered, buffered)
+ return buffered
class TBufferedTransport(TTransportBase):
@@ -99,9 +99,9 @@
"""Factory transport that builds framed transports"""
- def getIOTransports(self, trans):
+ def getTransport(self, trans):
framed = TFramedTransport(trans)
- return (framed, framed)
+ return framed
class TFramedTransport(TTransportBase):