Thrift TBinaryProtocol change

Summary: New Thrift TBinaryProtocol with a version identifier

Reviewed By: aditya, eugene

Test Plan: Modify your services to have strictRead_ and strictWrite_ both set to FALSE. Then redeploy your services and test running clients against them. Once you have clients and servers running stably on this new code, you should redploy versions with strictWrite_ set to TRUE. Once that's all good, we can set strictRead_ to TRUE as well, and eventually deprecate the old protocol code entirely.


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665138 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/src/protocol/TBinaryProtocol.cpp b/lib/cpp/src/protocol/TBinaryProtocol.cpp
index 5116e86..2645601 100644
--- a/lib/cpp/src/protocol/TBinaryProtocol.cpp
+++ b/lib/cpp/src/protocol/TBinaryProtocol.cpp
@@ -19,34 +19,34 @@
 // understanding_strict_aliasing.html
 template <typename To, typename From>
 static inline To bitwise_cast(From from) {
-	BOOST_STATIC_ASSERT(sizeof(From) == sizeof(To));
+  BOOST_STATIC_ASSERT(sizeof(From) == sizeof(To));
 
-	// BAD!!!  These are all broken with -O2.
-	//return *reinterpret_cast<To*>(&from);  // BAD!!!
-	//return *static_cast<To*>(static_cast<void*>(&from));  // BAD!!!
-	//return *(To*)(void*)&from;  // BAD!!!
+  // BAD!!!  These are all broken with -O2.
+  //return *reinterpret_cast<To*>(&from);  // BAD!!!
+  //return *static_cast<To*>(static_cast<void*>(&from));  // BAD!!!
+  //return *(To*)(void*)&from;  // BAD!!!
+  
+  // Super clean and paritally blessed by section 3.9 of the standard.
+  //unsigned char c[sizeof(from)];
+  //memcpy(c, &from, sizeof(from));
+  //To to;
+  //memcpy(&to, c, sizeof(c));
+  //return to;
 
-	// Super clean and paritally blessed by section 3.9 of the standard.
-	//unsigned char c[sizeof(from)];
-	//memcpy(c, &from, sizeof(from));
-	//To to;
-	//memcpy(&to, c, sizeof(c));
-	//return to;
+  // Slightly more questionable.
+  // Same code emitted by GCC.
+  //To to;
+  //memcpy(&to, &from, sizeof(from));
+  //return to;
 
-	// Slightly more questionable.
-	// Same code emitted by GCC.
-	//To to;
-	//memcpy(&to, &from, sizeof(from));
-	//return to;
-
-	// Technically undefined, but almost universally supported,
-	// and the most efficient implementation.
-	union {
-		From f;
-		To t;
-	} u;
-	u.f = from;
-	return u.t;
+  // Technically undefined, but almost universally supported,
+  // and the most efficient implementation.
+  union {
+    From f;
+    To t;
+  } u;
+  u.f = from;
+  return u.t;
 }
 
 
@@ -55,10 +55,18 @@
 uint32_t TBinaryProtocol::writeMessageBegin(const std::string& name,
                                             const TMessageType messageType,
                                             const int32_t seqid) {
-  return 
-    writeString(name) + 
-    writeByte((int8_t)messageType) +
-    writeI32(seqid);
+  if (strict_write_) {
+    int32_t version = (VERSION_1) | ((int32_t)messageType);
+    return
+      writeI32(version) +
+      writeString(name) +
+      writeI32(seqid);
+  } else {
+    return 
+      writeString(name) + 
+      writeByte((int8_t)messageType) +
+      writeI32(seqid);
+  }
 }
 
 uint32_t TBinaryProtocol::writeMessageEnd() {
@@ -181,13 +189,31 @@
 uint32_t TBinaryProtocol::readMessageBegin(std::string& name,
 					   TMessageType& messageType,
 					   int32_t& seqid) {
-
   uint32_t result = 0;
-  int8_t type;
-  result+= readString(name);
-  result+=  readByte(type);
-  messageType = (TMessageType)type;
-  result+= readI32(seqid);
+  int32_t sz;
+  result += readI32(sz);
+
+  if (sz < 0) {
+    // Check for correct version number
+    int32_t version = sz & VERSION_MASK;
+    if (version != VERSION_1) {
+      throw TProtocolException(TProtocolException::BAD_VERSION, "Bad version identifier");
+    }
+    messageType = (TMessageType)(sz & 0x000000ff);
+    result += readString(name);
+    result += readI32(seqid);
+  } else {
+    if (strict_read_) {
+      throw TProtocolException(TProtocolException::BAD_VERSION, "No version identifier... old protocol client in strict mode?");
+    } else {
+      // Handle pre-versioned input
+      int8_t type;
+      result += readStringBody(name, sz);
+      result += readByte(type);
+      messageType = (TMessageType)type;
+      result += readI32(seqid);
+    }
+  }
   return result;
 }
 
@@ -344,6 +370,11 @@
   uint32_t result;
   int32_t size;
   result = readI32(size);
+  return result + readStringBody(str, size);
+}
+
+uint32_t TBinaryProtocol::readStringBody(string& str, int32_t size) {
+  uint32_t result = 0;
 
   // Catch error cases
   if (size < 0) {
@@ -370,8 +401,7 @@
   }
   trans_->readAll(string_buf_, size);
   str = string((char*)string_buf_, size);
-
-  return result + (uint32_t)size;
+  return (uint32_t)size;
 }
 
 }}} // facebook::thrift::protocol
