THRIFT-4621 Add THeader for Python

Client: py
diff --git a/lib/py/src/compat.py b/lib/py/src/compat.py
index 41bcf35..0e8271d 100644
--- a/lib/py/src/compat.py
+++ b/lib/py/src/compat.py
@@ -29,6 +29,9 @@
     def str_to_binary(str_val):
         return str_val
 
+    def byte_index(bytes_val, i):
+        return ord(bytes_val[i])
+
 else:
 
     from io import BytesIO as BufferIO  # noqa
@@ -38,3 +41,6 @@
 
     def str_to_binary(str_val):
         return bytes(str_val, 'utf8')
+
+    def byte_index(bytes_val, i):
+        return bytes_val[i]
diff --git a/lib/py/src/protocol/THeaderProtocol.py b/lib/py/src/protocol/THeaderProtocol.py
new file mode 100644
index 0000000..b27a749
--- /dev/null
+++ b/lib/py/src/protocol/THeaderProtocol.py
@@ -0,0 +1,225 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from thrift.protocol.TBinaryProtocol import TBinaryProtocolAccelerated
+from thrift.protocol.TCompactProtocol import TCompactProtocolAccelerated
+from thrift.protocol.TProtocol import TProtocolBase, TProtocolException
+from thrift.Thrift import TApplicationException, TMessageType
+from thrift.transport.THeaderTransport import THeaderTransport, THeaderSubprotocolID, THeaderClientType
+
+
+PROTOCOLS_BY_ID = {
+    THeaderSubprotocolID.BINARY: TBinaryProtocolAccelerated,
+    THeaderSubprotocolID.COMPACT: TCompactProtocolAccelerated,
+}
+
+
+class THeaderProtocol(TProtocolBase):
+    """A framed protocol with headers and payload transforms.
+
+    THeaderProtocol frames other Thrift protocols and adds support for optional
+    out-of-band headers. The currently supported subprotocols are
+    TBinaryProtocol and TCompactProtocol.
+
+    It's also possible to apply transforms to the encoded message payload. The
+    only transform currently supported is to gzip.
+
+    When used in a server, THeaderProtocol can accept messages from
+    non-THeaderProtocol clients if allowed (see `allowed_client_types`). This
+    includes framed and unframed transports and both TBinaryProtocol and
+    TCompactProtocol. The server will respond in the appropriate dialect for
+    the connected client. HTTP clients are not currently supported.
+
+    THeaderProtocol does not currently support THTTPServer, TNonblockingServer,
+    or TProcessPoolServer.
+
+    See doc/specs/HeaderFormat.md for details of the wire format.
+
+    """
+
+    def __init__(self, transport, allowed_client_types):
+        # much of the actual work for THeaderProtocol happens down in
+        # THeaderTransport since we need to do low-level shenanigans to detect
+        # if the client is sending us headers or one of the headerless formats
+        # we support. this wraps the real transport with the one that does all
+        # the magic.
+        if not isinstance(transport, THeaderTransport):
+            transport = THeaderTransport(transport, allowed_client_types)
+        super(THeaderProtocol, self).__init__(transport)
+        self._set_protocol()
+
+    def get_headers(self):
+        return self.trans.get_headers()
+
+    def set_header(self, key, value):
+        self.trans.set_header(key, value)
+
+    def clear_headers(self):
+        self.trans.clear_headers()
+
+    def add_transform(self, transform_id):
+        self.trans.add_transform(transform_id)
+
+    def writeMessageBegin(self, name, ttype, seqid):
+        self.trans.sequence_id = seqid
+        return self._protocol.writeMessageBegin(name, ttype, seqid)
+
+    def writeMessageEnd(self):
+        return self._protocol.writeMessageEnd()
+
+    def writeStructBegin(self, name):
+        return self._protocol.writeStructBegin(name)
+
+    def writeStructEnd(self):
+        return self._protocol.writeStructEnd()
+
+    def writeFieldBegin(self, name, ttype, fid):
+        return self._protocol.writeFieldBegin(name, ttype, fid)
+
+    def writeFieldEnd(self):
+        return self._protocol.writeFieldEnd()
+
+    def writeFieldStop(self):
+        return self._protocol.writeFieldStop()
+
+    def writeMapBegin(self, ktype, vtype, size):
+        return self._protocol.writeMapBegin(ktype, vtype, size)
+
+    def writeMapEnd(self):
+        return self._protocol.writeMapEnd()
+
+    def writeListBegin(self, etype, size):
+        return self._protocol.writeListBegin(etype, size)
+
+    def writeListEnd(self):
+        return self._protocol.writeListEnd()
+
+    def writeSetBegin(self, etype, size):
+        return self._protocol.writeSetBegin(etype, size)
+
+    def writeSetEnd(self):
+        return self._protocol.writeSetEnd()
+
+    def writeBool(self, bool_val):
+        return self._protocol.writeBool(bool_val)
+
+    def writeByte(self, byte):
+        return self._protocol.writeByte(byte)
+
+    def writeI16(self, i16):
+        return self._protocol.writeI16(i16)
+
+    def writeI32(self, i32):
+        return self._protocol.writeI32(i32)
+
+    def writeI64(self, i64):
+        return self._protocol.writeI64(i64)
+
+    def writeDouble(self, dub):
+        return self._protocol.writeDouble(dub)
+
+    def writeBinary(self, str_val):
+        return self._protocol.writeBinary(str_val)
+
+    def _set_protocol(self):
+        try:
+            protocol_cls = PROTOCOLS_BY_ID[self.trans.protocol_id]
+        except KeyError:
+            raise TApplicationException(
+                TProtocolException.INVALID_PROTOCOL,
+                "Unknown protocol requested.",
+            )
+
+        self._protocol = protocol_cls(self.trans)
+        self._fast_encode = self._protocol._fast_encode
+        self._fast_decode = self._protocol._fast_decode
+
+    def readMessageBegin(self):
+        try:
+            self.trans.readFrame(0)
+            self._set_protocol()
+        except TApplicationException as exc:
+            self._protocol.writeMessageBegin(b"", TMessageType.EXCEPTION, 0)
+            exc.write(self._protocol)
+            self._protocol.writeMessageEnd()
+            self.trans.flush()
+
+        return self._protocol.readMessageBegin()
+
+    def readMessageEnd(self):
+        return self._protocol.readMessageEnd()
+
+    def readStructBegin(self):
+        return self._protocol.readStructBegin()
+
+    def readStructEnd(self):
+        return self._protocol.readStructEnd()
+
+    def readFieldBegin(self):
+        return self._protocol.readFieldBegin()
+
+    def readFieldEnd(self):
+        return self._protocol.readFieldEnd()
+
+    def readMapBegin(self):
+        return self._protocol.readMapBegin()
+
+    def readMapEnd(self):
+        return self._protocol.readMapEnd()
+
+    def readListBegin(self):
+        return self._protocol.readListBegin()
+
+    def readListEnd(self):
+        return self._protocol.readListEnd()
+
+    def readSetBegin(self):
+        return self._protocol.readSetBegin()
+
+    def readSetEnd(self):
+        return self._protocol.readSetEnd()
+
+    def readBool(self):
+        return self._protocol.readBool()
+
+    def readByte(self):
+        return self._protocol.readByte()
+
+    def readI16(self):
+        return self._protocol.readI16()
+
+    def readI32(self):
+        return self._protocol.readI32()
+
+    def readI64(self):
+        return self._protocol.readI64()
+
+    def readDouble(self):
+        return self._protocol.readDouble()
+
+    def readBinary(self):
+        return self._protocol.readBinary()
+
+
+class THeaderProtocolFactory(object):
+    def __init__(self, allowed_client_types=(THeaderClientType.HEADERS,)):
+        self.allowed_client_types = allowed_client_types
+
+    def getProtocol(self, trans):
+        return THeaderProtocol(trans, self.allowed_client_types)
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index fd20cb7..8314cf6 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -37,6 +37,7 @@
     BAD_VERSION = 4
     NOT_IMPLEMENTED = 5
     DEPTH_LIMIT = 6
