THRIFT-697. Union support in Ruby

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@910700 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/rb/ext/struct.c b/lib/rb/ext/struct.c
index 7429fb1..d459ddb 100644
--- a/lib/rb/ext/struct.c
+++ b/lib/rb/ext/struct.c
@@ -45,29 +45,18 @@
 
 static native_proto_method_table *mt;
 static native_proto_method_table *default_mt;
-// static VALUE last_proto_class = Qnil;
+
+VALUE thrift_union_class;
+
+ID setfield_id;
+ID setvalue_id;
+
+ID to_s_method_id;
+ID name_to_id_method_id;
 
 #define IS_CONTAINER(ttype) ((ttype) == TTYPE_MAP || (ttype) == TTYPE_LIST || (ttype) == TTYPE_SET)
 #define STRUCT_FIELDS(obj) rb_const_get(CLASS_OF(obj), fields_const_id)
 
-// static void set_native_proto_function_pointers(VALUE protocol) {
-//   VALUE method_table_object = rb_const_get(CLASS_OF(protocol), rb_intern("@native_method_table"));
-//   // TODO: check nil?
-//   Data_Get_Struct(method_table_object, native_proto_method_table, mt);
-// }
-
-// static void check_native_proto_method_table(VALUE protocol) {
-//   VALUE protoclass = CLASS_OF(protocol);
-//   if (protoclass != last_proto_class) {
-//     last_proto_class = protoclass;
-//     if (rb_funcall(protocol, native_qmark_method_id, 0) == Qtrue) {
-//       set_native_proto_function_pointers(protocol);
-//     } else {
-//       mt = default_mt;
-//     }
-//   }
-// }
-
 //-------------------------------------------
 // Writing section
 //-------------------------------------------
@@ -275,62 +264,62 @@
 
 // end default protocol methods
 
