THRIFT-723: java: Thrift buffers with set and map types in Java should implement Comparable
This makes structs that contain sets and maps in their hierarchy Comparable.
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@928944 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/cpp/src/generate/t_java_generator.cc b/compiler/cpp/src/generate/t_java_generator.cc
index c247921..bc2ac49 100644
--- a/compiler/cpp/src/generate/t_java_generator.cc
+++ b/compiler/cpp/src/generate/t_java_generator.cc
@@ -199,9 +199,6 @@
void generate_deep_copy_container(std::ofstream& out, std::string source_name_p1, std::string source_name_p2, std::string result_name, t_type* type);
void generate_deep_copy_non_container(std::ofstream& out, std::string source_name, std::string dest_name, t_type* type);
- bool is_comparable(t_struct* tstruct);
- bool is_comparable(t_type* type);
-
bool has_bit_vector(t_struct* tstruct);
/**
@@ -703,9 +700,7 @@
"public " << (is_final ? "final " : "") << "class " << tstruct->get_name()
<< " extends TUnion<" << tstruct->get_name() << "._Fields> ";
- if (is_comparable(tstruct)) {
- f_struct << "implements Comparable<" << type_name(tstruct) << "> ";
- }
+ f_struct << "implements Comparable<" << type_name(tstruct) << "> ";
scope_up(f_struct);
@@ -1002,22 +997,20 @@
indent(out) << "}" << endl;
out << endl;
- if (is_comparable(tstruct)) {
- indent(out) << "@Override" << endl;
- indent(out) << "public int compareTo(" << type_name(tstruct) << " other) {" << endl;
- indent(out) << " int lastComparison = TBaseHelper.compareTo(getSetField(), other.getSetField());" << endl;
- indent(out) << " if (lastComparison == 0) {" << endl;
- indent(out) << " Object myValue = getFieldValue();" << endl;
- indent(out) << " if (myValue instanceof byte[]) {" << endl;
- indent(out) << " return TBaseHelper.compareTo((byte[])myValue, (byte[])other.getFieldValue());" << endl;
- indent(out) << " } else {" << endl;
- indent(out) << " return TBaseHelper.compareTo((Comparable)myValue, (Comparable)other.getFieldValue());" << endl;
- indent(out) << " }" << endl;
- indent(out) << " }" << endl;
- indent(out) << " return lastComparison;" << endl;
- indent(out) << "}" << endl;
- out << endl;
- }
+ indent(out) << "@Override" << endl;
+ indent(out) << "public int compareTo(" << type_name(tstruct) << " other) {" << endl;
+ indent(out) << " int lastComparison = TBaseHelper.compareTo(getSetField(), other.getSetField());" << endl;
+ indent(out) << " if (lastComparison == 0) {" << endl;
+ indent(out) << " Object myValue = getFieldValue();" << endl;
+ indent(out) << " if (myValue instanceof byte[]) {" << endl;
+ indent(out) << " return TBaseHelper.compareTo((byte[])myValue, (byte[])other.getFieldValue());" << endl;
+ indent(out) << " } else {" << endl;
+ indent(out) << " return TBaseHelper.compareTo((Comparable)myValue, (Comparable)other.getFieldValue());" << endl;
+ indent(out) << " }" << endl;
+ indent(out) << " }" << endl;
+ indent(out) << " return lastComparison;" << endl;
+ indent(out) << "}" << endl;
+ out << endl;
}
void t_java_generator::generate_union_hashcode(ofstream& out, t_struct* tstruct) {
@@ -1077,9 +1070,7 @@
}
out << "implements TBase<" << tstruct->get_name() << "._Fields>, java.io.Serializable, Cloneable";
- if (is_comparable(tstruct)) {
- out << ", Comparable<" << type_name(tstruct) << ">";
- }
+ out << ", Comparable<" << type_name(tstruct) << ">";
out << " ";
@@ -1241,9 +1232,7 @@
generate_generic_isset_method(out, tstruct);
generate_java_struct_equality(out, tstruct);
- if (is_comparable(tstruct)) {
- generate_java_struct_compare_to(out, tstruct);
- }
+ generate_java_struct_compare_to(out, tstruct);
generate_java_struct_reader(out, tstruct);
if (is_result) {
@@ -3606,32 +3595,6 @@
indent(out) << "}" << endl;
}
-bool t_java_generator::is_comparable(t_struct* tstruct) {
- const vector<t_field*>& members = tstruct->get_members();
- vector<t_field*>::const_iterator m_iter;
-
- for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
- if (!is_comparable(get_true_type((*m_iter)->get_type()))) {
- return false;
- }
- }
- return true;
-}
-
-bool t_java_generator::is_comparable(t_type* type) {
- if (type->is_container()) {
- if (type->is_list()) {
- return is_comparable(get_true_type(((t_list*)type)->get_elem_type()));
- } else {
- return false;
- }
- } else if (type->is_struct() || type->is_xception()) {
- return is_comparable((t_struct*)type);
- } else {
- return true;
- }
-}
-
bool t_java_generator::has_bit_vector(t_struct* tstruct) {
const vector<t_field*>& members = tstruct->get_members();
vector<t_field*>::const_iterator m_iter;
diff --git a/lib/java/src/org/apache/thrift/TBaseHelper.java b/lib/java/src/org/apache/thrift/TBaseHelper.java
index b41daae..fccece8 100644
--- a/lib/java/src/org/apache/thrift/TBaseHelper.java
+++ b/lib/java/src/org/apache/thrift/TBaseHelper.java
@@ -17,14 +17,24 @@
*/
package org.apache.thrift;
+import java.util.Comparator;
+import java.util.Iterator;
import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.SortedSet;
+import java.util.TreeMap;
+import java.util.TreeSet;
public class TBaseHelper {
-
+
+ private static final Comparator comparator = new NestedStructureComparator();
+
public static int compareTo(boolean a, boolean b) {
return Boolean.valueOf(a).compareTo(b);
}
-
+
public static int compareTo(byte a, byte b) {
if (a < b) {
return -1;
@@ -44,7 +54,7 @@
return 0;
}
}
-
+
public static int compareTo(int a, int b) {
if (a < b) {
return -1;
@@ -54,7 +64,7 @@
return 0;
}
}
-
+
public static int compareTo(long a, long b) {
if (a < b) {
return -1;
@@ -64,7 +74,7 @@
return 0;
}
}
-
+
public static int compareTo(double a, double b) {
if (a < b) {
return -1;
@@ -74,11 +84,11 @@
return 0;
}
}
-
+
public static int compareTo(String a, String b) {
return a.compareTo(b);
}
-
+
public static int compareTo(byte[] a, byte[] b) {
int sizeCompare = compareTo(a.length, b.length);
if (sizeCompare != 0) {
@@ -92,28 +102,103 @@
}
return 0;
}
-
+
public static int compareTo(Comparable a, Comparable b) {
return a.compareTo(b);
}
-
+
public static int compareTo(List a, List b) {
int lastComparison = compareTo(a.size(), b.size());
if (lastComparison != 0) {
return lastComparison;
}
for (int i = 0; i < a.size(); i++) {
- Object oA = a.get(i);
- Object oB = b.get(i);
- if (oA instanceof List) {
- lastComparison = compareTo((List)oA, (List)oB);
- } else {
- lastComparison = compareTo((Comparable)oA, (Comparable)oB);
- }
+ lastComparison = comparator.compare(a.get(i), b.get(i));
if (lastComparison != 0) {
return lastComparison;
}
}
return 0;
}
+
+ public static int compareTo(Set a, Set b) {
+ int lastComparison = compareTo(a.size(), b.size());
+ if (lastComparison != 0) {
+ return lastComparison;
+ }
+ SortedSet sortedA = new TreeSet(comparator);
+ sortedA.addAll(a);
+ SortedSet sortedB = new TreeSet(comparator);
+ sortedB.addAll(b);
+
+ Iterator iterA = sortedA.iterator();
+ Iterator iterB = sortedB.iterator();
+
+ // Compare each item.
+ while (iterA.hasNext() && iterB.hasNext()) {
+ lastComparison = comparator.compare(iterA.next(), iterB.next());
+ if (lastComparison != 0) {
+ return lastComparison;
+ }
+ }
+
+ return 0;
+ }
+
+ public static int compareTo(Map a, Map b) {
+ int lastComparison = compareTo(a.size(), b.size());
+ if (lastComparison != 0) {
+ return lastComparison;
+ }
+
+ // Sort a and b so we can compare them.
+ SortedMap sortedA = new TreeMap(comparator);
+ sortedA.putAll(a);
+ Iterator<Map.Entry> iterA = sortedA.entrySet().iterator();
+ SortedMap sortedB = new TreeMap(comparator);
+ sortedB.putAll(b);
+ Iterator<Map.Entry> iterB = sortedB.entrySet().iterator();
+
+ // Compare each item.
+ while (iterA.hasNext() && iterB.hasNext()) {
+ Map.Entry entryA = iterA.next();
+ Map.Entry entryB = iterB.next();
+ lastComparison = comparator.compare(entryA.getKey(), entryB.getKey());
+ if (lastComparison != 0) {
+ return lastComparison;
+ }
+ lastComparison = comparator.compare(entryA.getValue(), entryB.getValue());
+ if (lastComparison != 0) {
+ return lastComparison;
+ }
+ }
+
+ return 0;
+ }
+
+ /**
+ * Comparator to compare items inside a structure (e.g. a list, set, or map).
+ */
+ private static class NestedStructureComparator implements Comparator {
+ public int compare(Object oA, Object oB) {
+ if (oA == null && oB == null) {
+ return 0;
+ } else if (oA == null) {
+ return -1;
+ } else if (oB == null) {
+ return 1;
+ } else if (oA instanceof List) {
+ return compareTo((List)oA, (List)oB);
+ } else if (oA instanceof Set) {
+ return compareTo((Set)oA, (Set)oB);
+ } else if (oA instanceof Map) {
+ return compareTo((Map)oA, (Map)oB);
+ } else if (oA instanceof byte[]) {
+ return compareTo((byte[])oA, (byte[])oB);
+ } else {
+ return compareTo((Comparable)oA, (Comparable)oB);
+ }
+ }
+ }
+
}
diff --git a/lib/java/test/org/apache/thrift/TestStruct.java b/lib/java/test/org/apache/thrift/TestStruct.java
index 9465090..6ba48a4 100644
--- a/lib/java/test/org/apache/thrift/TestStruct.java
+++ b/lib/java/test/org/apache/thrift/TestStruct.java
@@ -4,6 +4,7 @@
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
+import java.util.HashMap;
import junit.framework.TestCase;
@@ -11,7 +12,9 @@
import thrift.test.Bonk;
import thrift.test.HolyMoley;
+import thrift.test.Insanity;
import thrift.test.Nesting;
+import thrift.test.Numberz;
import thrift.test.OneOfEach;
public class TestStruct extends TestCase {
@@ -135,4 +138,52 @@
bonk2.setMessage("m");
assertEquals(0, bonk1.compareTo(bonk2));
}
+
+ public void testCompareToWithDataStructures() {
+ Insanity insanity1 = new Insanity();
+ Insanity insanity2 = new Insanity();
+
+ // Both empty.
+ expectEquals(insanity1, insanity2);
+
+ insanity1.setUserMap(new HashMap<Numberz, Long>());
+ // insanity1.map = {}, insanity2.map = null
+ expectGreaterThan(insanity1, insanity2);
+
+ // insanity1.map = {2:1}, insanity2.map = null
+ insanity1.getUserMap().put(Numberz.TWO, 1l);
+ expectGreaterThan(insanity1, insanity2);
+
+ // insanity1.map = {2:1}, insanity2.map = {}
+ insanity2.setUserMap(new HashMap<Numberz, Long>());
+ expectGreaterThan(insanity1, insanity2);
+
+ // insanity1.map = {2:1}, insanity2.map = {2:2}
+ insanity2.getUserMap().put(Numberz.TWO, 2l);
+ expectLessThan(insanity1, insanity2);
+
+ // insanity1.map = {2:1, 3:5}, insanity2.map = {2:2}
+ insanity1.getUserMap().put(Numberz.THREE, 5l);
+ expectGreaterThan(insanity1, insanity2);
+
+ // insanity1.map = {2:1, 3:5}, insanity2.map = {2:1, 4:5}
+ insanity2.getUserMap().put(Numberz.TWO, 1l);
+ insanity2.getUserMap().put(Numberz.FIVE, 5l);
+ expectLessThan(insanity1, insanity2);
+ }
+
+ private void expectLessThan(Insanity insanity1, Insanity insanity2) {
+ int compareTo = insanity1.compareTo(insanity2);
+ assertTrue(insanity1 + " should be less than " + insanity2 + ", but is: " + compareTo, compareTo < 0);
+ }
+
+ private void expectGreaterThan(Insanity insanity1, Insanity insanity2) {
+ int compareTo = insanity1.compareTo(insanity2);
+ assertTrue(insanity1 + " should be greater than " + insanity2 + ", but is: " + compareTo, compareTo > 0);
+ }
+
+ private void expectEquals(Insanity insanity1, Insanity insanity2) {
+ int compareTo = insanity1.compareTo(insanity2);
+ assertEquals(insanity1 + " should be equal to " + insanity2 + ", but is: " + compareTo, 0, compareTo);
+ }
}
diff --git a/lib/java/test/org/apache/thrift/TestTBaseHelper.java b/lib/java/test/org/apache/thrift/TestTBaseHelper.java
new file mode 100644
index 0000000..e2d7869
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/TestTBaseHelper.java
@@ -0,0 +1,124 @@
+package org.apache.thrift;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import junit.framework.TestCase;
+
+public class TestTBaseHelper extends TestCase {
+ public void testByteArrayComparison() {
+ assertTrue(TBaseHelper.compareTo(new byte[]{'a','b'}, new byte[]{'a','c'}) < 0);
+ }
+
+ public void testSets() {
+ Set<String> a = new HashSet<String>();
+ Set<String> b = new HashSet<String>();
+
+ assertTrue(TBaseHelper.compareTo(a, b) == 0);
+
+ a.add("test");
+
+ assertTrue(TBaseHelper.compareTo(a, b) > 0);
+
+ b.add("test");
+
+ assertTrue(TBaseHelper.compareTo(a, b) == 0);
+
+ b.add("aardvark");
+
+ assertTrue(TBaseHelper.compareTo(a, b) < 0);
+
+ a.add("test2");
+
+ assertTrue(TBaseHelper.compareTo(a, b) > 0);
+ }
+
+ public void testNestedStructures() {
+ Set<List<String>> a = new HashSet<List<String>>();
+ Set<List<String>> b = new HashSet<List<String>>();
+
+ a.add(Arrays.asList(new String[] {"a","b"}));
+ b.add(Arrays.asList(new String[] {"a","b", "c"}));
+ a.add(Arrays.asList(new String[] {"a","b"}));
+ b.add(Arrays.asList(new String[] {"a","b", "c"}));
+
+ assertTrue(TBaseHelper.compareTo(a, b) < 0);
+ }
+
+ public void testMapsInSets() {
+ Set<Map<String, Long>> a = new HashSet<Map<String, Long>>();
+ Set<Map<String, Long>> b = new HashSet<Map<String, Long>>();
+
+ assertTrue(TBaseHelper.compareTo(a, b) == 0);
+
+ Map<String, Long> innerA = new HashMap<String, Long>();
+ Map<String, Long> innerB = new HashMap<String, Long>();
+ a.add(innerA);
+ b.add(innerB);
+
+ innerA.put("a", 1l);
+ innerB.put("a", 2l);
+
+ assertTrue(TBaseHelper.compareTo(a, b) < 0);
+ }
+
+ public void testByteArraysInMaps() {
+ Map<byte[], Long> a = new HashMap<byte[], Long>();
+ Map<byte[], Long> b = new HashMap<byte[], Long>();
+
+ assertTrue(TBaseHelper.compareTo(a, b) == 0);
+
+ a.put(new byte[]{'a','b'}, 1000L);
+ b.put(new byte[]{'a','b'}, 1000L);
+ a.put(new byte[]{'a','b', 'd'}, 1000L);
+ b.put(new byte[]{'a','b', 'a'}, 1000L);
+ assertTrue(TBaseHelper.compareTo(a, b) > 0);
+ }
+
+ public void testMapsWithNulls() {
+ Map<String, String> a = new HashMap<String, String>();
+ Map<String, String> b = new HashMap<String, String>();
+ a.put("a", null);
+ a.put("b", null);
+ b.put("a", null);
+ b.put("b", null);
+
+ assertTrue(TBaseHelper.compareTo(a, b) == 0);
+ }
+
+ public void testMapKeyComparison() {
+ Map<String, String> a = new HashMap<String, String>();
+ Map<String, String> b = new HashMap<String, String>();
+ a.put("a", "a");
+ b.put("b", "a");
+
+ assertTrue(TBaseHelper.compareTo(a, b) < 0);
+ }
+
+ public void testMapValueComparison() {
+ Map<String, String> a = new HashMap<String, String>();
+ Map<String, String> b = new HashMap<String, String>();
+ a.put("a", "b");
+ b.put("a", "a");
+
+ assertTrue(TBaseHelper.compareTo(a, b) > 0);
+ }
+
+ public void testByteArraysInSets() {
+ Set<byte[]> a = new HashSet<byte[]>();
+ Set<byte[]> b = new HashSet<byte[]>();
+
+ if (TBaseHelper.compareTo(a, b) != 0)
+ throw new RuntimeException("Set compare failed:" + a + " vs. " + b);
+
+ a.add(new byte[]{'a','b'});
+ b.add(new byte[]{'a','b'});
+ a.add(new byte[]{'a','b', 'd'});
+ b.add(new byte[]{'a','b', 'a'});
+ assertTrue(TBaseHelper.compareTo(a, b) > 0);
+ }
+}