THRIFT-254. rb: Add optional strict version support to binary protocols

Author: Michael Stockton

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@740930 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/rb/ext/binary_protocol_accelerated.c b/lib/rb/ext/binary_protocol_accelerated.c
index 8a6757f..fc4b675 100644
--- a/lib/rb/ext/binary_protocol_accelerated.c
+++ b/lib/rb/ext/binary_protocol_accelerated.c
@@ -5,6 +5,8 @@
 #include <struct.h>
 
 #define GET_TRANSPORT(obj) rb_ivar_get(obj, transport_ivar_id)
+#define GET_STRICT_READ(obj) rb_ivar_get(obj, strict_read_ivar_id)
+#define GET_STRICT_WRITE(obj) rb_ivar_get(obj, strict_write_ivar_id)
 #define WRITE(obj, data, length) rb_funcall(obj, write_method_id, 1, rb_str_new(data, length))
 #define CHECK_NIL(obj) if (NIL_P(obj)) { rb_raise(rb_eStandardError, "nil argument not allowed!");}
 
@@ -16,6 +18,7 @@
 
 static int VERSION_1;
 static int VERSION_MASK;
+static int TYPE_MASK;
 static int BAD_VERSION;
 
 static void write_byte_direct(VALUE trans, int8_t b) {
@@ -97,9 +100,17 @@
 
 VALUE rb_thrift_binary_proto_write_message_begin(VALUE self, VALUE name, VALUE type, VALUE seqid) {
   VALUE trans = GET_TRANSPORT(self);
-  write_i32_direct(trans, VERSION_1 | FIX2INT(type));
-  write_string_direct(trans, name);
-  write_i32_direct(trans, FIX2INT(seqid));
+  VALUE strict_write = GET_STRICT_WRITE(self);
+
+  if (strict_write == Qtrue) {
+    write_i32_direct(trans, VERSION_1 | FIX2INT(type));
+    write_string_direct(trans, name);
+    write_i32_direct(trans, FIX2INT(seqid));
+  } else {
+    write_string_direct(trans, name);
+    write_byte_direct(trans, type);
+    write_i32_direct(trans, FIX2INT(seqid));
+  }
   
   return Qnil;
 }
@@ -260,14 +271,27 @@
 }
 
 VALUE rb_thrift_binary_proto_read_message_begin(VALUE self) {
-  int version = read_i32_direct(self);
-  if ((version & VERSION_MASK) != VERSION_1) {
-    rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("Missing version identifier")));
-  }
+  VALUE strict_read = GET_STRICT_READ(self);
+  VALUE name, seqid;
+  int type;
   
-  int type = version & 0x000000ff;
-  VALUE name = rb_thrift_binary_proto_read_string(self);
-  VALUE seqid = rb_thrift_binary_proto_read_i32(self);
+  int version = read_i32_direct(self);
+  
+  if (version < 0) {
+    if ((version & VERSION_MASK) != VERSION_1) {
+      rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("Missing version identifier")));
+    }
+    type = version & TYPE_MASK;
+    name = rb_thrift_binary_proto_read_string(self);
+    seqid = rb_thrift_binary_proto_read_i32(self);
+  } else {
+    if (strict_read == Qtrue) {
+      rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("No version identifier, old protocol client?")));
+    }
+    name = READ(self, version);
+    type = rb_thrift_binary_proto_read_byte(self);
+    seqid = rb_thrift_binary_proto_read_i32(self);
+  }
   
   return rb_ary_new3(3, name, INT2FIX(type), seqid);
 }
@@ -339,6 +363,7 @@
   
   VERSION_1 = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("VERSION_1")));
   VERSION_MASK = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("VERSION_MASK")));
+  TYPE_MASK = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("TYPE_MASK")));
   
   VALUE bpa_class = rb_define_class_under(thrift_module, "BinaryProtocolAccelerated", thrift_binary_protocol_class);
   
diff --git a/lib/rb/ext/constants.h b/lib/rb/ext/constants.h
index e540234..1922fb1 100644
--- a/lib/rb/ext/constants.h
+++ b/lib/rb/ext/constants.h
@@ -60,6 +60,8 @@
 
 extern ID fields_const_id;
 extern ID transport_ivar_id;
+extern ID strict_read_ivar_id;
+extern ID strict_write_ivar_id;
 
 extern VALUE type_sym;
 extern VALUE name_sym;
diff --git a/lib/rb/ext/thrift_native.c b/lib/rb/ext/thrift_native.c
index 89d32c5..4d5623d 100644
--- a/lib/rb/ext/thrift_native.c
+++ b/lib/rb/ext/thrift_native.c
@@ -73,6 +73,8 @@
 // constant ids
 ID fields_const_id;
 ID transport_ivar_id;
+ID strict_read_ivar_id;
+ID strict_write_ivar_id;
 
 // cached symbols
 VALUE type_sym;
@@ -153,6 +155,8 @@
   // constant ids
   fields_const_id = rb_intern("FIELDS");
   transport_ivar_id = rb_intern("@trans");
+  strict_read_ivar_id = rb_intern("@strict_read");
+  strict_write_ivar_id = rb_intern("@strict_write");  
   
   // cached symbols
   type_sym = ID2SYM(rb_intern("type"));