-
+static VALUE rb_thrift_union_write (VALUE self, VALUE protocol);
 static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol);
 static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_info);
 
 VALUE get_field_value(VALUE obj, VALUE field_name) {
   char name_buf[RSTRING_LEN(field_name) + 1];
-  
+
   name_buf[0] = '@';
   strlcpy(&name_buf[1], RSTRING_PTR(field_name), sizeof(name_buf));
 
   VALUE value = rb_ivar_get(obj, rb_intern(name_buf));
-  
+
   return value;
 }
 
 static void write_container(int ttype, VALUE field_info, VALUE value, VALUE protocol) {
   int sz, i;
-  
+
   if (ttype == TTYPE_MAP) {
     VALUE keys;
     VALUE key;
     VALUE val;
 
     Check_Type(value, T_HASH);
-    
+
     VALUE key_info = rb_hash_aref(field_info, key_sym);
     VALUE keytype_value = rb_hash_aref(key_info, type_sym);
     int keytype = FIX2INT(keytype_value);
-    
+
     VALUE value_info = rb_hash_aref(field_info, value_sym);
     VALUE valuetype_value = rb_hash_aref(value_info, type_sym);
     int valuetype = FIX2INT(valuetype_value);
-    
+
     keys = rb_funcall(value, keys_method_id, 0);
-    
+
     sz = RARRAY_LEN(keys);
-    
+
     mt->write_map_begin(protocol, keytype_value, valuetype_value, INT2FIX(sz));
-    
+
     for (i = 0; i < sz; i++) {
       key = rb_ary_entry(keys, i);
       val = rb_hash_aref(value, key);
-      
+
       if (IS_CONTAINER(keytype)) {
         write_container(keytype, key_info, key, protocol);
       } else {
         write_anything(keytype, key, protocol, key_info);
       }
-      
+
       if (IS_CONTAINER(valuetype)) {
         write_container(valuetype, value_info, val, protocol);
       } else {
         write_anything(valuetype, val, protocol, value_info);
       }
     }
-    
+
     mt->write_map_end(protocol);
   } else if (ttype == TTYPE_LIST) {
     Check_Type(value, T_ARRAY);
@@ -340,7 +329,7 @@
     VALUE element_type_info = rb_hash_aref(field_info, element_sym);
     VALUE element_type_value = rb_hash_aref(element_type_info, type_sym);
     int element_type = FIX2INT(element_type_value);
-    
+
     mt->write_list_begin(protocol, element_type_value, INT2FIX(sz));
     for (i = 0; i < sz; ++i) {
       VALUE val = rb_ary_entry(value, i);
@@ -370,9 +359,9 @@
     VALUE element_type_info = rb_hash_aref(field_info, element_sym);
     VALUE element_type_value = rb_hash_aref(element_type_info, type_sym);
     int element_type = FIX2INT(element_type_value);
-    
+
     mt->write_set_begin(protocol, element_type_value, INT2FIX(sz));
-    
+
     for (i = 0; i < sz; i++) {
       VALUE val = rb_ary_entry(items, i);
       if (IS_CONTAINER(element_type)) {
@@ -381,7 +370,7 @@
         write_anything(element_type, val, protocol, element_type_info);
       }
     }
-    
+
     mt->write_set_end(protocol);
   } else {
     rb_raise(rb_eNotImpError, "can't write container of type: %d", ttype);
@@ -406,7 +395,11 @@
   } else if (IS_CONTAINER(ttype)) {
     write_container(ttype, field_info, value, protocol);
   } else if (ttype == TTYPE_STRUCT) {
-    rb_thrift_struct_write(value, protocol);
+    if (rb_obj_is_kind_of(value, thrift_union_class)) {
+      rb_thrift_union_write(value, protocol);
+    } else {
+      rb_thrift_struct_write(value, protocol);
+    }
   } else {
     rb_raise(rb_eNotImpError, "Unknown type for binary_encoding: %d", ttype);
   }
@@ -423,24 +416,27 @@
 
   // iterate through all the fields here
   VALUE struct_fields = STRUCT_FIELDS(self);
+
   VALUE struct_field_ids_unordered = rb_funcall(struct_fields, keys_method_id, 0);
   VALUE struct_field_ids_ordered = rb_funcall(struct_field_ids_unordered, sort_method_id, 0);
 
   int i = 0;
   for (i=0; i < RARRAY_LEN(struct_field_ids_ordered); i++) {
     VALUE field_id = rb_ary_entry(struct_field_ids_ordered, i);
+
     VALUE field_info = rb_hash_aref(struct_fields, field_id);
 
     VALUE ttype_value = rb_hash_aref(field_info, type_sym);
     int ttype = FIX2INT(ttype_value);
     VALUE field_name = rb_hash_aref(field_info, name_sym);
+
     VALUE field_value = get_field_value(self, field_name);
 
     if (!NIL_P(field_value)) {
       mt->write_field_begin(protocol, field_name, ttype_value, field_id);
-      
+
       write_anything(ttype, field_value, protocol, field_info);
-      
+
       mt->write_field_end(protocol);
     }
   }
@@ -457,6 +453,7 @@
 // Reading section
 //-------------------------------------------
 
+static VALUE rb_thrift_union_read(VALUE self, VALUE protocol);
 static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol);
 
 static void set_field_value(VALUE obj, VALUE field_name, VALUE value) {
@@ -488,7 +485,12 @@
   } else if (ttype == TTYPE_STRUCT) {
     VALUE klass = rb_hash_aref(field_info, class_sym);
     result = rb_class_new_instance(0, NULL, klass);
-    rb_thrift_struct_read(result, protocol);
+
+    if (rb_obj_is_kind_of(result, thrift_union_class)) {
+      rb_thrift_union_read(result, protocol);
+    } else {
+      rb_thrift_struct_read(result, protocol);
+    }
   } else if (ttype == TTYPE_MAP) {
     int i;
 
@@ -524,7 +526,6 @@
       rb_ary_push(result, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym)));
     }
 
-
     mt->read_list_end(protocol);
   } else if (ttype == TTYPE_SET) {
     VALUE items;
@@ -539,7 +540,6 @@
       rb_ary_push(items, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym)));
     }
 
-
     mt->read_set_end(protocol);
 
     result = rb_class_new_instance(1, &items, rb_cSet);
