THRIFT-3357: Generate EnumSet/EnumMap where elements/keys are enums
Client: Java

This closes #1253
diff --git a/compiler/cpp/src/thrift/generate/t_java_generator.cc b/compiler/cpp/src/thrift/generate/t_java_generator.cc
index 3408bf6..a5fb0e5 100644
--- a/compiler/cpp/src/thrift/generate/t_java_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_java_generator.cc
@@ -347,6 +347,44 @@
     return annotations.find("deprecated") != annotations.end();
   }
 
+  bool is_enum_set(t_type* ttype) {
+    if (!sorted_containers_) {
+      ttype = get_true_type(ttype);
+      if (ttype->is_set()) {
+        t_set* tset = (t_set*)ttype;
+        t_type* elem_type = get_true_type(tset->get_elem_type());
+        return elem_type->is_enum();
+      }
+    }
+    return false;
+  }
+
+  bool is_enum_map(t_type* ttype) {
+    if (!sorted_containers_) {
+      ttype = get_true_type(ttype);
+      if (ttype->is_map()) {
+        t_map* tmap = (t_map*)ttype;
+        t_type* key_type = get_true_type(tmap->get_key_type());
+        return key_type->is_enum();
+      }
+    }
+    return false;
+  }
+
+  std::string inner_enum_type_name(t_type* ttype) {
+    ttype = get_true_type(ttype);
+    if (ttype->is_map()) {
+      t_map* tmap = (t_map*)ttype;
+      t_type* key_type = get_true_type(tmap->get_key_type());
+      return type_name(key_type, true) + ".class";
+    } else if (ttype->is_set()) {
+      t_set* tset = (t_set*)ttype;
+      t_type* elem_type = get_true_type(tset->get_elem_type());
+      return type_name(elem_type, true) + ".class";
+    }
+    return "";
+  }
+
   std::string constant_name(std::string name);
 
 private:
