THRIFT-947. java: Provide a helper method to determine the TProtocol used to serialize some data.

Patch: Mathias Herberts

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@1024455 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/java/src/org/apache/thrift/protocol/TProtocolUtil.java b/lib/java/src/org/apache/thrift/protocol/TProtocolUtil.java
index 9bf10f6..134eafc 100644
--- a/lib/java/src/org/apache/thrift/protocol/TProtocolUtil.java
+++ b/lib/java/src/org/apache/thrift/protocol/TProtocolUtil.java
@@ -155,4 +155,77 @@
       break;
     }
   }
+
+  /**
+   * Attempt to determine the protocol used to serialize some data.
+   *
+   * The guess is based on known specificities of supported protocols.
+   * In some cases, no guess can be done, in that case we return the
+   * fallback TProtocolFactory.
+   * To be certain to correctly detect the protocol, the first encoded
+   * field should have a field id < 256
+   *
+   * @param data The serialized data to guess the protocol for.
+   * @param fallback The TProtocol to return if no guess can be made.
+   * @return a Class implementing TProtocolFactory which can be used to create a deserializer.
+   */
+  public static TProtocolFactory guessProtocolFactory(byte[] data, TProtocolFactory fallback) {
+    //
+    // If the first and last bytes are opening/closing curly braces we guess the protocol as
+    // being TJSONProtocol.
+    // It could not be a TCompactBinary encoding for a field of type 0xb (Map)
+    // with delta id 7 as the last byte for TCompactBinary is always 0.
+    //
+
+    if ('{' == data[0] && '}' == data[data.length - 1]) {
+      return new TJSONProtocol.Factory();
+    }
+
+    //
+    // If the last byte is not 0, then it cannot be TCompactProtocol, it must be
+    // TBinaryProtocol.
+    //
+
+    if (data[data.length - 1] != 0) {
+      return new TBinaryProtocol.Factory();
+    }
+
+    //
+    // A first byte of value > 16 indicates TCompactProtocol was used, and the first byte
+    // encodes a delta field id (id <= 15) and a field type.
+    //
+
+    if (data[0] > 0x10) {
+      return new TCompactProtocol.Factory();
+    }
+
+    //
+    // If the second byte is 0 then it is a field id < 256 encoded by TBinaryProtocol.
+    // It cannot possibly be TCompactProtocol since a value of 0 would imply a field id
+    // of 0 as the zig zag varint encoding would end.
+    //
+
+    if (data.length > 1 && 0 == data[1]) {
+      return new TBinaryProtocol.Factory();
+    }
+
+    //
+    // If bit 7 of the first byte of the field id is set then we have two choices:
+    // 1. A field id > 63 was encoded with TCompactProtocol.
+    // 2. A field id > 0x7fff (32767) was encoded with TBinaryProtocol and the last byte of the
+    //    serialized data is 0.
+    // Option 2 is impossible since field ids are short and thus limited to 32767.
+    //
+
+    if (data.length > 1 && (data[1] & 0x80) != 0) {
+      return new TCompactProtocol.Factory();
+    }
+
+    //
+    // The remaining case is either a field id <= 63 encoded as TCompactProtocol,
+    // one >= 256 encoded with TBinaryProtocol with a last byte at 0, or an empty structure.
+    // As we cannot really decide, we return the fallback protocol.
+    //
+    return fallback;
+  }
 }
diff --git a/lib/java/test/org/apache/thrift/protocol/TestTProtocolUtil.java b/lib/java/test/org/apache/thrift/protocol/TestTProtocolUtil.java
new file mode 100644
index 0000000..199c707
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/protocol/TestTProtocolUtil.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.thrift.protocol;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
+
+import junit.framework.TestCase;
+
+import org.apache.thrift.Fixtures;
+import org.apache.thrift.TBase;
+import org.apache.thrift.TDeserializer;
+import org.apache.thrift.TException;
+import org.apache.thrift.TSerializer;
+import org.apache.thrift.transport.TMemoryBuffer;
+
+import thrift.test.CompactProtoTestStruct;
+import thrift.test.HolyMoley;
+import thrift.test.Nesting;
+import thrift.test.OneOfEach;
+import thrift.test.Srv;
+import thrift.test.GuessProtocolStruct;
+
+public class TestTProtocolUtil extends TestCase {
+
+  public void testGuessProtocolFactory_JSON() throws Exception {
+
+    byte[] data = "{foo}".getBytes();
+    TProtocolFactory factory = TProtocolUtil.guessProtocolFactory(data, new TCompactProtocol.Factory());
+    assertTrue(factory instanceof TJSONProtocol.Factory);
+
+    // Make sure data serialized with TCompact and which starts with '{'
+    // is not mistakenly guessed as serialized with JSON.
+
+    GuessProtocolStruct s = new GuessProtocolStruct();
+    s.putToMap_field("}","}");
+    byte[] ser = new TSerializer(new TCompactProtocol.Factory()).serialize(s);
+    factory = TProtocolUtil.guessProtocolFactory(ser, new TCompactProtocol.Factory());
+    assertFalse(factory instanceof TJSONProtocol.Factory);
+  }
+
+  public void testGuessProtocolFactory_Binary() throws Exception {
+    // Check that a last byte != 0 is correctly reported as Binary
+
+    byte[] buf = new byte[1];
+    for (int i = 1; i < 256; i++) {
+      buf[0] = (byte) i;
+      TProtocolFactory factory = TProtocolUtil.guessProtocolFactory(buf, new TCompactProtocol.Factory());
+      assertTrue(factory instanceof TBinaryProtocol.Factory);
+    }
+
+    // Check that a second byte set to 0 is reported as Binary
+    buf = new byte[2];
+    TProtocolFactory factory = TProtocolUtil.guessProtocolFactory(buf, new TCompactProtocol.Factory());
+    assertTrue(factory instanceof TBinaryProtocol.Factory);
+  }
+
+  public void testGuessProtocolFactory_Compact() throws Exception {
+    // Check that a first byte > 0x10 is reported as Compact
+    byte[] buf = new byte[3];
+    buf[0] = 0x11; 
+    TProtocolFactory factory = TProtocolUtil.guessProtocolFactory(buf, new TBinaryProtocol.Factory());
+    assertTrue(factory instanceof TCompactProtocol.Factory);
+
+    // Check that second byte >= 0x80 is reported as Compact
+    buf[0] = 0;
+    for (int i = 0x80; i < 0x100; i++) {
+      buf[1] = (byte) i;
+      factory = TProtocolUtil.guessProtocolFactory(buf, new TBinaryProtocol.Factory());
+      assertTrue(factory instanceof TCompactProtocol.Factory);
+    }
+  }
+
+  public void testGuessProtocolFactory_Undecided() throws Exception {
+    byte[] buf = new byte[3];
+    buf[1] = 0x7e;
+    TProtocolFactory factory = TProtocolUtil.guessProtocolFactory(buf, new TSimpleJSONProtocol.Factory());
+    assertTrue(factory instanceof TSimpleJSONProtocol.Factory);
+  }
+}
diff --git a/test/ThriftTest.thrift b/test/ThriftTest.thrift
index 66ec2b5..ce324ef 100644
--- a/test/ThriftTest.thrift
+++ b/test/ThriftTest.thrift
@@ -187,3 +187,7 @@
        1: list<string> strings;
        2: string hello;
 }
+
+struct GuessProtocolStruct {
+  7: map<string,string> map_field,
+}