Thrift TBinaryProtocol change

Summary: New Thrift TBinaryProtocol with a version identifier

Reviewed By: aditya, eugene

Test Plan: Modify your services to have strictRead_ and strictWrite_ both set to FALSE. Then redeploy your services and test running clients against them. Once you have clients and servers running stably on this new code, you should redploy versions with strictWrite_ set to TRUE. Once that's all good, we can set strictRead_ to TRUE as well, and eventually deprecate the old protocol code entirely.


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665138 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/src/protocol/TBinaryProtocol.cpp b/lib/cpp/src/protocol/TBinaryProtocol.cpp
index 5116e86..2645601 100644
--- a/lib/cpp/src/protocol/TBinaryProtocol.cpp
+++ b/lib/cpp/src/protocol/TBinaryProtocol.cpp
@@ -19,34 +19,34 @@
 // understanding_strict_aliasing.html
 template <typename To, typename From>
 static inline To bitwise_cast(From from) {
-	BOOST_STATIC_ASSERT(sizeof(From) == sizeof(To));
+  BOOST_STATIC_ASSERT(sizeof(From) == sizeof(To));
 
-	// BAD!!!  These are all broken with -O2.
-	//return *reinterpret_cast<To*>(&from);  // BAD!!!
-	//return *static_cast<To*>(static_cast<void*>(&from));  // BAD!!!
-	//return *(To*)(void*)&from;  // BAD!!!
+  // BAD!!!  These are all broken with -O2.
+  //return *reinterpret_cast<To*>(&from);  // BAD!!!
+  //return *static_cast<To*>(static_cast<void*>(&from));  // BAD!!!
+  //return *(To*)(void*)&from;  // BAD!!!
+  
+  // Super clean and paritally blessed by section 3.9 of the standard.
+  //unsigned char c[sizeof(from)];
+  //memcpy(c, &from, sizeof(from));
+  //To to;
+  //memcpy(&to, c, sizeof(c));
+  //return to;
 
-	// Super clean and paritally blessed by section 3.9 of the standard.
-	//unsigned char c[sizeof(from)];
-	//memcpy(c, &from, sizeof(from));
-	//To to;
-	//memcpy(&to, c, sizeof(c));
-	//return to;
+  // Slightly more questionable.
+  // Same code emitted by GCC.
+  //To to;
+  //memcpy(&to, &from, sizeof(from));
+  //return to;
 
-	// Slightly more questionable.
-	// Same code emitted by GCC.
-	//To to;
-	//memcpy(&to, &from, sizeof(from));
-	//return to;
-
-	// Technically undefined, but almost universally supported,
-	// and the most efficient implementation.
-	union {
-		From f;
-		To t;
-	} u;
-	u.f = from;
-	return u.t;
+  // Technically undefined, but almost universally supported,
+  // and the most efficient implementation.
+  union {
+    From f;
+    To t;
+  } u;
+  u.f = from;
+  return u.t;
 }
 
 
@@ -55,10 +55,18 @@
 uint32_t TBinaryProtocol::writeMessageBegin(const std::string& name,
                                             const TMessageType messageType,
                                             const int32_t seqid) {
-  return 
-    writeString(name) + 
-    writeByte((int8_t)messageType) +
-    writeI32(seqid);
+  if (strict_write_) {
+    int32_t version = (VERSION_1) | ((int32_t)messageType);
+    return
+      writeI32(version) +
+      writeString(name) +
+      writeI32(seqid);
+  } else {
+    return 
+      writeString(name) + 
+      writeByte((int8_t)messageType) +
+      writeI32(seqid);
+  }
 }
 
 uint32_t TBinaryProtocol::writeMessageEnd() {
@@ -181,13 +189,31 @@
 uint32_t TBinaryProtocol::readMessageBegin(std::string& name,
 					   TMessageType& messageType,
 					   int32_t& seqid) {
-
   uint32_t result = 0;
-  int8_t type;
-  result+= readString(name);
-  result+=  readByte(type);
-  messageType = (TMessageType)type;
-  result+= readI32(seqid);
+  int32_t sz;
+  result += readI32(sz);
+
+  if (sz < 0) {
+    // Check for correct version number
+    int32_t version = sz & VERSION_MASK;
+    if (version != VERSION_1) {
+      throw TProtocolException(TProtocolException::BAD_VERSION, "Bad version identifier");
+    }
+    messageType = (TMessageType)(sz & 0x000000ff);
+    result += readString(name);
+    result += readI32(seqid);
+  } else {
+    if (strict_read_) {
+      throw TProtocolException(TProtocolException::BAD_VERSION, "No version identifier... old protocol client in strict mode?");
+    } else {
+      // Handle pre-versioned input
+      int8_t type;
+      result += readStringBody(name, sz);
+      result += readByte(type);
+      messageType = (TMessageType)type;
+      result += readI32(seqid);
+    }
+  }
   return result;
 }
 
@@ -344,6 +370,11 @@
   uint32_t result;
   int32_t size;
   result = readI32(size);
+  return result + readStringBody(str, size);
+}
+
+uint32_t TBinaryProtocol::readStringBody(string& str, int32_t size) {
+  uint32_t result = 0;
 
   // Catch error cases
   if (size < 0) {
@@ -370,8 +401,7 @@
   }
   trans_->readAll(string_buf_, size);
   str = string((char*)string_buf_, size);
-
-  return result + (uint32_t)size;
+  return (uint32_t)size;
 }
 
 }}} // facebook::thrift::protocol
