THRIFT-1797 Python implementation of TSimpleJSONProtocol
Patch: Avi Flamholz
diff --git a/.gitignore b/.gitignore
old mode 100755
new mode 100644
index 9aa31aa..12a6b0b
--- a/.gitignore
+++ b/.gitignore
@@ -245,3 +245,5 @@
 /tutorial/py/Makefile
 /tutorial/py/Makefile.in
 /ylwrap
+.project
+.pydevproject
diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/protocol/TJSONProtocol.py
index 5fb3ec7..3048197 100644
--- a/lib/py/src/protocol/TJSONProtocol.py
+++ b/lib/py/src/protocol/TJSONProtocol.py
@@ -17,10 +17,15 @@
 # under the License.
 #
 
-from TProtocol import *
-import json, base64, sys
+from TProtocol import TType, TProtocolBase, TProtocolException
+import base64
+import json
+import math
 
-__all__ = ['TJSONProtocol', 'TJSONProtocolFactory']
+__all__ = ['TJSONProtocol',
+           'TJSONProtocolFactory',
+           'TSimpleJSONProtocol',
+           'TSimpleJSONProtocolFactory']
 
 VERSION = 1
 
@@ -74,6 +79,9 @@
   def escapeNum(self):
     return False
 
+  def __str__(self):
+    return self.__class__.__name__
+
 
 class JSONListContext(JSONBaseContext):
     
@@ -91,14 +99,17 @@
 
 
 class JSONPairContext(JSONBaseContext):
-  colon = True
+  
+  def __init__(self, protocol):
+    super(JSONPairContext, self).__init__(protocol)
+    self.colon = True
 
   def doIO(self, function):
-    if self.first is True:
+    if self.first:
       self.first = False
       self.colon = True
     else:
-      function(COLON if self.colon == True else COMMA)
+      function(COLON if self.colon else COMMA)
       self.colon = not self.colon
 
   def write(self):
@@ -110,6 +121,9 @@
   def escapeNum(self):
     return self.colon
 
+  def __str__(self):
+    return '%s, colon=%s' % (self.__class__.__name__, self.colon)
+
 
 class LookaheadReader():
   hasData = False
@@ -139,8 +153,8 @@
     self.resetReadContext()
 
   def resetWriteContext(self):
-    self.contextStack = []
-    self.context = JSONBaseContext(self)        
+    self.context = JSONBaseContext(self)
+    self.contextStack = [self.context]
 
   def resetReadContext(self):
     self.resetWriteContext()
@@ -152,6 +166,10 @@
 
   def popContext(self):
     self.contextStack.pop()
+    if self.contextStack:
+      self.context = self.contextStack[-1]
+    else:
+      self.context = JSONBaseContext(self)
 
   def writeJSONString(self, string):
     self.context.write()
@@ -210,7 +228,7 @@
           self.readJSONSyntaxChar(ZERO)
           character = json.JSONDecoder().decode('"\u00%s"' % self.trans.read(2))
         else:
-          off = ESCAPE_CHAR.find(char)
+          off = ESCAPE_CHAR.find(character)
           if off == -1:
             raise TProtocolException(TProtocolException.INVALID_DATA,
                                      "Expected control char")
@@ -251,7 +269,9 @@
       string  = self.readJSONString(True)
       try:
         double = float(string)
-        if self.context.escapeNum is False and double != inf and double != nan:
+        if (self.context.escapeNum is False and
+            not math.isinf(double) and
+            not math.isnan(double)):
           raise TProtocolException(TProtocolException.INVALID_DATA,
                                    "Numeric data unexpectedly quoted")
         return double
@@ -445,9 +465,86 @@
   def writeBinary(self, binary):
     self.writeJSONBase64(binary)
 
+
 class TJSONProtocolFactory:
-  def __init__(self):
-    pass
 
   def getProtocol(self, trans):
     return TJSONProtocol(trans)
+
+
+class TSimpleJSONProtocol(TJSONProtocolBase):
+    """Simple, readable, write-only JSON protocol.
+    
+    Useful for interacting with scripting languages.
+    """
+
+    def readMessageBegin(self):
+        raise NotImplementedError()
+    
+    def readMessageEnd(self):
+        raise NotImplementedError()
+    
+    def readStructBegin(self):
+        raise NotImplementedError()
+    
+    def readStructEnd(self):
+        raise NotImplementedError()
+    
+    def writeMessageBegin(self, name, request_type, seqid):
+        self.resetWriteContext()
+    
+    def writeMessageEnd(self):
+        pass
+    
+    def writeStructBegin(self, name):
+        self.writeJSONObjectStart()
+    
+    def writeStructEnd(self):
+        self.writeJSONObjectEnd()
+      
+    def writeFieldBegin(self, name, ttype, fid):
+        self.writeJSONString(name)
+    
+    def writeFieldEnd(self):
+        pass
+    
+    def writeMapBegin(self, ktype, vtype, size):
+        self.writeJSONObjectStart()
+    
+    def writeMapEnd(self):
+        self.writeJSONObjectEnd()
+    
+    def _writeCollectionBegin(self, etype, size):
+        self.writeJSONArrayStart()
+    
+    def _writeCollectionEnd(self):
+        self.writeJSONArrayEnd()
+    writeListBegin = _writeCollectionBegin
+    writeListEnd = _writeCollectionEnd
+    writeSetBegin = _writeCollectionBegin
+    writeSetEnd = _writeCollectionEnd
+
+    def writeInteger(self, integer):
+        self.writeJSONNumber(integer)
+    writeByte = writeInteger
+    writeI16 = writeInteger
+    writeI32 = writeInteger
+    writeI64 = writeInteger
+    
+    def writeBool(self, boolean):
+        self.writeJSONNumber(1 if boolean is True else 0)
+
+    def writeDouble(self, dbl):
+        self.writeJSONNumber(dbl)
+    
+    def writeString(self, string):
+        self.writeJSONString(string)
+      
+    def writeBinary(self, binary):
+        self.writeJSONBase64(binary)
+
+
+class TSimpleJSONProtocolFactory(object):
+
+    def getProtocol(self, trans):
+        return TSimpleJSONProtocol(trans)
diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py
index f9121c8..db0bfa4 100755
--- a/test/py/RunClientServer.py
+++ b/test/py/RunClientServer.py
@@ -46,7 +46,11 @@
 for gp_dir in options.genpydirs.split(','):
   generated_dirs.append('gen-py-%s' % (gp_dir))
 
