THRIFT-1339. java: Extend Tuple Protocol to TUnions
This patch implements TupleProtocol (and general Scheme support) to TUnion descendants.
Patch: Armaan Sarkar
git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1173418 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/cpp/src/generate/t_java_generator.cc b/compiler/cpp/src/generate/t_java_generator.cc
index db33e93..85090b8 100644
--- a/compiler/cpp/src/generate/t_java_generator.cc
+++ b/compiler/cpp/src/generate/t_java_generator.cc
@@ -142,8 +142,10 @@
void generate_union_is_set_methods(ofstream& out, t_struct* tstruct);
void generate_union_abstract_methods(ofstream& out, t_struct* tstruct);
void generate_check_type(ofstream& out, t_struct* tstruct);
- void generate_read_value(ofstream& out, t_struct* tstruct);
- void generate_write_value(ofstream& out, t_struct* tstruct);
+ void generate_standard_scheme_read_value(ofstream& out, t_struct* tstruct);
+ void generate_standard_scheme_write_value(ofstream& out, t_struct* tstruct);
+ void generate_tuple_scheme_read_value(ofstream& out, t_struct* tstruct);
+ void generate_tuple_scheme_write_value(ofstream& out, t_struct* tstruct);
void generate_get_field_desc(ofstream& out, t_struct* tstruct);
void generate_get_struct_desc(ofstream& out, t_struct* tstruct);
void generate_get_field_name(ofstream& out, t_struct* tstruct);
@@ -906,9 +908,13 @@
void t_java_generator::generate_union_abstract_methods(ofstream& out, t_struct* tstruct) {
generate_check_type(out, tstruct);
out << endl;
- generate_read_value(out, tstruct);
+ generate_standard_scheme_read_value(out, tstruct);
out << endl;
- generate_write_value(out, tstruct);
+ generate_standard_scheme_write_value(out, tstruct);
+ out << endl;
+ generate_tuple_scheme_read_value(out, tstruct);
+ out << endl;
+ generate_tuple_scheme_write_value(out, tstruct);
out << endl;
generate_get_field_desc(out, tstruct);
out << endl;
@@ -954,9 +960,9 @@
indent(out) << "}" << endl;
}
-void t_java_generator::generate_read_value(ofstream& out, t_struct* tstruct) {
+void t_java_generator::generate_standard_scheme_read_value(ofstream& out, t_struct* tstruct) {
indent(out) << "@Override" << endl;
- indent(out) << "protected Object readValue(org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TField field) throws org.apache.thrift.TException {" << endl;
+ indent(out) << "protected Object standardSchemeReadValue(org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TField field) throws org.apache.thrift.TException {" << endl;
indent_up();
@@ -995,8 +1001,7 @@
indent_down();
indent(out) << "} else {" << endl;
- indent_up();
- indent(out) << "org.apache.thrift.protocol.TProtocolUtil.skip(iprot, field.type);" << endl;
+ indent_up();
indent(out) << "return null;" << endl;
indent_down();
indent(out) << "}" << endl;
@@ -1005,9 +1010,9 @@
indent(out) << "}" << endl;
}
-void t_java_generator::generate_write_value(ofstream& out, t_struct* tstruct) {
+void t_java_generator::generate_standard_scheme_write_value(ofstream& out, t_struct* tstruct) {
indent(out) << "@Override" << endl;
- indent(out) << "protected void writeValue(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException {" << endl;
+ indent(out) << "protected void standardSchemeWriteValue(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException {" << endl;
indent_up();
@@ -1040,6 +1045,83 @@
indent(out) << "}" << endl;
}
+void t_java_generator::generate_tuple_scheme_read_value(ofstream& out, t_struct* tstruct) {
+ indent(out) << "@Override" << endl;
+ indent(out) << "protected Object tupleSchemeReadValue(org.apache.thrift.protocol.TProtocol iprot, short fieldID) throws org.apache.thrift.TException {" << endl;
+
+ indent_up();
+
+ indent(out) << "_Fields setField = _Fields.findByThriftId(fieldID);" << endl;
+ indent(out) << "if (setField != null) {" << endl;
+ indent_up();
+ indent(out) << "switch (setField) {" << endl;
+ indent_up();
+
+ const vector<t_field*>& members = tstruct->get_members();
+ vector<t_field*>::const_iterator m_iter;
+
+ for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+ t_field* field = (*m_iter);
+
+ indent(out) << "case " << constant_name(field->get_name()) << ":" << endl;
+ indent_up();
+ indent(out) << type_name(field->get_type(), true, false) << " " << field->get_name() << ";" << endl;
+ generate_deserialize_field(out, field, "");
+ indent(out) << "return " << field->get_name() << ";" << endl;
+ indent_down();
+ }
+
+ indent(out) << "default:" << endl;
+ indent(out) << " throw new IllegalStateException(\"setField wasn't null, but didn't match any of the case statements!\");" << endl;
+
+ indent_down();
+ indent(out) << "}" << endl;
+
+ indent_down();
+ indent(out) << "} else {" << endl;
+ indent_up();
+ indent(out) << "return null;" << endl;
+ indent_down();
+ indent(out) << "}" << endl;
+ indent_down();
+ indent(out) << "}" << endl;
+}
+
+void t_java_generator::generate_tuple_scheme_write_value(ofstream& out, t_struct* tstruct) {
+ indent(out) << "@Override" << endl;
+ indent(out) << "protected void tupleSchemeWriteValue(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException {" << endl;
+
+ indent_up();
+
+ indent(out) << "switch (setField_) {" << endl;
+ indent_up();
+
+ const vector<t_field*>& members = tstruct->get_members();
+ vector<t_field*>::const_iterator m_iter;
+
+ for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+ t_field* field = (*m_iter);
+
+ indent(out) << "case " << constant_name(field->get_name()) << ":" << endl;
+ indent_up();
+ indent(out) << type_name(field->get_type(), true, false) << " " << field->get_name()
+ << " = (" << type_name(field->get_type(), true, false) << ")value_;" << endl;
+ generate_serialize_field(out, field, "");
+ indent(out) << "return;" << endl;
+ indent_down();
+ }
+
+ indent(out) << "default:" << endl;
+ indent(out) << " throw new IllegalStateException(\"Cannot write union with unknown field \" + setField_);" << endl;
+
+ indent_down();
+ indent(out) << "}" << endl;
+
+ indent_down();
+
+ indent(out) << "}" << endl;
+}
+
void t_java_generator::generate_get_field_desc(ofstream& out, t_struct* tstruct) {
indent(out) << "@Override" << endl;
indent(out) << "protected org.apache.thrift.protocol.TField getFieldDesc(_Fields setField) {" << endl;
diff --git a/lib/java/src/org/apache/thrift/TUnion.java b/lib/java/src/org/apache/thrift/TUnion.java
index 240163f..0173f9b 100644
--- a/lib/java/src/org/apache/thrift/TUnion.java
+++ b/lib/java/src/org/apache/thrift/TUnion.java
@@ -25,10 +25,15 @@
import java.util.Set;
import java.nio.ByteBuffer;
+import org.apache.thrift.TUnion.TUnionStandardScheme;
import org.apache.thrift.protocol.TField;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolException;
import org.apache.thrift.protocol.TStruct;
+import org.apache.thrift.scheme.IScheme;
+import org.apache.thrift.scheme.SchemeFactory;
+import org.apache.thrift.scheme.StandardScheme;
+import org.apache.thrift.scheme.TupleScheme;
public abstract class TUnion<T extends TUnion<?,?>, F extends TFieldIdEnum> implements TBase<T, F> {
@@ -39,6 +44,12 @@
setField_ = null;
value_ = null;
}
+
+ private static final Map<Class<? extends IScheme>, SchemeFactory> schemes = new HashMap<Class<? extends IScheme>, SchemeFactory>();
+ static {
+ schemes.put(StandardScheme.class, new TUnionStandardSchemeFactory());
+ schemes.put(TupleScheme.class, new TUnionTupleSchemeFactory());
+ }
protected TUnion(F setField, Object value) {
setFieldValue(setField, value);
@@ -125,24 +136,7 @@
}
public void read(TProtocol iprot) throws TException {
- setField_ = null;
- value_ = null;
-
- iprot.readStructBegin();
-
- TField field = iprot.readFieldBegin();
-
- value_ = readValue(iprot, field);
- if (value_ != null) {
- setField_ = enumForId(field.id);
- }
-
- iprot.readFieldEnd();
- // this is so that we will eat the stop byte. we could put a check here to
- // make sure that it actually *is* the stop byte, but it's faster to do it
- // this way.
- iprot.readFieldBegin();
- iprot.readStructEnd();
+ schemes.get(iprot.getScheme()).getScheme().read(iprot, this);
}
public void setFieldValue(F fieldId, Object value) {
@@ -156,15 +150,7 @@
}
public void write(TProtocol oprot) throws TException {
- if (getSetField() == null || getFieldValue() == null) {
- throw new TProtocolException("Cannot write a TUnion with no set value!");
- }
- oprot.writeStructBegin(getStructDesc());
- oprot.writeFieldBegin(getFieldDesc(setField_));
- writeValue(oprot);
- oprot.writeFieldEnd();
- oprot.writeFieldStop();
- oprot.writeStructEnd();
+ schemes.get(oprot.getScheme()).getScheme().write(oprot, this);
}
/**
@@ -181,9 +167,11 @@
* @param field
* @return read Object based on the field header, as specified by the argument.
*/
- protected abstract Object readValue(TProtocol iprot, TField field) throws TException;
-
- protected abstract void writeValue(TProtocol oprot) throws TException;
+ protected abstract Object standardSchemeReadValue(TProtocol iprot, TField field) throws TException;
+ protected abstract void standardSchemeWriteValue(TProtocol oprot) throws TException;
+
+ protected abstract Object tupleSchemeReadValue(TProtocol iprot, short fieldID) throws TException;
+ protected abstract void tupleSchemeWriteValue(TProtocol oprot) throws TException;
protected abstract TStruct getStructDesc();
@@ -216,4 +204,77 @@
this.setField_ = null;
this.value_ = null;
}
+
+ private static class TUnionStandardSchemeFactory implements SchemeFactory {
+ public TUnionStandardScheme getScheme() {
+ return new TUnionStandardScheme();
+ }
+ }
+
+ public static class TUnionStandardScheme extends StandardScheme<TUnion> {
+
+ @Override
+ public void read(TProtocol iprot, TUnion struct) throws TException {
+ struct.setField_ = null;
+ struct.value_ = null;
+
+ iprot.readStructBegin();
+
+ TField field = iprot.readFieldBegin();
+
+ struct.value_ = struct.standardSchemeReadValue(iprot, field);
+ if (struct.value_ != null) {
+ struct.setField_ = struct.enumForId(field.id);
+ }
+
+ iprot.readFieldEnd();
+ // this is so that we will eat the stop byte. we could put a check here to
+ // make sure that it actually *is* the stop byte, but it's faster to do it
+ // this way.
+ iprot.readFieldBegin();
+ iprot.readStructEnd();
+ }
+
+ @Override
+ public void write(TProtocol oprot, TUnion struct) throws TException {
+ if (struct.getSetField() == null || struct.getFieldValue() == null) {
+ throw new TProtocolException("Cannot write a TUnion with no set value!");
+ }
+ oprot.writeStructBegin(struct.getStructDesc());
+ oprot.writeFieldBegin(struct.getFieldDesc(struct.setField_));
+ struct.standardSchemeWriteValue(oprot);
+ oprot.writeFieldEnd();
+ oprot.writeFieldStop();
+ oprot.writeStructEnd();
+ }
+ }
+
+ private static class TUnionTupleSchemeFactory implements SchemeFactory {
+ public TUnionStandardScheme getScheme() {
+ return new TUnionStandardScheme();
+ }
+ }
+
+ public static class TUnionTupleScheme extends TupleScheme<TUnion> {
+
+ @Override
+ public void read(TProtocol iprot, TUnion struct) throws TException {
+ struct.setField_ = null;
+ struct.value_ = null;
+ short fieldID = iprot.readI16();
+ struct.value_ = struct.tupleSchemeReadValue(iprot, fieldID);
+ if (struct.value_ != null) {
+ struct.setField_ = struct.enumForId(fieldID);
+ }
+ }
+
+ @Override
+ public void write(TProtocol oprot, TUnion struct) throws TException {
+ if (struct.getSetField() == null || struct.getFieldValue() == null) {
+ throw new TProtocolException("Cannot write a TUnion with no set value!");
+ }
+ oprot.writeI16(struct.setField_.getThriftFieldId());
+ struct.tupleSchemeWriteValue(oprot);
+ }
+ }
}
diff --git a/lib/java/test/org/apache/thrift/TestTUnion.java b/lib/java/test/org/apache/thrift/TestTUnion.java
index e9d9825..f1e6f0e 100644
--- a/lib/java/test/org/apache/thrift/TestTUnion.java
+++ b/lib/java/test/org/apache/thrift/TestTUnion.java
@@ -34,6 +34,7 @@
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.protocol.TTupleProtocol;
import org.apache.thrift.transport.TMemoryBuffer;
import thrift.test.ComparableUnion;
@@ -185,6 +186,41 @@
swau.write(proto);
new Empty().read(proto);
}
+
+ public void testTupleProtocolSerialization () throws Exception {
+ TestUnion union = new TestUnion(TestUnion._Fields.I32_FIELD, 25);
+ union.setI32_set(Collections.singleton(42));
+
+ TMemoryBuffer buf = new TMemoryBuffer(0);
+ TProtocol proto = new TTupleProtocol(buf);
+
+ union.write(proto);
+
+ TestUnion u2 = new TestUnion();
+
+ u2.read(proto);
+
+ assertEquals(u2, union);
+
+ StructWithAUnion swau = new StructWithAUnion(u2);
+
+ buf = new TMemoryBuffer(0);
+ proto = new TBinaryProtocol(buf);
+
+ swau.write(proto);
+
+ StructWithAUnion swau2 = new StructWithAUnion();
+ assertFalse(swau2.equals(swau));
+ swau2.read(proto);
+ assertEquals(swau2, swau);
+
+ // this should NOT throw an exception.
+ buf = new TMemoryBuffer(0);
+ proto = new TTupleProtocol(buf);
+
+ swau.write(proto);
+ new Empty().read(proto);
+ }
public void testSkip() throws Exception {
TestUnion tu = TestUnion.string_field("string");