THRIFT-697. Union support in Ruby

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@910700 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/rb/ext/struct.c b/lib/rb/ext/struct.c
index 7429fb1..d459ddb 100644
--- a/lib/rb/ext/struct.c
+++ b/lib/rb/ext/struct.c
@@ -45,29 +45,18 @@
 
 static native_proto_method_table *mt;
 static native_proto_method_table *default_mt;
-// static VALUE last_proto_class = Qnil;
+
+VALUE thrift_union_class;
+
+ID setfield_id;
+ID setvalue_id;
+
+ID to_s_method_id;
+ID name_to_id_method_id;
 
 #define IS_CONTAINER(ttype) ((ttype) == TTYPE_MAP || (ttype) == TTYPE_LIST || (ttype) == TTYPE_SET)
 #define STRUCT_FIELDS(obj) rb_const_get(CLASS_OF(obj), fields_const_id)
 
-// static void set_native_proto_function_pointers(VALUE protocol) {
-//   VALUE method_table_object = rb_const_get(CLASS_OF(protocol), rb_intern("@native_method_table"));
-//   // TODO: check nil?
-//   Data_Get_Struct(method_table_object, native_proto_method_table, mt);
-// }
-
-// static void check_native_proto_method_table(VALUE protocol) {
-//   VALUE protoclass = CLASS_OF(protocol);
-//   if (protoclass != last_proto_class) {
-//     last_proto_class = protoclass;
-//     if (rb_funcall(protocol, native_qmark_method_id, 0) == Qtrue) {
-//       set_native_proto_function_pointers(protocol);
-//     } else {
-//       mt = default_mt;
-//     }
-//   }
-// }
-
 //-------------------------------------------
 // Writing section
 //-------------------------------------------
@@ -275,62 +264,62 @@
 
 // end default protocol methods
 