+    INVALID_PROTOCOL = 7
 
     def __init__(self, type=UNKNOWN, message=None):
         TException.__init__(self, message)
diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py
index d5d9c98..df2a7bb 100644
--- a/lib/py/src/server/TServer.py
+++ b/lib/py/src/server/TServer.py
@@ -23,6 +23,7 @@
 import threading
 
 from thrift.protocol import TBinaryProtocol
+from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
 from thrift.transport import TTransport
 
 logger = logging.getLogger(__name__)
@@ -60,6 +61,12 @@
         self.inputProtocolFactory = inputProtocolFactory
         self.outputProtocolFactory = outputProtocolFactory
 
+        input_is_header = isinstance(self.inputProtocolFactory, THeaderProtocolFactory)
+        output_is_header = isinstance(self.outputProtocolFactory, THeaderProtocolFactory)
+        if any((input_is_header, output_is_header)) and input_is_header != output_is_header:
+            raise ValueError("THeaderProtocol servers require that both the input and "
+                             "output protocols are THeaderProtocol.")
+
     def serve(self):
         pass
 
@@ -76,10 +83,20 @@
             client = self.serverTransport.accept()
             if not client:
                 continue
+
             itrans = self.inputTransportFactory.getTransport(client)
-            otrans = self.outputTransportFactory.getTransport(client)
             iprot = self.inputProtocolFactory.getProtocol(itrans)
