THRIFT-3578 Make THeaderTransport detect TCompact framed and unframed
Client: C++
Patch: Nobuaki Sukegawa

This closes #819
diff --git a/lib/cpp/src/thrift/protocol/TCompactProtocol.h b/lib/cpp/src/thrift/protocol/TCompactProtocol.h
index 5b7ade2..d970be2 100644
--- a/lib/cpp/src/thrift/protocol/TCompactProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TCompactProtocol.h
@@ -34,11 +34,12 @@
  */
 template <class Transport_>
 class TCompactProtocolT : public TVirtualProtocol<TCompactProtocolT<Transport_> > {
-
-protected:
+public:
   static const int8_t PROTOCOL_ID = (int8_t)0x82u;
   static const int8_t VERSION_N = 1;
   static const int8_t VERSION_MASK = 0x1f;       // 0001 1111
+
+protected:
   static const int8_t TYPE_MASK = (int8_t)0xE0u; // 1110 0000
   static const int8_t TYPE_BITS = 0x07;          // 0000 0111
   static const int32_t TYPE_SHIFT_AMOUNT = 5;
diff --git a/lib/cpp/src/thrift/transport/THeaderTransport.cpp b/lib/cpp/src/thrift/transport/THeaderTransport.cpp
index fd24fed..b1fe923 100644
--- a/lib/cpp/src/thrift/transport/THeaderTransport.cpp
+++ b/lib/cpp/src/thrift/transport/THeaderTransport.cpp
@@ -21,6 +21,7 @@
 #include <thrift/TApplicationException.h>
 #include <thrift/protocol/TProtocolTypes.h>
 #include <thrift/protocol/TBinaryProtocol.h>
+#include <thrift/protocol/TCompactProtocol.h>
 
 #include <utility>
 #include <cassert>
@@ -41,7 +42,7 @@
 using apache::thrift::protocol::TBinaryProtocol;
 
 uint32_t THeaderTransport::readSlow(uint8_t* buf, uint32_t len) {
-  if (clientType == THRIFT_UNFRAMED_DEPRECATED) {
+  if (clientType == THRIFT_UNFRAMED_BINARY || clientType == THRIFT_UNFRAMED_COMPACT) {
     return transport_->read(buf, len);
   }
 
@@ -51,6 +52,8 @@
 uint16_t THeaderTransport::getProtocolId() const {
   if (clientType == THRIFT_HEADER_CLIENT_TYPE) {
     return protoId;
+  } else if (clientType == THRIFT_UNFRAMED_COMPACT || clientType == THRIFT_FRAMED_COMPACT) {
+    return T_COMPACT_PROTOCOL;
   } else {
     return T_BINARY_PROTOCOL; // Assume other transports use TBinary
   }
@@ -92,19 +95,19 @@
 
   sz = ntohl(szN);
 
-  uint32_t minFrameSize = 0;
-  ensureReadBuffer(minFrameSize + 4);
+  ensureReadBuffer(4);
 
   if ((sz & TBinaryProtocol::VERSION_MASK) == (uint32_t)TBinaryProtocol::VERSION_1) {
     // unframed
-    clientType = THRIFT_UNFRAMED_DEPRECATED;
+    clientType = THRIFT_UNFRAMED_BINARY;
     memcpy(rBuf_.get(), &szN, sizeof(szN));
-    if (minFrameSize > 4) {
-      transport_->readAll(rBuf_.get() + 4, minFrameSize - 4);
-      setReadBuffer(rBuf_.get(), minFrameSize);
-    } else {
-      setReadBuffer(rBuf_.get(), 4);
-    }
+    setReadBuffer(rBuf_.get(), 4);
+  } else if (static_cast<int8_t>(sz >> 24) == TCompactProtocol::PROTOCOL_ID
+             && (static_cast<int8_t>(sz >> 16) & TCompactProtocol::VERSION_MASK)
+                    == TCompactProtocol::VERSION_N) {
+    clientType = THRIFT_UNFRAMED_COMPACT;
+    memcpy(rBuf_.get(), &szN, sizeof(szN));
+    setReadBuffer(rBuf_.get(), 4);
   } else {
     // Could be header format or framed. Check next uint32
     uint32_t magic_n;
@@ -124,7 +127,13 @@
 
     if ((magic & TBinaryProtocol::VERSION_MASK) == (uint32_t)TBinaryProtocol::VERSION_1) {
       // framed
-      clientType = THRIFT_FRAMED_DEPRECATED;
+      clientType = THRIFT_FRAMED_BINARY;
+      transport_->readAll(rBuf_.get() + 4, sz - 4);
+      setReadBuffer(rBuf_.get(), sz);
+    } else if (static_cast<int8_t>(magic >> 24) == TCompactProtocol::PROTOCOL_ID
+               && (static_cast<int8_t>(magic >> 16) & TCompactProtocol::VERSION_MASK)
+                      == TCompactProtocol::VERSION_N) {
+      clientType = THRIFT_FRAMED_COMPACT;
       transport_->readAll(rBuf_.get() + 4, sz - 4);
       setReadBuffer(rBuf_.get(), sz);
     } else if (HEADER_MAGIC == (magic & HEADER_MASK)) {
@@ -506,13 +515,13 @@
 
     outTransport_->write(pktStart, szHbo - haveBytes + 4);
     outTransport_->write(wBuf_.get(), haveBytes);
-  } else if (clientType == THRIFT_FRAMED_DEPRECATED) {
+  } else if (clientType == THRIFT_FRAMED_BINARY || clientType == THRIFT_FRAMED_COMPACT) {
     uint32_t szHbo = (uint32_t)haveBytes;
     uint32_t szNbo = htonl(szHbo);
 
     outTransport_->write(reinterpret_cast<uint8_t*>(&szNbo), 4);
     outTransport_->write(wBuf_.get(), haveBytes);
-  } else if (clientType == THRIFT_UNFRAMED_DEPRECATED) {
+  } else if (clientType == THRIFT_UNFRAMED_BINARY || clientType == THRIFT_UNFRAMED_COMPACT) {
     outTransport_->write(wBuf_.get(), haveBytes);
   } else {
     throw TTransportException(TTransportException::BAD_ARGS, "Unknown client type");
diff --git a/lib/cpp/src/thrift/transport/THeaderTransport.h b/lib/cpp/src/thrift/transport/THeaderTransport.h
index a125632..bf82674 100644
--- a/lib/cpp/src/thrift/transport/THeaderTransport.h
+++ b/lib/cpp/src/thrift/transport/THeaderTransport.h
@@ -34,14 +34,13 @@
 #include <thrift/transport/TTransport.h>
 #include <thrift/transport/TVirtualTransport.h>
 
-// Don't include the unknown client.
-#define CLIENT_TYPES_LEN 3
-
 enum CLIENT_TYPE {
   THRIFT_HEADER_CLIENT_TYPE = 0,
-  THRIFT_FRAMED_DEPRECATED = 1,
-  THRIFT_UNFRAMED_DEPRECATED = 2,
-  THRIFT_UNKNOWN_CLIENT_TYPE = 4,
+  THRIFT_FRAMED_BINARY = 1,
+  THRIFT_UNFRAMED_BINARY = 2,
+  THRIFT_FRAMED_COMPACT = 3,
+  THRIFT_UNFRAMED_COMPACT = 4,
+  THRIFT_UNKNOWN_CLIENT_TYPE = 5,
 };
 
 namespace apache {
@@ -165,10 +164,6 @@
   };
 
 protected:
-  std::bitset<CLIENT_TYPES_LEN> supported_clients;
-
-  void initSupportedClients(std::bitset<CLIENT_TYPES_LEN> const*);
-
   /**
    * Reads a frame of input from the underlying stream.
    *
diff --git a/test/features/tests.json b/test/features/tests.json
index f726dad..cfcb4b6 100644
--- a/test/features/tests.json
+++ b/test/features/tests.json
@@ -4,7 +4,9 @@
     "name": "theader_unframed_binary",
     "command": [
       "python",
-      "theader_binary.py"
+      "theader_binary.py",
+      "--override-protocol=binary",
+      "--override-transport=buffered"
     ],
     "protocols": ["header"],
     "transports": ["buffered"],
@@ -17,6 +19,35 @@
     "command": [
       "python",
       "theader_binary.py",
+      "--override-protocol=binary",
+      "--override-transport=framed"
+    ],
+    "protocols": ["header"],
+    "transports": ["buffered"],
+    "sockets": ["ip"],
+    "workdir": "features"
+  },
+  {
+    "description": "THeader detects unframed compact wire format",
+    "name": "theader_unframed_compact",
+    "command": [
+      "python",
+      "theader_binary.py",
+      "--override-protocol=compact",
+      "--override-transport=buffered"
+    ],
+    "protocols": ["header"],
+    "transports": ["buffered"],
+    "sockets": ["ip"],
+    "workdir": "features"
+  },
+  {
+    "description": "THeader detects framed compact wire format",
+    "name": "theader_framed_compact",
+    "command": [
+      "python",
+      "theader_binary.py",
+      "--override-protocol=compact",
       "--override-transport=framed"
     ],
     "protocols": ["header"],
diff --git a/test/features/theader_binary.py b/test/features/theader_binary.py
index 0316741..62a2671 100644
--- a/test/features/theader_binary.py
+++ b/test/features/theader_binary.py
@@ -10,6 +10,24 @@
 from thrift.transport.TSocket import TSocket
 from thrift.transport.TTransport import TBufferedTransport, TFramedTransport
 from thrift.protocol.TBinaryProtocol import TBinaryProtocol
+from thrift.protocol.TCompactProtocol import TCompactProtocol
+
+
+def test_void(proto):
+  proto.writeMessageBegin('testVoid', TMessageType.CALL, 3)
+  proto.writeStructBegin('testVoid_args')
+  proto.writeFieldStop()
+  proto.writeStructEnd()
+  proto.writeMessageEnd()
+  proto.trans.flush()
+
+  _, mtype, _ = proto.readMessageBegin()
+  assert mtype == TMessageType.REPLY
+  proto.readStructBegin()
+  _, ftype, _ = proto.readFieldBegin()
+  assert ftype == TType.STOP
+  proto.readStructEnd()
+  proto.readMessageEnd()
 
 
 # THeader stack should accept binary protocol with optionally framed transport
@@ -19,6 +37,7 @@
   # Since THeaderTransport acts as framed transport when detected frame, we
   # cannot use --transport=framed as it would result in 2 layered frames.
   p.add_argument('--override-transport')
+  p.add_argument('--override-protocol')
   args = p.parse_args()
   assert args.protocol == 'header'
   assert args.transport == 'buffered'
@@ -28,26 +47,21 @@
   if not args.override_transport or args.override_transport == 'buffered':
     trans = TBufferedTransport(sock)
   elif args.override_transport == 'framed':
+    print('TFRAMED')
     trans = TFramedTransport(sock)
   else:
     raise ValueError('invalid transport')
   trans.open()
-  proto = TBinaryProtocol(trans)
-  proto.writeMessageBegin('testVoid', TMessageType.CALL, 3)
-  proto.writeStructBegin('testVoid_args')
-  proto.writeFieldStop()
-  proto.writeStructEnd()
-  proto.writeMessageEnd()
-  trans.flush()
 
-  _, mtype, _ = proto.readMessageBegin()
-  assert mtype == TMessageType.REPLY
-  proto.readStructBegin()
-  _, ftype, _ = proto.readFieldBegin()
-  assert ftype == TType.STOP
-  proto.readFieldEnd()
-  proto.readStructEnd()
-  proto.readMessageEnd()
+  if not args.override_protocol or args.override_protocol == 'binary':
+    proto = TBinaryProtocol(trans)
+  elif args.override_protocol == 'compact':
+    proto = TCompactProtocol(trans)
+  else:
+    raise ValueError('invalid transport')
+
+  test_void(proto)
+  test_void(proto)
 
   trans.close()