-
+static VALUE rb_thrift_union_write (VALUE self, VALUE protocol);
 static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol);
 static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_info);
 
 VALUE get_field_value(VALUE obj, VALUE field_name) {
   char name_buf[RSTRING_LEN(field_name) + 1];
-  
+
   name_buf[0] = '@';
   strlcpy(&name_buf[1], RSTRING_PTR(field_name), sizeof(name_buf));
 
   VALUE value = rb_ivar_get(obj, rb_intern(name_buf));
-  
+
   return value;
 }
 
 static void write_container(int ttype, VALUE field_info, VALUE value, VALUE protocol) {
   int sz, i;
-  
+
   if (ttype == TTYPE_MAP) {
     VALUE keys;
     VALUE key;
     VALUE val;
 
     Check_Type(value, T_HASH);
-    
+
     VALUE key_info = rb_hash_aref(field_info, key_sym);
     VALUE keytype_value = rb_hash_aref(key_info, type_sym);
     int keytype = FIX2INT(keytype_value);
-    
+
     VALUE value_info = rb_hash_aref(field_info, value_sym);
     VALUE valuetype_value = rb_hash_aref(value_info, type_sym);
     int valuetype = FIX2INT(valuetype_value);
-    
+
     keys = rb_funcall(value, keys_method_id, 0);
-    
+
     sz = RARRAY_LEN(keys);
-    
+
     mt->write_map_begin(protocol, keytype_value, valuetype_value, INT2FIX(sz));
-    
+
     for (i = 0; i < sz; i++) {
       key = rb_ary_entry(keys, i);
       val = rb_hash_aref(value, key);
-      
+
       if (IS_CONTAINER(keytype)) {
         write_container(keytype, key_info, key, protocol);
       } else {
         write_anything(keytype, key, protocol, key_info);
       }
-      
+
       if (IS_CONTAINER(valuetype)) {
         write_container(valuetype, value_info, val, protocol);
       } else {
         write_anything(valuetype, val, protocol, value_info);
       }
     }
-    
+
     mt->write_map_end(protocol);
   } else if (ttype == TTYPE_LIST) {
     Check_Type(value, T_ARRAY);
@@ -340,7 +329,7 @@
     VALUE element_type_info = rb_hash_aref(field_info, element_sym);
     VALUE element_type_value = rb_hash_aref(element_type_info, type_sym);
     int element_type = FIX2INT(element_type_value);
-    
+
     mt->write_list_begin(protocol, element_type_value, INT2FIX(sz));
     for (i = 0; i < sz; ++i) {
       VALUE val = rb_ary_entry(value, i);
@@ -370,9 +359,9 @@
     VALUE element_type_info = rb_hash_aref(field_info, element_sym);
     VALUE element_type_value = rb_hash_aref(element_type_info, type_sym);
     int element_type = FIX2INT(element_type_value);
-    
+
     mt->write_set_begin(protocol, element_type_value, INT2FIX(sz));
-    
+
     for (i = 0; i < sz; i++) {
       VALUE val = rb_ary_entry(items, i);
       if (IS_CONTAINER(element_type)) {
@@ -381,7 +370,7 @@
         write_anything(element_type, val, protocol, element_type_info);
       }
     }
-    
+
     mt->write_set_end(protocol);
   } else {
     rb_raise(rb_eNotImpError, "can't write container of type: %d", ttype);
@@ -406,7 +395,11 @@
   } else if (IS_CONTAINER(ttype)) {
     write_container(ttype, field_info, value, protocol);
   } else if (ttype == TTYPE_STRUCT) {
-    rb_thrift_struct_write(value, protocol);
+    if (rb_obj_is_kind_of(value, thrift_union_class)) {
+      rb_thrift_union_write(value, protocol);
+    } else {
+      rb_thrift_struct_write(value, protocol);
+    }
   } else {
     rb_raise(rb_eNotImpError, "Unknown type for binary_encoding: %d", ttype);
   }
@@ -423,24 +416,27 @@
 
   // iterate through all the fields here
   VALUE struct_fields = STRUCT_FIELDS(self);
+
   VALUE struct_field_ids_unordered = rb_funcall(struct_fields, keys_method_id, 0);
   VALUE struct_field_ids_ordered = rb_funcall(struct_field_ids_unordered, sort_method_id, 0);
 
   int i = 0;
   for (i=0; i < RARRAY_LEN(struct_field_ids_ordered); i++) {
     VALUE field_id = rb_ary_entry(struct_field_ids_ordered, i);
+
     VALUE field_info = rb_hash_aref(struct_fields, field_id);
 
     VALUE ttype_value = rb_hash_aref(field_info, type_sym);
     int ttype = FIX2INT(ttype_value);
     VALUE field_name = rb_hash_aref(field_info, name_sym);
+
     VALUE field_value = get_field_value(self, field_name);
 
     if (!NIL_P(field_value)) {
       mt->write_field_begin(protocol, field_name, ttype_value, field_id);
-      
+
       write_anything(ttype, field_value, protocol, field_info);
-      
+
       mt->write_field_end(protocol);
     }
   }
@@ -457,6 +453,7 @@
 // Reading section
 //-------------------------------------------
 
+static VALUE rb_thrift_union_read(VALUE self, VALUE protocol);
 static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol);
 
 static void set_field_value(VALUE obj, VALUE field_name, VALUE value) {
@@ -488,7 +485,12 @@
   } else if (ttype == TTYPE_STRUCT) {
     VALUE klass = rb_hash_aref(field_info, class_sym);
     result = rb_class_new_instance(0, NULL, klass);
-    rb_thrift_struct_read(result, protocol);
+
+    if (rb_obj_is_kind_of(result, thrift_union_class)) {
+      rb_thrift_union_read(result, protocol);
+    } else {
+      rb_thrift_struct_read(result, protocol);
+    }
   } else if (ttype == TTYPE_MAP) {
     int i;
 
@@ -524,7 +526,6 @@
       rb_ary_push(result, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym)));
     }
 
-
     mt->read_list_end(protocol);
   } else if (ttype == TTYPE_SET) {
     VALUE items;
@@ -539,7 +540,6 @@
       rb_ary_push(items, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym)));
     }
 
-
     mt->read_set_end(protocol);
 
     result = rb_class_new_instance(1, &items, rb_cSet);
@@ -597,13 +597,110 @@
   return Qnil;
 }
 
