THRIFT-697. Union support in Ruby
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@910700 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/rb/ext/struct.c b/lib/rb/ext/struct.c
index 7429fb1..d459ddb 100644
--- a/lib/rb/ext/struct.c
+++ b/lib/rb/ext/struct.c
@@ -45,29 +45,18 @@
static native_proto_method_table *mt;
static native_proto_method_table *default_mt;
-// static VALUE last_proto_class = Qnil;
+
+VALUE thrift_union_class;
+
+ID setfield_id;
+ID setvalue_id;
+
+ID to_s_method_id;
+ID name_to_id_method_id;
#define IS_CONTAINER(ttype) ((ttype) == TTYPE_MAP || (ttype) == TTYPE_LIST || (ttype) == TTYPE_SET)
#define STRUCT_FIELDS(obj) rb_const_get(CLASS_OF(obj), fields_const_id)
-// static void set_native_proto_function_pointers(VALUE protocol) {
-// VALUE method_table_object = rb_const_get(CLASS_OF(protocol), rb_intern("@native_method_table"));
-// // TODO: check nil?
-// Data_Get_Struct(method_table_object, native_proto_method_table, mt);
-// }
-
-// static void check_native_proto_method_table(VALUE protocol) {
-// VALUE protoclass = CLASS_OF(protocol);
-// if (protoclass != last_proto_class) {
-// last_proto_class = protoclass;
-// if (rb_funcall(protocol, native_qmark_method_id, 0) == Qtrue) {
-// set_native_proto_function_pointers(protocol);
-// } else {
-// mt = default_mt;
-// }
-// }
-// }
-
//-------------------------------------------
// Writing section
//-------------------------------------------
@@ -275,62 +264,62 @@
// end default protocol methods
-
+static VALUE rb_thrift_union_write (VALUE self, VALUE protocol);
static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol);
static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_info);
VALUE get_field_value(VALUE obj, VALUE field_name) {
char name_buf[RSTRING_LEN(field_name) + 1];
-
+
name_buf[0] = '@';
strlcpy(&name_buf[1], RSTRING_PTR(field_name), sizeof(name_buf));
VALUE value = rb_ivar_get(obj, rb_intern(name_buf));
-
+
return value;
}
static void write_container(int ttype, VALUE field_info, VALUE value, VALUE protocol) {
int sz, i;
-
+
if (ttype == TTYPE_MAP) {
VALUE keys;
VALUE key;
VALUE val;
Check_Type(value, T_HASH);
-
+
VALUE key_info = rb_hash_aref(field_info, key_sym);
VALUE keytype_value = rb_hash_aref(key_info, type_sym);
int keytype = FIX2INT(keytype_value);
-
+
VALUE value_info = rb_hash_aref(field_info, value_sym);
VALUE valuetype_value = rb_hash_aref(value_info, type_sym);
int valuetype = FIX2INT(valuetype_value);
-
+
keys = rb_funcall(value, keys_method_id, 0);
-
+
sz = RARRAY_LEN(keys);
-
+
mt->write_map_begin(protocol, keytype_value, valuetype_value, INT2FIX(sz));
-
+
for (i = 0; i < sz; i++) {
key = rb_ary_entry(keys, i);
val = rb_hash_aref(value, key);
-
+
if (IS_CONTAINER(keytype)) {
write_container(keytype, key_info, key, protocol);
} else {
write_anything(keytype, key, protocol, key_info);
}
-
+
if (IS_CONTAINER(valuetype)) {
write_container(valuetype, value_info, val, protocol);
} else {
write_anything(valuetype, val, protocol, value_info);
}
}
-
+
mt->write_map_end(protocol);
} else if (ttype == TTYPE_LIST) {
Check_Type(value, T_ARRAY);
@@ -340,7 +329,7 @@
VALUE element_type_info = rb_hash_aref(field_info, element_sym);
VALUE element_type_value = rb_hash_aref(element_type_info, type_sym);
int element_type = FIX2INT(element_type_value);
-
+
mt->write_list_begin(protocol, element_type_value, INT2FIX(sz));
for (i = 0; i < sz; ++i) {
VALUE val = rb_ary_entry(value, i);
@@ -370,9 +359,9 @@
VALUE element_type_info = rb_hash_aref(field_info, element_sym);
VALUE element_type_value = rb_hash_aref(element_type_info, type_sym);
int element_type = FIX2INT(element_type_value);
-
+
mt->write_set_begin(protocol, element_type_value, INT2FIX(sz));
-
+
for (i = 0; i < sz; i++) {
VALUE val = rb_ary_entry(items, i);
if (IS_CONTAINER(element_type)) {
@@ -381,7 +370,7 @@
write_anything(element_type, val, protocol, element_type_info);
}
}
-
+
mt->write_set_end(protocol);
} else {
rb_raise(rb_eNotImpError, "can't write container of type: %d", ttype);
@@ -406,7 +395,11 @@
} else if (IS_CONTAINER(ttype)) {
write_container(ttype, field_info, value, protocol);
} else if (ttype == TTYPE_STRUCT) {
- rb_thrift_struct_write(value, protocol);
+ if (rb_obj_is_kind_of(value, thrift_union_class)) {
+ rb_thrift_union_write(value, protocol);
+ } else {
+ rb_thrift_struct_write(value, protocol);
+ }
} else {
rb_raise(rb_eNotImpError, "Unknown type for binary_encoding: %d", ttype);
}
@@ -423,24 +416,27 @@
// iterate through all the fields here
VALUE struct_fields = STRUCT_FIELDS(self);
+
VALUE struct_field_ids_unordered = rb_funcall(struct_fields, keys_method_id, 0);
VALUE struct_field_ids_ordered = rb_funcall(struct_field_ids_unordered, sort_method_id, 0);
int i = 0;
for (i=0; i < RARRAY_LEN(struct_field_ids_ordered); i++) {
VALUE field_id = rb_ary_entry(struct_field_ids_ordered, i);
+
VALUE field_info = rb_hash_aref(struct_fields, field_id);
VALUE ttype_value = rb_hash_aref(field_info, type_sym);
int ttype = FIX2INT(ttype_value);
VALUE field_name = rb_hash_aref(field_info, name_sym);
+
VALUE field_value = get_field_value(self, field_name);
if (!NIL_P(field_value)) {
mt->write_field_begin(protocol, field_name, ttype_value, field_id);
-
+
write_anything(ttype, field_value, protocol, field_info);
-
+
mt->write_field_end(protocol);
}
}
@@ -457,6 +453,7 @@
// Reading section
//-------------------------------------------
+static VALUE rb_thrift_union_read(VALUE self, VALUE protocol);
static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol);
static void set_field_value(VALUE obj, VALUE field_name, VALUE value) {
@@ -488,7 +485,12 @@
} else if (ttype == TTYPE_STRUCT) {
VALUE klass = rb_hash_aref(field_info, class_sym);
result = rb_class_new_instance(0, NULL, klass);
- rb_thrift_struct_read(result, protocol);
+
+ if (rb_obj_is_kind_of(result, thrift_union_class)) {
+ rb_thrift_union_read(result, protocol);
+ } else {
+ rb_thrift_struct_read(result, protocol);
+ }
} else if (ttype == TTYPE_MAP) {
int i;
@@ -524,7 +526,6 @@
rb_ary_push(result, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym)));
}
-
mt->read_list_end(protocol);
} else if (ttype == TTYPE_SET) {
VALUE items;
@@ -539,7 +540,6 @@
rb_ary_push(items, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym)));
}
-
mt->read_set_end(protocol);
result = rb_class_new_instance(1, &items, rb_cSet);
@@ -597,13 +597,110 @@
return Qnil;
}
+
+// --------------------------------
+// Union section
+// --------------------------------
+
+static VALUE rb_thrift_union_read(VALUE self, VALUE protocol) {
+ // read struct begin
+ mt->read_struct_begin(protocol);
+
+ VALUE struct_fields = STRUCT_FIELDS(self);
+
+ VALUE field_header = mt->read_field_begin(protocol);
+ VALUE field_type_value = rb_ary_entry(field_header, 1);
+ int field_type = FIX2INT(field_type_value);
+
+ // make sure we got a type we expected
+ VALUE field_info = rb_hash_aref(struct_fields, rb_ary_entry(field_header, 2));
+
+ if (!NIL_P(field_info)) {
+ int specified_type = FIX2INT(rb_hash_aref(field_info, type_sym));
+ if (field_type == specified_type) {
+ // read the value
+ VALUE name = rb_hash_aref(field_info, name_sym);
+ rb_iv_set(self, "@setfield", ID2SYM(rb_intern(RSTRING_PTR(name))));
+ rb_iv_set(self, "@value", read_anything(protocol, field_type, field_info));
+ } else {
+ rb_funcall(protocol, skip_method_id, 1, field_type_value);
+ }
+ } else {
+ rb_funcall(protocol, skip_method_id, 1, field_type_value);
+ }
+
+ // read field end
+ mt->read_field_end(protocol);
+
+ field_header = mt->read_field_begin(protocol);
+ field_type_value = rb_ary_entry(field_header, 1);
+ field_type = FIX2INT(field_type_value);
+
+ if (field_type != TTYPE_STOP) {
+ rb_raise(rb_eRuntimeError, "too many fields in union!");
+ }
+
+ // read field end
+ mt->read_field_end(protocol);
+
+ // read struct end
+ mt->read_struct_end(protocol);
+
+ // call validate
+ rb_funcall(self, validate_method_id, 0);
+
+ return Qnil;
+}
+
+static VALUE rb_thrift_union_write(VALUE self, VALUE protocol) {
+ // call validate
+ rb_funcall(self, validate_method_id, 0);
+
+ // write struct begin
+ mt->write_struct_begin(protocol, rb_class_name(CLASS_OF(self)));
+
+ VALUE struct_fields = STRUCT_FIELDS(self);
+
+ VALUE setfield = rb_ivar_get(self, setfield_id);
+ VALUE setvalue = rb_ivar_get(self, setvalue_id);
+ VALUE field_id = rb_funcall(self, name_to_id_method_id, 1, rb_funcall(setfield, to_s_method_id, 0));
+
+ VALUE field_info = rb_hash_aref(struct_fields, field_id);
+
+ VALUE ttype_value = rb_hash_aref(field_info, type_sym);
+ int ttype = FIX2INT(ttype_value);
+
+ mt->write_field_begin(protocol, setfield, ttype_value, field_id);
+
+ write_anything(ttype, setvalue, protocol, field_info);
+
+ mt->write_field_end(protocol);
+
+ mt->write_field_stop(protocol);
+
+ // write struct end
+ mt->write_struct_end(protocol);
+
+ return Qnil;
+}
+
void Init_struct() {
VALUE struct_module = rb_const_get(thrift_module, rb_intern("Struct"));
rb_define_method(struct_module, "write", rb_thrift_struct_write, 1);
rb_define_method(struct_module, "read", rb_thrift_struct_read, 1);
+ thrift_union_class = rb_const_get(thrift_module, rb_intern("Union"));
+
+ rb_define_method(thrift_union_class, "write", rb_thrift_union_write, 1);
+ rb_define_method(thrift_union_class, "read", rb_thrift_union_read, 1);
+
+ setfield_id = rb_intern("@setfield");
+ setvalue_id = rb_intern("@value");
+
+ to_s_method_id = rb_intern("to_s");
+ name_to_id_method_id = rb_intern("name_to_id");
+
set_default_proto_function_pointers();
mt = default_mt;
-}
-
+}
\ No newline at end of file
diff --git a/lib/rb/ext/struct.h b/lib/rb/ext/struct.h
index 37b1b35..48ccef8 100644
--- a/lib/rb/ext/struct.h
+++ b/lib/rb/ext/struct.h
@@ -17,6 +17,7 @@
* under the License.
*/
+
#include <stdbool.h>
#include <ruby.h>
@@ -41,7 +42,7 @@
VALUE (*write_field_stop)(VALUE);
VALUE (*write_message_begin)(VALUE, VALUE, VALUE, VALUE);
VALUE (*write_message_end)(VALUE);
-
+
VALUE (*read_message_begin)(VALUE);
VALUE (*read_message_end)(VALUE);
VALUE (*read_field_begin)(VALUE);
@@ -61,7 +62,7 @@
VALUE (*read_string)(VALUE);
VALUE (*read_struct_begin)(VALUE);
VALUE (*read_struct_end)(VALUE);
-
} native_proto_method_table;
void Init_struct();
+void Init_union();
diff --git a/lib/rb/ext/thrift_native.c b/lib/rb/ext/thrift_native.c
index effa202..09b9fe4 100644
--- a/lib/rb/ext/thrift_native.c
+++ b/lib/rb/ext/thrift_native.c
@@ -111,7 +111,7 @@
thrift_types_module = rb_const_get(thrift_module, rb_intern("Types"));
rb_cSet = rb_const_get(rb_cObject, rb_intern("Set"));
protocol_exception_class = rb_const_get(thrift_module, rb_intern("ProtocolException"));
-
+
// Init ttype constants
TTYPE_BOOL = FIX2INT(rb_const_get(thrift_types_module, rb_intern("BOOL")));
TTYPE_BYTE = FIX2INT(rb_const_get(thrift_types_module, rb_intern("BYTE")));
@@ -171,13 +171,13 @@
write_method_id = rb_intern("write");
read_all_method_id = rb_intern("read_all");
native_qmark_method_id = rb_intern("native?");
-
+
// constant ids
fields_const_id = rb_intern("FIELDS");
transport_ivar_id = rb_intern("@trans");
strict_read_ivar_id = rb_intern("@strict_read");
strict_write_ivar_id = rb_intern("@strict_write");
-
+
// cached symbols
type_sym = ID2SYM(rb_intern("type"));
name_sym = ID2SYM(rb_intern("name"));
diff --git a/lib/rb/lib/thrift.rb b/lib/rb/lib/thrift.rb
index 4d4e130..02d67b8 100644
--- a/lib/rb/lib/thrift.rb
+++ b/lib/rb/lib/thrift.rb
@@ -28,6 +28,8 @@
require 'thrift/processor'
require 'thrift/client'
require 'thrift/struct'
+require 'thrift/union'
+require 'thrift/struct_union'
# serializer
require 'thrift/serializer/serializer'
diff --git a/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb b/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb
index eaf64f6..70ea652 100644
--- a/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb
+++ b/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb
@@ -29,7 +29,11 @@
module Thrift
class BinaryProtocolAcceleratedFactory < BaseProtocolFactory
def get_protocol(trans)
- BinaryProtocolAccelerated.new(trans)
+ if (defined? BinaryProtocolAccelerated)
+ BinaryProtocolAccelerated.new(trans)
+ else
+ BinaryProtocol.new(trans)
+ end
end
end
end
diff --git a/lib/rb/lib/thrift/struct.rb b/lib/rb/lib/thrift/struct.rb
index dfc8a2f..9e52073 100644
--- a/lib/rb/lib/thrift/struct.rb
+++ b/lib/rb/lib/thrift/struct.rb
@@ -65,26 +65,7 @@
end
fields_with_default_values
end
-
- def name_to_id(name)
- names_to_ids = self.class.instance_variable_get("@names_to_ids")
- unless names_to_ids
- names_to_ids = {}
- struct_fields.each do |fid, field_def|
- names_to_ids[field_def[:name]] = fid
- end
- self.class.instance_variable_set("@names_to_ids", names_to_ids)
- end
- names_to_ids[name]
- end
-
- def each_field
- struct_fields.keys.sort.each do |fid|
- data = struct_fields[fid]
- yield fid, data
- end
- end
-
+
def inspect(skip_optional_nulls = true)
fields = []
each_field do |fid, field_info|
@@ -115,7 +96,8 @@
each_field do |fid, field_info|
name = field_info[:name]
type = field_info[:type]
- if (value = instance_variable_get("@#{name}"))
+ value = instance_variable_get("@#{name}")
+ unless value.nil?
if is_container? type
oprot.write_field_begin(name, type, fid)
write_container(oprot, value, field_info)
@@ -210,89 +192,5 @@
iprot.skip(ftype)
end
end
-
- def read_field(iprot, field = {})
- case field[:type]
- when Types::STRUCT
- value = field[:class].new
- value.read(iprot)
- when Types::MAP
- key_type, val_type, size = iprot.read_map_begin
- value = {}
- size.times do
- k = read_field(iprot, field_info(field[:key]))
- v = read_field(iprot, field_info(field[:value]))
- value[k] = v
- end
- iprot.read_map_end
- when Types::LIST
- e_type, size = iprot.read_list_begin
- value = Array.new(size) do |n|
- read_field(iprot, field_info(field[:element]))
- end
- iprot.read_list_end
- when Types::SET
- e_type, size = iprot.read_set_begin
- value = Set.new
- size.times do
- element = read_field(iprot, field_info(field[:element]))
- value << element
- end
- iprot.read_set_end
- else
- value = iprot.read_type(field[:type])
- end
- value
- end
-
- def write_data(oprot, value, field)
- if is_container? field[:type]
- write_container(oprot, value, field)
- else
- oprot.write_type(field[:type], value)
- end
- end
-
- def write_container(oprot, value, field = {})
- case field[:type]
- when Types::MAP
- oprot.write_map_begin(field[:key][:type], field[:value][:type], value.size)
- value.each do |k, v|
- write_data(oprot, k, field[:key])
- write_data(oprot, v, field[:value])
- end
- oprot.write_map_end
- when Types::LIST
- oprot.write_list_begin(field[:element][:type], value.size)
- value.each do |elem|
- write_data(oprot, elem, field[:element])
- end
- oprot.write_list_end
- when Types::SET
- oprot.write_set_begin(field[:element][:type], value.size)
- value.each do |v,| # the , is to preserve compatibility with the old Hash-style sets
- write_data(oprot, v, field[:element])
- end
- oprot.write_set_end
- else
- raise "Not a container type: #{field[:type]}"
- end
- end
-
- CONTAINER_TYPES = []
- CONTAINER_TYPES[Types::LIST] = true
- CONTAINER_TYPES[Types::MAP] = true
- CONTAINER_TYPES[Types::SET] = true
- def is_container?(type)
- CONTAINER_TYPES[type]
- end
-
- def field_info(field)
- { :type => field[:type],
- :class => field[:class],
- :key => field[:key],
- :value => field[:value],
- :element => field[:element] }
- end
end
end
diff --git a/lib/rb/lib/thrift/struct_union.rb b/lib/rb/lib/thrift/struct_union.rb
new file mode 100644
index 0000000..9a5903f
--- /dev/null
+++ b/lib/rb/lib/thrift/struct_union.rb
@@ -0,0 +1,126 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+require 'set'
+
+module Thrift
+ module Struct_Union
+ def name_to_id(name)
+ names_to_ids = self.class.instance_variable_get("@names_to_ids")
+ unless names_to_ids
+ names_to_ids = {}
+ struct_fields.each do |fid, field_def|
+ names_to_ids[field_def[:name]] = fid
+ end
+ self.class.instance_variable_set("@names_to_ids", names_to_ids)
+ end
+ names_to_ids[name]
+ end
+
+ def each_field
+ struct_fields.keys.sort.each do |fid|
+ data = struct_fields[fid]
+ yield fid, data
+ end
+ end
+
+ def read_field(iprot, field = {})
+ case field[:type]
+ when Types::STRUCT
+ value = field[:class].new
+ value.read(iprot)
+ when Types::MAP
+ key_type, val_type, size = iprot.read_map_begin
+ value = {}
+ size.times do
+ k = read_field(iprot, field_info(field[:key]))
+ v = read_field(iprot, field_info(field[:value]))
+ value[k] = v
+ end
+ iprot.read_map_end
+ when Types::LIST
+ e_type, size = iprot.read_list_begin
+ value = Array.new(size) do |n|
+ read_field(iprot, field_info(field[:element]))
+ end
+ iprot.read_list_end
+ when Types::SET
+ e_type, size = iprot.read_set_begin
+ value = Set.new
+ size.times do
+ element = read_field(iprot, field_info(field[:element]))
+ value << element
+ end
+ iprot.read_set_end
+ else
+ value = iprot.read_type(field[:type])
+ end
+ value
+ end
+
+ def write_data(oprot, value, field)
+ if is_container? field[:type]
+ write_container(oprot, value, field)
+ else
+ oprot.write_type(field[:type], value)
+ end
+ end
+
+ def write_container(oprot, value, field = {})
+ case field[:type]
+ when Types::MAP
+ oprot.write_map_begin(field[:key][:type], field[:value][:type], value.size)
+ value.each do |k, v|
+ write_data(oprot, k, field[:key])
+ write_data(oprot, v, field[:value])
+ end
+ oprot.write_map_end
+ when Types::LIST
+ oprot.write_list_begin(field[:element][:type], value.size)
+ value.each do |elem|
+ write_data(oprot, elem, field[:element])
+ end
+ oprot.write_list_end
+ when Types::SET
+ oprot.write_set_begin(field[:element][:type], value.size)
+ value.each do |v,| # the , is to preserve compatibility with the old Hash-style sets
+ write_data(oprot, v, field[:element])
+ end
+ oprot.write_set_end
+ else
+ raise "Not a container type: #{field[:type]}"
+ end
+ end
+
+ CONTAINER_TYPES = []
+ CONTAINER_TYPES[Types::LIST] = true
+ CONTAINER_TYPES[Types::MAP] = true
+ CONTAINER_TYPES[Types::SET] = true
+ def is_container?(type)
+ CONTAINER_TYPES[type]
+ end
+
+ def field_info(field)
+ { :type => field[:type],
+ :class => field[:class],
+ :key => field[:key],
+ :value => field[:value],
+ :element => field[:element] }
+ end
+ end
+end
\ No newline at end of file
diff --git a/lib/rb/lib/thrift/types.rb b/lib/rb/lib/thrift/types.rb
index 20e4ca2..cac5269 100644
--- a/lib/rb/lib/thrift/types.rb
+++ b/lib/rb/lib/thrift/types.rb
@@ -57,7 +57,7 @@
when Types::STRING
String
when Types::STRUCT
- Struct
+ [Struct, Union]
when Types::MAP
Hash
when Types::SET
diff --git a/lib/rb/lib/thrift/union.rb b/lib/rb/lib/thrift/union.rb
new file mode 100644
index 0000000..0b41ed4
--- /dev/null
+++ b/lib/rb/lib/thrift/union.rb
@@ -0,0 +1,128 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+module Thrift
+ class Union
+ def initialize(name=nil, value=nil)
+ if name
+ if value.nil?
+ raise Exception, "Union #{self.class} cannot be instantiated with setfield and nil value!"
+ end
+
+ Thrift.check_type(value, struct_fields[name_to_id(name.to_s)], name) if Thrift.type_checking
+ elsif !value.nil?
+ raise Exception, "Value provided, but no name!"
+ end
+ @setfield = name
+ @value = value
+ end
+
+ def inspect
+ "<#{self.class} #{@setfield}: #{@value}>"
+ end
+
+ def read(iprot)
+ iprot.read_struct_begin
+ fname, ftype, fid = iprot.read_field_begin
+ handle_message(iprot, fid, ftype)
+ iprot.read_field_end
+
+ fname, ftype, fid = iprot.read_field_begin
+ raise "Too many fields for union" unless (ftype == Types::STOP)
+
+ iprot.read_struct_end
+ validate
+ end
+
+ def write(oprot)
+ validate
+ oprot.write_struct_begin(self.class.name)
+
+ fid = self.name_to_id(@setfield.to_s)
+
+ field_info = struct_fields[fid]
+ type = field_info[:type]
+ if is_container? type
+ oprot.write_field_begin(@setfield, type, fid)
+ write_container(oprot, @value, field_info)
+ oprot.write_field_end
+ else
+ oprot.write_field(@setfield, type, fid, @value)
+ end
+
+ oprot.write_field_stop
+ oprot.write_struct_end
+ end
+
+ def ==(other)
+ other != nil && @setfield == other.get_set_field && @value == other.get_value
+ end
+
+ def eql?(other)
+ self.class == other.class && self == other
+ end
+
+ def hash
+ [self.class.name, @setfield, @value].hash
+ end
+
+ def self.field_accessor(klass, *fields)
+ fields.each do |field|
+ klass.send :define_method, "#{field}" do
+ if field == @setfield
+ @value
+ else
+ raise RuntimeError, "#{field} is not union's set field."
+ end
+ end
+
+ klass.send :define_method, "#{field}=" do |value|
+ Thrift.check_type(value, klass::FIELDS.values.find {|f| f[:name].to_s == field.to_s }, field) if Thrift.type_checking
+ @setfield = field
+ @value = value
+ end
+ end
+ end
+
+ # get the symbol that indicates what the currently set field type is.
+ def get_set_field
+ @setfield
+ end
+
+ # get the current value of this union, regardless of what the set field is.
+ # generally, you should only use this method when you don't know in advance
+ # what field to expect.
+ def get_value
+ @value
+ end
+
+ protected
+
+ def handle_message(iprot, fid, ftype)
+ field = struct_fields[fid]
+ if field and field[:type] == ftype
+ @value = read_field(iprot, field)
+ name = field[:name].to_sym
+ @setfield = name
+ else
+ iprot.skip(ftype)
+ end
+ end
+ end
+end
\ No newline at end of file
diff --git a/lib/rb/spec/ThriftSpec.thrift b/lib/rb/spec/ThriftSpec.thrift
index fe5a8aa..f5c8c09 100644
--- a/lib/rb/spec/ThriftSpec.thrift
+++ b/lib/rb/spec/ThriftSpec.thrift
@@ -42,6 +42,38 @@
1: string greeting = "hello world"
}
+union My_union {
+ 1: bool im_true,
+ 2: byte a_bite,
+ 3: i16 integer16,
+ 4: i32 integer32,
+ 5: i64 integer64,
+ 6: double double_precision,
+ 7: string some_characters,
+ 8: i32 other_i32
+}
+
+struct Struct_with_union {
+ 1: My_union fun_union
+ 2: i32 integer32
+ 3: string some_characters
+}
+
+enum SomeEnum {
+ ONE
+ TWO
+}
+
+union TestUnion {
+ /**
+ * A doc string
+ */
+ 1: string string_field;
+ 2: i32 i32_field;
+ 3: i32 other_i32_field;
+ 4: SomeEnum enum_field;
+}
+
struct Foo {
1: i32 simple = 53,
2: string words = "words",
diff --git a/lib/rb/spec/binary_protocol_accelerated_spec.rb b/lib/rb/spec/binary_protocol_accelerated_spec.rb
index 48c22e4..b8518c8 100644
--- a/lib/rb/spec/binary_protocol_accelerated_spec.rb
+++ b/lib/rb/spec/binary_protocol_accelerated_spec.rb
@@ -20,22 +20,27 @@
require File.dirname(__FILE__) + '/spec_helper'
require File.dirname(__FILE__) + '/binary_protocol_spec_shared'
-class ThriftBinaryProtocolAcceleratedSpec < Spec::ExampleGroup
- include Thrift
+if defined? Thrift::BinaryProtocolAccelerated
- describe Thrift::BinaryProtocolAccelerated do
- # since BinaryProtocolAccelerated should be directly equivalent to
- # BinaryProtocol, we don't need any custom specs!
- it_should_behave_like 'a binary protocol'
+ class ThriftBinaryProtocolAcceleratedSpec < Spec::ExampleGroup
+ include Thrift
- def protocol_class
- BinaryProtocolAccelerated
+ describe Thrift::BinaryProtocolAccelerated do
+ # since BinaryProtocolAccelerated should be directly equivalent to
+ # BinaryProtocol, we don't need any custom specs!
+ it_should_behave_like 'a binary protocol'
+
+ def protocol_class
+ BinaryProtocolAccelerated
+ end
+ end
+
+ describe BinaryProtocolAcceleratedFactory do
+ it "should create a BinaryProtocolAccelerated" do
+ BinaryProtocolAcceleratedFactory.new.get_protocol(mock("MockTransport")).should be_instance_of(BinaryProtocolAccelerated)
+ end
end
end
-
- describe BinaryProtocolAcceleratedFactory do
- it "should create a BinaryProtocolAccelerated" do
- BinaryProtocolAcceleratedFactory.new.get_protocol(mock("MockTransport")).should be_instance_of(BinaryProtocolAccelerated)
- end
- end
-end
+else
+ puts "skipping BinaryProtocolAccelerated spec because it is not defined."
+end
\ No newline at end of file
diff --git a/lib/rb/spec/binary_protocol_spec_shared.rb b/lib/rb/spec/binary_protocol_spec_shared.rb
index 84f5920..28da760 100644
--- a/lib/rb/spec/binary_protocol_spec_shared.rb
+++ b/lib/rb/spec/binary_protocol_spec_shared.rb
@@ -349,9 +349,9 @@
# first block
firstblock.call(client)
-
+
processor.process(serverproto, serverproto)
-
+
# second block
secondblock.call(client)
ensure
diff --git a/lib/rb/spec/union_spec.rb b/lib/rb/spec/union_spec.rb
new file mode 100644
index 0000000..4835288
--- /dev/null
+++ b/lib/rb/spec/union_spec.rb
@@ -0,0 +1,145 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+require File.dirname(__FILE__) + '/spec_helper'
+
+class ThriftUnionSpec < Spec::ExampleGroup
+ include Thrift
+ include SpecNamespace
+
+ describe Union do
+ it "should return nil value in unset union" do
+ union = My_union.new
+ union.get_set_field.should == nil
+ union.get_value.should == nil
+ end
+
+ it "should set a field and be accessible through get_value and the named field accessor" do
+ union = My_union.new
+ union.integer32 = 25
+ union.get_set_field.should == :integer32
+ union.get_value.should == 25
+ union.integer32.should == 25
+ end
+
+ it "should work correctly when instantiated with static field constructors" do
+ union = My_union.integer32(5)
+ union.get_set_field.should == :integer32
+ union.integer32.should == 5
+ end
+
+ it "should raise for wrong set field" do
+ union = My_union.new
+ union.integer32 = 25
+ lambda { union.some_characters }.should raise_error(RuntimeError, "some_characters is not union's set field.")
+ end
+
+ it "should not be equal to nil" do
+ union = My_union.new
+ union.should_not == nil
+ end
+
+ it "should not equate two different unions, i32 vs. string" do
+ union = My_union.new(:integer32, 25)
+ other_union = My_union.new(:some_characters, "blah!")
+ union.should_not == other_union
+ end
+
+ it "should properly reset setfield and setvalue" do
+ union = My_union.new(:integer32, 25)
+ union.get_set_field.should == :integer32
+ union.some_characters = "blah!"
+ union.get_set_field.should == :some_characters
+ union.get_value.should == "blah!"
+ lambda { union.integer32 }.should raise_error(RuntimeError, "integer32 is not union's set field.")
+ end
+
+ it "should not equate two different unions with different values" do
+ union = My_union.new(:integer32, 25)
+ other_union = My_union.new(:integer32, 400)
+ union.should_not == other_union
+ end
+
+ it "should not equate two different unions with different fields" do
+ union = My_union.new(:integer32, 25)
+ other_union = My_union.new(:other_i32, 25)
+ union.should_not == other_union
+ end
+
+ it "should inspect properly" do
+ union = My_union.new(:integer32, 25)
+ union.inspect.should == "<SpecNamespace::My_union integer32: 25>"
+ end
+
+ it "should not allow setting with instance_variable_set" do
+ union = My_union.new(:integer32, 27)
+ union.instance_variable_set(:@some_characters, "hallo!")
+ union.get_set_field.should == :integer32
+ union.get_value.should == 27
+ lambda { union.some_characters }.should raise_error(RuntimeError, "some_characters is not union's set field.")
+ end
+
+ it "should serialize correctly" do
+ trans = Thrift::MemoryBufferTransport.new
+ proto = Thrift::BinaryProtocol.new(trans)
+
+ union = My_union.new(:integer32, 25)
+ union.write(proto)
+
+ other_union = My_union.new(:integer32, 25)
+ other_union.read(proto)
+ other_union.should == union
+ end
+
+ it "should raise when validating unset union" do
+ union = My_union.new
+ lambda { union.validate }.should raise_error(StandardError, "Union fields are not set.")
+
+ other_union = My_union.new(:integer32, 1)
+ lambda { other_union.validate }.should_not raise_error(StandardError, "Union fields are not set.")
+ end
+
+ it "should validate an enum field properly" do
+ union = TestUnion.new(:enum_field, 3)
+ union.get_set_field.should == :enum_field
+ lambda { union.validate }.should raise_error(ProtocolException, "Invalid value of field enum_field!")
+
+ other_union = TestUnion.new(:enum_field, 1)
+ lambda { other_union.validate }.should_not raise_error(ProtocolException, "Invalid value of field enum_field!")
+ end
+
+ it "should properly serialize and match structs with a union" do
+ union = My_union.new(:integer32, 26)
+ swu = Struct_with_union.new(:fun_union => union)
+
+ trans = Thrift::MemoryBufferTransport.new
+ proto = Thrift::CompactProtocol.new(trans)
+
+ swu.write(proto)
+
+ other_union = My_union.new(:some_characters, "hello there")
+ swu2 = Struct_with_union.new(:fun_union => other_union)
+
+ swu2.should_not == swu
+
+ swu2.read(proto)
+ swu2.should == swu
+ end
+ end
+end