THRIFT-697. Union support in Ruby
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@910700 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/rb/ext/struct.c b/lib/rb/ext/struct.c
index 7429fb1..d459ddb 100644
--- a/lib/rb/ext/struct.c
+++ b/lib/rb/ext/struct.c
@@ -45,29 +45,18 @@
static native_proto_method_table *mt;
static native_proto_method_table *default_mt;
-// static VALUE last_proto_class = Qnil;
+
+VALUE thrift_union_class;
+
+ID setfield_id;
+ID setvalue_id;
+
+ID to_s_method_id;
+ID name_to_id_method_id;
#define IS_CONTAINER(ttype) ((ttype) == TTYPE_MAP || (ttype) == TTYPE_LIST || (ttype) == TTYPE_SET)
#define STRUCT_FIELDS(obj) rb_const_get(CLASS_OF(obj), fields_const_id)
-// static void set_native_proto_function_pointers(VALUE protocol) {
-// VALUE method_table_object = rb_const_get(CLASS_OF(protocol), rb_intern("@native_method_table"));
-// // TODO: check nil?
-// Data_Get_Struct(method_table_object, native_proto_method_table, mt);
-// }
-
-// static void check_native_proto_method_table(VALUE protocol) {
-// VALUE protoclass = CLASS_OF(protocol);
-// if (protoclass != last_proto_class) {
-// last_proto_class = protoclass;
-// if (rb_funcall(protocol, native_qmark_method_id, 0) == Qtrue) {
-// set_native_proto_function_pointers(protocol);
-// } else {
-// mt = default_mt;
-// }
-// }
-// }
-
//-------------------------------------------
// Writing section
//-------------------------------------------
@@ -275,62 +264,62 @@
// end default protocol methods
-
+static VALUE rb_thrift_union_write (VALUE self, VALUE protocol);
static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol);
static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_info);
VALUE get_field_value(VALUE obj, VALUE field_name) {
char name_buf[RSTRING_LEN(field_name) + 1];
-
+
name_buf[0] = '@';
strlcpy(&name_buf[1], RSTRING_PTR(field_name), sizeof(name_buf));
VALUE value = rb_ivar_get(obj, rb_intern(name_buf));
-
+
return value;
}
static void write_container(int ttype, VALUE field_info, VALUE value, VALUE protocol) {
int sz, i;
-
+
if (ttype == TTYPE_MAP) {
VALUE keys;
VALUE key;
VALUE val;
Check_Type(value, T_HASH);
-
+
VALUE key_info = rb_hash_aref(field_info, key_sym);
VALUE keytype_value = rb_hash_aref(key_info, type_sym);
int keytype = FIX2INT(keytype_value);
-
+
VALUE value_info = rb_hash_aref(field_info, value_sym);
VALUE valuetype_value = rb_hash_aref(value_info, type_sym);
int valuetype = FIX2INT(valuetype_value);
-
+
keys = rb_funcall(value, keys_method_id, 0);
-
+
sz = RARRAY_LEN(keys);
-
+
mt->write_map_begin(protocol, keytype_value, valuetype_value, INT2FIX(sz));
-
+
for (i = 0; i < sz; i++) {
key = rb_ary_entry(keys, i);
val = rb_hash_aref(value, key);
-
+
if (IS_CONTAINER(keytype)) {
write_container(keytype, key_info, key, protocol);
} else {
write_anything(keytype, key, protocol, key_info);
}
-
+
if (IS_CONTAINER(valuetype)) {
write_container(valuetype, value_info, val, protocol);
} else {
write_anything(valuetype, val, protocol, value_info);
}
}
-
+
mt->write_map_end(protocol);
} else if (ttype == TTYPE_LIST) {
Check_Type(value, T_ARRAY);
@@ -340,7 +329,7 @@
VALUE element_type_info = rb_hash_aref(field_info, element_sym);
VALUE element_type_value = rb_hash_aref(element_type_info, type_sym);
int element_type = FIX2INT(element_type_value);
-
+
mt->write_list_begin(protocol, element_type_value, INT2FIX(sz));
for (i = 0; i < sz; ++i) {
VALUE val = rb_ary_entry(value, i);
@@ -370,9 +359,9 @@
VALUE element_type_info = rb_hash_aref(field_info, element_sym);
VALUE element_type_value = rb_hash_aref(element_type_info, type_sym);
int element_type = FIX2INT(element_type_value);
-
+
mt->write_set_begin(protocol, element_type_value, INT2FIX(sz));
-
+
for (i = 0; i < sz; i++) {
VALUE val = rb_ary_entry(items, i);
if (IS_CONTAINER(element_type)) {
@@ -381,7 +370,7 @@
write_anything(element_type, val, protocol, element_type_info);
}
}
-
+
mt->write_set_end(protocol);
} else {
rb_raise(rb_eNotImpError, "can't write container of type: %d", ttype);
@@ -406,7 +395,11 @@
} else if (IS_CONTAINER(ttype)) {
write_container(ttype, field_info, value, protocol);
} else if (ttype == TTYPE_STRUCT) {
- rb_thrift_struct_write(value, protocol);
+ if (rb_obj_is_kind_of(value, thrift_union_class)) {
+ rb_thrift_union_write(value, protocol);
+ } else {
+ rb_thrift_struct_write(value, protocol);
+ }
} else {
rb_raise(rb_eNotImpError, "Unknown type for binary_encoding: %d", ttype);
}
@@ -423,24 +416,27 @@
// iterate through all the fields here
VALUE struct_fields = STRUCT_FIELDS(self);
+
VALUE struct_field_ids_unordered = rb_funcall(struct_fields, keys_method_id, 0);
VALUE struct_field_ids_ordered = rb_funcall(struct_field_ids_unordered, sort_method_id, 0);
int i = 0;
for (i=0; i < RARRAY_LEN(struct_field_ids_ordered); i++) {
VALUE field_id = rb_ary_entry(struct_field_ids_ordered, i);
+
VALUE field_info = rb_hash_aref(struct_fields, field_id);
VALUE ttype_value = rb_hash_aref(field_info, type_sym);
int ttype = FIX2INT(ttype_value);
VALUE field_name = rb_hash_aref(field_info, name_sym);
+
VALUE field_value = get_field_value(self, field_name);
if (!NIL_P(field_value)) {
mt->write_field_begin(protocol, field_name, ttype_value, field_id);
-
+
write_anything(ttype, field_value, protocol, field_info);
-
+
mt->write_field_end(protocol);
}
}
@@ -457,6 +453,7 @@
// Reading section
//-------------------------------------------
+static VALUE rb_thrift_union_read(VALUE self, VALUE protocol);
static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol);
static void set_field_value(VALUE obj, VALUE field_name, VALUE value) {
@@ -488,7 +485,12 @@
} else if (ttype == TTYPE_STRUCT) {
VALUE klass = rb_hash_aref(field_info, class_sym);
result = rb_class_new_instance(0, NULL, klass);
- rb_thrift_struct_read(result, protocol);
+
+ if (rb_obj_is_kind_of(result, thrift_union_class)) {
+ rb_thrift_union_read(result, protocol);
+ } else {
+ rb_thrift_struct_read(result, protocol);
+ }
} else if (ttype == TTYPE_MAP) {
int i;
@@ -524,7 +526,6 @@
rb_ary_push(result, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym)));
}
-
mt->read_list_end(protocol);
} else if (ttype == TTYPE_SET) {
VALUE items;
@@ -539,7 +540,6 @@
rb_ary_push(items, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym)));
}
-
mt->read_set_end(protocol);
result = rb_class_new_instance(1, &items, rb_cSet);
@@ -597,13 +597,110 @@
return Qnil;
}
+
+// --------------------------------
+// Union section
+// --------------------------------
+
+static VALUE rb_thrift_union_read(VALUE self, VALUE protocol) {
+ // read struct begin
+ mt->read_struct_begin(protocol);
+
+ VALUE struct_fields = STRUCT_FIELDS(self);
+
+ VALUE field_header = mt->read_field_begin(protocol);
+ VALUE field_type_value = rb_ary_entry(field_header, 1);
+ int field_type = FIX2INT(field_type_value);
+
+ // make sure we got a type we expected
+ VALUE field_info = rb_hash_aref(struct_fields, rb_ary_entry(field_header, 2));
+
+ if (!NIL_P(field_info)) {
+ int specified_type = FIX2INT(rb_hash_aref(field_info, type_sym));
+ if (field_type == specified_type) {
+ // read the value
+ VALUE name = rb_hash_aref(field_info, name_sym);
+ rb_iv_set(self, "@setfield", ID2SYM(rb_intern(RSTRING_PTR(name))));
+ rb_iv_set(self, "@value", read_anything(protocol, field_type, field_info));
+ } else {
+ rb_funcall(protocol, skip_method_id, 1, field_type_value);
+ }
+ } else {
+ rb_funcall(protocol, skip_method_id, 1, field_type_value);
+ }
+
+ // read field end
+ mt->read_field_end(protocol);
+
+ field_header = mt->read_field_begin(protocol);
+ field_type_value = rb_ary_entry(field_header, 1);
+ field_type = FIX2INT(field_type_value);
+
+ if (field_type != TTYPE_STOP) {
+ rb_raise(rb_eRuntimeError, "too many fields in union!");
+ }
+
+ // read field end
+ mt->read_field_end(protocol);
+
+ // read struct end
+ mt->read_struct_end(protocol);
+
+ // call validate
+ rb_funcall(self, validate_method_id, 0);
+
+ return Qnil;
+}
+
+static VALUE rb_thrift_union_write(VALUE self, VALUE protocol) {
+ // call validate
+ rb_funcall(self, validate_method_id, 0);
+
+ // write struct begin
+ mt->write_struct_begin(protocol, rb_class_name(CLASS_OF(self)));
+
+ VALUE struct_fields = STRUCT_FIELDS(self);
+
+ VALUE setfield = rb_ivar_get(self, setfield_id);
+ VALUE setvalue = rb_ivar_get(self, setvalue_id);
+ VALUE field_id = rb_funcall(self, name_to_id_method_id, 1, rb_funcall(setfield, to_s_method_id, 0));
+
+ VALUE field_info = rb_hash_aref(struct_fields, field_id);
+
+ VALUE ttype_value = rb_hash_aref(field_info, type_sym);
+ int ttype = FIX2INT(ttype_value);
+
+ mt->write_field_begin(protocol, setfield, ttype_value, field_id);
+
+ write_anything(ttype, setvalue, protocol, field_info);
+
+ mt->write_field_end(protocol);
+
+ mt->write_field_stop(protocol);
+
+ // write struct end
+ mt->write_struct_end(protocol);
+
+ return Qnil;
+}
+
void Init_struct() {
VALUE struct_module = rb_const_get(thrift_module, rb_intern("Struct"));
rb_define_method(struct_module, "write", rb_thrift_struct_write, 1);
rb_define_method(struct_module, "read", rb_thrift_struct_read, 1);
+ thrift_union_class = rb_const_get(thrift_module, rb_intern("Union"));
+
+ rb_define_method(thrift_union_class, "write", rb_thrift_union_write, 1);
+ rb_define_method(thrift_union_class, "read", rb_thrift_union_read, 1);
+
+ setfield_id = rb_intern("@setfield");
+ setvalue_id = rb_intern("@value");
+
+ to_s_method_id = rb_intern("to_s");
+ name_to_id_method_id = rb_intern("name_to_id");
+
set_default_proto_function_pointers();
mt = default_mt;
-}
-
+}
\ No newline at end of file