+
+// --------------------------------
+// Union section
+// --------------------------------
+
+static VALUE rb_thrift_union_read(VALUE self, VALUE protocol) {
+  // read struct begin
+  mt->read_struct_begin(protocol);
+
+  VALUE struct_fields = STRUCT_FIELDS(self);
+
+  VALUE field_header = mt->read_field_begin(protocol);
+  VALUE field_type_value = rb_ary_entry(field_header, 1);
+  int field_type = FIX2INT(field_type_value);
+
+  // make sure we got a type we expected
+  VALUE field_info = rb_hash_aref(struct_fields, rb_ary_entry(field_header, 2));
+
+  if (!NIL_P(field_info)) {
+    int specified_type = FIX2INT(rb_hash_aref(field_info, type_sym));
+    if (field_type == specified_type) {
+      // read the value
+      VALUE name = rb_hash_aref(field_info, name_sym);
+      rb_iv_set(self, "@setfield", ID2SYM(rb_intern(RSTRING_PTR(name))));
+      rb_iv_set(self, "@value", read_anything(protocol, field_type, field_info));
+    } else {
+      rb_funcall(protocol, skip_method_id, 1, field_type_value);
+    }
+  } else {
+    rb_funcall(protocol, skip_method_id, 1, field_type_value);
+  }
+
+  // read field end
+  mt->read_field_end(protocol);
+
+  field_header = mt->read_field_begin(protocol);
+  field_type_value = rb_ary_entry(field_header, 1);
+  field_type = FIX2INT(field_type_value);
+
+  if (field_type != TTYPE_STOP) {
+    rb_raise(rb_eRuntimeError, "too many fields in union!");
+  }
+
+  // read field end
+  mt->read_field_end(protocol);
+
+  // read struct end
+  mt->read_struct_end(protocol);
+
+  // call validate
+  rb_funcall(self, validate_method_id, 0);
+
+  return Qnil;
+}
+
+static VALUE rb_thrift_union_write(VALUE self, VALUE protocol) {
+  // call validate
+  rb_funcall(self, validate_method_id, 0);
+
+  // write struct begin
+  mt->write_struct_begin(protocol, rb_class_name(CLASS_OF(self)));
+
+  VALUE struct_fields = STRUCT_FIELDS(self);
+
+  VALUE setfield = rb_ivar_get(self, setfield_id);
+  VALUE setvalue = rb_ivar_get(self, setvalue_id);
+  VALUE field_id = rb_funcall(self, name_to_id_method_id, 1, rb_funcall(setfield, to_s_method_id, 0));
+
+  VALUE field_info = rb_hash_aref(struct_fields, field_id);
+
+  VALUE ttype_value = rb_hash_aref(field_info, type_sym);
+  int ttype = FIX2INT(ttype_value);
+
+  mt->write_field_begin(protocol, setfield, ttype_value, field_id);
+
+  write_anything(ttype, setvalue, protocol, field_info);
+
+  mt->write_field_end(protocol);
+
+  mt->write_field_stop(protocol);
+
+  // write struct end
+  mt->write_struct_end(protocol);
+
+  return Qnil;
+}
+
 void Init_struct() {
   VALUE struct_module = rb_const_get(thrift_module, rb_intern("Struct"));
 
   rb_define_method(struct_module, "write", rb_thrift_struct_write, 1);
   rb_define_method(struct_module, "read", rb_thrift_struct_read, 1);
 
+  thrift_union_class = rb_const_get(thrift_module, rb_intern("Union"));
+
+  rb_define_method(thrift_union_class, "write", rb_thrift_union_write, 1);
+  rb_define_method(thrift_union_class, "read", rb_thrift_union_read, 1);
+
+  setfield_id = rb_intern("@setfield");
+  setvalue_id = rb_intern("@value");
+
+  to_s_method_id = rb_intern("to_s");
+  name_to_id_method_id = rb_intern("name_to_id");
+
   set_default_proto_function_pointers();
   mt = default_mt;
-}
-
+}
\ No newline at end of file
diff --git a/lib/rb/ext/struct.h b/lib/rb/ext/struct.h
index 37b1b35..48ccef8 100644
--- a/lib/rb/ext/struct.h
+++ b/lib/rb/ext/struct.h
@@ -17,6 +17,7 @@
  * under the License.
  */
 
+
 #include <stdbool.h>
 #include <ruby.h>
 
@@ -41,7 +42,7 @@
   VALUE (*write_field_stop)(VALUE);
   VALUE (*write_message_begin)(VALUE, VALUE, VALUE, VALUE);
   VALUE (*write_message_end)(VALUE);
-  
+
   VALUE (*read_message_begin)(VALUE);
   VALUE (*read_message_end)(VALUE);
   VALUE (*read_field_begin)(VALUE);
@@ -61,7 +62,7 @@
   VALUE (*read_string)(VALUE);
   VALUE (*read_struct_begin)(VALUE);
   VALUE (*read_struct_end)(VALUE);
-  
 } native_proto_method_table;
 
 void Init_struct();
