THRIFT-4621 Add THeader for Python
Client: py
diff --git a/lib/py/src/compat.py b/lib/py/src/compat.py
index 41bcf35..0e8271d 100644
--- a/lib/py/src/compat.py
+++ b/lib/py/src/compat.py
@@ -29,6 +29,9 @@
def str_to_binary(str_val):
return str_val
+ def byte_index(bytes_val, i):
+ return ord(bytes_val[i])
+
else:
from io import BytesIO as BufferIO # noqa
@@ -38,3 +41,6 @@
def str_to_binary(str_val):
return bytes(str_val, 'utf8')
+
+ def byte_index(bytes_val, i):
+ return bytes_val[i]
diff --git a/lib/py/src/protocol/THeaderProtocol.py b/lib/py/src/protocol/THeaderProtocol.py
new file mode 100644
index 0000000..b27a749
--- /dev/null
+++ b/lib/py/src/protocol/THeaderProtocol.py
@@ -0,0 +1,225 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from thrift.protocol.TBinaryProtocol import TBinaryProtocolAccelerated
+from thrift.protocol.TCompactProtocol import TCompactProtocolAccelerated
+from thrift.protocol.TProtocol import TProtocolBase, TProtocolException
+from thrift.Thrift import TApplicationException, TMessageType
+from thrift.transport.THeaderTransport import THeaderTransport, THeaderSubprotocolID, THeaderClientType
+
+
+PROTOCOLS_BY_ID = {
+ THeaderSubprotocolID.BINARY: TBinaryProtocolAccelerated,
+ THeaderSubprotocolID.COMPACT: TCompactProtocolAccelerated,
+}
+
+
+class THeaderProtocol(TProtocolBase):
+ """A framed protocol with headers and payload transforms.
+
+ THeaderProtocol frames other Thrift protocols and adds support for optional
+ out-of-band headers. The currently supported subprotocols are
+ TBinaryProtocol and TCompactProtocol.
+
+ It's also possible to apply transforms to the encoded message payload. The
+ only transform currently supported is to gzip.
+
+ When used in a server, THeaderProtocol can accept messages from
+ non-THeaderProtocol clients if allowed (see `allowed_client_types`). This
+ includes framed and unframed transports and both TBinaryProtocol and
+ TCompactProtocol. The server will respond in the appropriate dialect for
+ the connected client. HTTP clients are not currently supported.
+
+ THeaderProtocol does not currently support THTTPServer, TNonblockingServer,
+ or TProcessPoolServer.
+
+ See doc/specs/HeaderFormat.md for details of the wire format.
+
+ """
+
+ def __init__(self, transport, allowed_client_types):
+ # much of the actual work for THeaderProtocol happens down in
+ # THeaderTransport since we need to do low-level shenanigans to detect
+ # if the client is sending us headers or one of the headerless formats
+ # we support. this wraps the real transport with the one that does all
+ # the magic.
+ if not isinstance(transport, THeaderTransport):
+ transport = THeaderTransport(transport, allowed_client_types)
+ super(THeaderProtocol, self).__init__(transport)
+ self._set_protocol()
+
+ def get_headers(self):
+ return self.trans.get_headers()
+
+ def set_header(self, key, value):
+ self.trans.set_header(key, value)
+
+ def clear_headers(self):
+ self.trans.clear_headers()
+
+ def add_transform(self, transform_id):
+ self.trans.add_transform(transform_id)
+
+ def writeMessageBegin(self, name, ttype, seqid):
+ self.trans.sequence_id = seqid
+ return self._protocol.writeMessageBegin(name, ttype, seqid)
+
+ def writeMessageEnd(self):
+ return self._protocol.writeMessageEnd()
+
+ def writeStructBegin(self, name):
+ return self._protocol.writeStructBegin(name)
+
+ def writeStructEnd(self):
+ return self._protocol.writeStructEnd()
+
+ def writeFieldBegin(self, name, ttype, fid):
+ return self._protocol.writeFieldBegin(name, ttype, fid)
+
+ def writeFieldEnd(self):
+ return self._protocol.writeFieldEnd()
+
+ def writeFieldStop(self):
+ return self._protocol.writeFieldStop()
+
+ def writeMapBegin(self, ktype, vtype, size):
+ return self._protocol.writeMapBegin(ktype, vtype, size)
+
+ def writeMapEnd(self):
+ return self._protocol.writeMapEnd()
+
+ def writeListBegin(self, etype, size):
+ return self._protocol.writeListBegin(etype, size)
+
+ def writeListEnd(self):
+ return self._protocol.writeListEnd()
+
+ def writeSetBegin(self, etype, size):
+ return self._protocol.writeSetBegin(etype, size)
+
+ def writeSetEnd(self):
+ return self._protocol.writeSetEnd()
+
+ def writeBool(self, bool_val):
+ return self._protocol.writeBool(bool_val)
+
+ def writeByte(self, byte):
+ return self._protocol.writeByte(byte)
+
+ def writeI16(self, i16):
+ return self._protocol.writeI16(i16)
+
+ def writeI32(self, i32):
+ return self._protocol.writeI32(i32)
+
+ def writeI64(self, i64):
+ return self._protocol.writeI64(i64)
+
+ def writeDouble(self, dub):
+ return self._protocol.writeDouble(dub)
+
+ def writeBinary(self, str_val):
+ return self._protocol.writeBinary(str_val)
+
+ def _set_protocol(self):
+ try:
+ protocol_cls = PROTOCOLS_BY_ID[self.trans.protocol_id]
+ except KeyError:
+ raise TApplicationException(
+ TProtocolException.INVALID_PROTOCOL,
+ "Unknown protocol requested.",
+ )
+
+ self._protocol = protocol_cls(self.trans)
+ self._fast_encode = self._protocol._fast_encode
+ self._fast_decode = self._protocol._fast_decode
+
+ def readMessageBegin(self):
+ try:
+ self.trans.readFrame(0)
+ self._set_protocol()
+ except TApplicationException as exc:
+ self._protocol.writeMessageBegin(b"", TMessageType.EXCEPTION, 0)
+ exc.write(self._protocol)
+ self._protocol.writeMessageEnd()
+ self.trans.flush()
+
+ return self._protocol.readMessageBegin()
+
+ def readMessageEnd(self):
+ return self._protocol.readMessageEnd()
+
+ def readStructBegin(self):
+ return self._protocol.readStructBegin()
+
+ def readStructEnd(self):
+ return self._protocol.readStructEnd()
+
+ def readFieldBegin(self):
+ return self._protocol.readFieldBegin()
+
+ def readFieldEnd(self):
+ return self._protocol.readFieldEnd()
+
+ def readMapBegin(self):
+ return self._protocol.readMapBegin()
+
+ def readMapEnd(self):
+ return self._protocol.readMapEnd()
+
+ def readListBegin(self):
+ return self._protocol.readListBegin()
+
+ def readListEnd(self):
+ return self._protocol.readListEnd()
+
+ def readSetBegin(self):
+ return self._protocol.readSetBegin()
+
+ def readSetEnd(self):
+ return self._protocol.readSetEnd()
+
+ def readBool(self):
+ return self._protocol.readBool()
+
+ def readByte(self):
+ return self._protocol.readByte()
+
+ def readI16(self):
+ return self._protocol.readI16()
+
+ def readI32(self):
+ return self._protocol.readI32()
+
+ def readI64(self):
+ return self._protocol.readI64()
+
+ def readDouble(self):
+ return self._protocol.readDouble()
+
+ def readBinary(self):
+ return self._protocol.readBinary()
+
+
+class THeaderProtocolFactory(object):
+ def __init__(self, allowed_client_types=(THeaderClientType.HEADERS,)):
+ self.allowed_client_types = allowed_client_types
+
+ def getProtocol(self, trans):
+ return THeaderProtocol(trans, self.allowed_client_types)
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index fd20cb7..8314cf6 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -37,6 +37,7 @@
BAD_VERSION = 4
NOT_IMPLEMENTED = 5
DEPTH_LIMIT = 6
+ INVALID_PROTOCOL = 7
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)
diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py
index d5d9c98..df2a7bb 100644
--- a/lib/py/src/server/TServer.py
+++ b/lib/py/src/server/TServer.py
@@ -23,6 +23,7 @@
import threading
from thrift.protocol import TBinaryProtocol
+from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
from thrift.transport import TTransport
logger = logging.getLogger(__name__)
@@ -60,6 +61,12 @@
self.inputProtocolFactory = inputProtocolFactory
self.outputProtocolFactory = outputProtocolFactory
+ input_is_header = isinstance(self.inputProtocolFactory, THeaderProtocolFactory)
+ output_is_header = isinstance(self.outputProtocolFactory, THeaderProtocolFactory)
+ if any((input_is_header, output_is_header)) and input_is_header != output_is_header:
+ raise ValueError("THeaderProtocol servers require that both the input and "
+ "output protocols are THeaderProtocol.")
+
def serve(self):
pass
@@ -76,10 +83,20 @@
client = self.serverTransport.accept()
if not client:
continue
+
itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ # for THeaderProtocol, we must use the same protocol instance for
+ # input and output so that the response is in the same dialect that
+ # the server detected the request was in.
+ if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+ otrans = None
+ oprot = iprot
+ else:
+ otrans = self.outputTransportFactory.getTransport(client)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+
try:
while True:
self.processor.process(iprot, oprot)
@@ -89,7 +106,8 @@
logger.exception(x)
itrans.close()
- otrans.close()
+ if otrans:
+ otrans.close()
class TThreadedServer(TServer):
@@ -116,9 +134,18 @@
def handle(self, client):
itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ # for THeaderProtocol, we must use the same protocol instance for input
+ # and output so that the response is in the same dialect that the
+ # server detected the request was in.
+ if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+ otrans = None
+ oprot = iprot
+ else:
+ otrans = self.outputTransportFactory.getTransport(client)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+
try:
while True:
self.processor.process(iprot, oprot)
@@ -128,7 +155,8 @@
logger.exception(x)
itrans.close()
- otrans.close()
+ if otrans:
+ otrans.close()
class TThreadPoolServer(TServer):
@@ -156,9 +184,18 @@
def serveClient(self, client):
"""Process input/output from a client for as long as possible"""
itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ # for THeaderProtocol, we must use the same protocol instance for input
+ # and output so that the response is in the same dialect that the
+ # server detected the request was in.
+ if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+ otrans = None
+ oprot = iprot
+ else:
+ otrans = self.outputTransportFactory.getTransport(client)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
+
try:
while True:
self.processor.process(iprot, oprot)
@@ -168,7 +205,8 @@
logger.exception(x)
itrans.close()
- otrans.close()
+ if otrans:
+ otrans.close()
def serve(self):
"""Start a fixed number of worker threads and put client into a queue"""
@@ -237,10 +275,18 @@
try_close(otrans)
else:
itrans = self.inputTransportFactory.getTransport(client)
- otrans = self.outputTransportFactory.getTransport(client)
-
iprot = self.inputProtocolFactory.getProtocol(itrans)
- oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+ # for THeaderProtocol, we must use the same protocol
+ # instance for input and output so that the response is in
+ # the same dialect that the server detected the request was
+ # in.
+ if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+ otrans = None
+ oprot = iprot
+ else:
+ otrans = self.outputTransportFactory.getTransport(client)
+ oprot = self.outputProtocolFactory.getProtocol(otrans)
ecode = 0
try:
@@ -254,7 +300,8 @@
ecode = 1
finally:
try_close(itrans)
- try_close(otrans)
+ if otrans:
+ try_close(otrans)
os._exit(ecode)
diff --git a/lib/py/src/transport/THeaderTransport.py b/lib/py/src/transport/THeaderTransport.py
new file mode 100644
index 0000000..c0d5640
--- /dev/null
+++ b/lib/py/src/transport/THeaderTransport.py
@@ -0,0 +1,352 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import struct
+import zlib
+
+from thrift.compat import BufferIO, byte_index
+from thrift.protocol.TBinaryProtocol import TBinaryProtocol
+from thrift.protocol.TCompactProtocol import TCompactProtocol, readVarint, writeVarint
+from thrift.Thrift import TApplicationException
+from thrift.transport.TTransport import (
+ CReadableTransport,
+ TMemoryBuffer,
+ TTransportBase,
+ TTransportException,
+)
+
+
+U16 = struct.Struct("!H")
+I32 = struct.Struct("!i")
+HEADER_MAGIC = 0x0FFF
+HARD_MAX_FRAME_SIZE = 0x3FFFFFFF
+
+
+class THeaderClientType(object):
+ HEADERS = 0x00
+
+ FRAMED_BINARY = 0x01
+ UNFRAMED_BINARY = 0x02
+
+ FRAMED_COMPACT = 0x03
+ UNFRAMED_COMPACT = 0x04
+
+
+class THeaderSubprotocolID(object):
+ BINARY = 0x00
+ COMPACT = 0x02
+
+
+class TInfoHeaderType(object):
+ KEY_VALUE = 0x01
+
+
+class THeaderTransformID(object):
+ ZLIB = 0x01
+
+
+READ_TRANSFORMS_BY_ID = {
+ THeaderTransformID.ZLIB: zlib.decompress,
+}
+
+
+WRITE_TRANSFORMS_BY_ID = {
+ THeaderTransformID.ZLIB: zlib.compress,
+}
+
+
+def _readString(trans):
+ size = readVarint(trans)
+ if size < 0:
+ raise TTransportException(
+ TTransportException.NEGATIVE_SIZE,
+ "Negative length"
+ )
+ return trans.read(size)
+
+
+def _writeString(trans, value):
+ writeVarint(trans, len(value))
+ trans.write(value)
+
+
+class THeaderTransport(TTransportBase, CReadableTransport):
+ def __init__(self, transport, allowed_client_types):
+ self._transport = transport
+ self._client_type = THeaderClientType.HEADERS
+ self._allowed_client_types = allowed_client_types
+
+ self._read_buffer = BufferIO(b"")
+ self._read_headers = {}
+
+ self._write_buffer = BufferIO()
+ self._write_headers = {}
+ self._write_transforms = []
+
+ self.flags = 0
+ self.sequence_id = 0
+ self._protocol_id = THeaderSubprotocolID.BINARY
+ self._max_frame_size = HARD_MAX_FRAME_SIZE
+
+ def isOpen(self):
+ return self._transport.isOpen()
+
+ def open(self):
+ return self._transport.open()
+
+ def close(self):
+ return self._transport.close()
+
+ def get_headers(self):
+ return self._read_headers
+
+ def set_header(self, key, value):
+ if not isinstance(key, bytes):
+ raise ValueError("header names must be bytes")
+ if not isinstance(value, bytes):
+ raise ValueError("header values must be bytes")
+ self._write_headers[key] = value
+
+ def clear_headers(self):
+ self._write_headers.clear()
+
+ def add_transform(self, transform_id):
+ if transform_id not in WRITE_TRANSFORMS_BY_ID:
+ raise ValueError("unknown transform")
+ self._write_transforms.append(transform_id)
+
+ def set_max_frame_size(self, size):
+ if not 0 < size < HARD_MAX_FRAME_SIZE:
+ raise ValueError("maximum frame size should be < %d and > 0" % HARD_MAX_FRAME_SIZE)
+ self._max_frame_size = size
+
+ @property
+ def protocol_id(self):
+ if self._client_type == THeaderClientType.HEADERS:
+ return self._protocol_id
+ elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.UNFRAMED_BINARY):
+ return THeaderSubprotocolID.BINARY
+ elif self._client_type in (THeaderClientType.FRAMED_COMPACT, THeaderClientType.UNFRAMED_COMPACT):
+ return THeaderSubprotocolID.COMPACT
+ else:
+ raise TTransportException(
+ TTransportException.INVALID_CLIENT_TYPE,
+ "Protocol ID not know for client type %d" % self._client_type,
+ )
+
+ def read(self, sz):
+ # if there are bytes left in the buffer, produce those first.
+ bytes_read = self._read_buffer.read(sz)
+ bytes_left_to_read = sz - len(bytes_read)
+ if bytes_left_to_read == 0:
+ return bytes_read
+
+ # if we've determined this is an unframed client, just pass the read
+ # through to the underlying transport until we're reset again at the
+ # beginning of the next message.
+ if self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
+ return bytes_read + self._transport.read(bytes_left_to_read)
+
+ # we're empty and (maybe) framed. fill the buffers with the next frame.
+ self.readFrame(bytes_left_to_read)
+ return bytes_read + self._read_buffer.read(bytes_left_to_read)
+
+ def _set_client_type(self, client_type):
+ if client_type not in self._allowed_client_types:
+ raise TTransportException(
+ TTransportException.INVALID_CLIENT_TYPE,
+ "Client type %d not allowed by server." % client_type,
+ )
+ self._client_type = client_type
+
+ def readFrame(self, req_sz):
+ # the first word could either be the length field of a framed message
+ # or the first bytes of an unframed message.
+ first_word = self._transport.readAll(I32.size)
+ frame_size, = I32.unpack(first_word)
+ is_unframed = False
+ if frame_size & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
+ self._set_client_type(THeaderClientType.UNFRAMED_BINARY)
+ is_unframed = True
+ elif (byte_index(first_word, 0) == TCompactProtocol.PROTOCOL_ID and
+ byte_index(first_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
+ self._set_client_type(THeaderClientType.UNFRAMED_COMPACT)
+ is_unframed = True
+
+ if is_unframed:
+ bytes_left_to_read = req_sz - I32.size
+ if bytes_left_to_read > 0:
+ rest = self._transport.read(bytes_left_to_read)
+ else:
+ rest = b""
+ self._read_buffer = BufferIO(first_word + rest)
+ return
+
+ # ok, we're still here so we're framed.
+ if frame_size > self._max_frame_size:
+ raise TTransportException(
+ TTransportException.SIZE_LIMIT,
+ "Frame was too large.",
+ )
+ read_buffer = BufferIO(self._transport.readAll(frame_size))
+
+ # the next word is either going to be the version field of a
+ # binary/compact protocol message or the magic value + flags of a
+ # header protocol message.
+ second_word = read_buffer.read(I32.size)
+ version, = I32.unpack(second_word)
+ read_buffer.seek(0)
+ if version >> 16 == HEADER_MAGIC:
+ self._set_client_type(THeaderClientType.HEADERS)
+ self._read_buffer = self._parse_header_format(read_buffer)
+ elif version & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
+ self._set_client_type(THeaderClientType.FRAMED_BINARY)
+ self._read_buffer = read_buffer
+ elif (byte_index(second_word, 0) == TCompactProtocol.PROTOCOL_ID and
+ byte_index(second_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
+ self._set_client_type(THeaderClientType.FRAMED_COMPACT)
+ self._read_buffer = read_buffer
+ else:
+ raise TTransportException(
+ TTransportException.INVALID_CLIENT_TYPE,
+ "Could not detect client transport type.",
+ )
+
+ def _parse_header_format(self, buffer):
+ # make BufferIO look like TTransport for varint helpers
+ buffer_transport = TMemoryBuffer()
+ buffer_transport._buffer = buffer
+
+ buffer.read(2) # discard the magic bytes
+ self.flags, = U16.unpack(buffer.read(U16.size))
+ self.sequence_id, = I32.unpack(buffer.read(I32.size))
+
+ header_length = U16.unpack(buffer.read(U16.size))[0] * 4
+ end_of_headers = buffer.tell() + header_length
+ if end_of_headers > len(buffer.getvalue()):
+ raise TTransportException(
+ TTransportException.SIZE_LIMIT,
+ "Header size is larger than whole frame.",
+ )
+
+ self._protocol_id = readVarint(buffer_transport)
+
+ transforms = []
+ transform_count = readVarint(buffer_transport)
+ for _ in range(transform_count):
+ transform_id = readVarint(buffer_transport)
+ if transform_id not in READ_TRANSFORMS_BY_ID:
+ raise TApplicationException(
+ TApplicationException.INVALID_TRANSFORM,
+ "Unknown transform: %d" % transform_id,
+ )
+ transforms.append(transform_id)
+ transforms.reverse()
+
+ headers = {}
+ while buffer.tell() < end_of_headers:
+ header_type = readVarint(buffer_transport)
+ if header_type == TInfoHeaderType.KEY_VALUE:
+ count = readVarint(buffer_transport)
+ for _ in range(count):
+ key = _readString(buffer_transport)
+ value = _readString(buffer_transport)
+ headers[key] = value
+ else:
+ break # ignore unknown headers
+ self._read_headers = headers
+
+ # skip padding / anything we didn't understand
+ buffer.seek(end_of_headers)
+
+ payload = buffer.read()
+ for transform_id in transforms:
+ transform_fn = READ_TRANSFORMS_BY_ID[transform_id]
+ payload = transform_fn(payload)
+ return BufferIO(payload)
+
+ def write(self, buf):
+ self._write_buffer.write(buf)
+
+ def flush(self):
+ payload = self._write_buffer.getvalue()
+ self._write_buffer = BufferIO()
+
+ buffer = BufferIO()
+ if self._client_type == THeaderClientType.HEADERS:
+ for transform_id in self._write_transforms:
+ transform_fn = WRITE_TRANSFORMS_BY_ID[transform_id]
+ payload = transform_fn(payload)
+
+ headers = BufferIO()
+ writeVarint(headers, self._protocol_id)
+ writeVarint(headers, len(self._write_transforms))
+ for transform_id in self._write_transforms:
+ writeVarint(headers, transform_id)
+ if self._write_headers:
+ writeVarint(headers, TInfoHeaderType.KEY_VALUE)
+ writeVarint(headers, len(self._write_headers))
+ for key, value in self._write_headers.items():
+ _writeString(headers, key)
+ _writeString(headers, value)
+ self._write_headers = {}
+ padding_needed = (4 - (len(headers.getvalue()) % 4)) % 4
+ headers.write(b"\x00" * padding_needed)
+ header_bytes = headers.getvalue()
+
+ buffer.write(I32.pack(10 + len(header_bytes) + len(payload)))
+ buffer.write(U16.pack(HEADER_MAGIC))
+ buffer.write(U16.pack(self.flags))
+ buffer.write(I32.pack(self.sequence_id))
+ buffer.write(U16.pack(len(header_bytes) // 4))
+ buffer.write(header_bytes)
+ buffer.write(payload)
+ elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.FRAMED_COMPACT):
+ buffer.write(I32.pack(len(payload)))
+ buffer.write(payload)
+ elif self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
+ buffer.write(payload)
+ else:
+ raise TTransportException(
+ TTransportException.INVALID_CLIENT_TYPE,
+ "Unknown client type.",
+ )
+
+ # the frame length field doesn't count towards the frame payload size
+ frame_bytes = buffer.getvalue()
+ frame_payload_size = len(frame_bytes) - 4
+ if frame_payload_size > self._max_frame_size:
+ raise TTransportException(
+ TTransportException.SIZE_LIMIT,
+ "Attempting to send frame that is too large.",
+ )
+
+ self._transport.write(frame_bytes)
+ self._transport.flush()
+
+ @property
+ def cstringio_buf(self):
+ return self._read_buffer
+
+ def cstringio_refill(self, partialread, reqlen):
+ result = bytearray(partialread)
+ while len(result) < reqlen:
+ result += self.read(reqlen - len(result))
+ self._read_buffer = BufferIO(result)
+ return self._read_buffer
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index c8855ca..d13060f 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -32,6 +32,7 @@
END_OF_FILE = 4
NEGATIVE_SIZE = 5
SIZE_LIMIT = 6
+ INVALID_CLIENT_TYPE = 7
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)
diff --git a/test/known_failures_Linux.json b/test/known_failures_Linux.json
index e523085..9d6d54b 100644
--- a/test/known_failures_Linux.json
+++ b/test/known_failures_Linux.json
@@ -83,6 +83,8 @@
"cpp-py3_compact-accelc_http-ip-ssl",
"cpp-py3_compact_http-ip",
"cpp-py3_compact_http-ip-ssl",
+ "cpp-py3_header_http-ip",
+ "cpp-py3_header_http-ip-ssl",
"cpp-py3_json_http-ip",
"cpp-py3_json_http-ip-ssl",
"cpp-py3_multi-accel_http-ip",
@@ -101,6 +103,8 @@
"cpp-py3_multic-multiac_http-ip-ssl",
"cpp-py3_multic_http-ip",
"cpp-py3_multic_http-ip-ssl",
+ "cpp-py3_multih-header_http-ip",
+ "cpp-py3_multih-header_http-ip-ssl",
"cpp-py3_multij-json_http-ip",
"cpp-py3_multij-json_http-ip-ssl",
"cpp-py3_multij_http-ip",
@@ -113,6 +117,8 @@
"cpp-py_compact-accelc_http-ip-ssl",
"cpp-py_compact_http-ip",
"cpp-py_compact_http-ip-ssl",
+ "cpp-py_header_http-ip",
+ "cpp-py_header_http-ip-ssl",
"cpp-py_json_http-ip",
"cpp-py_json_http-ip-ssl",
"cpp-py_multi-accel_http-ip",
@@ -131,6 +137,8 @@
"cpp-py_multic-multiac_http-ip-ssl",
"cpp-py_multic_http-ip",
"cpp-py_multic_http-ip-ssl",
+ "cpp-py_multih-header_http-ip",
+ "cpp-py_multih-header_http-ip-ssl",
"cpp-py_multij-json_http-ip",
"cpp-py_multij-json_http-ip-ssl",
"cpp-py_multij_http-ip",
@@ -375,6 +383,8 @@
"py-cpp_binary_http-ip-ssl",
"py-cpp_compact_http-ip",
"py-cpp_compact_http-ip-ssl",
+ "py-cpp_header_http-ip",
+ "py-cpp_header_http-ip-ssl",
"py-cpp_json_http-ip",
"py-cpp_json_http-ip-ssl",
"py-d_accel-binary_http-ip",
@@ -396,6 +406,7 @@
"py-hs_accelc-compact_http-ip",
"py-hs_binary_http-ip",
"py-hs_compact_http-ip",
+ "py-hs_header_http-ip",
"py-hs_json_http-ip",
"py-java_accel-binary_http-ip",
"py-java_accel-binary_http-ip-ssl",
@@ -420,6 +431,8 @@
"py3-cpp_binary_http-ip-ssl",
"py3-cpp_compact_http-ip",
"py3-cpp_compact_http-ip-ssl",
+ "py3-cpp_header_http-ip",
+ "py3-cpp_header_http-ip-ssl",
"py3-cpp_json_http-ip",
"py3-cpp_json_http-ip-ssl",
"py3-d_accel-binary_http-ip",
@@ -441,6 +454,7 @@
"py3-hs_accelc-compact_http-ip",
"py3-hs_binary_http-ip",
"py3-hs_compact_http-ip",
+ "py3-hs_header_http-ip",
"py3-hs_json_http-ip",
"py3-java_accel-binary_http-ip",
"py3-java_accel-binary_http-ip-ssl",
diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py
index b213d1a..56a408e 100755
--- a/test/py/RunClientServer.py
+++ b/test/py/RunClientServer.py
@@ -56,6 +56,7 @@
'binary',
'compact',
'json',
+ 'header',
]
diff --git a/test/py/TestClient.py b/test/py/TestClient.py
index 2164162..ddcce8d 100755
--- a/test/py/TestClient.py
+++ b/test/py/TestClient.py
@@ -348,6 +348,12 @@
return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "SecondService")
+class HeaderTest(MultiplexedOptionalTest):
+ def get_protocol(self, transport):
+ factory = THeaderProtocol.THeaderProtocolFactory()
+ return factory.getProtocol(transport)
+
+
def suite():
suite = unittest.TestSuite()
loader = unittest.TestLoader()
@@ -359,6 +365,8 @@
suite.addTest(loader.loadTestsFromTestCase(AcceleratedCompactTest))
elif options.proto == 'compact':
suite.addTest(loader.loadTestsFromTestCase(CompactTest))
+ elif options.proto == 'header':
+ suite.addTest(loader.loadTestsFromTestCase(HeaderTest))
elif options.proto == 'json':
suite.addTest(loader.loadTestsFromTestCase(JSONTest))
elif options.proto == 'multi':
@@ -408,7 +416,7 @@
dest="verbose", const=0,
help="minimal output")
parser.add_option('--protocol', dest="proto", type="string",
- help="protocol to use, one of: accel, accelc, binary, compact, json, multi, multia, multiac, multic, multij")
+ help="protocol to use, one of: accel, accelc, binary, compact, header, json, multi, multia, multiac, multic, multij")
parser.add_option('--transport', dest="trans", type="string",
help="transport to use, one of: buffered, framed, http")
parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090, proto='binary')
@@ -431,6 +439,7 @@
from thrift.transport import TZlibTransport
from thrift.protocol import TBinaryProtocol
from thrift.protocol import TCompactProtocol
+ from thrift.protocol import THeaderProtocol
from thrift.protocol import TJSONProtocol
from thrift.protocol import TMultiplexedProtocol
diff --git a/test/py/TestServer.py b/test/py/TestServer.py
index 4dc4c07..aba0d42 100755
--- a/test/py/TestServer.py
+++ b/test/py/TestServer.py
@@ -181,16 +181,22 @@
def main(options):
# set up the protocol factory form the --protocol option
prot_factories = {
- 'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory,
- 'accelc': TCompactProtocol.TCompactProtocolAcceleratedFactory,
- 'binary': TBinaryProtocol.TBinaryProtocolFactory,
- 'compact': TCompactProtocol.TCompactProtocolFactory,
- 'json': TJSONProtocol.TJSONProtocolFactory
+ 'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory(),
+ 'accelc': TCompactProtocol.TCompactProtocolAcceleratedFactory(),
+ 'binary': TBinaryProtocol.TBinaryProtocolFactory(),
+ 'compact': TCompactProtocol.TCompactProtocolFactory(),
+ 'header': THeaderProtocol.THeaderProtocolFactory(allowed_client_types=[
+ THeaderTransport.THeaderClientType.HEADERS,
+ THeaderTransport.THeaderClientType.FRAMED_BINARY,
+ THeaderTransport.THeaderClientType.UNFRAMED_BINARY,
+ THeaderTransport.THeaderClientType.FRAMED_COMPACT,
+ THeaderTransport.THeaderClientType.UNFRAMED_COMPACT,
+ ]),
+ 'json': TJSONProtocol.TJSONProtocolFactory(),
}
- pfactory_cls = prot_factories.get(options.proto, None)
- if pfactory_cls is None:
+ pfactory = prot_factories.get(options.proto, None)
+ if pfactory is None:
raise AssertionError('Unknown --protocol option: %s' % options.proto)
- pfactory = pfactory_cls()
try:
pfactory.string_length_limit = options.string_limit
pfactory.container_length_limit = options.container_limit
@@ -323,11 +329,13 @@
from ThriftTest import ThriftTest
from ThriftTest.ttypes import Xtruct, Xception, Xception2, Insanity
from thrift.Thrift import TException
+ from thrift.transport import THeaderTransport
from thrift.transport import TTransport
from thrift.transport import TSocket
from thrift.transport import TZlibTransport
from thrift.protocol import TBinaryProtocol
from thrift.protocol import TCompactProtocol
+ from thrift.protocol import THeaderProtocol
from thrift.protocol import TJSONProtocol
from thrift.server import TServer, TNonblockingServer, THttpServer
diff --git a/test/tests.json b/test/tests.json
index 72790ac..85a0c07 100644
--- a/test/tests.json
+++ b/test/tests.json
@@ -273,7 +273,8 @@
"binary",
"json",
"binary:accel",
- "compact:accelc"
+ "compact:accelc",
+ "header"
],
"workdir": "py"
},
@@ -319,7 +320,8 @@
"binary",
"json",
"binary:accel",
- "compact:accelc"
+ "compact:accelc",
+ "header"
],
"workdir": "py"
},