Fix python server bugs and go to new protocol wraps transport model
Reviewed By: ccheever
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664849 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index 0c4a458..1be84e2 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -2,5 +2,5 @@
"""Base class for procsessor, which works on two streams."""
- def process(itrans, otrans):
+ def process(iprot, oprot):
pass
diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py
index 25f3218..8275e34 100644
--- a/lib/py/src/protocol/TBinaryProtocol.py
+++ b/lib/py/src/protocol/TBinaryProtocol.py
@@ -5,164 +5,172 @@
"""Binary implementation of the Thrift protocol driver."""
- def writeMessageBegin(self, otrans, name, type, seqid):
- self.writeString(otrans, name)
- self.writeByte(otrans, type)
- self.writeI32(otrans, seqid)
+ def __init__(self, itrans, otrans=None):
+ TProtocolBase.__init__(self, itrans, otrans)
- def writeMessageEnd(self, otrans):
+ def writeMessageBegin(self, name, type, seqid):
+ self.writeString(name)
+ self.writeByte(type)
+ self.writeI32(seqid)
+
+ def writeMessageEnd(self):
pass
- def writeStructBegin(self, otrans, name):
+ def writeStructBegin(self, name):
pass
- def writeStructEnd(self, otrans):
+ def writeStructEnd(self):
pass
- def writeFieldBegin(self, otrans, name, type, id):
- self.writeByte(otrans, type)
- self.writeI16(otrans, id)
+ def writeFieldBegin(self, name, type, id):
+ self.writeByte(type)
+ self.writeI16(id)
- def writeFieldEnd(self, otrans):
+ def writeFieldEnd(self):
pass
- def writeFieldStop(self, otrans):
- self.writeByte(otrans, TType.STOP);
+ def writeFieldStop(self):
+ self.writeByte(TType.STOP);
- def writeMapBegin(self, otrans, ktype, vtype, size):
- self.writeByte(otrans, ktype)
- self.writeByte(otrans, vtype)
- self.writeI32(otrans, size)
+ def writeMapBegin(self, ktype, vtype, size):
+ self.writeByte(ktype)
+ self.writeByte(vtype)
+ self.writeI32(size)
- def writeMapEnd(self, otrans):
+ def writeMapEnd(self):
pass
- def writeListBegin(self, otrans, etype, size):
- self.writeByte(otrans, etype)
- self.writeI32(otrans, size)
+ def writeListBegin(self, etype, size):
+ self.writeByte(etype)
+ self.writeI32(size)
- def writeListEnd(self, otrans):
+ def writeListEnd(self):
pass
- def writeSetBegin(self, otrans, etype, size):
- self.writeByte(otrans, etype)
- self.writeI32(otrans, size)
+ def writeSetBegin(self, etype, size):
+ self.writeByte(etype)
+ self.writeI32(size)
- def writeSetEnd(self, otrans):
+ def writeSetEnd(self):
pass
- def writeBool(self, otrans, bool):
+ def writeBool(self, bool):
if bool:
- self.writeByte(otrans, 1)
+ self.writeByte(1)
else:
- self.writeByte(otrans, 0)
+ self.writeByte(0)
- def writeByte(self, otrans, byte):
+ def writeByte(self, byte):
buff = pack("!b", byte)
- otrans.write(buff)
+ self.otrans.write(buff)
- def writeI16(self, otrans, i16):
+ def writeI16(self, i16):
buff = pack("!h", i16)
- otrans.write(buff)
+ self.otrans.write(buff)
- def writeI32(self, otrans, i32):
+ def writeI32(self, i32):
buff = pack("!i", i32)
- otrans.write(buff)
+ self.otrans.write(buff)
- def writeI64(self, otrans, i64):
+ def writeI64(self, i64):
buff = pack("!q", i64)
- otrans.write(buff)
+ self.otrans.write(buff)
- def writeDouble(self, otrans, dub):
+ def writeDouble(self, dub):
buff = pack("!d", dub)
- otrans.write(buff)
+ self.otrans.write(buff)
- def writeString(self, otrans, str):
- self.writeI32(otrans, len(str))
- otrans.write(str)
+ def writeString(self, str):
+ self.writeI32(len(str))
+ self.otrans.write(str)
- def readMessageBegin(self, itrans):
- name = self.readString(itrans)
- type = self.readByte(itrans)
- seqid = self.readI32(itrans)
+ def readMessageBegin(self):
+ name = self.readString()
+ type = self.readByte()
+ seqid = self.readI32()
return (name, type, seqid)
- def readMessageEnd(self, itrans):
+ def readMessageEnd(self):
pass
- def readStructBegin(self, itrans):
+ def readStructBegin(self):
pass
- def readStructEnd(self, itrans):
+ def readStructEnd(self):
pass
- def readFieldBegin(self, itrans):
- type = self.readByte(itrans)
+ def readFieldBegin(self):
+ type = self.readByte()
if type == TType.STOP:
return (None, type, 0)
- id = self.readI16(itrans)
+ id = self.readI16()
return (None, type, id)
- def readFieldEnd(self, itrans):
+ def readFieldEnd(self):
pass
- def readMapBegin(self, itrans):
- ktype = self.readByte(itrans)
- vtype = self.readByte(itrans)
- size = self.readI32(itrans)
+ def readMapBegin(self):
+ ktype = self.readByte()
+ vtype = self.readByte()
+ size = self.readI32()
return (ktype, vtype, size)
- def readMapEnd(self, itrans):
+ def readMapEnd(self):
pass
- def readListBegin(self, itrans):
- etype = self.readByte(itrans)
- size = self.readI32(itrans)
+ def readListBegin(self):
+ etype = self.readByte()
+ size = self.readI32()
return (etype, size)
- def readListEnd(self, itrans):
+ def readListEnd(self):
pass
- def readSetBegin(self, itrans):
- etype = self.readByte(itrans)
- size = self.readI32(itrans)
+ def readSetBegin(self):
+ etype = self.readByte()
+ size = self.readI32()
return (etype, size)
- def readSetEnd(self, itrans):
+ def readSetEnd(self):
pass
- def readBool(self, itrans):
- byte = self.readByte(itrans)
+ def readBool(self):
+ byte = self.readByte()
if byte == 0:
return False
return True
- def readByte(self, itrans):
- buff = itrans.readAll(1)
+ def readByte(self):
+ buff = self.itrans.readAll(1)
val, = unpack('!b', buff)
return val
- def readI16(self, itrans):
- buff = itrans.readAll(2)
+ def readI16(self):
+ buff = self.itrans.readAll(2)
val, = unpack('!h', buff)
return val
- def readI32(self, itrans):
- buff = itrans.readAll(4)
+ def readI32(self):
+ buff = self.itrans.readAll(4)
val, = unpack('!i', buff)
return val
- def readI64(self, itrans):
- buff = itrans.readAll(8)
+ def readI64(self):
+ buff = self.itrans.readAll(8)
val, = unpack('!q', buff)
return val
- def readDouble(self, itrans):
- buff = itrans.readAll(8)
+ def readDouble(self):
+ buff = self.itrans.readAll(8)
val, = unpack('!d', buff)
return val
- def readString(self, itrans):
- len = self.readI32(itrans)
- str = itrans.readAll(len)
+ def readString(self):
+ len = self.readI32()
+ str = self.itrans.readAll(len)
return str
+
+class TBinaryProtocolFactory:
+ def getIOProtocols(self, itrans, otrans):
+ prot = TBinaryProtocol(itrans, otrans)
+ return (prot, prot)
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index 0b480d3..cc9517c 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -25,165 +25,171 @@
"""Base class for Thrift protocol driver."""
- def writeMessageBegin(self, otrans, name, type, seqid):
+ def __init__(self, itrans, otrans=None):
+ self.itrans = self.otrans = itrans
+ if otrans != None:
+ self.otrans = otrans
+
+ def writeMessageBegin(self, name, type, seqid):
pass
- def writeMessageEnd(self, otrans):
+ def writeMessageEnd(self):
pass
- def writeStructBegin(self, otrans, name):
+ def writeStructBegin(self, name):
pass
- def writeStructEnd(self, otrans):
+ def writeStructEnd(self):
pass
- def writeFieldBegin(self, otrans, name, type, id):
+ def writeFieldBegin(self, name, type, id):
pass
- def writeFieldEnd(self, otrans):
+ def writeFieldEnd(self):
pass
- def writeFieldStop(self, otrans):
+ def writeFieldStop(self):
pass
- def writeMapBegin(self, otrans, ktype, vtype, size):
+ def writeMapBegin(self, ktype, vtype, size):
pass
- def writeMapEnd(self, otrans):
+ def writeMapEnd(self):
pass
- def writeListBegin(self, otrans, etype, size):
+ def writeListBegin(self, etype, size):
pass
- def writeListEnd(self, otrans):
+ def writeListEnd(self):
pass
- def writeSetBegin(self, otrans, etype, size):
+ def writeSetBegin(self, etype, size):
pass
- def writeSetEnd(self, otrans):
+ def writeSetEnd(self):
pass
- def writeBool(self, otrans, bool):
+ def writeBool(self, bool):
pass
- def writeByte(self, otrans, byte):
+ def writeByte(self, byte):
pass
- def writeI16(self, otrans, i16):
+ def writeI16(self, i16):
pass
- def writeI32(self, otrans, i32):
+ def writeI32(self, i32):
pass
- def writeI64(self, otrans, i64):
+ def writeI64(self, i64):
pass
- def writeDouble(self, otrans, dub):
+ def writeDouble(self, dub):
pass
- def writeString(self, otrans, str):
+ def writeString(self, str):
pass
- def readMessageBegin(self, itrans):
+ def readMessageBegin(self):
pass
- def readMessageEnd(self, itrans):
+ def readMessageEnd(self):
pass
- def readStructBegin(self, itrans):
+ def readStructBegin(self):
pass
- def readStructEnd(self, itrans):
+ def readStructEnd(self):
pass
- def readFieldBegin(self, itrans):
+ def readFieldBegin(self):
pass
- def readFieldEnd(self, itrans):
+ def readFieldEnd(self):
pass
- def readMapBegin(self, itrans):
+ def readMapBegin(self):
pass
- def readMapEnd(self, itrans):
+ def readMapEnd(self):
pass
- def readListBegin(self, itrans):
+ def readListBegin(self):
pass
- def readListEnd(self, itrans):
+ def readListEnd(self):
pass
- def readSetBegin(self, itrans):
+ def readSetBegin(self):
pass
- def readSetEnd(self, itrans):
+ def readSetEnd(self):
pass
- def readBool(self, itrans):
+ def readBool(self):
pass
- def readByte(self, itrans):
+ def readByte(self):
pass
- def readI16(self, itrans):
+ def readI16(self):
pass
- def readI32(self, itrans):
+ def readI32(self):
pass
- def readI64(self, itrans):
+ def readI64(self):
pass
- def readDouble(self, itrans):
+ def readDouble(self):
pass
- def readString(self, itrans):
+ def readString(self):
pass
- def skip(self, itrans, type):
+ def skip(self, type):
if type == TType.STOP:
return
elif type == TType.BOOL:
- self.readBool(itrans)
+ self.readBool()
elif type == TType.BYTE:
- self.readByte(itrans)
+ self.readByte()
elif type == TType.I16:
- self.readI16(itrans)
+ self.readI16()
elif type == TType.I32:
- self.readI32(itrans)
+ self.readI32()
elif type == TType.I64:
- self.readI64(itrans)
+ self.readI64()
elif type == TType.DOUBLE:
- self.readDouble(itrans)
+ self.readDouble()
elif type == TType.STRING:
- self.readString(itrans)
+ self.readString()
elif type == TType.STRUCT:
- name = self.readStructBegin(itrans)
+ name = self.readStructBegin()
while True:
- (name, type, id) = self.readFieldBegin(itrans)
+ (name, type, id) = self.readFieldBegin()
if type == TType.STOP:
break
- self.skip(itrans, type)
- self.readFieldEnd(itrans)
- self.readStructEnd(itrans)
+ self.skip(type)
+ self.readFieldEnd()
+ self.readStructEnd()
elif type == TType.MAP:
- (ktype, vtype, size) = self.readMapBegin(itrans)
+ (ktype, vtype, size) = self.readMapBegin()
for i in range(size):
- self.skip(itrans, ktype)
- self.skip(itrans, vtype)
- self.readMapEnd(itrans)
+ self.skip(ktype)
+ self.skip(vtype)
+ self.readMapEnd()
elif type == TType.SET:
- (etype, size) = self.readSetBegin(itrans)
+ (etype, size) = self.readSetBegin()
for i in range(size):
- self.skip(itrans, etype)
- self.readSetEnd(itrans)
+ self.skip(etype)
+ self.readSetEnd()
elif type == TType.LIST:
- (etype, size) = self.readListBegin(itrans)
+ (etype, size) = self.readListBegin()
for i in range(size):
- self.skip(itrans, etype)
- self.readListEnd(itrans)
+ self.skip(etype)
+ self.readListEnd()
-
-
+class TProtocolFactory:
+ def getIOProtocols(self, itrans, otrans):
+ pass
diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py
index 56ee9c0..7514264 100644
--- a/lib/py/src/server/TServer.py
+++ b/lib/py/src/server/TServer.py
@@ -5,18 +5,23 @@
from thrift.Thrift import TProcessor
from thrift.transport import TTransport
+from thrift.protocol import TBinaryProtocol
class TServer:
"""Base interface for a server, which must have a serve method."""
- def __init__(self, processor, serverTransport, transportFactory=None):
+ def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
self.processor = processor
self.serverTransport = serverTransport
if transportFactory == None:
self.transportFactory = TTransport.TTransportFactoryBase()
else:
self.transportFactory = transportFactory
+ if protocolFactory == None:
+ self.protocolFactory = TBinaryProtocol.TBinaryProtocolFactory()
+ else:
+ self.protocolFactory = protocolFactory
def serve(self):
pass
@@ -25,31 +30,32 @@
"""Simple single-threaded server that just pumps around one transport."""
- def __init__(self, processor, serverTransport, transportFactory=None):
- TServer.__init__(self, processor, serverTransport, transportFactory)
+ def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
+ TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory)
def serve(self):
self.serverTransport.listen()
while True:
client = self.serverTransport.accept()
- (input, output) = self.transportFactory.getIOTransports(client)
+ (itrans, otrans) = self.transportFactory.getIOTransports(client)
+ (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans)
try:
while True:
- self.processor.process(input, output)
+ self.processor.process(iprot, oprot)
except TTransport.TTransportException, tx:
pass
except Exception, x:
print '%s, %s, %s' % (type(x), x, traceback.format_exc())
- input.close()
- output.close()
+ itrans.close()
+ otrans.close()
class TThreadedServer(TServer):
"""Threaded server that spawns a new thread per each connection."""
- def __init__(self, processor, serverTransport, transportFactory=None):
- TServer.__init__(self, processor, serverTransport, transportFactory)
+ def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None):
+ TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory)
def serve(self):
self.serverTransport.listen()
@@ -62,15 +68,19 @@
print '%s, %s, %s,' % (type(x), x, traceback.format_exc())
def handle(self, client):
- (input, output) = self.transportFactory.getIOTransports(client)
+ (itrans, otrans) = self.transportFactory.getIOTransports(client)
+ (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans)
try:
while True:
- self.processor.process(input, output)
+ self.processor.process(iprot, oprot)
except TTransport.TTransportException, tx:
pass
except Exception, x:
print '%s, %s, %s' % (type(x), x, traceback.format_exc())
+ itrans.close()
+ otrans.close()
+
class TThreadPoolServer(TServer):
"""Server with a fixed size pool of threads which service requests."""
@@ -95,15 +105,19 @@
def serveClient(self, client):
"""Process input/output from a client for as long as possible"""
- (input, output) = self.transportFactory.getIOTransports(client)
+ (itrans, otrans) = self.transportFactory.getIOTransports(client)
+ (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans)
try:
while True:
- self.processor.process(input, output)
+ self.processor.process(iprot, oprot)
except TTransport.TTransportException, tx:
pass
except Exception, x:
print '%s, %s, %s' % (type(x), x, traceback.format_exc())
+ itrans.close()
+ otrans.close()
+
def serve(self):
"""Start a fixed number of worker threads and put client into a queue"""
for i in range(self.threads):