+void Init_union();
diff --git a/lib/rb/ext/thrift_native.c b/lib/rb/ext/thrift_native.c
index effa202..09b9fe4 100644
--- a/lib/rb/ext/thrift_native.c
+++ b/lib/rb/ext/thrift_native.c
@@ -111,7 +111,7 @@
   thrift_types_module = rb_const_get(thrift_module, rb_intern("Types"));
   rb_cSet = rb_const_get(rb_cObject, rb_intern("Set"));
   protocol_exception_class = rb_const_get(thrift_module, rb_intern("ProtocolException"));
-  
+
   // Init ttype constants
   TTYPE_BOOL = FIX2INT(rb_const_get(thrift_types_module, rb_intern("BOOL")));
   TTYPE_BYTE = FIX2INT(rb_const_get(thrift_types_module, rb_intern("BYTE")));
@@ -171,13 +171,13 @@
   write_method_id = rb_intern("write");
   read_all_method_id = rb_intern("read_all");
   native_qmark_method_id = rb_intern("native?");
-  
+
   // constant ids
   fields_const_id = rb_intern("FIELDS");
   transport_ivar_id = rb_intern("@trans");
   strict_read_ivar_id = rb_intern("@strict_read");
   strict_write_ivar_id = rb_intern("@strict_write");  
-  
+
   // cached symbols
   type_sym = ID2SYM(rb_intern("type"));
   name_sym = ID2SYM(rb_intern("name"));
diff --git a/lib/rb/lib/thrift.rb b/lib/rb/lib/thrift.rb
index 4d4e130..02d67b8 100644
--- a/lib/rb/lib/thrift.rb
+++ b/lib/rb/lib/thrift.rb
@@ -28,6 +28,8 @@
 require 'thrift/processor'
 require 'thrift/client'
 require 'thrift/struct'
+require 'thrift/union'
+require 'thrift/struct_union'
 
 # serializer
 require 'thrift/serializer/serializer'
diff --git a/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb b/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb
index eaf64f6..70ea652 100644
--- a/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb
+++ b/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb
@@ -29,7 +29,11 @@
 module Thrift
   class BinaryProtocolAcceleratedFactory < BaseProtocolFactory
     def get_protocol(trans)
-      BinaryProtocolAccelerated.new(trans)
+      if (defined? BinaryProtocolAccelerated)
+        BinaryProtocolAccelerated.new(trans)
+      else
+        BinaryProtocol.new(trans)
+      end
     end
   end
 end
diff --git a/lib/rb/lib/thrift/struct.rb b/lib/rb/lib/thrift/struct.rb
index dfc8a2f..9e52073 100644
--- a/lib/rb/lib/thrift/struct.rb
+++ b/lib/rb/lib/thrift/struct.rb
@@ -65,26 +65,7 @@
       end
       fields_with_default_values
     end
-
-    def name_to_id(name)
-      names_to_ids = self.class.instance_variable_get("@names_to_ids")
-      unless names_to_ids
-        names_to_ids = {}
-        struct_fields.each do |fid, field_def|
-          names_to_ids[field_def[:name]] = fid
-        end
-        self.class.instance_variable_set("@names_to_ids", names_to_ids)
-      end
-      names_to_ids[name]
-    end
-
-    def each_field
-      struct_fields.keys.sort.each do |fid|
-        data = struct_fields[fid]
-        yield fid, data
-      end
-    end
-
+    
     def inspect(skip_optional_nulls = true)
       fields = []
       each_field do |fid, field_info|
@@ -115,7 +96,8 @@
       each_field do |fid, field_info|
         name = field_info[:name]
         type = field_info[:type]
-        if (value = instance_variable_get("@#{name}"))
+        value = instance_variable_get("@#{name}")
+        unless value.nil?
           if is_container? type
             oprot.write_field_begin(name, type, fid)
             write_container(oprot, value, field_info)
@@ -210,89 +192,5 @@
         iprot.skip(ftype)
       end
     end