-            oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+            # for THeaderProtocol, we must use the same protocol instance for
+            # input and output so that the response is in the same dialect that
+            # the server detected the request was in.
+            if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+                otrans = None
+                oprot = iprot
+            else:
+                otrans = self.outputTransportFactory.getTransport(client)
+                oprot = self.outputProtocolFactory.getProtocol(otrans)
+
             try:
                 while True:
                     self.processor.process(iprot, oprot)
@@ -89,7 +106,8 @@
                 logger.exception(x)
 
             itrans.close()
-            otrans.close()
+            if otrans:
+                otrans.close()
 
 
 class TThreadedServer(TServer):
@@ -116,9 +134,18 @@
 
     def handle(self, client):
         itrans = self.inputTransportFactory.getTransport(client)
-        otrans = self.outputTransportFactory.getTransport(client)
         iprot = self.inputProtocolFactory.getProtocol(itrans)
-        oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+        # for THeaderProtocol, we must use the same protocol instance for input
+        # and output so that the response is in the same dialect that the
+        # server detected the request was in.
+        if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+            otrans = None
+            oprot = iprot
+        else:
+            otrans = self.outputTransportFactory.getTransport(client)
+            oprot = self.outputProtocolFactory.getProtocol(otrans)
+
         try:
             while True:
                 self.processor.process(iprot, oprot)
@@ -128,7 +155,8 @@
             logger.exception(x)
 
         itrans.close()
-        otrans.close()
+        if otrans:
+            otrans.close()
 
 
 class TThreadPoolServer(TServer):
@@ -156,9 +184,18 @@
     def serveClient(self, client):
         """Process input/output from a client for as long as possible"""
         itrans = self.inputTransportFactory.getTransport(client)
-        otrans = self.outputTransportFactory.getTransport(client)
         iprot = self.inputProtocolFactory.getProtocol(itrans)
-        oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+        # for THeaderProtocol, we must use the same protocol instance for input
+        # and output so that the response is in the same dialect that the
+        # server detected the request was in.
+        if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+            otrans = None
+            oprot = iprot
+        else:
+            otrans = self.outputTransportFactory.getTransport(client)
+            oprot = self.outputProtocolFactory.getProtocol(otrans)
+
         try:
             while True:
                 self.processor.process(iprot, oprot)
@@ -168,7 +205,8 @@
             logger.exception(x)
 
         itrans.close()
-        otrans.close()
+        if otrans:
+            otrans.close()
 
     def serve(self):
         """Start a fixed number of worker threads and put client into a queue"""
@@ -237,10 +275,18 @@
                     try_close(otrans)
                 else:
                     itrans = self.inputTransportFactory.getTransport(client)
-                    otrans = self.outputTransportFactory.getTransport(client)
-
                     iprot = self.inputProtocolFactory.getProtocol(itrans)
-                    oprot = self.outputProtocolFactory.getProtocol(otrans)
+
+                    # for THeaderProtocol, we must use the same protocol
+                    # instance for input and output so that the response is in
+                    # the same dialect that the server detected the request was
+                    # in.
+                    if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
+                        otrans = None
+                        oprot = iprot
+                    else:
+                        otrans = self.outputTransportFactory.getTransport(client)
+                        oprot = self.outputProtocolFactory.getProtocol(otrans)
 
                     ecode = 0
                     try:
@@ -254,7 +300,8 @@
                             ecode = 1
                     finally:
                         try_close(itrans)