diff --git a/lib/cpp/src/protocol/TBinaryProtocol.h b/lib/cpp/src/protocol/TBinaryProtocol.h
index 89220e2..6cf2c4b 100644
--- a/lib/cpp/src/protocol/TBinaryProtocol.h
+++ b/lib/cpp/src/protocol/TBinaryProtocol.h
@@ -20,20 +20,30 @@
  * @author Mark Slee <mcslee@facebook.com>
  */
 class TBinaryProtocol : public TProtocol {
+ protected:
+  static const int32_t VERSION_MASK = 0xffff0000;
+  static const int32_t VERSION_1 = 0x80010000;
+
  public:
   TBinaryProtocol(boost::shared_ptr<TTransport> trans) :
     TProtocol(trans),
     string_limit_(0),
     container_limit_(0),
+    strict_read_(false),
+    strict_write_(true),
     string_buf_(NULL),
     string_buf_size_(0) {}
 
   TBinaryProtocol(boost::shared_ptr<TTransport> trans,
                   int32_t string_limit,
-                  int32_t container_limit) :
+                  int32_t container_limit,
+                  bool strict_read,
+                  bool strict_write) :
     TProtocol(trans),
     string_limit_(string_limit),
     container_limit_(container_limit),
+    strict_read_(strict_read),
+    strict_write_(strict_write),
     string_buf_(NULL),
     string_buf_size_(0) {}
 
@@ -52,6 +62,11 @@
     container_limit_ = container_limit;
   }
 
+  void setStrict(bool strict_read, bool strict_write) {
+    strict_read_ = strict_read;
+    strict_write_ = strict_write;
+  }
+
   /**
    * Writing functions.
    */
@@ -157,10 +172,17 @@
 
   uint32_t readString(std::string& str);
 
+ protected:
+  uint32_t readStringBody(std::string& str, int32_t sz);
+
  private:
   int32_t string_limit_;
   int32_t container_limit_;
 
+  // Enforce presence of version identifier
+  bool strict_read_;
+  bool strict_write_;
+
   // Buffer for reading strings, save for the lifetime of the protocol to
   // avoid memory churn allocating memory on every string read
   uint8_t* string_buf_;
@@ -175,11 +197,15 @@
  public:
   TBinaryProtocolFactory() :
     string_limit_(0),
-    container_limit_(0) {}
+    container_limit_(0),
+    strict_read_(false),
+    strict_write_(true) {}
 
-  TBinaryProtocolFactory(int32_t string_limit, int32_t container_limit) :
+  TBinaryProtocolFactory(int32_t string_limit, int32_t container_limit, bool strict_read, bool strict_write) :
     string_limit_(string_limit),
-    container_limit_(container_limit) {}
+    container_limit_(container_limit),
+    strict_read_(strict_read),
+    strict_write_(strict_write) {}
 
   virtual ~TBinaryProtocolFactory() {}
 
@@ -191,13 +217,20 @@
     container_limit_ = container_limit;
   }
 
+  void setStrict(bool strict_read, bool strict_write) {
+    strict_read_ = strict_read;
+    strict_write_ = strict_write;
+  }
+
   boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) {
-    return boost::shared_ptr<TProtocol>(new TBinaryProtocol(trans, string_limit_, container_limit_));
+    return boost::shared_ptr<TProtocol>(new TBinaryProtocol(trans, string_limit_, container_limit_, strict_read_, strict_write_));
   }
 
  private:
   int32_t string_limit_;
   int32_t container_limit_;