-
-    def read_field(iprot, field = {})
-      case field[:type]
-      when Types::STRUCT
-        value = field[:class].new
-        value.read(iprot)
-      when Types::MAP
-        key_type, val_type, size = iprot.read_map_begin
-        value = {}
-        size.times do
-          k = read_field(iprot, field_info(field[:key]))
-          v = read_field(iprot, field_info(field[:value]))
-          value[k] = v
-        end
-        iprot.read_map_end
-      when Types::LIST
-        e_type, size = iprot.read_list_begin
-        value = Array.new(size) do |n|
-          read_field(iprot, field_info(field[:element]))
-        end
-        iprot.read_list_end
-      when Types::SET
-        e_type, size = iprot.read_set_begin
-        value = Set.new
-        size.times do
-          element = read_field(iprot, field_info(field[:element]))
-          value << element
-        end
-        iprot.read_set_end
-      else
-        value = iprot.read_type(field[:type])
-      end
-      value
-    end
-
-    def write_data(oprot, value, field)
-      if is_container? field[:type]
-        write_container(oprot, value, field)
-      else
-        oprot.write_type(field[:type], value)
-      end
-    end
-
-    def write_container(oprot, value, field = {})
-      case field[:type]
-      when Types::MAP
-        oprot.write_map_begin(field[:key][:type], field[:value][:type], value.size)
-        value.each do |k, v|
-          write_data(oprot, k, field[:key])
-          write_data(oprot, v, field[:value])
-        end
-        oprot.write_map_end
-      when Types::LIST
-        oprot.write_list_begin(field[:element][:type], value.size)
-        value.each do |elem|
-          write_data(oprot, elem, field[:element])
-        end
-        oprot.write_list_end
-      when Types::SET
-        oprot.write_set_begin(field[:element][:type], value.size)
-        value.each do |v,| # the , is to preserve compatibility with the old Hash-style sets
-          write_data(oprot, v, field[:element])
-        end
-        oprot.write_set_end
-      else
-        raise "Not a container type: #{field[:type]}"
-      end
-    end
-
-    CONTAINER_TYPES = []
-    CONTAINER_TYPES[Types::LIST] = true
-    CONTAINER_TYPES[Types::MAP] = true
-    CONTAINER_TYPES[Types::SET] = true
-    def is_container?(type)
-      CONTAINER_TYPES[type]
-    end
-
-    def field_info(field)
-      { :type => field[:type],
-        :class => field[:class],
-        :key => field[:key],
-        :value => field[:value],
-        :element => field[:element] }
-    end
   end
 end
diff --git a/lib/rb/lib/thrift/struct_union.rb b/lib/rb/lib/thrift/struct_union.rb
new file mode 100644
index 0000000..9a5903f
--- /dev/null
+++ b/lib/rb/lib/thrift/struct_union.rb
@@ -0,0 +1,126 @@
+# 
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+require 'set'
+
+module Thrift
+  module Struct_Union
+    def name_to_id(name)
+      names_to_ids = self.class.instance_variable_get("@names_to_ids")
+      unless names_to_ids
+        names_to_ids = {}
+        struct_fields.each do |fid, field_def|
+          names_to_ids[field_def[:name]] = fid
+        end
+        self.class.instance_variable_set("@names_to_ids", names_to_ids)
+      end
+      names_to_ids[name]
+    end
+
+    def each_field
+      struct_fields.keys.sort.each do |fid|
+        data = struct_fields[fid]
+        yield fid, data
+      end
+    end
+
+    def read_field(iprot, field = {})
+      case field[:type]
+      when Types::STRUCT
+        value = field[:class].new
+        value.read(iprot)
+      when Types::MAP
+        key_type, val_type, size = iprot.read_map_begin
+        value = {}
+        size.times do
+          k = read_field(iprot, field_info(field[:key]))
+          v = read_field(iprot, field_info(field[:value]))
+          value[k] = v
+        end
+        iprot.read_map_end
+      when Types::LIST
+        e_type, size = iprot.read_list_begin
+        value = Array.new(size) do |n|
+          read_field(iprot, field_info(field[:element]))
+        end
+        iprot.read_list_end
+      when Types::SET
+        e_type, size = iprot.read_set_begin
+        value = Set.new
+        size.times do
+          element = read_field(iprot, field_info(field[:element]))
+          value << element
+        end
+        iprot.read_set_end
+      else
+        value = iprot.read_type(field[:type])
+      end
+      value
+    end
+
+    def write_data(oprot, value, field)
+      if is_container? field[:type]
+        write_container(oprot, value, field)
+      else
+        oprot.write_type(field[:type], value)
+      end
+    end
+
+    def write_container(oprot, value, field = {})
+      case field[:type]
+      when Types::MAP
+        oprot.write_map_begin(field[:key][:type], field[:value][:type], value.size)
+        value.each do |k, v|
+          write_data(oprot, k, field[:key])
+          write_data(oprot, v, field[:value])
+        end
+        oprot.write_map_end
+      when Types::LIST
+        oprot.write_list_begin(field[:element][:type], value.size)
+        value.each do |elem|
+          write_data(oprot, elem, field[:element])
+        end
+        oprot.write_list_end
+      when Types::SET
+        oprot.write_set_begin(field[:element][:type], value.size)
+        value.each do |v,| # the , is to preserve compatibility with the old Hash-style sets
+          write_data(oprot, v, field[:element])
+        end
+        oprot.write_set_end
+      else
+        raise "Not a container type: #{field[:type]}"
+      end
+    end
+
+    CONTAINER_TYPES = []
+    CONTAINER_TYPES[Types::LIST] = true
+    CONTAINER_TYPES[Types::MAP] = true
+    CONTAINER_TYPES[Types::SET] = true
+    def is_container?(type)
+      CONTAINER_TYPES[type]
+    end
+
+    def field_info(field)
+      { :type => field[:type],
+        :class => field[:class],
+        :key => field[:key],
+        :value => field[:value],
+        :element => field[:element] }
+    end
+  end
+end
\ No newline at end of file
diff --git a/lib/rb/lib/thrift/types.rb b/lib/rb/lib/thrift/types.rb
index 20e4ca2..cac5269 100644
--- a/lib/rb/lib/thrift/types.rb
+++ b/lib/rb/lib/thrift/types.rb
@@ -57,7 +57,7 @@
               when Types::STRING
                 String
               when Types::STRUCT
