THRIFT-774. java: TDeserializer should provide a partialDeserialize method for primitive types
This patch adds partialDeserialize* methods for each of the supported Thrift primitives.
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@943679 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/java/src/org/apache/thrift/TDeserializer.java b/lib/java/src/org/apache/thrift/TDeserializer.java
index 3c43a45..3e9c7c3 100644
--- a/lib/java/src/org/apache/thrift/TDeserializer.java
+++ b/lib/java/src/org/apache/thrift/TDeserializer.java
@@ -91,54 +91,19 @@
/**
* Deserialize only a single Thrift object (addressed by recursively using field id)
- * from a byte record.
- * @param record The object to read from
+ * from a byte record.
* @param tb The object to read into
- * @param fieldIdPath The FieldId's that define a path tb
+ * @param bytes The serialized object to read from
+ * @param fieldIdPathFirst First of the FieldId's that define a path tb
+ * @param fieldIdPathRest The rest FieldId's that define a path tb
* @throws TException
*/
- public void partialDeserialize(TBase tb, byte[] bytes, TFieldIdEnum ... fieldIdPath) throws TException {
+ public void partialDeserialize(TBase tb, byte[] bytes, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
try {
- // if there are no elements in the path, then the user is looking for the
- // regular deserialize method
- // TODO: it might be nice not to have to do this check every time to save
- // some performance.
- if (fieldIdPath.length == 0) {
- deserialize(tb, bytes);
- return;
- }
-
- trans_.reset(bytes);
-
- // index into field ID path being currently searched for
- int curPathIndex = 0;
-
- protocol_.readStructBegin();
-
- while (curPathIndex < fieldIdPath.length) {
- TField field = protocol_.readFieldBegin();
- // we can stop searching if we either see a stop or we go past the field
- // id we're looking for (since fields should now be serialized in asc
- // order).
- if (field.type == TType.STOP || field.id > fieldIdPath[curPathIndex].getThriftFieldId()) {
- return;
- }
-
- if (field.id != fieldIdPath[curPathIndex].getThriftFieldId()) {
- // Not the field we're looking for. Skip field.
- TProtocolUtil.skip(protocol_, field.type);
- protocol_.readFieldEnd();
- } else {
- // This field is the next step in the path. Step into field.
- curPathIndex++;
- if (curPathIndex < fieldIdPath.length) {
- protocol_.readStructBegin();
- }
- }
- }
-
- // when this line is reached, iprot will be positioned at the start of tb.
- tb.read(protocol_);
+ if (locateField(bytes, fieldIdPathFirst, fieldIdPathRest) != null) {
+ // if this line is reached, iprot will be positioned at the start of tb.
+ tb.read(protocol_);
+ }
} catch (Exception e) {
throw new TException(e);
} finally {
@@ -147,6 +112,223 @@
}
/**
+ * Deserialize only a boolean field (addressed by recursively using field id)
+ * from a byte record.
+ * @param bytes The serialized object to read from
+ * @param fieldIdPathFirst First of the FieldId's that define a path to a boolean field
+ * @param fieldIdPathRest The rest FieldId's that define a path to a boolean field
+ * @throws TException
+ */
+ public Boolean partialDeserializeBool(byte[] bytes, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
+ return (Boolean) partialDeserializeField(TType.BOOL, bytes, fieldIdPathFirst, fieldIdPathRest);
+ }
+
+ /**
+ * Deserialize only a byte field (addressed by recursively using field id)
+ * from a byte record.
+ * @param bytes The serialized object to read from
+ * @param fieldIdPathFirst First of the FieldId's that define a path to a byte field
+ * @param fieldIdPathRest The rest FieldId's that define a path to a byte field
+ * @throws TException
+ */
+ public Byte partialDeserializeByte(byte[] bytes, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
+ return (Byte) partialDeserializeField(TType.BYTE, bytes, fieldIdPathFirst, fieldIdPathRest);
+ }
+
+ /**
+ * Deserialize only a double field (addressed by recursively using field id)
+ * from a byte record.
+ * @param bytes The serialized object to read from
+ * @param fieldIdPathFirst First of the FieldId's that define a path to a double field
+ * @param fieldIdPathRest The rest FieldId's that define a path to a double field
+ * @throws TException
+ */
+ public Double partialDeserializeDouble(byte[] bytes, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
+ return (Double) partialDeserializeField(TType.DOUBLE, bytes, fieldIdPathFirst, fieldIdPathRest);
+ }
+
+ /**
+ * Deserialize only an i16 field (addressed by recursively using field id)
+ * from a byte record.
+ * @param bytes The serialized object to read from
+ * @param fieldIdPathFirst First of the FieldId's that define a path to an i16 field
+ * @param fieldIdPathRest The rest FieldId's that define a path to an i16 field
+ * @throws TException
+ */
+ public Short partialDeserializeI16(byte[] bytes, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
+ return (Short) partialDeserializeField(TType.I16, bytes, fieldIdPathFirst, fieldIdPathRest);
+ }
+
+ /**
+ * Deserialize only an i32 field (addressed by recursively using field id)
+ * from a byte record.
+ * @param bytes The serialized object to read from
+ * @param fieldIdPathFirst First of the FieldId's that define a path to an i32 field
+ * @param fieldIdPathRest The rest FieldId's that define a path to an i32 field
+ * @throws TException
+ */
+ public Integer partialDeserializeI32(byte[] bytes, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
+ return (Integer) partialDeserializeField(TType.I32, bytes, fieldIdPathFirst, fieldIdPathRest);
+ }
+
+ /**
+ * Deserialize only an i64 field (addressed by recursively using field id)
+ * from a byte record.
+ * @param bytes The serialized object to read from
+ * @param fieldIdPathFirst First of the FieldId's that define a path to an i64 field
+ * @param fieldIdPathRest The rest FieldId's that define a path to an i64 field
+ * @throws TException
+ */
+ public Long partialDeserializeI64(byte[] bytes, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
+ return (Long) partialDeserializeField(TType.I64, bytes, fieldIdPathFirst, fieldIdPathRest);
+ }
+
+ /**
+ * Deserialize only a string field (addressed by recursively using field id)
+ * from a byte record.
+ * @param bytes The serialized object to read from
+ * @param fieldIdPathFirst First of the FieldId's that define a path to a string field
+ * @param fieldIdPathRest The rest FieldId's that define a path to a string field
+ * @throws TException
+ */
+ public String partialDeserializeString(byte[] bytes, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
+ return (String) partialDeserializeField(TType.STRING, bytes, fieldIdPathFirst, fieldIdPathRest);
+ }
+
+ /**
+ * Deserialize only a binary field (addressed by recursively using field id)
+ * from a byte record.
+ * @param bytes The serialized object to read from
+ * @param fieldIdPathFirst First of the FieldId's that define a path to a binary field
+ * @param fieldIdPathRest The rest FieldId's that define a path to a binary field
+ * @throws TException
+ */
+ public byte[] partialDeserializeByteArray(byte[] bytes, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
+ // TType does not have binary, so we use the arbitrary num 100
+ return (byte[]) partialDeserializeField((byte)100, bytes, fieldIdPathFirst, fieldIdPathRest);
+ }
+
+ /**
+ * Deserialize only the id of the field set in a TUnion (addressed by recursively using field id)
+ * from a byte record.
+ * @param bytes The serialized object to read from
+ * @param fieldIdPathFirst First of the FieldId's that define a path to a TUnion
+ * @param fieldIdPathRest The rest FieldId's that define a path to a TUnion
+ * @throws TException
+ */
+ public Short partialDeserializeSetFieldIdInUnion(byte[] bytes, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
+ try {
+ TField field = locateField(bytes, fieldIdPathFirst, fieldIdPathRest);
+ if (field != null){
+ protocol_.readStructBegin(); // The Union
+ return protocol_.readFieldBegin().id; // The field set in the union
+ }
+ return null;
+ } catch (Exception e) {
+ throw new TException(e);
+ } finally {
+ protocol_.reset();
+ }
+ }
+
+ private Object partialDeserializeField(byte ttype, byte[] bytes, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
+ try {
+ TField field = locateField(bytes, fieldIdPathFirst, fieldIdPathRest);
+ if (field != null) {
+ // if this point is reached, iprot will be positioned at the start of the field.
+ switch(ttype){
+ case TType.BOOL:
+ if (field.type == TType.BOOL){
+ return protocol_.readBool();
+ }
+ break;
+ case TType.BYTE:
+ if (field.type == TType.BYTE) {
+ return protocol_.readByte();
+ }
+ break;
+ case TType.DOUBLE:
+ if (field.type == TType.DOUBLE) {
+ return protocol_.readDouble();
+ }
+ break;
+ case TType.I16:
+ if (field.type == TType.I16) {
+ return protocol_.readI16();
+ }
+ break;
+ case TType.I32:
+ if (field.type == TType.I32) {
+ return protocol_.readI32();
+ }
+ break;
+ case TType.I64:
+ if (field.type == TType.I64) {
+ return protocol_.readI64();
+ }
+ break;
+ case TType.STRING:
+ if (field.type == TType.STRING) {
+ return protocol_.readString();
+ }
+ break;
+ case 100: // hack to differentiate between string and binary
+ if (field.type == TType.STRING) {
+ return protocol_.readBinary();
+ }
+ break;
+ }
+ }
+ return null;
+ } catch (Exception e) {
+ throw new TException(e);
+ } finally {
+ protocol_.reset();
+ }
+ }
+
+ private TField locateField(byte[] bytes, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
+ trans_.reset(bytes);
+
+ TFieldIdEnum[] fieldIdPath= new TFieldIdEnum[fieldIdPathRest.length + 1];
+ fieldIdPath[0] = fieldIdPathFirst;
+ for (int i = 0; i < fieldIdPathRest.length; i++){
+ fieldIdPath[i + 1] = fieldIdPathRest[i];
+ }
+
+ // index into field ID path being currently searched for
+ int curPathIndex = 0;
+
+ // this will be the located field, or null if it is not located
+ TField field = null;
+
+ protocol_.readStructBegin();
+
+ while (curPathIndex < fieldIdPath.length) {
+ field = protocol_.readFieldBegin();
+ // we can stop searching if we either see a stop or we go past the field
+ // id we're looking for (since fields should now be serialized in asc
+ // order).
+ if (field.type == TType.STOP || field.id > fieldIdPath[curPathIndex].getThriftFieldId()) {
+ return null;
+ }
+
+ if (field.id != fieldIdPath[curPathIndex].getThriftFieldId()) {
+ // Not the field we're looking for. Skip field.
+ TProtocolUtil.skip(protocol_, field.type);
+ protocol_.readFieldEnd();
+ } else {
+ // This field is the next step in the path. Step into field.
+ curPathIndex++;
+ if (curPathIndex < fieldIdPath.length) {
+ protocol_.readStructBegin();
+ }
+ }
+ }
+ return field;
+ }
+
+ /**
* Deserialize the Thrift object from a Java string, using the default JVM
* charset encoding.
*
diff --git a/lib/java/test/org/apache/thrift/TestTDeserializer.java b/lib/java/test/org/apache/thrift/TestTDeserializer.java
index 4bb3fe2..aae2ee4 100644
--- a/lib/java/test/org/apache/thrift/TestTDeserializer.java
+++ b/lib/java/test/org/apache/thrift/TestTDeserializer.java
@@ -18,6 +18,10 @@
*/
package org.apache.thrift;
+import java.util.Arrays;
+
+import junit.framework.TestCase;
+
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TCompactProtocol;
import org.apache.thrift.protocol.TJSONProtocol;
@@ -29,8 +33,6 @@
import thrift.test.StructWithAUnion;
import thrift.test.TestUnion;
-import junit.framework.TestCase;
-
public class TestTDeserializer extends TestCase {
private static final TProtocolFactory[] PROTOCOLS = new TProtocolFactory[] {
@@ -43,40 +45,82 @@
//Root:StructWithAUnion
// 1:Union
// 1.3:OneOfEach
- OneOfEach Level3OneOfEach = Fixtures.oneOfEach;
- TestUnion Level2TestUnion = new TestUnion(TestUnion._Fields.STRUCT_FIELD, Level3OneOfEach);
- StructWithAUnion Level1SWU = new StructWithAUnion(Level2TestUnion);
+ OneOfEach level3OneOfEach = Fixtures.oneOfEach;
+ TestUnion level2TestUnion = new TestUnion(TestUnion._Fields.STRUCT_FIELD, level3OneOfEach);
+ StructWithAUnion level1SWU = new StructWithAUnion(level2TestUnion);
Backwards bw = new Backwards(2, 1);
PrimitiveThenStruct pts = new PrimitiveThenStruct(12345, 67890, bw);
for (TProtocolFactory factory : PROTOCOLS) {
- //Full deserialization test
- testPartialDeserialize(factory, Level1SWU, new StructWithAUnion(), Level1SWU);
//Level 2 test
- testPartialDeserialize(factory, Level1SWU, new TestUnion(), Level2TestUnion, StructWithAUnion._Fields.TEST_UNION);
+ testPartialDeserialize(factory, level1SWU, new TestUnion(), level2TestUnion, StructWithAUnion._Fields.TEST_UNION);
//Level 3 on 3rd field test
- testPartialDeserialize(factory, Level1SWU, new OneOfEach(), Level3OneOfEach, StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.STRUCT_FIELD);
+ testPartialDeserialize(factory, level1SWU, new OneOfEach(), level3OneOfEach, StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.STRUCT_FIELD);
//Test early termination when traversed path Field.id exceeds the one being searched for
- testPartialDeserialize(factory, Level1SWU, new OneOfEach(), new OneOfEach(), StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.I32_FIELD);
+ testPartialDeserialize(factory, level1SWU, new OneOfEach(), new OneOfEach(), StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.I32_FIELD);
//Test that readStructBegin isn't called on primitive
testPartialDeserialize(factory, pts, new Backwards(), bw, PrimitiveThenStruct._Fields.BW);
+
+ //Test primitive types
+ TDeserializer deserializer = new TDeserializer(factory);
+
+ Boolean expectedBool = level3OneOfEach.isIm_true();
+ Boolean resultBool = deserializer.partialDeserializeBool(serialize(level1SWU, factory), StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.STRUCT_FIELD, OneOfEach._Fields.IM_TRUE);
+ assertEquals(expectedBool, resultBool);
+
+ Byte expectedByte = level3OneOfEach.getA_bite();
+ Byte resultByte = deserializer.partialDeserializeByte(serialize(level1SWU, factory), StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.STRUCT_FIELD, OneOfEach._Fields.A_BITE);
+ assertEquals(expectedByte, resultByte);
+
+ Double expectedDouble = level3OneOfEach.getDouble_precision();
+ Double resultDouble = deserializer.partialDeserializeDouble(serialize(level1SWU, factory), StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.STRUCT_FIELD, OneOfEach._Fields.DOUBLE_PRECISION);
+ assertEquals(expectedDouble, resultDouble);
+
+ Short expectedI16 = level3OneOfEach.getInteger16();
+ Short resultI16 = deserializer.partialDeserializeI16(serialize(level1SWU, factory), StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.STRUCT_FIELD, OneOfEach._Fields.INTEGER16);
+ assertEquals(expectedI16, resultI16);
+
+ Integer expectedI32 = level3OneOfEach.getInteger32();
+ Integer resultI32 = deserializer.partialDeserializeI32(serialize(level1SWU, factory), StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.STRUCT_FIELD, OneOfEach._Fields.INTEGER32);
+ assertEquals(expectedI32, resultI32);
+
+ Long expectedI64 = level3OneOfEach.getInteger64();
+ Long resultI64= deserializer.partialDeserializeI64(serialize(level1SWU, factory), StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.STRUCT_FIELD, OneOfEach._Fields.INTEGER64);
+ assertEquals(expectedI64, resultI64);
+
+ String expectedString = level3OneOfEach.getSome_characters();
+ String resultString = deserializer.partialDeserializeString(serialize(level1SWU, factory), StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.STRUCT_FIELD, OneOfEach._Fields.SOME_CHARACTERS);
+ assertEquals(expectedString, resultString);
+
+ byte[] expectedBinary = level3OneOfEach.getBase64();
+ byte[] resultBinary = deserializer.partialDeserializeByteArray(serialize(level1SWU, factory), StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.STRUCT_FIELD, OneOfEach._Fields.BASE64);
+ assertEquals(expectedBinary.length, resultBinary.length);
+ assertTrue(Arrays.equals(expectedBinary, resultBinary));
+
+ // Test field id in Union
+ short id = deserializer.partialDeserializeSetFieldIdInUnion(serialize(level1SWU, factory), StructWithAUnion._Fields.TEST_UNION);
+ assertEquals(level2TestUnion.getSetField().getThriftFieldId(), id);
}
}
- public static void testPartialDeserialize(TProtocolFactory protocolFactory, TBase input, TBase output, TBase expected, TFieldIdEnum ... fieldIdPath) throws TException {
- byte[] record = new TSerializer(protocolFactory).serialize(input);
+ public static void testPartialDeserialize(TProtocolFactory protocolFactory, TBase input, TBase output, TBase expected, TFieldIdEnum fieldIdPathFirst, TFieldIdEnum ... fieldIdPathRest) throws TException {
+ byte[] record = serialize(input, protocolFactory);
TDeserializer deserializer = new TDeserializer(protocolFactory);
for (int i = 0; i < 2; i++) {
TBase outputCopy = output.deepCopy();
- deserializer.partialDeserialize(outputCopy, record, fieldIdPath);
+ deserializer.partialDeserialize(outputCopy, record, fieldIdPathFirst, fieldIdPathRest);
assertEquals("on attempt " + i + ", with " + protocolFactory.toString()
+ ", expected " + expected + " but got " + outputCopy,
expected, outputCopy);
}
}
+
+ private static byte[] serialize(TBase input, TProtocolFactory protocolFactory) throws TException{
+ return new TSerializer(protocolFactory).serialize(input);
+ }
}