+  bool strict_read_;
+  bool strict_write_;
 
 };
 
diff --git a/lib/cpp/src/protocol/TProtocolException.h b/lib/cpp/src/protocol/TProtocolException.h
index 8f939c6..50959b7 100644
--- a/lib/cpp/src/protocol/TProtocolException.h
+++ b/lib/cpp/src/protocol/TProtocolException.h
@@ -31,7 +31,8 @@
     UNKNOWN = 0,
     INVALID_DATA = 1,
     NEGATIVE_SIZE = 2,
-    SIZE_LIMIT = 3
+    SIZE_LIMIT = 3,
+    BAD_VERSION = 4,
   };
 
   TProtocolException() :
diff --git a/lib/java/src/protocol/TBinaryProtocol.java b/lib/java/src/protocol/TBinaryProtocol.java
index be02a3e..7dff485 100644
--- a/lib/java/src/protocol/TBinaryProtocol.java
+++ b/lib/java/src/protocol/TBinaryProtocol.java
@@ -16,12 +16,30 @@
  */
 public class TBinaryProtocol extends TProtocol {
 
+  protected static final int VERSION_MASK = 0xffff0000;
+  protected static final int VERSION_1 = 0x80010000;
+
+  protected boolean strictRead_ = false;
+  protected boolean strictWrite_ = true;
+
   /**
    * Factory
    */
   public static class Factory implements TProtocolFactory {
+    protected boolean strictRead_ = false;
+    protected boolean strictWrite_ = true;
+    
+    public Factory() {
+      this(false, false);
+    }
+
+    public Factory(boolean strictRead, boolean strictWrite) {
+      strictRead_ = strictRead;
+      strictWrite_ = strictWrite;
+    }
+
     public TProtocol getProtocol(TTransport trans) {
-      return new TBinaryProtocol(trans);
+      return new TBinaryProtocol(trans, strictRead_, strictWrite_);
     }
   }
 
@@ -29,13 +47,27 @@
    * Constructor
    */
   public TBinaryProtocol(TTransport trans) {
-    super(trans);
+    this(trans, false, false);
   }
 
+  public TBinaryProtocol(TTransport trans, boolean strictRead, boolean strictWrite) {
+    super(trans);
+    strictRead_ = strictRead;
+    strictWrite_ = strictWrite;
+  }
+
+
   public void writeMessageBegin(TMessage message) throws TException {
-    writeString(message.name);
-    writeByte(message.type);
-    writeI32(message.seqid);
+    if (strictWrite_) {
+      int version = VERSION_1 | message.type;
+      writeI32(version);
+      writeString(message.name);
+      writeI32(message.seqid);
+    } else {
+      writeString(message.name);
+      writeByte(message.type);
+      writeI32(message.seqid);
+    }
   }
 
   public void writeMessageEnd() {}
@@ -137,9 +169,24 @@
 
   public TMessage readMessageBegin() throws TException {
     TMessage message = new TMessage();
-    message.name = readString();
-    message.type = readByte();
-    message.seqid = readI32();
+
+    int size = readI32();
+    if (size < 0) {
+      int version = size & VERSION_MASK;
+      if (version != VERSION_1) {
+        throw new TProtocolException(TProtocolException.BAD_VERSION, "Bad version in readMessageBegin");
+      }
+      message.type = (byte)(version & 0x000000ff);
+      message.name = readString();
+      message.seqid = readI32();
+    } else {
+      if (strictRead_) {
+        throw new TProtocolException(TProtocolException.BAD_VERSION, "Missing version in readMessageBegin, old client?");
+      }
+      message.name = readStringBody(size);
+      message.type = readByte();
+      message.seqid = readI32();
+    }
     return message;
   }
 
@@ -239,6 +286,10 @@
 
   public String readString() throws TException {
     int size = readI32();
+    return readStringBody(size);
+  }
+
+  public String readStringBody(int size) throws TException {
     byte[] buf = new byte[size];
     trans_.readAll(buf, 0, size);
     return new String(buf);
diff --git a/lib/java/src/protocol/TProtocolException.java b/lib/java/src/protocol/TProtocolException.java
index efc54b4..783b5be 100644
--- a/lib/java/src/protocol/TProtocolException.java
+++ b/lib/java/src/protocol/TProtocolException.java
@@ -19,6 +19,7 @@
   public static final int INVALID_DATA = 1;
   public static final int NEGATIVE_SIZE = 2;
   public static final int SIZE_LIMIT = 3;
+  public static final int BAD_VERSION = 4;
 
   protected int type_ = UNKNOWN;
 
diff --git a/lib/java/src/server/TThreadPoolServer.java b/lib/java/src/server/TThreadPoolServer.java
index 22930d5..0945fbe 100644
--- a/lib/java/src/server/TThreadPoolServer.java
+++ b/lib/java/src/server/TThreadPoolServer.java
@@ -58,6 +58,15 @@
 
   public TThreadPoolServer(TProcessor processor,
                            TServerTransport serverTransport,
+                           TProtocolFactory protocolFactory) {
+    this(processor, serverTransport,
+         new TTransportFactory(), new TTransportFactory(),
+         protocolFactory, protocolFactory,
+         new Options());
+  }
+
+  public TThreadPoolServer(TProcessor processor,
+                           TServerTransport serverTransport,
                            TTransportFactory transportFactory,
                            TProtocolFactory protocolFactory) {
     this(processor, serverTransport, 
diff --git a/lib/perl/lib/Thrift/BinaryProtocol.pm b/lib/perl/lib/Thrift/BinaryProtocol.pm
index c17fe91..c31e1d7 100644
--- a/lib/perl/lib/Thrift/BinaryProtocol.pm
+++ b/lib/perl/lib/Thrift/BinaryProtocol.pm
@@ -26,6 +26,9 @@
 package Thrift::BinaryProtocol;
 use base('Thrift::Protocol');
 
+use constant VERSION_MASK   => 0xffff0000;
+use constant VERSION_1      => 0x80010000;
+
 sub new
 {
     my $classname = shift;
@@ -41,8 +44,8 @@
     my ($name, $type, $seqid) = @_;
 
     return
+        $self->writeI32(VERSION_1 | $type) +
         $self->writeString($name) +
-        $self->writeByte($type) +
         $self->writeI32($seqid);
 }
 
@@ -224,9 +227,15 @@
     my $self = shift;
     my ($name, $type, $seqid) = @_;
 
+    my $version = 0;
+    my $result = $self->readI32($version);
+    if ($version & VERSION_MASK != VERSION_1) {
+      die new Thrift::TException('Missing version identifier')
+    }
+    $$type = $version & 0x000000ff;
     return
+        $result +
         $self->readString($name) +
-        $self->readByte($type) +
         $self->readI32($seqid);
 }
 
diff --git a/lib/perl/lib/Thrift/Protocol.pm b/lib/perl/lib/Thrift/Protocol.pm
index ceea52e..d58d8b0 100644
--- a/lib/perl/lib/Thrift/Protocol.pm
+++ b/lib/perl/lib/Thrift/Protocol.pm
@@ -26,7 +26,7 @@
 use constant INVALID_DATA  => 1;
 use constant NEGATIVE_SIZE => 2;
 use constant SIZE_LIMIT    => 3;
-
+use constant BAD_VERSION   => 4;
 
 sub new {
     my $classname = shift;
diff --git a/lib/php/src/protocol/TBinaryProtocol.php b/lib/php/src/protocol/TBinaryProtocol.php
index 962fcfe..69723da 100644
--- a/lib/php/src/protocol/TBinaryProtocol.php
+++ b/lib/php/src/protocol/TBinaryProtocol.php
@@ -19,15 +19,31 @@
  */
 class TBinaryProtocol extends TProtocol {
 
-  public function __construct($trans) {
+  const VERSION_MASK = 0xffff0000;
+  const VERSION_1 = 0x80010000;
+
+  private $strictRead_ = false;
+  private $strictWrite_ = true;
+
+  public function __construct($trans, $strictRead=false, $strictWrite=true) {
     parent::__construct($trans);
+    $this->strictRead_ = $strictRead;
+    $this->strictWrite_ = $strictWrite;
   }
 
   public function writeMessageBegin($name, $type, $seqid) {
-    return 
-      $this->writeString($name) +
-      $this->writeByte($type) +
-      $this->writeI32($seqid);
+    if ($this->strictWrite_) {
+      $version = self::VERSION_1 | $type;
+      return
+        $this->writeI32($version) +
+        $this->writeString($name) +
+        $this->writeI32($seqid);
+    } else {
+      return 
+        $this->writeString($name) +
+        $this->writeByte($type) +
+        $this->writeI32($seqid);
+    }
   }
 
   public function writeMessageEnd() {
@@ -50,7 +66,7 @@
 
   public function writeFieldEnd() {
     return 0;
-  } 
+  }
 
   public function writeFieldStop() {
     return
@@ -115,29 +131,29 @@
   public function writeI64($value) {
     // If we are on a 32bit architecture we have to explicitly deal with
     // 64-bit twos-complement arithmetic since PHP wants to treat all ints
-    // as signed and any int over 2^31 - 1 as a float   
+    // as signed and any int over 2^31 - 1 as a float
     if (PHP_INT_SIZE == 4) {
       $neg = $value < 0;
 
       if ($neg) {
-	$value *= -1;
+        $value *= -1;
       }
-   
+
       $hi = (int)($value / 4294967296);
       $lo = (int)$value;
-    
+
       if ($neg) {
-	$hi = ~$hi;
-	$lo = ~$lo;
-	if (($lo & (int)0xffffffff) == (int)0xffffffff) {
-	  $lo = 0;
-	  $hi++;
-	} else {
-	  $lo++;
-	}
+        $hi = ~$hi;
+        $lo = ~$lo;
+        if (($lo & (int)0xffffffff) == (int)0xffffffff) {
+          $lo = 0;
+          $hi++;
+        } else {
+          $lo++;
+        }
       }
       $data = pack('N2', $hi, $lo);
-    
+
     } else {
       $hi = $value >> 32;
       $lo = $value & 0xFFFFFFFF;
@@ -164,10 +180,29 @@
   }
 
   public function readMessageBegin(&$name, &$type, &$seqid) {
-    return 
-      $this->readString($name) +
-      $this->readByte($type) +
-      $this->readI32($seqid);
+    $result = $this->readI32($sz);
+    if ($sz < 0) {
+      $version = $sz & self::VERSION_MASK;
+      if ($version != self::VERSION_1) {
+        throw new TProtocolException('Bad version identifier: '.$sz, TProtocolException::BAD_VERSION);
+      }
+      $type = $sz & 0x000000ff;
+      $result +=
+        $this->readString($name) +
+        $this->readI32($seqid);
+    } else {
+      if ($this->strictRead_) {
+        throw new TProtocolException('No version identifier, old protocol client?', TProtocolException::BAD_VERSION);
+      } else {
+        // Handle pre-versioned input
+        $name = $this->trans_->readAll($sz);
+        $result += 
+          $sz +
+          $this->readByte($type) +
+          $this->readI32($seqid);
+      }
+    }
+    return $result;
   }
 
   public function readMessageEnd() {
@@ -266,7 +301,7 @@
     $data = $this->trans_->readAll(8);
 
     $arr = unpack('N2', $data);
-    
+
     // If we are on a 32bit architecture we have to explicitly deal with
     // 64-bit twos-complement arithmetic since PHP wants to treat all ints
     // as signed and any int over 2^31 - 1 as a float
@@ -350,8 +385,16 @@
  * Binary Protocol Factory
  */
 class TBinaryProtocolFactory implements TProtocolFactory {
+  private $strictRead_ = false;
+  private $strictWrite_ = false;
+
+  public function __construct($strictRead=false, $strictWrite=false) {
+    $this->strictRead_ = $strictRead;
+    $this->strictWrite_ = $strictWrite;
+  } 
+
   public function getProtocol($trans) {
-    return new TBinaryProtocol($trans);
+    return new TBinaryProtocol($trans, $this->strictRead, $this->strictWrite);
   }
 }
 
diff --git a/lib/php/src/protocol/TProtocol.php b/lib/php/src/protocol/TProtocol.php
index deddfa2..658e54c 100644
--- a/lib/php/src/protocol/TProtocol.php
+++ b/lib/php/src/protocol/TProtocol.php
@@ -28,6 +28,7 @@
   const INVALID_DATA = 1;
   const NEGATIVE_SIZE = 2;
   const SIZE_LIMIT = 3;
+  const BAD_VERSION = 4;
 
   function __construct($message=null, $code=0) {
     parent::__construct($message, $code);
diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py
index 93734d4..3c236fb 100644
--- a/lib/py/src/protocol/TBinaryProtocol.py
+++ b/lib/py/src/protocol/TBinaryProtocol.py
@@ -13,13 +13,23 @@
 
   """Binary implementation of the Thrift protocol driver."""
 
-  def __init__(self, trans):
+  VERSION_MASK = 0xffff0000
+  VERSION_1 = 0x80010000
+
+  def __init__(self, trans, strictRead=False, strictWrite=True):
     TProtocolBase.__init__(self, trans)
+    self.strictRead = strictRead
+    self.strictWrite = strictWrite
 
   def writeMessageBegin(self, name, type, seqid):
-    self.writeString(name)
-    self.writeByte(type)
-    self.writeI32(seqid)
+    if self.strictWrite:
+      self.writeI32(VERSION_1 | type)
+      self.writeString(name)
+      self.writeI32(seqid)
+    else:
+      self.writeString(name)
+      self.writeByte(type)
+      self.writeI32(seqid)
 
   def writeMessageEnd(self):
     pass
@@ -93,9 +103,20 @@
     self.trans.write(str)
 
   def readMessageBegin(self):
-    name = self.readString()
-    type = self.readByte()
-    seqid = self.readI32()
+    sz = self.readI32()
+    if sz < 0:
+      version = sz & VERSION_MASK
+      if version != VERSION_1:
+        raise TProtocolException(TProtocolException.BAD_VERSION, 'Bad version in readMessageBegin: %d' % (sz))
+      type = version & 0x000000ff
+      name = self.readString()
+      seqid = self.readI32()
+    else:
+      if self.strictRead:
+        raise TProtocolException(TProtocolException.BAD_VERSION, 'No protocol version header')
+      name = self.trans.readAll(sz)
+      type = self.readByte()
+      seqid = self.readI32()
     return (name, type, seqid)
 
   def readMessageEnd(self):
@@ -179,6 +200,10 @@
     return str
 
 class TBinaryProtocolFactory:
+  def __init__(self, strictRead=False, strictWrite=True):
+    self.strictRead = strictRead
+    self.strictWrite = strictWrite
+
   def getProtocol(self, trans):
-    prot = TBinaryProtocol(trans)
+    prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite)
     return prot
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index f7d0b34..146a802 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -16,6 +16,7 @@
   INVALID_DATA = 1
   NEGATIVE_SIZE = 2
   SIZE_LIMIT = 3
+  BAD_VERSION = 4
 
   def __init__(self, type=UNKNOWN, message=None):
     TException.__init__(self, message)
diff --git a/lib/rb/lib/thrift/protocol/tbinaryprotocol.rb b/lib/rb/lib/thrift/protocol/tbinaryprotocol.rb
index aaf7776..c8fa57d 100644
--- a/lib/rb/lib/thrift/protocol/tbinaryprotocol.rb
+++ b/lib/rb/lib/thrift/protocol/tbinaryprotocol.rb
@@ -11,13 +11,17 @@
 require 'thrift/protocol/tprotocol'
 
 class TBinaryProtocol < TProtocol
+
+  VERSION_MASK = 0xffff0000
+  VERSION_1 = 0x80010000
+
   def initialize(trans)
     super(trans)
   end
 
   def writeMessageBegin(name, type, seqid)
+    writeI32(VERSION_1 & type)
     writeString(name)
-    writeByte(type)
     writeI32(seqid)
   end
 
@@ -82,8 +86,12 @@
   end
 
   def readMessageBegin()
+    version = readI32()
+    if (version & VERSION_MASK != VERSION_1)
+      raise TProtocolException.new(TProtocolException::BAD_VERSION, 'Missing version identifier')
+    end
+    type = version & 0x000000ff
     name = readString()
-    type = readByte()
     seqid = readI32()
     return name, type, seqid
   end
diff --git a/lib/rb/lib/thrift/protocol/tprotocol.rb b/lib/rb/lib/thrift/protocol/tprotocol.rb
index 3c13b4b..da079d1 100644
--- a/lib/rb/lib/thrift/protocol/tprotocol.rb
+++ b/lib/rb/lib/thrift/protocol/tprotocol.rb
@@ -11,6 +11,23 @@
 
 require 'thrift/thrift'
 
+class TProtocolException < TException
+
+  UNKNOWN = 0
+  INVALID_DATA = 1
+  NEGATIVE_SIZE = 2
+  SIZE_LIMIT = 3
+  BAD_VERSION = 4
+
+  attr_reader :type
+
+  def initialize(type=UNKNOWN, message=nil)
+    super(message)
+    @type = type
+  end
+
+end
+
 class TProtocol
   
   attr_reader :trans