-                Struct
+                [Struct, Union]
               when Types::MAP
                 Hash
               when Types::SET
diff --git a/lib/rb/lib/thrift/union.rb b/lib/rb/lib/thrift/union.rb
new file mode 100644
index 0000000..0b41ed4
--- /dev/null
+++ b/lib/rb/lib/thrift/union.rb
@@ -0,0 +1,128 @@
+# 
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# 
+
+module Thrift
+  class Union
+    def initialize(name=nil, value=nil)
+      if name
+        if value.nil?
+          raise Exception, "Union #{self.class} cannot be instantiated with setfield and nil value!"
+        end
+
+        Thrift.check_type(value, struct_fields[name_to_id(name.to_s)], name) if Thrift.type_checking
+      elsif !value.nil?
+        raise Exception, "Value provided, but no name!"
+      end
+      @setfield = name
+      @value = value
+    end
+
+    def inspect
+      "<#{self.class} #{@setfield}: #{@value}>"
+    end
+
+    def read(iprot)
+      iprot.read_struct_begin
+      fname, ftype, fid = iprot.read_field_begin
+      handle_message(iprot, fid, ftype)
+      iprot.read_field_end
+
+      fname, ftype, fid = iprot.read_field_begin
+      raise "Too many fields for union" unless (ftype == Types::STOP) 
+
+      iprot.read_struct_end
+      validate
+    end
+
+    def write(oprot)
+      validate
+      oprot.write_struct_begin(self.class.name)
+
+      fid = self.name_to_id(@setfield.to_s)
+
+      field_info = struct_fields[fid]
+      type = field_info[:type]
+      if is_container? type
+        oprot.write_field_begin(@setfield, type, fid)
+        write_container(oprot, @value, field_info)
+        oprot.write_field_end
+      else
+        oprot.write_field(@setfield, type, fid, @value)
+      end
+
+      oprot.write_field_stop
+      oprot.write_struct_end
+    end
+
+    def ==(other)
+      other != nil && @setfield == other.get_set_field && @value == other.get_value
+    end
+
+    def eql?(other)
+      self.class == other.class && self == other
+    end
+
+    def hash
+      [self.class.name, @setfield, @value].hash
+    end
+
+    def self.field_accessor(klass, *fields)
+      fields.each do |field|
+        klass.send :define_method, "#{field}" do
+          if field == @setfield
+            @value
+          else 
+            raise RuntimeError, "#{field} is not union's set field."
+          end
+        end
+
+        klass.send :define_method, "#{field}=" do |value|
+          Thrift.check_type(value, klass::FIELDS.values.find {|f| f[:name].to_s == field.to_s }, field) if Thrift.type_checking
+          @setfield = field
+          @value = value
+        end
+      end
+    end
+
+    # get the symbol that indicates what the currently set field type is. 
+    def get_set_field
+      @setfield
+    end
+
+    # get the current value of this union, regardless of what the set field is.
+    # generally, you should only use this method when you don't know in advance
+    # what field to expect.
+    def get_value
+      @value
+    end
+
+    protected
+
+    def handle_message(iprot, fid, ftype)
+      field = struct_fields[fid]
+      if field and field[:type] == ftype
+        @value = read_field(iprot, field)
+        name = field[:name].to_sym
+        @setfield = name
+      else
+        iprot.skip(ftype)
+      end
+    end
+  end
+end
\ No newline at end of file
diff --git a/lib/rb/spec/ThriftSpec.thrift b/lib/rb/spec/ThriftSpec.thrift
index fe5a8aa..f5c8c09 100644
--- a/lib/rb/spec/ThriftSpec.thrift
+++ b/lib/rb/spec/ThriftSpec.thrift
@@ -42,6 +42,38 @@
   1: string greeting = "hello world"
 }
 
