THRIFT-765. java: Improved string encoding and decoding performance

This patch fixes a regression caused by the previous 'fast' implementation, in particular, dealing with unicode characters that need to be encoded as surrogate pairs. The performance stays about the same.

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@939822 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/java/src/org/apache/thrift/Utf8Helper.java b/lib/java/src/org/apache/thrift/Utf8Helper.java
index e754517..2d3fd26 100644
--- a/lib/java/src/org/apache/thrift/Utf8Helper.java
+++ b/lib/java/src/org/apache/thrift/Utf8Helper.java
@@ -5,15 +5,26 @@
 
   public static final int getByteLength(final String s) {
     int byteLength = 0;
-    int c;
+    int codePoint;
     for (int i = 0; i < s.length(); i++) {
-      c = s.charAt(i);
-      if (c <= 0x007F) {
+      codePoint = s.charAt(i);
+      if (codePoint >= 0x07FF) {
+        codePoint = s.codePointAt(i);
+        if (Character.isSupplementaryCodePoint(codePoint)) {
+          i++;
+        }
+      }
+      if (codePoint >= 0 && codePoint <= 0x007F) {
         byteLength++;
-      } else if (c > 0x07FF) {
+      } else if (codePoint >= 0x80 && codePoint <= 0x07FF) {
+        byteLength += 2;
+      } else if ((codePoint >= 0x0800 && codePoint < 0xD800) || (codePoint > 0xDFFF && codePoint <= 0xFFFD)) {
         byteLength+=3;
+      } else if (codePoint >= 0x10000 && codePoint <= 0x10FFFF) {
+        byteLength+=4;
       } else {
-        byteLength+=2;
+        throw new RuntimeException("Unknown unicode codepoint in string! "
+            + Integer.toHexString(codePoint));
       }
     }
     return byteLength;
@@ -25,62 +36,89 @@
     return buf;
   }
 
-  public static void encode(String s, byte[] buf, int offset) {
+  public static void encode(final String s, final byte[] buf, final int offset) {
     int nextByte = 0;
-    int c;
-    for (int i = 0; i < s.length(); i++) {
-      c = s.charAt(i);
-      if (c <= 0x007F) {
-        buf[offset + nextByte] = (byte)c;
+    int codePoint;
+    final int strLen = s.length();
+    for (int i = 0; i < strLen; i++) {
+      codePoint = s.charAt(i);
+      if (codePoint >= 0x07FF) {
+        codePoint = s.codePointAt(i);
+        if (Character.isSupplementaryCodePoint(codePoint)) {
+          i++;
+        }
+      }
+      if (codePoint <= 0x007F) {
+        buf[offset + nextByte] = (byte)codePoint;
         nextByte++;
-      } else if (c > 0x07FF) {
-        buf[offset + nextByte    ] = (byte)(0xE0 | c >> 12 & 0x0F);
-        buf[offset + nextByte + 1] = (byte)(0x80 | c >>  6 & 0x3F);
-        buf[offset + nextByte + 2] = (byte)(0x80 | c       & 0x3F);
-        nextByte+=3;
-      } else {
-        buf[offset + nextByte    ] = (byte)(0xC0 | c >> 6 & 0x1F);
-        buf[offset + nextByte + 1] = (byte)(0x80 | c      & 0x3F);
+      } else if (codePoint <= 0x7FF) {
+        buf[offset + nextByte    ] = (byte)(0xC0 | ((codePoint >> 6) & 0x1F));
+        buf[offset + nextByte + 1] = (byte)(0x80 | ((codePoint >> 0) & 0x3F));
         nextByte+=2;
+      } else if ((codePoint < 0xD800) || (codePoint > 0xDFFF && codePoint <= 0xFFFD)) {
+        buf[offset + nextByte    ] = (byte)(0xE0 | ((codePoint >> 12) & 0x0F));
+        buf[offset + nextByte + 1] = (byte)(0x80 | ((codePoint >>  6) & 0x3F));
+        buf[offset + nextByte + 2] = (byte)(0x80 | ((codePoint >>  0) & 0x3F));
+        nextByte+=3;
+      } else if (codePoint >= 0x10000 && codePoint <= 0x10FFFF) {
+        buf[offset + nextByte    ] = (byte)(0xF0 | ((codePoint >> 18) & 0x07));
+        buf[offset + nextByte + 1] = (byte)(0x80 | ((codePoint >> 12) & 0x3F));
+        buf[offset + nextByte + 2] = (byte)(0x80 | ((codePoint >>  6) & 0x3F));
+        buf[offset + nextByte + 3] = (byte)(0x80 | ((codePoint >>  0) & 0x3F));
+        nextByte+=4;
+      } else {
+        throw new RuntimeException("Unknown unicode codepoint in string! "
+            + Integer.toHexString(codePoint));
       }
     }
   }
 
   public static String decode(byte[] buf) {
-    return decode(buf, 0, buf.length);
+    char[] charBuf = new char[buf.length];
+    int charsDecoded = decode(buf, 0, buf.length, charBuf);
+    return new String(charBuf, 0, charsDecoded);
   }
 
-  public static String decode(byte[] buf, int offset, int byteLength) {
-    int charCount = 0;
-    char[] chars = new char[byteLength];
-    int c;
-    int byteIndex = offset;
-    int charIndex = 0;
-    while (byteIndex < offset + byteLength) {
-      c = buf[byteIndex++] & 0xFF;
-      switch (c >> 4) {
-        case 0:
-        case 1:
-        case 2:
-        case 3:
-        case 4:
-        case 5:
-        case 6:
-        case 7:
-          chars[charIndex++] = (char) c;
-          break;
-        case 12:
-        case 13:
-          chars[charIndex++] = (char) ((c & 0x1F) << 6 | (buf[byteIndex++] & 0x3F));
-          break;
-        case 14:
-          chars[charIndex++] = (char) ((c & 0x0F) << 12 | (buf[byteIndex++] & 0x3F) << 6 | (buf[byteIndex++] & 0x3F) << 0);
-          break;
+  public static final int UNI_SUR_HIGH_START = 0xD800;
+  public static final int UNI_SUR_HIGH_END = 0xDBFF;
+  public static final int UNI_SUR_LOW_START = 0xDC00;
+  public static final int UNI_SUR_LOW_END = 0xDFFF;
+  public static final int UNI_REPLACEMENT_CHAR = 0xFFFD;
+
+  private static final int HALF_BASE = 0x0010000;
+  private static final long HALF_SHIFT = 10;
+  private static final long HALF_MASK = 0x3FFL;
+
+  public static int decode(final byte[] buf, final int offset, final int byteLength, final char[] charBuf) {
+    int curByteIdx = offset;
+    int endByteIdx = offset + byteLength;
+
+    int curCharIdx = 0;
+
+    while (curByteIdx < endByteIdx) {
+      final int b = buf[curByteIdx++]&0xff;
+      final int ch;
+
+      if (b < 0xC0) {
+        ch = b;
+      } else if (b < 0xE0) {
+        ch = ((b & 0x1F) << 6) + (buf[curByteIdx++] & 0x3F);
+      } else if (b < 0xf0) {
+        ch = ((b & 0xF) << 12) + ((buf[curByteIdx++] & 0x3F) << 6) + (buf[curByteIdx++] & 0x3F);
+      } else {
+        ch = ((b & 0x7) << 18) + ((buf[curByteIdx++]& 0x3F) << 12) + ((buf[curByteIdx++] & 0x3F) << 6) + (buf[curByteIdx++] & 0x3F);
       }
-      charCount++;
-    }
-    return new String(chars, 0, charCount);
 
+      if (ch <= 0xFFFF) {
+        // target is a character <= 0xFFFF
+        charBuf[curCharIdx++] = (char) ch;
+      } else {
+        // target is a character in range 0xFFFF - 0x10FFFF
+        final int chHalf = ch - HALF_BASE;
+        charBuf[curCharIdx++] = (char) ((chHalf >> HALF_SHIFT) + UNI_SUR_HIGH_START);
+        charBuf[curCharIdx++] = (char) ((chHalf & HALF_MASK) + UNI_SUR_LOW_START);
+      }
+    }
+    return curCharIdx;
   }
-  
 }
diff --git a/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java b/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java
index 8c9fbf5..9e76348 100644
--- a/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java
+++ b/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java
@@ -328,9 +328,10 @@
     int size = readI32();
 
     if (trans_.getBytesRemainingInBuffer() >= size) {
-      String s = Utf8Helper.decode(trans_.getBuffer(), trans_.getBufferPosition(), size);
+      char[] charBuf = new char[size];
+      int charsDecoded = Utf8Helper.decode(trans_.getBuffer(), trans_.getBufferPosition(), size, charBuf);
       trans_.consumeBuffer(size);
-      return s;
+      return new String(charBuf, 0, charsDecoded);
     }
 
     return readStringBody(size);
diff --git a/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java
index f50ef1b..e81ed82 100755
--- a/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java
+++ b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java
@@ -606,9 +606,10 @@
     }
 
     if (trans_.getBytesRemainingInBuffer() >= length) {
-      String str = Utf8Helper.decode(trans_.getBuffer(), trans_.getBufferPosition(), length);
+      char[] charBuf = new char[length];
+      int charsDecoded = Utf8Helper.decode(trans_.getBuffer(), trans_.getBufferPosition(), length, charBuf);
       trans_.consumeBuffer(length);
-      return str;
+      return new String(charBuf, 0, charsDecoded);
     } else {
       return Utf8Helper.decode(readBinary(length));
     }
diff --git a/lib/java/test/org/apache/thrift/BenchStringEncoding.java b/lib/java/test/org/apache/thrift/BenchStringEncoding.java
new file mode 100644
index 0000000..3ae22c7
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/BenchStringEncoding.java
@@ -0,0 +1,67 @@
+package org.apache.thrift;
+
+import java.io.UnsupportedEncodingException;
+
+public class BenchStringEncoding {
+  private static final String STRING = "a moderately long (but not overly long) string";
+  private static final int HOW_MANY = 100000;
+  private static final byte[] BYTES;
+  static {
+    try {
+      BYTES = STRING.getBytes("UTF-8");
+    } catch (UnsupportedEncodingException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  public static void main(String[] args) throws UnsupportedEncodingException {
+    for (int trial = 0; trial < 5; trial++) {
+      benchGetBytes();
+      benchFromBytes();
+      benchEncode();
+      benchDecode();
+    }
+  }
+
+  private static void benchDecode() {
+    char[] charBuf = new char[256];
+    long start = System.currentTimeMillis();
+    for (int i = 0; i < HOW_MANY; i++) {
+      Utf8Helper.decode(BYTES, 0, BYTES.length, charBuf);
+    }
+    long end = System.currentTimeMillis();
+    System.out.println("decode: decode: " + (end-start) + "ms");
+  }
+
+  private static void benchFromBytes() {
+    long start = System.currentTimeMillis();
+    for (int i = 0; i < HOW_MANY; i++) {
+      try {
+        new String(BYTES, "UTF-8");
+      } catch (UnsupportedEncodingException e) {
+        throw new RuntimeException(e);
+      }
+    }
+    long end = System.currentTimeMillis();
+    System.out.println("decode: fromBytes: " + (end-start) + "ms");
+  }
+
+  private static void benchEncode() {
+    long start = System.currentTimeMillis();
+    byte[] outbuf = new byte[256];
+    for (int i = 0; i < HOW_MANY; i++) {
+      Utf8Helper.encode(STRING, outbuf, 0);
+    }
+    long end = System.currentTimeMillis();
+    System.out.println("encode: directEncode: " + (end-start) + "ms");
+  }
+
+  private static void benchGetBytes() throws UnsupportedEncodingException {
+    long start = System.currentTimeMillis();
+    for (int i = 0; i < HOW_MANY; i++) {
+      STRING.getBytes("UTF-8");
+    }
+    long end = System.currentTimeMillis();
+    System.out.println("encode: getBytes(UTF-8): " + (end-start) + "ms");
+  }
+}
diff --git a/lib/java/test/org/apache/thrift/TestUtf8Helper.java b/lib/java/test/org/apache/thrift/TestUtf8Helper.java
index 155f55c..bdfd35a 100644
--- a/lib/java/test/org/apache/thrift/TestUtf8Helper.java
+++ b/lib/java/test/org/apache/thrift/TestUtf8Helper.java
@@ -25,15 +25,19 @@
   private static final String UNICODE_STRING_2;
   private static final byte[] UNICODE_STRING_BYTES_2;
 
-  private static final String REALLY_WHACKY_ONE = "\u20491";
+  private static final String REALLY_WHACKY_ONE = "\uD841\uDC91";
   private static final byte[] REALLY_WHACKY_ONE_BYTES;
 
+  private static final String TWO_CHAR_CHAR = "\uD801\uDC00";
+  private static final byte[] TWO_CHAR_CHAR_BYTES;
+
   static {
     try {
       UNICODE_STRING_BYTES = UNICODE_STRING.getBytes("UTF-8");
       UNICODE_STRING_2 = new String(kUnicodeBytes, "UTF-8");
       UNICODE_STRING_BYTES_2 = UNICODE_STRING_2.getBytes("UTF-8");
       REALLY_WHACKY_ONE_BYTES = REALLY_WHACKY_ONE.getBytes("UTF-8");
+      TWO_CHAR_CHAR_BYTES = TWO_CHAR_CHAR.getBytes("UTF-8");
     } catch (UnsupportedEncodingException e) {
       throw new RuntimeException(e);
     }
@@ -53,6 +57,9 @@
 
     otherBytes = Utf8Helper.encode(REALLY_WHACKY_ONE);
     assertTrue(Arrays.equals(REALLY_WHACKY_ONE_BYTES, otherBytes));
+
+    otherBytes = Utf8Helper.encode(TWO_CHAR_CHAR);
+    assertTrue(Arrays.equals(TWO_CHAR_CHAR_BYTES, otherBytes));
   }
 
   public void testDecode() throws Exception {
@@ -62,5 +69,6 @@
     assertEquals(UNICODE_STRING, Utf8Helper.decode(UNICODE_STRING_BYTES));
     assertEquals(UNICODE_STRING_2, Utf8Helper.decode(UNICODE_STRING_BYTES_2));
     assertEquals(REALLY_WHACKY_ONE, Utf8Helper.decode(REALLY_WHACKY_ONE_BYTES));
+    assertEquals(TWO_CHAR_CHAR, Utf8Helper.decode(TWO_CHAR_CHAR_BYTES));
   }
 }