THRIFT-4621 Add THeader for Python
Client: py
diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py
index d5d9c98..df2a7bb 100644
--- a/lib/py/src/server/TServer.py
+++ b/lib/py/src/server/TServer.py
@@ -23,6 +23,7 @@
import threading
from thrift.protocol import TBinaryProtocol
+from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
from thrift.transport import TTransport
logger = logging.getLogger(__name__)
@@ -60,6 +61,12 @@
self.inputProtocolFactory = inputProtocolFactory
self.outputProtocolFactory = outputProtocolFactory
+ input_is_header = isinstance(self.inputProtocolFactory, THeaderProtocolFactory)
+ output_is_header = isinstance(self.outputProtocolFactory, THeaderProtocolFactory)
+ if any((input_is_header, output_is_header)) and input_is_header != output_is_header:
+ raise ValueError("THeaderProtocol servers require that both the input and "
+ "output protocols are THeaderProtocol.")
+
def serve(self):
pass
@@ -76,10 +83,20 @@
client = self.serverTransport.accept()
if not client:
continue
+
itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ # for THeaderProtocol, we must use the same protocol instance for
+ # input and output so that the response is in the same dialect that
+ # the server detected the request was in.
+ if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+ otrans = None
+ oprot = iprot
+ else:
+ otrans = self.outputTransportFactory.getTransport(client)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+
try:
while True:
self.processor.process(iprot, oprot)
@@ -89,7 +106,8 @@
logger.exception(x)
itrans.close()
- otrans.close()
+ if otrans:
+ otrans.close()
class TThreadedServer(TServer):
@@ -116,9 +134,18 @@
def handle(self, client):
itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ # for THeaderProtocol, we must use the same protocol instance for input
+ # and output so that the response is in the same dialect that the
+ # server detected the request was in.
+ if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+ otrans = None
+ oprot = iprot
+ else:
+ otrans = self.outputTransportFactory.getTransport(client)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+
try:
while True:
self.processor.process(iprot, oprot)
@@ -128,7 +155,8 @@
logger.exception(x)
itrans.close()
- otrans.close()
+ if otrans:
+ otrans.close()
class TThreadPoolServer(TServer):
@@ -156,9 +184,18 @@
def serveClient(self, client):
"""Process input/output from a client for as long as possible"""
itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ # for THeaderProtocol, we must use the same protocol instance for input
+ # and output so that the response is in the same dialect that the
+ # server detected the request was in.
+ if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+ otrans = None
+ oprot = iprot
+ else:
+ otrans = self.outputTransportFactory.getTransport(client)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+
try:
while True:
self.processor.process(iprot, oprot)
@@ -168,7 +205,8 @@
logger.exception(x)
itrans.close()
- otrans.close()
+ if otrans:
+ otrans.close()
def serve(self):
"""Start a fixed number of worker threads and put client into a queue"""
@@ -237,10 +275,18 @@
try_close(otrans)
else:
itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
-
iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ # for THeaderProtocol, we must use the same protocol
+ # instance for input and output so that the response is in
+ # the same dialect that the server detected the request was
+ # in.
+ if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+ otrans = None
+ oprot = iprot
+ else:
+ otrans = self.outputTransportFactory.getTransport(client)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
ecode = 0
try:
@@ -254,7 +300,8 @@
ecode = 1
finally:
try_close(itrans)
- try_close(otrans)
+ if otrans:
+ try_close(otrans)
os._exit(ecode)