diff --git a/lib/cpp/src/protocol/TBinaryProtocol.h b/lib/cpp/src/protocol/TBinaryProtocol.h
index 89220e2..6cf2c4b 100644
--- a/lib/cpp/src/protocol/TBinaryProtocol.h
+++ b/lib/cpp/src/protocol/TBinaryProtocol.h
@@ -20,20 +20,30 @@
  * @author Mark Slee <mcslee@facebook.com>
  */
 class TBinaryProtocol : public TProtocol {
+ protected:
+  static const int32_t VERSION_MASK = 0xffff0000;
+  static const int32_t VERSION_1 = 0x80010000;
+
  public:
   TBinaryProtocol(boost::shared_ptr<TTransport> trans) :
     TProtocol(trans),
     string_limit_(0),
     container_limit_(0),
+    strict_read_(false),
+    strict_write_(true),
     string_buf_(NULL),
     string_buf_size_(0) {}
 
   TBinaryProtocol(boost::shared_ptr<TTransport> trans,
                   int32_t string_limit,
-                  int32_t container_limit) :
+                  int32_t container_limit,
+                  bool strict_read,
+                  bool strict_write) :
     TProtocol(trans),
     string_limit_(string_limit),
     container_limit_(container_limit),
+    strict_read_(strict_read),
+    strict_write_(strict_write),
     string_buf_(NULL),
     string_buf_size_(0) {}
 
@@ -52,6 +62,11 @@
     container_limit_ = container_limit;
   }
 
+  void setStrict(bool strict_read, bool strict_write) {
+    strict_read_ = strict_read;
+    strict_write_ = strict_write;
+  }
+
   /**
    * Writing functions.
    */
@@ -157,10 +172,17 @@
 
   uint32_t readString(std::string& str);
 
+ protected:
+  uint32_t readStringBody(std::string& str, int32_t sz);
+
  private:
   int32_t string_limit_;
   int32_t container_limit_;
 
+  // Enforce presence of version identifier
+  bool strict_read_;
+  bool strict_write_;
+
   // Buffer for reading strings, save for the lifetime of the protocol to
   // avoid memory churn allocating memory on every string read
   uint8_t* string_buf_;
@@ -175,11 +197,15 @@
  public:
   TBinaryProtocolFactory() :
     string_limit_(0),
-    container_limit_(0) {}
+    container_limit_(0),
+    strict_read_(false),
+    strict_write_(true) {}
 
-  TBinaryProtocolFactory(int32_t string_limit, int32_t container_limit) :
+  TBinaryProtocolFactory(int32_t string_limit, int32_t container_limit, bool strict_read, bool strict_write) :
     string_limit_(string_limit),
-    container_limit_(container_limit) {}
+    container_limit_(container_limit),
+    strict_read_(strict_read),
+    strict_write_(strict_write) {}
 
   virtual ~TBinaryProtocolFactory() {}
 
@@ -191,13 +217,20 @@
     container_limit_ = container_limit;
   }
 
+  void setStrict(bool strict_read, bool strict_write) {
+    strict_read_ = strict_read;
+    strict_write_ = strict_write;
+  }
+
   boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) {
-    return boost::shared_ptr<TProtocol>(new TBinaryProtocol(trans, string_limit_, container_limit_));
+    return boost::shared_ptr<TProtocol>(new TBinaryProtocol(trans, string_limit_, container_limit_, strict_read_, strict_write_));
   }
 
  private:
   int32_t string_limit_;
   int32_t container_limit_;
+  bool strict_read_;
+  bool strict_write_;
 
 };
 
diff --git a/lib/cpp/src/protocol/TProtocolException.h b/lib/cpp/src/protocol/TProtocolException.h
index 8f939c6..50959b7 100644
--- a/lib/cpp/src/protocol/TProtocolException.h
+++ b/lib/cpp/src/protocol/TProtocolException.h
@@ -31,7 +31,8 @@
     UNKNOWN = 0,
     INVALID_DATA = 1,
     NEGATIVE_SIZE = 2,
-    SIZE_LIMIT = 3
+    SIZE_LIMIT = 3,
+    BAD_VERSION = 4,
   };
 
   TProtocolException() :