-SCRIPTS = ['SerializationTest.py', 'TestEof.py', 'TestSyntax.py', 'TestSocket.py']
+SCRIPTS = ['TSimpleJSONProtocolTest.py',
+           'SerializationTest.py',
+           'TestEof.py',
+           'TestSyntax.py',
+           'TestSocket.py']
 FRAMED = ["TNonblockingServer"]
 SKIP_ZLIB = ['TNonblockingServer', 'THttpServer']
 SKIP_SSL = ['TNonblockingServer', 'THttpServer']
diff --git a/test/py/TSimpleJSONProtocolTest.py b/test/py/TSimpleJSONProtocolTest.py
new file mode 100644
index 0000000..080293a
--- /dev/null
+++ b/test/py/TSimpleJSONProtocolTest.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import sys
+import glob
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
+sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
+
+from ThriftTest.ttypes import *
+from thrift.protocol import TJSONProtocol
+from thrift.transport import TTransport
+
+import json
+import unittest
+
+
+class SimpleJSONProtocolTest(unittest.TestCase):
+  protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory()
+
+  def _assertDictEqual(self, a ,b, msg=None):
+    if hasattr(self, 'assertDictEqual'):
+      # assertDictEqual only in Python 2.7. Depends on your machine.
+      self.assertDictEqual(a, b, msg)
+      return
+    
+    # Substitute implementation not as good as unittest library's
+    self.assertEquals(len(a), len(b), msg)
+    for k, v in a.iteritems():
+      self.assertTrue(k in b, msg)
+      self.assertEquals(b.get(k), v, msg)
+
+  def _serialize(self, obj):
+    trans = TTransport.TMemoryBuffer()
+    prot = self.protocol_factory.getProtocol(trans)
+    obj.write(prot)
+    return trans.getvalue()
+
+  def _deserialize(self, objtype, data):
+    prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
+    ret = objtype()
+    ret.read(prot)
+    return ret
+
+  def testWriteOnly(self):
+    self.assertRaises(NotImplementedError,
+                      self._deserialize, VersioningTestV1, '{}')
+
+  def testSimpleMessage(self):
+      v1obj = VersioningTestV1(
+          begin_in_both=12345,
+          old_string='aaa',
+          end_in_both=54321)
+      expected = dict(begin_in_both=v1obj.begin_in_both,
+                      old_string=v1obj.old_string,
+                      end_in_both=v1obj.end_in_both)
+      actual = json.loads(self._serialize(v1obj))
+
+      self._assertDictEqual(expected, actual)
+     
+  def testComplicated(self):
+      v2obj = VersioningTestV2(
+          begin_in_both=12345,
+          newint=1,
+          newbyte=2,
+          newshort=3,
+          newlong=4,
+          newdouble=5.0,
+          newstruct=Bonk(message="Hello!", type=123),
+          newlist=[7,8,9],
+          newset=set([42,1,8]),
+          newmap={1:2,2:3},
+          newstring="Hola!",
+          end_in_both=54321)
+      expected = dict(begin_in_both=v2obj.begin_in_both,
+                      newint=v2obj.newint,
+                      newbyte=v2obj.newbyte,
+                      newshort=v2obj.newshort,
+                      newlong=v2obj.newlong,
+                      newdouble=v2obj.newdouble,
+                      newstruct=dict(message=v2obj.newstruct.message,
+                                     type=v2obj.newstruct.type),
+                      newlist=v2obj.newlist,
+                      newset=list(v2obj.newset),
+                      newmap=v2obj.newmap,
+                      newstring=v2obj.newstring,
+                      end_in_both=v2obj.end_in_both)
+      
+      # Need to load/dump because map keys get escaped.
+      expected = json.loads(json.dumps(expected))
+      actual = json.loads(self._serialize(v2obj))
+      self._assertDictEqual(expected, actual)
+
+
+if __name__ == '__main__':
+  unittest.main()
+