THRIFT-623. java: Use a Java enum to represent field ids in generated structs
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@835538 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/cpp/src/generate/t_java_generator.cc b/compiler/cpp/src/generate/t_java_generator.cc
index 230d0d0..e3f5449 100644
--- a/compiler/cpp/src/generate/t_java_generator.cc
+++ b/compiler/cpp/src/generate/t_java_generator.cc
@@ -96,7 +96,6 @@
void generate_java_struct_writer(std::ofstream& out, t_struct* tstruct);
void generate_java_struct_tostring(std::ofstream& out, t_struct* tstruct);
void generate_java_meta_data_map(std::ofstream& out, t_struct* tstruct);
- void generate_java_field_name_map(std::ofstream& out, t_struct* tstruct);
void generate_field_value_meta_data(std::ofstream& out, t_type* type);
std::string get_java_type_string(t_type* type);
void generate_reflection_setters(std::ostringstream& out, t_type* type, std::string field_name, std::string cap_name);
@@ -309,8 +308,10 @@
"import java.util.ArrayList;\n" +
"import java.util.Map;\n" +
"import java.util.HashMap;\n" +
+ "import java.util.EnumMap;\n" +
"import java.util.Set;\n" +
"import java.util.HashSet;\n" +
+ "import java.util.EnumSet;\n" +
"import java.util.Collections;\n" +
"import java.util.BitSet;\n" +
"import java.util.Arrays;\n" +
@@ -671,7 +672,7 @@
indent(f_struct) <<
"public " << (is_final ? "final " : "") << "class " << tstruct->get_name()
- << " extends TUnion ";
+ << " extends TUnion<" << tstruct->get_name() << "._Fields> ";
if (is_comparable(tstruct)) {
f_struct << "implements Comparable<" << type_name(tstruct) << "> ";
@@ -690,8 +691,6 @@
generate_java_meta_data_map(f_struct, tstruct);
- generate_java_field_name_map(f_struct, tstruct);
-
generate_union_constructor(f_struct, tstruct);
f_struct << endl;
@@ -722,7 +721,7 @@
indent(out) << " super();" << endl;
indent(out) << "}" << endl << endl;
- indent(out) << "public " << type_name(tstruct) << "(int setField, Object value) {" << endl;
+ indent(out) << "public " << type_name(tstruct) << "(_Fields setField, Object value) {" << endl;
indent(out) << " super(setField, value);" << endl;
indent(out) << "}" << endl << endl;
@@ -762,7 +761,7 @@
generate_java_doc(out, field);
indent(out) << "public " << type_name(field->get_type()) << " get" << get_cap_name(field->get_name()) << "() {" << endl;
- indent(out) << " if (getSetField() == " << upcase_string(field->get_name()) << ") {" << endl;
+ indent(out) << " if (getSetField() == _Fields." << constant_name(field->get_name()) << ") {" << endl;
indent(out) << " return (" << type_name(field->get_type(), true) << ")getFieldValue();" << endl;
indent(out) << " } else {" << endl;
indent(out) << " throw new RuntimeException(\"Cannot get field '" << field->get_name()
@@ -777,7 +776,7 @@
if (type_can_be_null(field->get_type())) {
indent(out) << " if (value == null) throw new NullPointerException();" << endl;
}
- indent(out) << " setField_ = " << upcase_string(field->get_name()) << ";" << endl;
+ indent(out) << " setField_ = _Fields." << constant_name(field->get_name()) << ";" << endl;
indent(out) << " value_ = value;" << endl;
indent(out) << "}" << endl;
}
@@ -793,11 +792,16 @@
generate_get_field_desc(out, tstruct);
out << endl;
generate_get_struct_desc(out, tstruct);
+ out << endl;
+ indent(out) << "@Override" << endl;
+ indent(out) << "protected _Fields enumForId(short id) {" << endl;
+ indent(out) << " return _Fields.findByThriftIdOrThrow(id);" << endl;
+ indent(out) << "}" << endl;
}
void t_java_generator::generate_check_type(ofstream& out, t_struct* tstruct) {
indent(out) << "@Override" << endl;
- indent(out) << "protected void checkType(short setField, Object value) throws ClassCastException {" << endl;
+ indent(out) << "protected void checkType(_Fields setField, Object value) throws ClassCastException {" << endl;
indent_up();
indent(out) << "switch (setField) {" << endl;
@@ -809,7 +813,7 @@
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
t_field* field = (*m_iter);
- indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl;
+ indent(out) << "case " << constant_name(field->get_name()) << ":" << endl;
indent(out) << " if (value instanceof " << type_name(field->get_type(), true, false, true) << ") {" << endl;
indent(out) << " break;" << endl;
indent(out) << " }" << endl;
@@ -835,7 +839,7 @@
indent_up();
- indent(out) << "switch (field.id) {" << endl;
+ indent(out) << "switch (_Fields.findByThriftId(field.id)) {" << endl;
indent_up();
const vector<t_field*>& members = tstruct->get_members();
@@ -844,9 +848,9 @@
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
t_field* field = (*m_iter);
- indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl;
+ indent(out) << "case " << constant_name(field->get_name()) << ":" << endl;
indent_up();
- indent(out) << "if (field.type == " << upcase_string(field->get_name()) << "_FIELD_DESC.type) {" << endl;
+ indent(out) << "if (field.type == " << constant_name(field->get_name()) << "_FIELD_DESC.type) {" << endl;
indent_up();
indent(out) << type_name(field->get_type(), true, false) << " " << field->get_name() << ";" << endl;
generate_deserialize_field(out, field, "");
@@ -872,7 +876,7 @@
void t_java_generator::generate_write_value(ofstream& out, t_struct* tstruct) {
indent(out) << "@Override" << endl;
- indent(out) << "protected void writeValue(TProtocol oprot, short setField, Object value) throws TException {" << endl;
+ indent(out) << "protected void writeValue(TProtocol oprot, _Fields setField, Object value) throws TException {" << endl;
indent_up();
@@ -885,7 +889,7 @@
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
t_field* field = (*m_iter);
- indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl;
+ indent(out) << "case " << constant_name(field->get_name()) << ":" << endl;
indent_up();
indent(out) << type_name(field->get_type(), true, false) << " " << field->get_name()
<< " = (" << type_name(field->get_type(), true, false) << ")getFieldValue();" << endl;
@@ -909,7 +913,7 @@
void t_java_generator::generate_get_field_desc(ofstream& out, t_struct* tstruct) {
indent(out) << "@Override" << endl;
- indent(out) << "protected TField getFieldDesc(int setField) {" << endl;
+ indent(out) << "protected TField getFieldDesc(_Fields setField) {" << endl;
indent_up();
const vector<t_field*>& members = tstruct->get_members();
@@ -920,8 +924,8 @@
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
t_field* field = (*m_iter);
- indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl;
- indent(out) << " return " << upcase_string(field->get_name()) << "_FIELD_DESC;" << endl;
+ indent(out) << "case " << constant_name(field->get_name()) << ":" << endl;
+ indent(out) << " return " << constant_name(field->get_name()) << "_FIELD_DESC;" << endl;
}
indent(out) << "default:" << endl;
@@ -1020,7 +1024,7 @@
if (is_exception) {
out << "extends Exception ";
}
- out << "implements TBase, java.io.Serializable, Cloneable";
+ out << "implements TBase<" << tstruct->get_name() << "._Fields>, java.io.Serializable, Cloneable";
if (is_comparable(tstruct)) {
out << ", Comparable<" << type_name(tstruct) << ">";
@@ -1036,8 +1040,10 @@
const vector<t_field*>& members = tstruct->get_members();
vector<t_field*>::const_iterator m_iter;
+ out << endl;
+
generate_field_descs(out, tstruct);
-
+
out << endl;
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
@@ -1050,8 +1056,10 @@
out << declare_field(*m_iter, false) << endl;
}
+ out << endl;
+
generate_field_name_constants(out, tstruct);
-
+
// isset data
if (members.size() > 0) {
out << endl;
@@ -1078,8 +1086,6 @@
bool all_optional_members = true;
- generate_java_field_name_map(out, tstruct);
-
// Default constructor
indent(out) <<
"public " << tstruct->get_name() << "() {" << endl;
@@ -1389,38 +1395,37 @@
"}" << endl;
// Switch statement on the field we are reading
- indent(out) <<
- "switch (field.id)" << endl;
+ indent(out) << "_Fields fieldId = _Fields.findByThriftId(field.id);" << endl;
+ indent(out) << "if (fieldId == null) {" << endl;
+ indent(out) << " TProtocolUtil.skip(iprot, field.type);" << endl;
+ indent(out) << "} else {" << endl;
+ indent_up();
- scope_up(out);
+ indent(out) << "switch (fieldId)" << endl;
- // Generate deserialization code for known cases
- for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
- indent(out) <<
- "case " << upcase_string((*f_iter)->get_name()) << ":" << endl;
- indent_up();
- indent(out) <<
- "if (field.type == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl;
- indent_up();
+ scope_up(out);
- generate_deserialize_field(out, *f_iter, "this.");
- generate_isset_set(out, *f_iter);
- indent_down();
- out <<
- indent() << "} else { " << endl <<
- indent() << " TProtocolUtil.skip(iprot, field.type);" << endl <<
- indent() << "}" << endl <<
- indent() << "break;" << endl;
- indent_down();
- }
+ // Generate deserialization code for known cases
+ for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+ indent(out) <<
+ "case " << constant_name((*f_iter)->get_name()) << ":" << endl;
+ indent_up();
+ indent(out) <<
+ "if (field.type == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl;
+ indent_up();
- // In the default case we skip the field
+ generate_deserialize_field(out, *f_iter, "this.");
+ generate_isset_set(out, *f_iter);
+ indent_down();
out <<
- indent() << "default:" << endl <<
+ indent() << "} else { " << endl <<
indent() << " TProtocolUtil.skip(iprot, field.type);" << endl <<
- indent() << " break;" << endl;
+ indent() << "}" << endl <<
+ indent() << "break;" << endl;
+ indent_down();
+ }
- scope_down(out);
+ scope_down(out);
// Read field end marker
indent(out) <<
@@ -1428,8 +1433,11 @@
scope_down(out);
+ indent_down();
+ indent(out) << "}" << endl;
+
out <<
- indent() << "iprot.readStructEnd();" << endl << endl;
+ indent() << "iprot.readStructEnd();" << endl;
// in non-beans style, check for required fields of primitive type
// (which can be checked here but not in the general validate method)
@@ -1625,7 +1633,7 @@
}
void t_java_generator::generate_reflection_getters(ostringstream& out, t_type* type, string field_name, string cap_name) {
- indent(out) << "case " << upcase_string(field_name) << ":" << endl;
+ indent(out) << "case " << constant_name(field_name) << ":" << endl;
indent_up();
if (type->is_base_type() && !type->is_string()) {
@@ -1640,7 +1648,7 @@
}
void t_java_generator::generate_reflection_setters(ostringstream& out, t_type* type, string field_name, string cap_name) {
- indent(out) << "case " << upcase_string(field_name) << ":" << endl;
+ indent(out) << "case " << constant_name(field_name) << ":" << endl;
indent_up();
indent(out) << "if (value == null) {" << endl;
indent(out) << " unset" << get_cap_name(field_name) << "();" << endl;
@@ -1674,36 +1682,29 @@
// create the setter
- indent(out) << "public void setFieldValue(int fieldID, Object value) {" << endl;
- indent_up();
-
- indent(out) << "switch (fieldID) {" << endl;
-
+
+ indent(out) << "public void setFieldValue(_Fields field, Object value) {" << endl;
+ indent(out) << " switch (field) {" << endl;
out << setter_stream.str();
+ indent(out) << " }" << endl;
+ indent(out) << "}" << endl << endl;
- indent(out) << "default:" << endl;
- indent(out) << " throw new IllegalArgumentException(\"Field \" + fieldID + \" doesn't exist!\");" << endl;
-
- indent(out) << "}" << endl;
-
- indent_down();
+ indent(out) << "public void setFieldValue(int fieldID, Object value) {" << endl;
+ indent(out) << " setFieldValue(_Fields.findByThriftIdOrThrow(fieldID), value);" << endl;
indent(out) << "}" << endl << endl;
// create the getter
- indent(out) << "public Object getFieldValue(int fieldID) {" << endl;
+ indent(out) << "public Object getFieldValue(_Fields field) {" << endl;
indent_up();
-
- indent(out) << "switch (fieldID) {" << endl;
-
+ indent(out) << "switch (field) {" << endl;
out << getter_stream.str();
-
- indent(out) << "default:" << endl;
- indent(out) << " throw new IllegalArgumentException(\"Field \" + fieldID + \" doesn't exist!\");" << endl;
-
indent(out) << "}" << endl;
-
+ indent(out) << "throw new IllegalStateException();" << endl;
indent_down();
+ indent(out) << "}" << endl << endl;
+ indent(out) << "public Object getFieldValue(int fieldId) {" << endl;
+ indent(out) << " return getFieldValue(_Fields.findByThriftIdOrThrow(fieldId));" << endl;
indent(out) << "}" << endl << endl;
}
@@ -1713,26 +1714,27 @@
vector<t_field*>::const_iterator f_iter;
// create the isSet method
- indent(out) << "// Returns true if field corresponding to fieldID is set (has been asigned a value) and false otherwise" << endl;
- indent(out) << "public boolean isSet(int fieldID) {" << endl;
+ indent(out) << "/** Returns true if field corresponding to fieldID is set (has been asigned a value) and false otherwise */" << endl;
+ indent(out) << "public boolean isSet(_Fields field) {" << endl;
indent_up();
- indent(out) << "switch (fieldID) {" << endl;
+ indent(out) << "switch (field) {" << endl;
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
t_field* field = *f_iter;
- indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl;
+ indent(out) << "case " << constant_name(field->get_name()) << ":" << endl;
indent_up();
indent(out) << "return " << generate_isset_check(field) << ";" << endl;
indent_down();
}
- indent(out) << "default:" << endl;
- indent(out) << " throw new IllegalArgumentException(\"Field \" + fieldID + \" doesn't exist!\");" << endl;
-
indent(out) << "}" << endl;
-
+ indent(out) << "throw new IllegalStateException();" << endl;
indent_down();
indent(out) << "}" << endl << endl;
+
+ indent(out) << "public boolean isSet(int fieldID) {" << endl;
+ indent(out) << " return isSet(_Fields.findByThriftIdOrThrow(fieldID));" << endl;
+ indent(out) << "}" << endl << endl;
}
/**
@@ -1861,7 +1863,7 @@
indent(out) << "}" << endl << endl;
// isSet method
- indent(out) << "// Returns true if field " << field_name << " is set (has been asigned a value) and false otherwise" << endl;
+ indent(out) << "/** Returns true if field " << field_name << " is set (has been asigned a value) and false otherwise */" << endl;
indent(out) << "public boolean is" << get_cap_name("set") << cap_name << "() {" << endl;
indent_up();
if (type_can_be_null(type)) {
@@ -1979,14 +1981,14 @@
vector<t_field*>::const_iterator f_iter;
// Static Map with fieldID -> FieldMetaData mappings
- indent(out) << "public static final Map<Integer, FieldMetaData> metaDataMap = Collections.unmodifiableMap(new HashMap<Integer, FieldMetaData>() {{" << endl;
+ indent(out) << "public static final Map<_Fields, FieldMetaData> metaDataMap = Collections.unmodifiableMap(new EnumMap<_Fields, FieldMetaData>(_Fields.class) {{" << endl;
// Populate map
indent_up();
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
t_field* field = *f_iter;
std::string field_name = field->get_name();
- indent(out) << "put(" << upcase_string(field_name) << ", new FieldMetaData(\"" << field_name << "\", ";
+ indent(out) << "put(_Fields." << constant_name(field_name) << ", new FieldMetaData(\"" << field_name << "\", ";
// Set field requirement type (required, optional, etc.)
if (field->get_req() == t_field::T_REQUIRED) {
@@ -1997,7 +1999,7 @@
out << "TFieldRequirementType.DEFAULT, ";
}
- // Create value meta data
+ // Create value meta data
generate_field_value_meta_data(out, field->get_type());
out << "));" << endl;
}
@@ -2012,30 +2014,6 @@
indent(out) << "}" << endl << endl;
}
-/**
- * Generates a static map from field names to field IDs
- *
- * @param tstruct The struct definition
- */
-void t_java_generator::generate_java_field_name_map(ofstream& out,
- t_struct* tstruct) {
- const vector<t_field*>& fields = tstruct->get_members();
- vector<t_field*>::const_iterator f_iter;
-
- // Static Map with fieldName -> fieldID
- indent(out) << "public static final Map<String, Integer> fieldNameMap = Collections.unmodifiableMap(new HashMap<String, Integer>() {{" << endl;
-
- // Populate map
- indent_up();
- for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
- t_field* field = *f_iter;
- std::string field_name = field->get_name();
- indent(out) << "put(\"" << field->get_name() << "\", new Integer(" << upcase_string(field->get_name()) << "));" << endl;
- }
- indent_down();
- indent(out) << "}});" << endl << endl;
-}
-
/**
* Returns a string with the java representation of the given thrift type
* (e.g. for the type struct it returns "TType.STRUCT")
@@ -3514,13 +3492,78 @@
}
void t_java_generator::generate_field_name_constants(ofstream& out, t_struct* tstruct) {
- // Members are public for -java, private for -javabean
+ indent(out) << "/** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */" << endl;
+ indent(out) << "public enum _Fields implements TFieldIdEnum {" << endl;
+
+ indent_up();
+ bool first = true;
const vector<t_field*>& members = tstruct->get_members();
vector<t_field*>::const_iterator m_iter;
-
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
- indent(out) << "public static final int " << upcase_string((*m_iter)->get_name()) << " = " << (*m_iter)->get_key() << ";" << endl;
+ if (!first) {
+ out << "," << endl;
+ }
+ first = false;
+ generate_java_doc(out, *m_iter);
+ indent(out) << constant_name((*m_iter)->get_name()) << "((short)" << (*m_iter)->get_key() << ", \"" << (*m_iter)->get_name() << "\")";
}
+
+ out << ";" << endl << endl;
+
+ indent(out) << "private static final Map<Integer, _Fields> byId = new HashMap<Integer, _Fields>();" << endl;
+ indent(out) << "private static final Map<String, _Fields> byName = new HashMap<String, _Fields>();" << endl;
+ out << endl;
+
+ indent(out) << "static {" << endl;
+ indent(out) << " for (_Fields field : EnumSet.allOf(_Fields.class)) {" << endl;
+ indent(out) << " byId.put((int)field._thriftId, field);" << endl;
+ indent(out) << " byName.put(field.getFieldName(), field);" << endl;
+ indent(out) << " }" << endl;
+ indent(out) << "}" << endl << endl;
+
+ indent(out) << "/**" << endl;
+ indent(out) << " * Find the _Fields constant that matches fieldId, or null if its not found." << endl;
+ indent(out) << " */" << endl;
+ indent(out) << "public static _Fields findByThriftId(int fieldId) {" << endl;
+ indent(out) << " return byId.get(fieldId);" << endl;
+ indent(out) << "}" << endl << endl;
+
+ indent(out) << "/**" << endl;
+ indent(out) << " * Find the _Fields constant that matches fieldId, throwing an exception" << endl;
+ indent(out) << " * if it is not found." << endl;
+ indent(out) << " */" << endl;
+ indent(out) << "public static _Fields findByThriftIdOrThrow(int fieldId) {" << endl;
+ indent(out) << " _Fields fields = findByThriftId(fieldId);" << endl;
+ indent(out) << " if (fields == null) throw new IllegalArgumentException(\"Field \" + fieldId + \" doesn't exist!\");" << endl;
+ indent(out) << " return fields;" << endl;
+ indent(out) << "}" << endl << endl;
+
+ indent(out) << "/**" << endl;
+ indent(out) << " * Find the _Fields constant that matches name, or null if its not found." << endl;
+ indent(out) << " */" << endl;
+ indent(out) << "public static _Fields findByName(String name) {" << endl;
+ indent(out) << " return byName.get(name);" << endl;
+ indent(out) << "}" << endl << endl;
+
+ indent(out) << "private final short _thriftId;" << endl;
+ indent(out) << "private final String _fieldName;" << endl << endl;
+
+ indent(out) << "_Fields(short thriftId, String fieldName) {" << endl;
+ indent(out) << " _thriftId = thriftId;" << endl;
+ indent(out) << " _fieldName = fieldName;" << endl;
+ indent(out) << "}" << endl << endl;
+
+ indent(out) << "public short getThriftFieldId() {" << endl;
+ indent(out) << " return _thriftId;" << endl;
+ indent(out) << "}" << endl << endl;
+
+ indent(out) << "public String getFieldName() {" << endl;
+ indent(out) << " return _fieldName;" << endl;
+ indent(out) << "}" << endl;
+
+ indent_down();
+
+ indent(out) << "}" << endl;
}
bool t_java_generator::is_comparable(t_struct* tstruct) {
diff --git a/lib/java/src/org/apache/thrift/TBase.java b/lib/java/src/org/apache/thrift/TBase.java
index 3c3b12f..bfa0abe 100644
--- a/lib/java/src/org/apache/thrift/TBase.java
+++ b/lib/java/src/org/apache/thrift/TBase.java
@@ -27,7 +27,7 @@
* Generic base interface for generated Thrift objects.
*
*/
-public interface TBase extends Serializable {
+public interface TBase<F extends TFieldIdEnum> extends Serializable {
/**
* Reads the TObject from the given input protocol.
@@ -48,23 +48,49 @@
*
* @param fieldId The field's id tag as found in the IDL.
*/
+ @Deprecated
public boolean isSet(int fieldId);
/**
+ * Check if a field is currently set or unset.
+ *
+ * @param field
+ */
+ public boolean isSet(F field);
+
+ /**
* Get a field's value by id. Primitive types will be wrapped in the
* appropriate "boxed" types.
*
* @param fieldId The field's id tag as found in the IDL.
*/
+ @Deprecated
public Object getFieldValue(int fieldId);
/**
+ * Get a field's value by field variable. Primitive types will be wrapped in
+ * the appropriate "boxed" types.
+ *
+ * @param field
+ */
+ public Object getFieldValue(F field);
+
+ /**
* Set a field's value by id. Primitive types must be "boxed" in the
* appropriate object wrapper type.
*
* @param fieldId The field's id tag as found in the IDL.
*/
+ @Deprecated
public void setFieldValue(int fieldId, Object value);
- public TBase deepCopy();
+ /**
+ * Set a field's value by field variable. Primitive types must be "boxed" in
+ * the appropriate object wrapper type.
+ *
+ * @param field
+ */
+ public void setFieldValue(F field, Object value);
+
+ public TBase<F> deepCopy();
}
diff --git a/lib/java/src/org/apache/thrift/TDeserializer.java b/lib/java/src/org/apache/thrift/TDeserializer.java
index 7b7d51d..750ea48 100644
--- a/lib/java/src/org/apache/thrift/TDeserializer.java
+++ b/lib/java/src/org/apache/thrift/TDeserializer.java
@@ -29,6 +29,7 @@
import org.apache.thrift.protocol.TProtocolUtil;
import org.apache.thrift.protocol.TType;
import org.apache.thrift.transport.TIOStreamTransport;
+import org.apache.thrift.TFieldIdEnum;
/**
* Generic utility for easily deserializing objects from a byte array or Java
@@ -92,7 +93,7 @@
* @param fieldIdPath The FieldId's that define a path tb
* @throws TException
*/
- public void partialDeserialize(TBase tb, byte[] bytes, int ... fieldIdPath) throws TException {
+ public void partialDeserialize(TBase tb, byte[] bytes, TFieldIdEnum ... fieldIdPath) throws TException {
// if there are no elements in the path, then the user is looking for the
// regular deserialize method
// TODO: it might be nice not to have to do this check every time to save
@@ -116,11 +117,11 @@
// we can stop searching if we either see a stop or we go past the field
// id we're looking for (since fields should now be serialized in asc
// order).
- if (field.type == TType.STOP || field.id > fieldIdPath[curPathIndex]) {
+ if (field.type == TType.STOP || field.id > fieldIdPath[curPathIndex].getThriftFieldId()) {
return;
}
- if (field.id != fieldIdPath[curPathIndex]) {
+ if (field.id != fieldIdPath[curPathIndex].getThriftFieldId()) {
// Not the field we're looking for. Skip field.
TProtocolUtil.skip(iprot, field.type);
iprot.readFieldEnd();
diff --git a/lib/java/src/org/apache/thrift/TFieldIdEnum.java b/lib/java/src/org/apache/thrift/TFieldIdEnum.java
new file mode 100644
index 0000000..6bcc9f2
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/TFieldIdEnum.java
@@ -0,0 +1,16 @@
+package org.apache.thrift;
+
+/**
+ * Interface for all generated struct Fields objects.
+ */
+public interface TFieldIdEnum {
+ /**
+ * Get the Thrift field id for the named field.
+ */
+ public short getThriftFieldId();
+
+ /**
+ * Get the field's name, exactly as in the IDL.
+ */
+ public String getFieldName();
+}
diff --git a/lib/java/src/org/apache/thrift/TUnion.java b/lib/java/src/org/apache/thrift/TUnion.java
index 9375475..219669f 100644
--- a/lib/java/src/org/apache/thrift/TUnion.java
+++ b/lib/java/src/org/apache/thrift/TUnion.java
@@ -12,28 +12,28 @@
import org.apache.thrift.protocol.TProtocolException;
import org.apache.thrift.protocol.TStruct;
-public abstract class TUnion implements TBase {
+public abstract class TUnion<F extends TFieldIdEnum> implements TBase<F> {
protected Object value_;
- protected int setField_;
-
+ protected F setField_;
+
protected TUnion() {
- setField_ = 0;
+ setField_ = null;
value_ = null;
}
- protected TUnion(int setField, Object value) {
+ protected TUnion(F setField, Object value) {
setFieldValue(setField, value);
}
- protected TUnion(TUnion other) {
+ protected TUnion(TUnion<F> other) {
if (!other.getClass().equals(this.getClass())) {
throw new ClassCastException();
}
setField_ = other.setField_;
value_ = deepCopyObject(other.value_);
}
-
+
private static Object deepCopyObject(Object o) {
if (o instanceof TBase) {
return ((TBase)o).deepCopy();
@@ -52,7 +52,7 @@
return o;
}
}
-
+
private static Map deepCopyMap(Map<Object, Object> map) {
Map copy = new HashMap();
for (Map.Entry<Object, Object> entry : map.entrySet()) {
@@ -77,15 +77,15 @@
return copy;
}
- public int getSetField() {
+ public F getSetField() {
return setField_;
}
-
+
public Object getFieldValue() {
return value_;
}
-
- public Object getFieldValue(int fieldId) {
+
+ public Object getFieldValue(F fieldId) {
if (fieldId != setField_) {
throw new IllegalArgumentException("Cannot get the value of field " + fieldId + " because union's set field is " + setField_);
}
@@ -93,16 +93,24 @@
return getFieldValue();
}
+ public Object getFieldValue(int fieldId) {
+ return getFieldValue(enumForId((short)fieldId));
+ }
+
public boolean isSet() {
- return setField_ != 0;
+ return setField_ != null;
}
- public boolean isSet(int fieldId) {
+ public boolean isSet(F fieldId) {
return setField_ == fieldId;
}
+ public boolean isSet(int fieldId) {
+ return isSet(enumForId((short)fieldId));
+ }
+
public void read(TProtocol iprot) throws TException {
- setField_ = 0;
+ setField_ = null;
value_ = null;
iprot.readStructBegin();
@@ -111,7 +119,7 @@
value_ = readValue(iprot, field);
if (value_ != null) {
- setField_ = field.id;
+ setField_ = enumForId(field.id);
}
iprot.readFieldEnd();
@@ -122,19 +130,23 @@
iprot.readStructEnd();
}
- public void setFieldValue(int fieldId, Object value) {
- checkType((short)fieldId, value);
- setField_ = (short)fieldId;
+ public void setFieldValue(F fieldId, Object value) {
+ checkType(fieldId, value);
+ setField_ = fieldId;
value_ = value;
}
+ public void setFieldValue(int fieldId, Object value) {
+ setFieldValue(enumForId((short)fieldId), value);
+ }
+
public void write(TProtocol oprot) throws TException {
- if (getSetField() == 0 || getFieldValue() == null) {
+ if (getSetField() == null || getFieldValue() == null) {
throw new TProtocolException("Cannot write a TUnion with no set value!");
}
oprot.writeStructBegin(getStructDesc());
oprot.writeFieldBegin(getFieldDesc(setField_));
- writeValue(oprot, (short)setField_, value_);
+ writeValue(oprot, setField_, value_);
oprot.writeFieldEnd();
oprot.writeFieldStop();
oprot.writeStructEnd();
@@ -146,7 +158,7 @@
* @param setField
* @param value
*/
- protected abstract void checkType(short setField, Object value) throws ClassCastException;
+ protected abstract void checkType(F setField, Object value) throws ClassCastException;
/**
* Implementation should be generated to read the right stuff from the wire
@@ -156,11 +168,13 @@
*/
protected abstract Object readValue(TProtocol iprot, TField field) throws TException;
- protected abstract void writeValue(TProtocol oprot, short setField, Object value) throws TException;
+ protected abstract void writeValue(TProtocol oprot, F setField, Object value) throws TException;
protected abstract TStruct getStructDesc();
- protected abstract TField getFieldDesc(int setField);
+ protected abstract TField getFieldDesc(F setField);
+
+ protected abstract F enumForId(short id);
@Override
public String toString() {
diff --git a/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java b/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java
index 3e90a8b..b634291 100644
--- a/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java
+++ b/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java
@@ -22,6 +22,7 @@
import java.util.HashMap;
import java.util.Map;
import org.apache.thrift.TBase;
+import org.apache.thrift.TFieldIdEnum;
/**
* This class is used to store meta data about thrift fields. Every field in a
@@ -32,10 +33,10 @@
public final String fieldName;
public final byte requirementType;
public final FieldValueMetaData valueMetaData;
- private static Map<Class<? extends TBase>, Map<Integer, FieldMetaData>> structMap;
+ private static Map<Class<? extends TBase>, Map<? extends TFieldIdEnum, FieldMetaData>> structMap;
static {
- structMap = new HashMap<Class<? extends TBase>, Map<Integer, FieldMetaData>>();
+ structMap = new HashMap<Class<? extends TBase>, Map<? extends TFieldIdEnum, FieldMetaData>>();
}
public FieldMetaData(String name, byte req, FieldValueMetaData vMetaData){
@@ -44,7 +45,7 @@
this.valueMetaData = vMetaData;
}
- public static void addStructMetaDataMap(Class<? extends TBase> sClass, Map<Integer, FieldMetaData> map){
+ public static void addStructMetaDataMap(Class<? extends TBase> sClass, Map<? extends TFieldIdEnum, FieldMetaData> map){
structMap.put(sClass, map);
}
@@ -54,7 +55,7 @@
*
* @param sClass The TBase class for which the metadata map is requested
*/
- public static Map<Integer, FieldMetaData> getStructMetaDataMap(Class<? extends TBase> sClass){
+ public static Map<? extends TFieldIdEnum, FieldMetaData> getStructMetaDataMap(Class<? extends TBase> sClass){
if (!structMap.containsKey(sClass)){ // Load class if it hasn't been loaded
try{
sClass.newInstance();
diff --git a/lib/java/test/org/apache/thrift/test/MetaDataTest.java b/lib/java/test/org/apache/thrift/test/MetaDataTest.java
index 386ec9b..8bb9f2c 100644
--- a/lib/java/test/org/apache/thrift/test/MetaDataTest.java
+++ b/lib/java/test/org/apache/thrift/test/MetaDataTest.java
@@ -28,43 +28,43 @@
import org.apache.thrift.meta_data.SetMetaData;
import org.apache.thrift.meta_data.StructMetaData;
import org.apache.thrift.protocol.TType;
+import org.apache.thrift.TFieldIdEnum;
import thrift.test.*;
public class MetaDataTest {
-
public static void main(String[] args) throws Exception {
- Map<Integer, FieldMetaData> mdMap = CrazyNesting.metaDataMap;
-
+ Map<CrazyNesting._Fields, FieldMetaData> mdMap = CrazyNesting.metaDataMap;
+
// Check for struct fields existence
if (mdMap.size() != 3)
throw new RuntimeException("metadata map contains wrong number of entries!");
- if (!mdMap.containsKey(CrazyNesting.SET_FIELD) || !mdMap.containsKey(CrazyNesting.LIST_FIELD) || !mdMap.containsKey(CrazyNesting.STRING_FIELD))
+ if (!mdMap.containsKey(CrazyNesting._Fields.SET_FIELD) || !mdMap.containsKey(CrazyNesting._Fields.LIST_FIELD) || !mdMap.containsKey(CrazyNesting._Fields.STRING_FIELD))
throw new RuntimeException("metadata map doesn't contain entry for a struct field!");
-
+
// Check for struct fields contents
- if (!mdMap.get(CrazyNesting.STRING_FIELD).fieldName.equals("string_field") ||
- !mdMap.get(CrazyNesting.LIST_FIELD).fieldName.equals("list_field") ||
- !mdMap.get(CrazyNesting.SET_FIELD).fieldName.equals("set_field"))
+ if (!mdMap.get(CrazyNesting._Fields.STRING_FIELD).fieldName.equals("string_field") ||
+ !mdMap.get(CrazyNesting._Fields.LIST_FIELD).fieldName.equals("list_field") ||
+ !mdMap.get(CrazyNesting._Fields.SET_FIELD).fieldName.equals("set_field"))
throw new RuntimeException("metadata map contains a wrong fieldname");
- if (mdMap.get(CrazyNesting.STRING_FIELD).requirementType != TFieldRequirementType.DEFAULT ||
- mdMap.get(CrazyNesting.LIST_FIELD).requirementType != TFieldRequirementType.REQUIRED ||
- mdMap.get(CrazyNesting.SET_FIELD).requirementType != TFieldRequirementType.OPTIONAL)
+ if (mdMap.get(CrazyNesting._Fields.STRING_FIELD).requirementType != TFieldRequirementType.DEFAULT ||
+ mdMap.get(CrazyNesting._Fields.LIST_FIELD).requirementType != TFieldRequirementType.REQUIRED ||
+ mdMap.get(CrazyNesting._Fields.SET_FIELD).requirementType != TFieldRequirementType.OPTIONAL)
throw new RuntimeException("metadata map contains the wrong requirement type for a field");
- if (mdMap.get(CrazyNesting.STRING_FIELD).valueMetaData.type != TType.STRING ||
- mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData.type != TType.LIST ||
- mdMap.get(CrazyNesting.SET_FIELD).valueMetaData.type != TType.SET)
+ if (mdMap.get(CrazyNesting._Fields.STRING_FIELD).valueMetaData.type != TType.STRING ||
+ mdMap.get(CrazyNesting._Fields.LIST_FIELD).valueMetaData.type != TType.LIST ||
+ mdMap.get(CrazyNesting._Fields.SET_FIELD).valueMetaData.type != TType.SET)
throw new RuntimeException("metadata map contains the wrong requirement type for a field");
-
+
// Check nested structures
- if (!mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData.isContainer())
+ if (!mdMap.get(CrazyNesting._Fields.LIST_FIELD).valueMetaData.isContainer())
throw new RuntimeException("value metadata for a list is stored as non-container!");
- if (mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData.isStruct())
+ if (mdMap.get(CrazyNesting._Fields.LIST_FIELD).valueMetaData.isStruct())
throw new RuntimeException("value metadata for a list is stored as a struct!");
- if (((MapMetaData)((ListMetaData)((SetMetaData)((MapMetaData)((MapMetaData)((ListMetaData)mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData).elemMetaData).valueMetaData).valueMetaData).elemMetaData).elemMetaData).keyMetaData.type != TType.STRUCT)
+ if (((MapMetaData)((ListMetaData)((SetMetaData)((MapMetaData)((MapMetaData)((ListMetaData)mdMap.get(CrazyNesting._Fields.LIST_FIELD).valueMetaData).elemMetaData).valueMetaData).valueMetaData).elemMetaData).elemMetaData).keyMetaData.type != TType.STRUCT)
throw new RuntimeException("metadata map contains wrong type for a value in a deeply nested structure");
- if (((StructMetaData)((MapMetaData)((ListMetaData)((SetMetaData)((MapMetaData)((MapMetaData)((ListMetaData)mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData).elemMetaData).valueMetaData).valueMetaData).elemMetaData).elemMetaData).keyMetaData).structClass != Insanity.class)
+ if (((StructMetaData)((MapMetaData)((ListMetaData)((SetMetaData)((MapMetaData)((MapMetaData)((ListMetaData)mdMap.get(CrazyNesting._Fields.LIST_FIELD).valueMetaData).elemMetaData).valueMetaData).valueMetaData).elemMetaData).elemMetaData).keyMetaData).structClass != Insanity.class)
throw new RuntimeException("metadata map contains wrong class for a struct in a deeply nested structure");
-
+
// Check that FieldMetaData contains a map with metadata for all generated struct classes
if (FieldMetaData.getStructMetaDataMap(CrazyNesting.class) == null ||
FieldMetaData.getStructMetaDataMap(Insanity.class) == null ||
@@ -74,12 +74,8 @@
FieldMetaData.getStructMetaDataMap(Insanity.class) != Insanity.metaDataMap)
throw new RuntimeException("global metadata map contains wrong entry for a loaded struct");
- Map<String, Integer> fnMap = CrazyNesting.fieldNameMap;
- if (fnMap.size() != 3) {
- throw new RuntimeException("Field Name Map contains wrong number of entries!");
- }
- for (Map.Entry<Integer, FieldMetaData> mdEntry : mdMap.entrySet()) {
- if (!fnMap.get(mdEntry.getValue().fieldName).equals(mdEntry.getKey())) {
+ for (Map.Entry<? extends TFieldIdEnum, FieldMetaData> mdEntry : mdMap.entrySet()) {
+ if (!CrazyNesting._Fields.findByName(mdEntry.getValue().fieldName).equals(mdEntry.getKey())) {
throw new RuntimeException("Field name map contained invalid Name <-> ID mapping");
}
}
diff --git a/lib/java/test/org/apache/thrift/test/PartialDeserializeTest.java b/lib/java/test/org/apache/thrift/test/PartialDeserializeTest.java
index d88a686..a7fa59b 100644
--- a/lib/java/test/org/apache/thrift/test/PartialDeserializeTest.java
+++ b/lib/java/test/org/apache/thrift/test/PartialDeserializeTest.java
@@ -24,6 +24,7 @@
import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.TSerializer;
+import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TCompactProtocol;
import org.apache.thrift.protocol.TJSONProtocol;
@@ -48,7 +49,7 @@
// 1:Union
// 1.3:OneOfEach
OneOfEach Level3OneOfEach = Fixtures.oneOfEach;
- TestUnion Level2TestUnion = new TestUnion(TestUnion.STRUCT_FIELD, Level3OneOfEach);
+ TestUnion Level2TestUnion = new TestUnion(TestUnion._Fields.STRUCT_FIELD, Level3OneOfEach);
StructWithAUnion Level1SWU = new StructWithAUnion(Level2TestUnion);
Backwards bw = new Backwards(2, 1);
@@ -59,20 +60,20 @@
testPartialDeserialize(factory, Level1SWU, new StructWithAUnion(), Level1SWU);
//Level 2 test
- testPartialDeserialize(factory, Level1SWU, new TestUnion(), Level2TestUnion, StructWithAUnion.TEST_UNION);
+ testPartialDeserialize(factory, Level1SWU, new TestUnion(), Level2TestUnion, StructWithAUnion._Fields.TEST_UNION);
//Level 3 on 3rd field test
- testPartialDeserialize(factory, Level1SWU, new OneOfEach(), Level3OneOfEach, StructWithAUnion.TEST_UNION, TestUnion.STRUCT_FIELD);
+ testPartialDeserialize(factory, Level1SWU, new OneOfEach(), Level3OneOfEach, StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.STRUCT_FIELD);
//Test early termination when traversed path Field.id exceeds the one being searched for
- testPartialDeserialize(factory, Level1SWU, new OneOfEach(), new OneOfEach(), StructWithAUnion.TEST_UNION, TestUnion.I32_FIELD);
-
+ testPartialDeserialize(factory, Level1SWU, new OneOfEach(), new OneOfEach(), StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.I32_FIELD);
+
//Test that readStructBegin isn't called on primitive
- testPartialDeserialize(factory, pts, new Backwards(), bw, PrimitiveThenStruct.BW);
+ testPartialDeserialize(factory, pts, new Backwards(), bw, PrimitiveThenStruct._Fields.BW);
}
}
- public static void testPartialDeserialize(TProtocolFactory protocolFactory, TBase input, TBase output, TBase expected, int ... fieldIdPath) throws TException {
+ public static void testPartialDeserialize(TProtocolFactory protocolFactory, TBase input, TBase output, TBase expected, TFieldIdEnum ... fieldIdPath) throws TException {
byte[] record = new TSerializer(protocolFactory).serialize(input);
new TDeserializer(protocolFactory).partialDeserialize(output, record, fieldIdPath);
if(!output.equals(expected))
diff --git a/lib/java/test/org/apache/thrift/test/ReadStruct.java b/lib/java/test/org/apache/thrift/test/ReadStruct.java
index 2dc042c..ef36f4d 100644
--- a/lib/java/test/org/apache/thrift/test/ReadStruct.java
+++ b/lib/java/test/org/apache/thrift/test/ReadStruct.java
@@ -35,28 +35,27 @@
System.out.println("usage: java -cp build/classes org.apache.thrift.test.ReadStruct filename proto_factory_class");
System.out.println("Read in an instance of CompactProtocolTestStruct from 'file', making sure that it is equivalent to Fixtures.compactProtoTestStruct. Use a protocol from 'proto_factory_class'.");
}
-
+
TTransport trans = new TIOStreamTransport(new BufferedInputStream(new FileInputStream(args[0])));
-
+
TProtocolFactory factory = (TProtocolFactory)Class.forName(args[1]).newInstance();
-
+
TProtocol proto = factory.getProtocol(trans);
-
+
CompactProtoTestStruct cpts = new CompactProtoTestStruct();
-
- for (Integer fid : CompactProtoTestStruct.metaDataMap.keySet()) {
+
+ for (CompactProtoTestStruct._Fields fid : CompactProtoTestStruct.metaDataMap.keySet()) {
cpts.setFieldValue(fid, null);
}
-
+
cpts.read(proto);
-
+
if (cpts.equals(Fixtures.compactProtoTestStruct)) {
System.out.println("Object verified successfully!");
} else {
System.out.println("Object failed verification!");
System.out.println("Expected: " + Fixtures.compactProtoTestStruct + " but got " + cpts);
}
-
}
}
diff --git a/lib/java/test/org/apache/thrift/test/UnionTest.java b/lib/java/test/org/apache/thrift/test/UnionTest.java
index 04716c6..cb69063 100644
--- a/lib/java/test/org/apache/thrift/test/UnionTest.java
+++ b/lib/java/test/org/apache/thrift/test/UnionTest.java
@@ -32,18 +32,18 @@
throw new RuntimeException("unset union didn't return null for value");
}
- union = new TestUnion(TestUnion.I32_FIELD, 25);
+ union = new TestUnion(TestUnion._Fields.I32_FIELD, 25);
if ((Integer)union.getFieldValue() != 25) {
throw new RuntimeException("set i32 field didn't come out as planned");
}
- if ((Integer)union.getFieldValue(TestUnion.I32_FIELD) != 25) {
+ if ((Integer)union.getFieldValue(TestUnion._Fields.I32_FIELD) != 25) {
throw new RuntimeException("set i32 field didn't come out of TBase getFieldValue");
}
try {
- union.getFieldValue(TestUnion.STRING_FIELD);
+ union.getFieldValue(TestUnion._Fields.STRING_FIELD);
throw new RuntimeException("was expecting an exception around wrong set field");
} catch (IllegalArgumentException e) {
// cool!
@@ -73,21 +73,21 @@
public static void testEquality() throws Exception {
- TestUnion union = new TestUnion(TestUnion.I32_FIELD, 25);
+ TestUnion union = new TestUnion(TestUnion._Fields.I32_FIELD, 25);
- TestUnion otherUnion = new TestUnion(TestUnion.STRING_FIELD, "blah!!!");
+ TestUnion otherUnion = new TestUnion(TestUnion._Fields.STRING_FIELD, "blah!!!");
if (union.equals(otherUnion)) {
throw new RuntimeException("shouldn't be equal");
}
- otherUnion = new TestUnion(TestUnion.I32_FIELD, 400);
+ otherUnion = new TestUnion(TestUnion._Fields.I32_FIELD, 400);
if (union.equals(otherUnion)) {
throw new RuntimeException("shouldn't be equal");
}
- otherUnion = new TestUnion(TestUnion.OTHER_I32_FIELD, 25);
+ otherUnion = new TestUnion(TestUnion._Fields.OTHER_I32_FIELD, 25);
if (union.equals(otherUnion)) {
throw new RuntimeException("shouldn't be equal");
@@ -96,7 +96,7 @@
public static void testSerialization() throws Exception {
- TestUnion union = new TestUnion(TestUnion.I32_FIELD, 25);
+ TestUnion union = new TestUnion(TestUnion._Fields.I32_FIELD, 25);
TMemoryBuffer buf = new TMemoryBuffer(0);
TProtocol proto = new TBinaryProtocol(buf);