THRIFT-254. rb: Add optional strict version support to binary protocols
Author: Michael Stockton
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@740930 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/rb/ext/binary_protocol_accelerated.c b/lib/rb/ext/binary_protocol_accelerated.c
index 8a6757f..fc4b675 100644
--- a/lib/rb/ext/binary_protocol_accelerated.c
+++ b/lib/rb/ext/binary_protocol_accelerated.c
@@ -5,6 +5,8 @@
#include <struct.h>
#define GET_TRANSPORT(obj) rb_ivar_get(obj, transport_ivar_id)
+#define GET_STRICT_READ(obj) rb_ivar_get(obj, strict_read_ivar_id)
+#define GET_STRICT_WRITE(obj) rb_ivar_get(obj, strict_write_ivar_id)
#define WRITE(obj, data, length) rb_funcall(obj, write_method_id, 1, rb_str_new(data, length))
#define CHECK_NIL(obj) if (NIL_P(obj)) { rb_raise(rb_eStandardError, "nil argument not allowed!");}
@@ -16,6 +18,7 @@
static int VERSION_1;
static int VERSION_MASK;
+static int TYPE_MASK;
static int BAD_VERSION;
static void write_byte_direct(VALUE trans, int8_t b) {
@@ -97,9 +100,17 @@
VALUE rb_thrift_binary_proto_write_message_begin(VALUE self, VALUE name, VALUE type, VALUE seqid) {
VALUE trans = GET_TRANSPORT(self);
- write_i32_direct(trans, VERSION_1 | FIX2INT(type));
- write_string_direct(trans, name);
- write_i32_direct(trans, FIX2INT(seqid));
+ VALUE strict_write = GET_STRICT_WRITE(self);
+
+ if (strict_write == Qtrue) {
+ write_i32_direct(trans, VERSION_1 | FIX2INT(type));
+ write_string_direct(trans, name);
+ write_i32_direct(trans, FIX2INT(seqid));
+ } else {
+ write_string_direct(trans, name);
+ write_byte_direct(trans, type);
+ write_i32_direct(trans, FIX2INT(seqid));
+ }
return Qnil;
}
@@ -260,14 +271,27 @@
}
VALUE rb_thrift_binary_proto_read_message_begin(VALUE self) {
- int version = read_i32_direct(self);
- if ((version & VERSION_MASK) != VERSION_1) {
- rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("Missing version identifier")));
- }
+ VALUE strict_read = GET_STRICT_READ(self);
+ VALUE name, seqid;
+ int type;
- int type = version & 0x000000ff;
- VALUE name = rb_thrift_binary_proto_read_string(self);
- VALUE seqid = rb_thrift_binary_proto_read_i32(self);
+ int version = read_i32_direct(self);
+
+ if (version < 0) {
+ if ((version & VERSION_MASK) != VERSION_1) {
+ rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("Missing version identifier")));
+ }
+ type = version & TYPE_MASK;
+ name = rb_thrift_binary_proto_read_string(self);
+ seqid = rb_thrift_binary_proto_read_i32(self);
+ } else {
+ if (strict_read == Qtrue) {
+ rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("No version identifier, old protocol client?")));
+ }
+ name = READ(self, version);
+ type = rb_thrift_binary_proto_read_byte(self);
+ seqid = rb_thrift_binary_proto_read_i32(self);
+ }
return rb_ary_new3(3, name, INT2FIX(type), seqid);
}
@@ -339,6 +363,7 @@
VERSION_1 = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("VERSION_1")));
VERSION_MASK = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("VERSION_MASK")));
+ TYPE_MASK = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("TYPE_MASK")));
VALUE bpa_class = rb_define_class_under(thrift_module, "BinaryProtocolAccelerated", thrift_binary_protocol_class);
diff --git a/lib/rb/ext/constants.h b/lib/rb/ext/constants.h
index e540234..1922fb1 100644
--- a/lib/rb/ext/constants.h
+++ b/lib/rb/ext/constants.h
@@ -60,6 +60,8 @@
extern ID fields_const_id;
extern ID transport_ivar_id;
+extern ID strict_read_ivar_id;
+extern ID strict_write_ivar_id;
extern VALUE type_sym;
extern VALUE name_sym;
diff --git a/lib/rb/ext/thrift_native.c b/lib/rb/ext/thrift_native.c
index 89d32c5..4d5623d 100644
--- a/lib/rb/ext/thrift_native.c
+++ b/lib/rb/ext/thrift_native.c
@@ -73,6 +73,8 @@
// constant ids
ID fields_const_id;
ID transport_ivar_id;
+ID strict_read_ivar_id;
+ID strict_write_ivar_id;
// cached symbols
VALUE type_sym;
@@ -153,6 +155,8 @@
// constant ids
fields_const_id = rb_intern("FIELDS");
transport_ivar_id = rb_intern("@trans");
+ strict_read_ivar_id = rb_intern("@strict_read");
+ strict_write_ivar_id = rb_intern("@strict_write");
// cached symbols
type_sym = ID2SYM(rb_intern("type"));
diff --git a/lib/rb/lib/thrift/protocol/binaryprotocol.rb b/lib/rb/lib/thrift/protocol/binaryprotocol.rb
index f0fa3ad..c3d927a 100644
--- a/lib/rb/lib/thrift/protocol/binaryprotocol.rb
+++ b/lib/rb/lib/thrift/protocol/binaryprotocol.rb
@@ -13,14 +13,29 @@
class BinaryProtocol < Protocol
VERSION_MASK = 0xffff0000
VERSION_1 = 0x80010000
+ TYPE_MASK = 0x000000ff
+
+ attr_reader :strict_read, :strict_write
+
+ def initialize(trans, strict_read=true, strict_write=true)
+ super(trans)
+ @strict_read = strict_read
+ @strict_write = strict_write
+ end
def write_message_begin(name, type, seqid)
# this is necessary because we added (needed) bounds checking to
# write_i32, and 0x80010000 is too big for that.
- write_i16(VERSION_1 >> 16)
- write_i16(type)
- write_string(name)
- write_i32(seqid)
+ if strict_write
+ write_i16(VERSION_1 >> 16)
+ write_i16(type)
+ write_string(name)
+ write_i32(seqid)
+ else
+ write_string(name)
+ write_byte(type)
+ write_i32(seqid)
+ end
end
def write_field_begin(name, type, id)
@@ -82,13 +97,23 @@
def read_message_begin
version = read_i32
- if (version & VERSION_MASK != VERSION_1)
- raise ProtocolException.new(ProtocolException::BAD_VERSION, 'Missing version identifier')
+ if version < 0
+ if (version & VERSION_MASK != VERSION_1)
+ raise ProtocolException.new(ProtocolException::BAD_VERSION, 'Missing version identifier')
+ end
+ type = version & TYPE_MASK
+ name = read_string
+ seqid = read_i32
+ [name, type, seqid]
+ else
+ if strict_read
+ raise ProtocolException.new(ProtocolException::BAD_VERSION, 'No version identifier, old protocol client?')
+ end
+ name = trans.read_all(version)
+ type = read_byte
+ seqid = read_i32
+ [name, type, seqid]
end
- type = version & 0x000000ff
- name = read_string
- seqid = read_i32
- [name, type, seqid]
end
def read_field_begin
diff --git a/lib/rb/spec/binaryprotocol_spec.rb b/lib/rb/spec/binaryprotocol_spec.rb
index b85f096..2d5b375 100644
--- a/lib/rb/spec/binaryprotocol_spec.rb
+++ b/lib/rb/spec/binaryprotocol_spec.rb
@@ -13,17 +13,28 @@
end
it "should read a message header" do
- @prot.should_receive(:read_i32).and_return(protocol_class.const_get(:VERSION_1) | Thrift::MessageTypes::REPLY, 42)
+ @trans.should_receive(:read_all).exactly(2).times.and_return(
+ [protocol_class.const_get(:VERSION_1) | Thrift::MessageTypes::REPLY].pack('N'),
+ [42].pack('N')
+ )
@prot.should_receive(:read_string).and_return('testMessage')
@prot.read_message_begin.should == ['testMessage', Thrift::MessageTypes::REPLY, 42]
end
it "should raise an exception if the message header has the wrong version" do
- @prot.should_receive(:read_i32).and_return(42)
+ @prot.should_receive(:read_i32).and_return(-1)
lambda { @prot.read_message_begin }.should raise_error(Thrift::ProtocolException, 'Missing version identifier') do |e|
e.type == Thrift::ProtocolException::BAD_VERSION
end
end
+
+ it "should raise an exception if the message header does not exist and strict_read is enabled" do
+ @prot.should_receive(:read_i32).and_return(42)
+ @prot.should_receive(:strict_read).and_return(true)
+ lambda { @prot.read_message_begin }.should raise_error(Thrift::ProtocolException, 'No version identifier, old protocol client?') do |e|
+ e.type == Thrift::ProtocolException::BAD_VERSION
+ end
+ end
end
describe BinaryProtocolFactory do
diff --git a/lib/rb/spec/binaryprotocol_spec_shared.rb b/lib/rb/spec/binaryprotocol_spec_shared.rb
index 78e2ccb..1d685b6 100644
--- a/lib/rb/spec/binaryprotocol_spec_shared.rb
+++ b/lib/rb/spec/binaryprotocol_spec_shared.rb
@@ -6,15 +6,37 @@
@prot = protocol_class.new(@trans)
end
- it "should define the proper VERSION_1 and VERSION_MASK" do
+ it "should define the proper VERSION_1, VERSION_MASK AND TYPE_MASK" do
protocol_class.const_get(:VERSION_MASK).should == 0xffff0000
protocol_class.const_get(:VERSION_1).should == 0x80010000
+ protocol_class.const_get(:TYPE_MASK).should == 0x000000ff
end
+ it "should make strict_read readable" do
+ @prot.strict_read.should eql(true)
+ end
+
+ it "should make strict_write readable" do
+ @prot.strict_write.should eql(true)
+ end
+
it "should write the message header" do
@prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17)
@trans.read(1000).should == [protocol_class.const_get(:VERSION_1) | Thrift::MessageTypes::CALL, "testMessage".size, "testMessage", 17].pack("NNa11N")
end
+
+ it "should write the message header without version when writes are not strict" do
+ @prot = protocol_class.new(@trans, true, false) # no strict write
+ @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17)
+ @trans.read(1000).should == "\000\000\000\vtestMessage\001\000\000\000\021"
+ end
+
+ it "should write the message header with a version when writes are strict" do
+ @prot = protocol_class.new(@trans) # strict write
+ @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17)
+ @trans.read(1000).should == "\200\001\000\001\000\000\000\vtestMessage\000\000\000\021"
+ end
+
# message footer is a noop