@@ -611,7 +649,11 @@
     }
     out << endl;
   } else if (type->is_map()) {
-    out << name << " = new " << type_name(type, false, true) << "();" << endl;
+    std::string constructor_args;
+    if (is_enum_map(type)) {
+      constructor_args = inner_enum_type_name(type);
+    }
+    out << name << " = new " << type_name(type, false, true) << "(" << constructor_args << ");" << endl;
     if (!in_static) {
       indent(out) << "static {" << endl;
       indent_up();
@@ -631,7 +673,11 @@
     }
     out << endl;
   } else if (type->is_list() || type->is_set()) {
-    out << name << " = new " << type_name(type, false, true) << "();" << endl;
+    if (is_enum_set(type)) {
+      out << name << " = " << type_name(type, false, true, true) << ".noneOf(" << inner_enum_type_name(type) << ");" << endl;
+    } else {
+      out << name << " = new " << type_name(type, false, true) << "();" << endl;
+    }
     if (!in_static) {
       indent(out) << "static {" << endl;
       indent_up();
@@ -2294,8 +2340,12 @@
       indent_up();
       indent(out) << "if (this." << field_name << " == null) {" << endl;
       indent_up();
-      indent(out) << "this." << field_name << " = new " << type_name(type, false, true) << "();"
-                  << endl;
+      indent(out) << "this." << field_name;
+      if (is_enum_set(type)) {
+        out << " = " << type_name(type, false, true, true) << ".noneOf(" << inner_enum_type_name(type) << ");" << endl;
+      } else {
+        out << " = new " << type_name(type, false, true) << "();" << endl;
+      }
       indent_down();
       indent(out) << "}" << endl;
       indent(out) << "this." << field_name << ".add(elem);" << endl;
@@ -2316,7 +2366,11 @@
       indent_up();
       indent(out) << "if (this." << field_name << " == null) {" << endl;
       indent_up();
-      indent(out) << "this." << field_name << " = new " << type_name(type, false, true) << "();"
+      std::string constructor_args;
+      if (is_enum_map(type)) {
+        constructor_args = inner_enum_type_name(type);
+      }
+      indent(out) << "this." << field_name << " = new " << type_name(type, false, true) << "(" << constructor_args << ");"
                   << endl;
       indent_down();
       indent(out) << "}" << endl;
@@ -3758,11 +3812,17 @@
     indent_up();
   }
 
-  out << indent() << prefix << " = new " << type_name(ttype, false, true);
+  if (is_enum_set(ttype)) {
+    out << indent() << prefix << " = " << type_name(ttype, false, true, true) << ".noneOf";
+  } else {
+    out << indent() << prefix << " = new " << type_name(ttype, false, true);
+  }
 
-  // size the collection correctly
-  if (sorted_containers_ && (ttype->is_map() || ttype->is_set())) {
-    // TreeSet and TreeMap don't have any constructor which takes a capactity as an argument
+  // construct the collection correctly i.e. with appropriate size/type
+  if (is_enum_set(ttype) || is_enum_map(ttype)) {
+    out << "(" << inner_enum_type_name(ttype) << ");" << endl;
+  } else if (sorted_containers_ && (ttype->is_map() || ttype->is_set())) {
+    // TreeSet and TreeMap don't have any constructor which takes a capacity as an argument
     out << "();" << endl;
   } else {
     out << "(" << (ttype->is_list() ? "" : "2*") << obj << ".size"
@@ -3824,8 +3884,17 @@
   generate_deserialize_field(out, &fkey, "", has_metadata);
   generate_deserialize_field(out, &fval, "", has_metadata);
 
+  if (get_true_type(fkey.get_type())->is_enum()) {
+    indent(out) << "if (" << key << " != null)" << endl;
+    scope_up(out);
+  }
+
   indent(out) << prefix << ".put(" << key << ", " << val << ");" << endl;
 
+  if (get_true_type(fkey.get_type())->is_enum()) {
+    scope_down(out);
+  }
+
   if (reuse_objects_ && !get_true_type(fkey.get_type())->is_base_type()) {
     indent(out) << key << " = null;" << endl;
   }
@@ -3857,8 +3926,17 @@
 
   generate_deserialize_field(out, &felem, "", has_metadata);
 
+  if (get_true_type(felem.get_type())->is_enum()) {
+    indent(out) << "if (" << elem << " != null)" << endl;
+    scope_up(out);
+  }
+
   indent(out) << prefix << ".add(" << elem << ");" << endl;
 
+  if (get_true_type(felem.get_type())->is_enum()) {
+    scope_down(out);
+  }
+
   if (reuse_objects_ && !get_true_type(felem.get_type())->is_base_type()) {
     indent(out) << elem << " = null;" << endl;
   }
@@ -3886,8 +3964,17 @@
 
   generate_deserialize_field(out, &felem, "", has_metadata);
 
+  if (get_true_type(felem.get_type())->is_enum()) {
+    indent(out) << "if (" << elem << " != null)" << endl;
+    scope_up(out);
+  }
+
   indent(out) << prefix << ".add(" << elem << ");" << endl;
 
+  if (get_true_type(felem.get_type())->is_enum()) {
+    scope_down(out);
+  }
+
   if (reuse_objects_ && !get_true_type(felem.get_type())->is_base_type()) {
     indent(out) << elem << " = null;" << endl;
   }
@@ -4103,7 +4190,9 @@
   } else if (ttype->is_map()) {
     t_map* tmap = (t_map*)ttype;
     if (in_init) {
-      if (sorted_containers_) {
+      if (is_enum_map(tmap)) {
+        prefix = "java.util.EnumMap";
+      } else if (sorted_containers_) {
         prefix = "java.util.TreeMap";
       } else {
         prefix = "java.util.HashMap";
@@ -4116,7 +4205,9 @@
   } else if (ttype->is_set()) {
     t_set* tset = (t_set*)ttype;
     if (in_init) {
-      if (sorted_containers_) {
+      if (is_enum_set(tset)) {
+        prefix = "java.util.EnumSet";
+      } else if (sorted_containers_) {
         prefix = "java.util.TreeSet";
       } else {
         prefix = "java.util.HashSet";
@@ -4542,13 +4633,21 @@
     return;
   }
 
-  std::string capacity;
-  if (!(sorted_containers_ && (container->is_map() || container->is_set()))) {
+  std::string constructor_args;
+  if (is_enum_set(container) || is_enum_map(container)) {
+    constructor_args = inner_enum_type_name(container);
+  } else if (!(sorted_containers_ && (container->is_map() || container->is_set()))) {
     // unsorted containers accept a capacity value
-    capacity = source_name + ".size()";
+    constructor_args = source_name + ".size()";
   }
-  indent(out) << type_name(type, true, false) << " " << result_name << " = new "
-              << type_name(container, false, true) << "(" << capacity << ");" << endl;
+
+  if (is_enum_set(container)) {
+    indent(out) << type_name(type, true, false) << " " << result_name << " = "
+                << type_name(container, false, true, true) << ".noneOf(" << constructor_args << ");" << endl;
+  } else {
+    indent(out) << type_name(type, true, false) << " " << result_name << " = new "
+                << type_name(container, false, true) << "(" << constructor_args << ");" << endl;
+  }
 
   std::string iterator_element_name = source_name_p1 + "_element";
   std::string result_element_name = result_name + "_copy";
diff --git a/lib/java/build.xml b/lib/java/build.xml
index e40eafa..512aec7 100644
--- a/lib/java/build.xml
+++ b/lib/java/build.xml
@@ -290,6 +290,9 @@
     <exec executable="${thrift.compiler}" failonerror="true">
       <arg line="--gen java ${test.thrift.home}/JavaDeepCopyTest.thrift"/>
     </exec>
+    <exec executable="${thrift.compiler}" failonerror="true">
+      <arg line="--gen java ${test.thrift.home}/EnumContainersTest.thrift"/>
+    </exec>
   </target>
 
   <target name="proxy" if="proxy.enabled">
diff --git a/lib/java/test/org/apache/thrift/TestEnumContainers.java b/lib/java/test/org/apache/thrift/TestEnumContainers.java
new file mode 100644
index 0000000..683246b
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/TestEnumContainers.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift;
+
+import junit.framework.TestCase;
+import thrift.test.enumcontainers.EnumContainersTestConstants;
+import thrift.test.enumcontainers.GodBean;
+import thrift.test.enumcontainers.GreekGodGoddess;
+
+import java.util.EnumMap;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.HashSet;
+
+public class TestEnumContainers extends TestCase {
+
+    public void testEnumContainers() throws Exception {
+        final GodBean b1 = new GodBean();
+        b1.addToGoddess(GreekGodGoddess.HERA);
+        b1.getGoddess().add(GreekGodGoddess.APHRODITE);
+        b1.putToPower(GreekGodGoddess.ZEUS, 1000);
+        b1.getPower().put(GreekGodGoddess.HERA, 333);
+        b1.putToByAlias("Mr. Z", GreekGodGoddess.ZEUS);
+        b1.addToImages("Baths of Aphrodite 01.jpeg");
+
+        final GodBean b2 = new GodBean(b1);
+
+        final GodBean b3 = new GodBean();
+        {
+            final TSerializer serializer = new TSerializer();
+            final TDeserializer deserializer = new TDeserializer();
+
+            final byte[] bytes = serializer.serialize(b1);
+            deserializer.deserialize(b3, bytes);
+        }
+
+        assertTrue(b1.getGoddess() != b2.getGoddess());
+        assertTrue(b1.getPower() != b2.getPower());
+
+        assertTrue(b1.getGoddess() != b3.getGoddess());
+        assertTrue(b1.getPower() != b3.getPower());
+
+        for (GodBean each : new GodBean[]{b1, b2, b3}) {
+            assertTrue(each.getGoddess().contains(GreekGodGoddess.HERA));
+            assertFalse(each.getGoddess().contains(GreekGodGoddess.POSEIDON));
+            assertTrue(each.getGoddess() instanceof EnumSet);
+
+            assertEquals(Integer.valueOf(1000), each.getPower().get(GreekGodGoddess.ZEUS));
+            assertEquals(Integer.valueOf(333), each.getPower().get(GreekGodGoddess.HERA));
+            assertTrue(each.getPower() instanceof EnumMap);
+
+            assertTrue(each.getByAlias() instanceof HashMap);
+            assertTrue(each.getImages() instanceof HashSet);
+        }
+    }
+
+    public void testEnumConstants() {
+        assertEquals("lightning bolt", EnumContainersTestConstants.ATTRIBUTES.get(GreekGodGoddess.ZEUS));
+        assertTrue(EnumContainersTestConstants.ATTRIBUTES instanceof EnumMap);
+
+        assertTrue(EnumContainersTestConstants.BEAUTY.contains(GreekGodGoddess.APHRODITE));
+        assertTrue(EnumContainersTestConstants.BEAUTY instanceof EnumSet);
+    }
+}
diff --git a/test/EnumContainersTest.thrift b/test/EnumContainersTest.thrift
new file mode 100644
index 0000000..3b6408f
--- /dev/null
+++ b/test/EnumContainersTest.thrift
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+namespace java thrift.test.enumcontainers
+
+enum GreekGodGoddess {
+    ARES,
+    APHRODITE,
+    ZEUS,
+    POSEIDON,
+    HERA,
+}
+
+typedef GreekGodGoddess GreekGodGoddessType
+typedef i32 Power
+
+struct GodBean {
+    1: optional map<GreekGodGoddessType, Power> power,
+    2: optional set<GreekGodGoddessType> goddess,
+    3: optional map<string, GreekGodGoddess> byAlias,
+    4: optional set<string> images,
+}
+
+const map<GreekGodGoddessType, string> ATTRIBUTES =
+{
+    GreekGodGoddess.ZEUS: "lightning bolt",
+    GreekGodGoddess.POSEIDON: "trident",
+}
+
+const set<GreekGodGoddessType> BEAUTY = [ GreekGodGoddess.APHRODITE, GreekGodGoddess.HERA ]
diff --git a/test/Makefile.am b/test/Makefile.am
index 5e4ebcf..335bae6 100755
--- a/test/Makefile.am
+++ b/test/Makefile.am
@@ -143,6 +143,7 @@
 	OptionalRequiredTest.thrift \
 	Recursive.thrift \
 	ReuseObjects.thrift \
+	EnumContainersTest.thrift \
 	SmallTest.thrift \
 	StressTest.thrift \
 	ThriftTest.thrift \