-                        try_close(otrans)
+                        if otrans:
+                            try_close(otrans)
 
                     os._exit(ecode)
 
diff --git a/lib/py/src/transport/THeaderTransport.py b/lib/py/src/transport/THeaderTransport.py
new file mode 100644
index 0000000..c0d5640
--- /dev/null
+++ b/lib/py/src/transport/THeaderTransport.py
@@ -0,0 +1,352 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import struct
+import zlib
+
+from thrift.compat import BufferIO, byte_index
+from thrift.protocol.TBinaryProtocol import TBinaryProtocol
+from thrift.protocol.TCompactProtocol import TCompactProtocol, readVarint, writeVarint
+from thrift.Thrift import TApplicationException
+from thrift.transport.TTransport import (
+    CReadableTransport,
+    TMemoryBuffer,
+    TTransportBase,
+    TTransportException,
+)
+
+
+U16 = struct.Struct("!H")
+I32 = struct.Struct("!i")
+HEADER_MAGIC = 0x0FFF
+HARD_MAX_FRAME_SIZE = 0x3FFFFFFF
+
+
+class THeaderClientType(object):
+    HEADERS = 0x00
+
+    FRAMED_BINARY = 0x01
+    UNFRAMED_BINARY = 0x02
+
+    FRAMED_COMPACT = 0x03
+    UNFRAMED_COMPACT = 0x04
+
+
+class THeaderSubprotocolID(object):
+    BINARY = 0x00
+    COMPACT = 0x02
+
+
+class TInfoHeaderType(object):
+    KEY_VALUE = 0x01
+
+
+class THeaderTransformID(object):
+    ZLIB = 0x01
+
+
+READ_TRANSFORMS_BY_ID = {
+    THeaderTransformID.ZLIB: zlib.decompress,
+}
+
+
+WRITE_TRANSFORMS_BY_ID = {
+    THeaderTransformID.ZLIB: zlib.compress,
+}
+
+
+def _readString(trans):
+    size = readVarint(trans)
+    if size < 0:
+        raise TTransportException(
+            TTransportException.NEGATIVE_SIZE,
+            "Negative length"
+        )
+    return trans.read(size)
+
+
+def _writeString(trans, value):
+    writeVarint(trans, len(value))
+    trans.write(value)
+
+
+class THeaderTransport(TTransportBase, CReadableTransport):
+    def __init__(self, transport, allowed_client_types):
+        self._transport = transport
+        self._client_type = THeaderClientType.HEADERS
+        self._allowed_client_types = allowed_client_types
+
+        self._read_buffer = BufferIO(b"")
+        self._read_headers = {}
+
+        self._write_buffer = BufferIO()
+        self._write_headers = {}
+        self._write_transforms = []
+
+        self.flags = 0
+        self.sequence_id = 0
+        self._protocol_id = THeaderSubprotocolID.BINARY
+        self._max_frame_size = HARD_MAX_FRAME_SIZE
+
+    def isOpen(self):
+        return self._transport.isOpen()
+
+    def open(self):
+        return self._transport.open()
+
+    def close(self):
+        return self._transport.close()
+
+    def get_headers(self):
+        return self._read_headers
+
+    def set_header(self, key, value):
+        if not isinstance(key, bytes):
+            raise ValueError("header names must be bytes")
+        if not isinstance(value, bytes):
+            raise ValueError("header values must be bytes")
+        self._write_headers[key] = value
+
+    def clear_headers(self):
+        self._write_headers.clear()
+
+    def add_transform(self, transform_id):
+        if transform_id not in WRITE_TRANSFORMS_BY_ID:
+            raise ValueError("unknown transform")
+        self._write_transforms.append(transform_id)
+
+    def set_max_frame_size(self, size):
+        if not 0 < size < HARD_MAX_FRAME_SIZE:
+            raise ValueError("maximum frame size should be < %d and > 0" % HARD_MAX_FRAME_SIZE)
+        self._max_frame_size = size
+
+    @property
+    def protocol_id(self):
+        if self._client_type == THeaderClientType.HEADERS:
+            return self._protocol_id
+        elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.UNFRAMED_BINARY):
+            return THeaderSubprotocolID.BINARY
+        elif self._client_type in (THeaderClientType.FRAMED_COMPACT, THeaderClientType.UNFRAMED_COMPACT):
+            return THeaderSubprotocolID.COMPACT
+        else:
+            raise TTransportException(
+                TTransportException.INVALID_CLIENT_TYPE,
+                "Protocol ID not know for client type %d" % self._client_type,
+            )
+
+    def read(self, sz):
+        # if there are bytes left in the buffer, produce those first.
+        bytes_read = self._read_buffer.read(sz)
+        bytes_left_to_read = sz - len(bytes_read)
+        if bytes_left_to_read == 0:
+            return bytes_read
+
+        # if we've determined this is an unframed client, just pass the read
+        # through to the underlying transport until we're reset again at the
+        # beginning of the next message.
+        if self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
+            return bytes_read + self._transport.read(bytes_left_to_read)
+
+        # we're empty and (maybe) framed. fill the buffers with the next frame.
+        self.readFrame(bytes_left_to_read)
+        return bytes_read + self._read_buffer.read(bytes_left_to_read)
+
+    def _set_client_type(self, client_type):
+        if client_type not in self._allowed_client_types:
+            raise TTransportException(
+                TTransportException.INVALID_CLIENT_TYPE,
+                "Client type %d not allowed by server." % client_type,
+            )
+        self._client_type = client_type
+
+    def readFrame(self, req_sz):
+        # the first word could either be the length field of a framed message
+        # or the first bytes of an unframed message.
+        first_word = self._transport.readAll(I32.size)
+        frame_size, = I32.unpack(first_word)
+        is_unframed = False
+        if frame_size & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
+            self._set_client_type(THeaderClientType.UNFRAMED_BINARY)
+            is_unframed = True
+        elif (byte_index(first_word, 0) == TCompactProtocol.PROTOCOL_ID and
+              byte_index(first_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
+            self._set_client_type(THeaderClientType.UNFRAMED_COMPACT)
+            is_unframed = True
+
+        if is_unframed:
+            bytes_left_to_read = req_sz - I32.size
+            if bytes_left_to_read > 0:
+                rest = self._transport.read(bytes_left_to_read)
+            else:
+                rest = b""
+            self._read_buffer = BufferIO(first_word + rest)
+            return
+
+        # ok, we're still here so we're framed.
+        if frame_size > self._max_frame_size:
+            raise TTransportException(
+                TTransportException.SIZE_LIMIT,
+                "Frame was too large.",
+            )
+        read_buffer = BufferIO(self._transport.readAll(frame_size))
+
+        # the next word is either going to be the version field of a
+        # binary/compact protocol message or the magic value + flags of a
+        # header protocol message.
+        second_word = read_buffer.read(I32.size)
+        version, = I32.unpack(second_word)
+        read_buffer.seek(0)
+        if version >> 16 == HEADER_MAGIC:
+            self._set_client_type(THeaderClientType.HEADERS)
+            self._read_buffer = self._parse_header_format(read_buffer)
+        elif version & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
+            self._set_client_type(THeaderClientType.FRAMED_BINARY)
+            self._read_buffer = read_buffer
+        elif (byte_index(second_word, 0) == TCompactProtocol.PROTOCOL_ID and
+              byte_index(second_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
+            self._set_client_type(THeaderClientType.FRAMED_COMPACT)
+            self._read_buffer = read_buffer
+        else:
+            raise TTransportException(
+                TTransportException.INVALID_CLIENT_TYPE,
+                "Could not detect client transport type.",
+            )
+
+    def _parse_header_format(self, buffer):
+        # make BufferIO look like TTransport for varint helpers
+        buffer_transport = TMemoryBuffer()
+        buffer_transport._buffer = buffer
+
+        buffer.read(2)  # discard the magic bytes
+        self.flags, = U16.unpack(buffer.read(U16.size))
+        self.sequence_id, = I32.unpack(buffer.read(I32.size))
+
+        header_length = U16.unpack(buffer.read(U16.size))[0] * 4
+        end_of_headers = buffer.tell() + header_length
+        if end_of_headers > len(buffer.getvalue()):
+            raise TTransportException(
+                TTransportException.SIZE_LIMIT,
+                "Header size is larger than whole frame.",
+            )
+
+        self._protocol_id = readVarint(buffer_transport)
+
+        transforms = []
+        transform_count = readVarint(buffer_transport)
+        for _ in range(transform_count):
+            transform_id = readVarint(buffer_transport)
+            if transform_id not in READ_TRANSFORMS_BY_ID:
+                raise TApplicationException(
+                    TApplicationException.INVALID_TRANSFORM,
+                    "Unknown transform: %d" % transform_id,
+                )
+            transforms.append(transform_id)
+        transforms.reverse()
+
+        headers = {}
+        while buffer.tell() < end_of_headers:
+            header_type = readVarint(buffer_transport)
+            if header_type == TInfoHeaderType.KEY_VALUE:
+                count = readVarint(buffer_transport)
+                for _ in range(count):
+                    key = _readString(buffer_transport)
+                    value = _readString(buffer_transport)
+                    headers[key] = value
+            else:
+                break  # ignore unknown headers
+        self._read_headers = headers
+
+        # skip padding / anything we didn't understand
+        buffer.seek(end_of_headers)
+
+        payload = buffer.read()
+        for transform_id in transforms:
+            transform_fn = READ_TRANSFORMS_BY_ID[transform_id]
+            payload = transform_fn(payload)
+        return BufferIO(payload)
+
+    def write(self, buf):
+        self._write_buffer.write(buf)
+
+    def flush(self):
+        payload = self._write_buffer.getvalue()
+        self._write_buffer = BufferIO()
+
+        buffer = BufferIO()
+        if self._client_type == THeaderClientType.HEADERS:
+            for transform_id in self._write_transforms:
+                transform_fn = WRITE_TRANSFORMS_BY_ID[transform_id]
+                payload = transform_fn(payload)
+
+            headers = BufferIO()
+            writeVarint(headers, self._protocol_id)
+            writeVarint(headers, len(self._write_transforms))
+            for transform_id in self._write_transforms:
+                writeVarint(headers, transform_id)
+            if self._write_headers:
+                writeVarint(headers, TInfoHeaderType.KEY_VALUE)
+                writeVarint(headers, len(self._write_headers))
+                for key, value in self._write_headers.items():
+                    _writeString(headers, key)
+                    _writeString(headers, value)
+                self._write_headers = {}
+            padding_needed = (4 - (len(headers.getvalue()) % 4)) % 4
+            headers.write(b"\x00" * padding_needed)
+            header_bytes = headers.getvalue()
+
+            buffer.write(I32.pack(10 + len(header_bytes) + len(payload)))
+            buffer.write(U16.pack(HEADER_MAGIC))
+            buffer.write(U16.pack(self.flags))
+            buffer.write(I32.pack(self.sequence_id))
+            buffer.write(U16.pack(len(header_bytes) // 4))
+            buffer.write(header_bytes)
+            buffer.write(payload)
+        elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.FRAMED_COMPACT):
+            buffer.write(I32.pack(len(payload)))
+            buffer.write(payload)
+        elif self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
+            buffer.write(payload)
+        else:
+            raise TTransportException(
+                TTransportException.INVALID_CLIENT_TYPE,
+                "Unknown client type.",
+            )
+
+        # the frame length field doesn't count towards the frame payload size
+        frame_bytes = buffer.getvalue()
+        frame_payload_size = len(frame_bytes) - 4
+        if frame_payload_size > self._max_frame_size:
+            raise TTransportException(
+                TTransportException.SIZE_LIMIT,
+                "Attempting to send frame that is too large.",
+            )
+
+        self._transport.write(frame_bytes)
+        self._transport.flush()
+
+    @property
+    def cstringio_buf(self):
+        return self._read_buffer
+
+    def cstringio_refill(self, partialread, reqlen):
+        result = bytearray(partialread)
+        while len(result) < reqlen:
+            result += self.read(reqlen - len(result))
+        self._read_buffer = BufferIO(result)
+        return self._read_buffer
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index c8855ca..d13060f 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -32,6 +32,7 @@
     END_OF_FILE = 4
     NEGATIVE_SIZE = 5
     SIZE_LIMIT = 6
+    INVALID_CLIENT_TYPE = 7
 
     def __init__(self, type=UNKNOWN, message=None):
         TException.__init__(self, message)
diff --git a/test/known_failures_Linux.json b/test/known_failures_Linux.json
index e523085..9d6d54b 100644
--- a/test/known_failures_Linux.json
+++ b/test/known_failures_Linux.json
@@ -83,6 +83,8 @@
   "cpp-py3_compact-accelc_http-ip-ssl",
   "cpp-py3_compact_http-ip",
   "cpp-py3_compact_http-ip-ssl",
+  "cpp-py3_header_http-ip",
+  "cpp-py3_header_http-ip-ssl",
   "cpp-py3_json_http-ip",
   "cpp-py3_json_http-ip-ssl",
   "cpp-py3_multi-accel_http-ip",
@@ -101,6 +103,8 @@
   "cpp-py3_multic-multiac_http-ip-ssl",
   "cpp-py3_multic_http-ip",
   "cpp-py3_multic_http-ip-ssl",
+  "cpp-py3_multih-header_http-ip",
+  "cpp-py3_multih-header_http-ip-ssl",
   "cpp-py3_multij-json_http-ip",
   "cpp-py3_multij-json_http-ip-ssl",
   "cpp-py3_multij_http-ip",
@@ -113,6 +117,8 @@
   "cpp-py_compact-accelc_http-ip-ssl",
   "cpp-py_compact_http-ip",
   "cpp-py_compact_http-ip-ssl",
+  "cpp-py_header_http-ip",
+  "cpp-py_header_http-ip-ssl",
   "cpp-py_json_http-ip",
   "cpp-py_json_http-ip-ssl",
   "cpp-py_multi-accel_http-ip",
@@ -131,6 +137,8 @@
   "cpp-py_multic-multiac_http-ip-ssl",
   "cpp-py_multic_http-ip",
   "cpp-py_multic_http-ip-ssl",
+  "cpp-py_multih-header_http-ip",
+  "cpp-py_multih-header_http-ip-ssl",
   "cpp-py_multij-json_http-ip",
   "cpp-py_multij-json_http-ip-ssl",
   "cpp-py_multij_http-ip",
@@ -375,6 +383,8 @@
   "py-cpp_binary_http-ip-ssl",
   "py-cpp_compact_http-ip",
   "py-cpp_compact_http-ip-ssl",
+  "py-cpp_header_http-ip",
+  "py-cpp_header_http-ip-ssl",
   "py-cpp_json_http-ip",
   "py-cpp_json_http-ip-ssl",
   "py-d_accel-binary_http-ip",
@@ -396,6 +406,7 @@
   "py-hs_accelc-compact_http-ip",
   "py-hs_binary_http-ip",
   "py-hs_compact_http-ip",
+  "py-hs_header_http-ip",
   "py-hs_json_http-ip",
   "py-java_accel-binary_http-ip",
   "py-java_accel-binary_http-ip-ssl",
@@ -420,6 +431,8 @@
   "py3-cpp_binary_http-ip-ssl",
   "py3-cpp_compact_http-ip",
   "py3-cpp_compact_http-ip-ssl",
+  "py3-cpp_header_http-ip",
+  "py3-cpp_header_http-ip-ssl",
   "py3-cpp_json_http-ip",
   "py3-cpp_json_http-ip-ssl",
   "py3-d_accel-binary_http-ip",
@@ -441,6 +454,7 @@
   "py3-hs_accelc-compact_http-ip",
   "py3-hs_binary_http-ip",
   "py3-hs_compact_http-ip",
+  "py3-hs_header_http-ip",
   "py3-hs_json_http-ip",
   "py3-java_accel-binary_http-ip",
   "py3-java_accel-binary_http-ip-ssl",
diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py
index b213d1a..56a408e 100755
--- a/test/py/RunClientServer.py
+++ b/test/py/RunClientServer.py
@@ -56,6 +56,7 @@
     'binary',
     'compact',
     'json',
+    'header',
 ]
 
 
diff --git a/test/py/TestClient.py b/test/py/TestClient.py
index 2164162..ddcce8d 100755
--- a/test/py/TestClient.py
+++ b/test/py/TestClient.py
@@ -348,6 +348,12 @@
         return TMultiplexedProtocol.TMultiplexedProtocol(wrapped_proto, "SecondService")
 
 
+class HeaderTest(MultiplexedOptionalTest):
+    def get_protocol(self, transport):
+        factory = THeaderProtocol.THeaderProtocolFactory()
+        return factory.getProtocol(transport)
+
+
 def suite():
     suite = unittest.TestSuite()
     loader = unittest.TestLoader()
@@ -359,6 +365,8 @@
         suite.addTest(loader.loadTestsFromTestCase(AcceleratedCompactTest))
     elif options.proto == 'compact':
         suite.addTest(loader.loadTestsFromTestCase(CompactTest))
+    elif options.proto == 'header':
+        suite.addTest(loader.loadTestsFromTestCase(HeaderTest))
     elif options.proto == 'json':
         suite.addTest(loader.loadTestsFromTestCase(JSONTest))
     elif options.proto == 'multi':
@@ -408,7 +416,7 @@
                       dest="verbose", const=0,
                       help="minimal output")
     parser.add_option('--protocol', dest="proto", type="string",
-                      help="protocol to use, one of: accel, accelc, binary, compact, json, multi, multia, multiac, multic, multij")
+                      help="protocol to use, one of: accel, accelc, binary, compact, header, json, multi, multia, multiac, multic, multij")
     parser.add_option('--transport', dest="trans", type="string",
                       help="transport to use, one of: buffered, framed, http")
     parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090, proto='binary')
@@ -431,6 +439,7 @@
     from thrift.transport import TZlibTransport
     from thrift.protocol import TBinaryProtocol
     from thrift.protocol import TCompactProtocol
+    from thrift.protocol import THeaderProtocol
     from thrift.protocol import TJSONProtocol
     from thrift.protocol import TMultiplexedProtocol
 
diff --git a/test/py/TestServer.py b/test/py/TestServer.py
index 4dc4c07..aba0d42 100755
--- a/test/py/TestServer.py
+++ b/test/py/TestServer.py
@@ -181,16 +181,22 @@
 def main(options):
     # set up the protocol factory form the --protocol option
     prot_factories = {
-        'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory,
-        'accelc': TCompactProtocol.TCompactProtocolAcceleratedFactory,
-        'binary': TBinaryProtocol.TBinaryProtocolFactory,
-        'compact': TCompactProtocol.TCompactProtocolFactory,
-        'json': TJSONProtocol.TJSONProtocolFactory
+        'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory(),
+        'accelc': TCompactProtocol.TCompactProtocolAcceleratedFactory(),
+        'binary': TBinaryProtocol.TBinaryProtocolFactory(),
+        'compact': TCompactProtocol.TCompactProtocolFactory(),
+        'header': THeaderProtocol.THeaderProtocolFactory(allowed_client_types=[
+            THeaderTransport.THeaderClientType.HEADERS,
+            THeaderTransport.THeaderClientType.FRAMED_BINARY,
+            THeaderTransport.THeaderClientType.UNFRAMED_BINARY,
+            THeaderTransport.THeaderClientType.FRAMED_COMPACT,
+            THeaderTransport.THeaderClientType.UNFRAMED_COMPACT,
+        ]),
+        'json': TJSONProtocol.TJSONProtocolFactory(),
     }
-    pfactory_cls = prot_factories.get(options.proto, None)
-    if pfactory_cls is None:
+    pfactory = prot_factories.get(options.proto, None)
+    if pfactory is None:
         raise AssertionError('Unknown --protocol option: %s' % options.proto)
-    pfactory = pfactory_cls()
     try:
         pfactory.string_length_limit = options.string_limit
         pfactory.container_length_limit = options.container_limit
@@ -323,11 +329,13 @@
     from ThriftTest import ThriftTest
     from ThriftTest.ttypes import Xtruct, Xception, Xception2, Insanity
     from thrift.Thrift import TException
+    from thrift.transport import THeaderTransport
     from thrift.transport import TTransport
     from thrift.transport import TSocket
     from thrift.transport import TZlibTransport
     from thrift.protocol import TBinaryProtocol
     from thrift.protocol import TCompactProtocol
+    from thrift.protocol import THeaderProtocol
     from thrift.protocol import TJSONProtocol
     from thrift.server import TServer, TNonblockingServer, THttpServer
 
diff --git a/test/tests.json b/test/tests.json
index 72790ac..85a0c07 100644
--- a/test/tests.json
+++ b/test/tests.json
@@ -273,7 +273,8 @@
       "binary",
       "json",
       "binary:accel",
-      "compact:accelc"
+      "compact:accelc",
+      "header"
     ],
     "workdir": "py"
   },
@@ -319,7 +320,8 @@
       "binary",
       "json",
       "binary:accel",
-      "compact:accelc"
+      "compact:accelc",
+      "header"
     ],
     "workdir": "py"
   },