@@ -597,13 +597,110 @@
   return Qnil;
 }
 
+
+// --------------------------------
+// Union section
+// --------------------------------
+
+static VALUE rb_thrift_union_read(VALUE self, VALUE protocol) {
+  // read struct begin
+  mt->read_struct_begin(protocol);
+
+  VALUE struct_fields = STRUCT_FIELDS(self);
+
+  VALUE field_header = mt->read_field_begin(protocol);
+  VALUE field_type_value = rb_ary_entry(field_header, 1);
+  int field_type = FIX2INT(field_type_value);
+
+  // make sure we got a type we expected
+  VALUE field_info = rb_hash_aref(struct_fields, rb_ary_entry(field_header, 2));
+
+  if (!NIL_P(field_info)) {
+    int specified_type = FIX2INT(rb_hash_aref(field_info, type_sym));
+    if (field_type == specified_type) {
+      // read the value
+      VALUE name = rb_hash_aref(field_info, name_sym);
+      rb_iv_set(self, "@setfield", ID2SYM(rb_intern(RSTRING_PTR(name))));
+      rb_iv_set(self, "@value", read_anything(protocol, field_type, field_info));
+    } else {
+      rb_funcall(protocol, skip_method_id, 1, field_type_value);
+    }
+  } else {
+    rb_funcall(protocol, skip_method_id, 1, field_type_value);
+  }
+
+  // read field end
+  mt->read_field_end(protocol);
+
+  field_header = mt->read_field_begin(protocol);
+  field_type_value = rb_ary_entry(field_header, 1);
+  field_type = FIX2INT(field_type_value);
+
+  if (field_type != TTYPE_STOP) {
+    rb_raise(rb_eRuntimeError, "too many fields in union!");
+  }
+
+  // read field end
+  mt->read_field_end(protocol);
+
+  // read struct end
+  mt->read_struct_end(protocol);
+
+  // call validate
+  rb_funcall(self, validate_method_id, 0);
+
+  return Qnil;
+}
+
+static VALUE rb_thrift_union_write(VALUE self, VALUE protocol) {
+  // call validate
+  rb_funcall(self, validate_method_id, 0);
+
+  // write struct begin
+  mt->write_struct_begin(protocol, rb_class_name(CLASS_OF(self)));
+
+  VALUE struct_fields = STRUCT_FIELDS(self);
+
+  VALUE setfield = rb_ivar_get(self, setfield_id);
+  VALUE setvalue = rb_ivar_get(self, setvalue_id);
+  VALUE field_id = rb_funcall(self, name_to_id_method_id, 1, rb_funcall(setfield, to_s_method_id, 0));
+
+  VALUE field_info = rb_hash_aref(struct_fields, field_id);
+
+  VALUE ttype_value = rb_hash_aref(field_info, type_sym);
+  int ttype = FIX2INT(ttype_value);
+
+  mt->write_field_begin(protocol, setfield, ttype_value, field_id);
+
+  write_anything(ttype, setvalue, protocol, field_info);
+
+  mt->write_field_end(protocol);
+
+  mt->write_field_stop(protocol);
+
+  // write struct end
+  mt->write_struct_end(protocol);
+
+  return Qnil;
+}
+
 void Init_struct() {
   VALUE struct_module = rb_const_get(thrift_module, rb_intern("Struct"));
 
   rb_define_method(struct_module, "write", rb_thrift_struct_write, 1);
   rb_define_method(struct_module, "read", rb_thrift_struct_read, 1);
 
+  thrift_union_class = rb_const_get(thrift_module, rb_intern("Union"));
+
+  rb_define_method(thrift_union_class, "write", rb_thrift_union_write, 1);
+  rb_define_method(thrift_union_class, "read", rb_thrift_union_read, 1);
+
+  setfield_id = rb_intern("@setfield");
+  setvalue_id = rb_intern("@value");
+
+  to_s_method_id = rb_intern("to_s");
+  name_to_id_method_id = rb_intern("name_to_id");
+
   set_default_proto_function_pointers();
   mt = default_mt;
-}
-
+}
\ No newline at end of file