+union My_union {
+  1: bool im_true,
+  2: byte a_bite,
+  3: i16 integer16,
+  4: i32 integer32,
+  5: i64 integer64,
+  6: double double_precision,
+  7: string some_characters,
+  8: i32 other_i32
+}
+
+struct Struct_with_union {
+  1: My_union fun_union
+  2: i32 integer32
+  3: string some_characters
+}
+
+enum SomeEnum {
+  ONE
+  TWO
+}
+
+union TestUnion {
+  /**
+   * A doc string
+   */
+  1: string string_field;
+  2: i32 i32_field;
+  3: i32 other_i32_field;
+  4: SomeEnum enum_field;
+}
+
 struct Foo {
   1: i32 simple = 53,
   2: string words = "words",
diff --git a/lib/rb/spec/binary_protocol_accelerated_spec.rb b/lib/rb/spec/binary_protocol_accelerated_spec.rb
index 48c22e4..b8518c8 100644
--- a/lib/rb/spec/binary_protocol_accelerated_spec.rb
+++ b/lib/rb/spec/binary_protocol_accelerated_spec.rb
@@ -20,22 +20,27 @@
 require File.dirname(__FILE__) + '/spec_helper'
 require File.dirname(__FILE__) + '/binary_protocol_spec_shared'
 
-class ThriftBinaryProtocolAcceleratedSpec < Spec::ExampleGroup
-  include Thrift
+if defined? Thrift::BinaryProtocolAccelerated
 
-  describe Thrift::BinaryProtocolAccelerated do
-    # since BinaryProtocolAccelerated should be directly equivalent to 
-    # BinaryProtocol, we don't need any custom specs!
-    it_should_behave_like 'a binary protocol'
+  class ThriftBinaryProtocolAcceleratedSpec < Spec::ExampleGroup
+    include Thrift
 
-    def protocol_class
-      BinaryProtocolAccelerated
+    describe Thrift::BinaryProtocolAccelerated do
+      # since BinaryProtocolAccelerated should be directly equivalent to 
+      # BinaryProtocol, we don't need any custom specs!
+      it_should_behave_like 'a binary protocol'
+
+      def protocol_class
+        BinaryProtocolAccelerated
+      end
+    end
+
+    describe BinaryProtocolAcceleratedFactory do
+      it "should create a BinaryProtocolAccelerated" do
+        BinaryProtocolAcceleratedFactory.new.get_protocol(mock("MockTransport")).should be_instance_of(BinaryProtocolAccelerated)
+      end
     end
   end
-
-  describe BinaryProtocolAcceleratedFactory do
-    it "should create a BinaryProtocolAccelerated" do
-      BinaryProtocolAcceleratedFactory.new.get_protocol(mock("MockTransport")).should be_instance_of(BinaryProtocolAccelerated)
-    end
-  end
-end
+else
+  puts "skipping BinaryProtocolAccelerated spec because it is not defined."
+end
\ No newline at end of file
diff --git a/lib/rb/spec/binary_protocol_spec_shared.rb b/lib/rb/spec/binary_protocol_spec_shared.rb
index 84f5920..28da760 100644
--- a/lib/rb/spec/binary_protocol_spec_shared.rb
+++ b/lib/rb/spec/binary_protocol_spec_shared.rb
@@ -349,9 +349,9 @@
 
     # first block
     firstblock.call(client)
-    
+
     processor.process(serverproto, serverproto)
-    
+
     # second block
     secondblock.call(client)
   ensure
diff --git a/lib/rb/spec/union_spec.rb b/lib/rb/spec/union_spec.rb
new file mode 100644
index 0000000..4835288
--- /dev/null
+++ b/lib/rb/spec/union_spec.rb
@@ -0,0 +1,145 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+require File.dirname(__FILE__) + '/spec_helper'
+
+class ThriftUnionSpec < Spec::ExampleGroup
+  include Thrift
+  include SpecNamespace
+
+  describe Union do
+    it "should return nil value in unset union" do
+      union = My_union.new
+      union.get_set_field.should == nil
+      union.get_value.should == nil
+    end
+
+    it "should set a field and be accessible through get_value and the named field accessor" do
+      union = My_union.new
+      union.integer32 = 25
+      union.get_set_field.should == :integer32
+      union.get_value.should == 25
+      union.integer32.should == 25
+    end
+
+    it "should work correctly when instantiated with static field constructors" do
+      union = My_union.integer32(5)
+      union.get_set_field.should == :integer32
+      union.integer32.should == 5
+    end
+
+    it "should raise for wrong set field" do
+      union = My_union.new
+      union.integer32 = 25
+      lambda { union.some_characters }.should raise_error(RuntimeError, "some_characters is not union's set field.")
+    end
+     
+    it "should not be equal to nil" do
+      union = My_union.new
+      union.should_not == nil
+    end
+     
+    it "should not equate two different unions, i32 vs. string" do
+      union = My_union.new(:integer32, 25)
+      other_union = My_union.new(:some_characters, "blah!")
+      union.should_not == other_union
+    end
+
+    it "should properly reset setfield and setvalue" do
+      union = My_union.new(:integer32, 25)
+      union.get_set_field.should == :integer32
+      union.some_characters = "blah!"
+      union.get_set_field.should == :some_characters
+      union.get_value.should == "blah!"
+      lambda { union.integer32 }.should raise_error(RuntimeError, "integer32 is not union's set field.")
+    end
+
+    it "should not equate two different unions with different values" do
+      union = My_union.new(:integer32, 25)
+      other_union = My_union.new(:integer32, 400)
+      union.should_not == other_union
+    end
+
+    it "should not equate two different unions with different fields" do
+      union = My_union.new(:integer32, 25)
+      other_union = My_union.new(:other_i32, 25)
+      union.should_not == other_union
+    end
+
+    it "should inspect properly" do
+      union = My_union.new(:integer32, 25)
+      union.inspect.should == "<SpecNamespace::My_union integer32: 25>"
+    end
+
+    it "should not allow setting with instance_variable_set" do
+      union = My_union.new(:integer32, 27)
+      union.instance_variable_set(:@some_characters, "hallo!")
+      union.get_set_field.should == :integer32
+      union.get_value.should == 27
+      lambda { union.some_characters }.should raise_error(RuntimeError, "some_characters is not union's set field.")
+    end
+
+    it "should serialize correctly" do
+      trans = Thrift::MemoryBufferTransport.new
+      proto = Thrift::BinaryProtocol.new(trans)
+
+      union = My_union.new(:integer32, 25)
+      union.write(proto)
+
+      other_union = My_union.new(:integer32, 25)
+      other_union.read(proto)
+      other_union.should == union
+    end
+
+    it "should raise when validating unset union" do
+      union = My_union.new
+      lambda { union.validate }.should raise_error(StandardError, "Union fields are not set.")
+
+      other_union = My_union.new(:integer32, 1)
+      lambda { other_union.validate }.should_not raise_error(StandardError, "Union fields are not set.")
+    end
+
+    it "should validate an enum field properly" do
+      union = TestUnion.new(:enum_field, 3)
+      union.get_set_field.should == :enum_field
+      lambda { union.validate }.should raise_error(ProtocolException, "Invalid value of field enum_field!")
+
+      other_union = TestUnion.new(:enum_field, 1)
+      lambda { other_union.validate }.should_not raise_error(ProtocolException, "Invalid value of field enum_field!")
+    end
+
+    it "should properly serialize and match structs with a union" do
+      union = My_union.new(:integer32, 26)
+      swu = Struct_with_union.new(:fun_union => union)
+
+      trans = Thrift::MemoryBufferTransport.new
+      proto = Thrift::CompactProtocol.new(trans)
+
+      swu.write(proto)
+
+      other_union = My_union.new(:some_characters, "hello there")
+      swu2 = Struct_with_union.new(:fun_union => other_union)
+
+      swu2.should_not == swu
+
+      swu2.read(proto)
+      swu2.should == swu
+    end
+  end
+end