THRIFT-1797 Python implementation of TSimpleJSONProtocol
Patch: Avi Flamholz
diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/protocol/TJSONProtocol.py
index 5fb3ec7..3048197 100644
--- a/lib/py/src/protocol/TJSONProtocol.py
+++ b/lib/py/src/protocol/TJSONProtocol.py
@@ -17,10 +17,15 @@
 # under the License.
 #
 
-from TProtocol import *
-import json, base64, sys
+from TProtocol import TType, TProtocolBase, TProtocolException
+import base64
+import json
+import math
 
-__all__ = ['TJSONProtocol', 'TJSONProtocolFactory']
+__all__ = ['TJSONProtocol',
+           'TJSONProtocolFactory',
+           'TSimpleJSONProtocol',
+           'TSimpleJSONProtocolFactory']
 
 VERSION = 1
 
@@ -74,6 +79,9 @@
   def escapeNum(self):
     return False
 
+  def __str__(self):
+    return self.__class__.__name__
+
 
 class JSONListContext(JSONBaseContext):
     
@@ -91,14 +99,17 @@
 
 
 class JSONPairContext(JSONBaseContext):
-  colon = True
+  
+  def __init__(self, protocol):
+    super(JSONPairContext, self).__init__(protocol)
+    self.colon = True
 
   def doIO(self, function):
-    if self.first is True:
+    if self.first:
       self.first = False
       self.colon = True
     else:
-      function(COLON if self.colon == True else COMMA)
+      function(COLON if self.colon else COMMA)
       self.colon = not self.colon
 
   def write(self):
@@ -110,6 +121,9 @@
   def escapeNum(self):
     return self.colon
 
+  def __str__(self):
+    return '%s, colon=%s' % (self.__class__.__name__, self.colon)
+
 
 class LookaheadReader():
   hasData = False
@@ -139,8 +153,8 @@
     self.resetReadContext()
 
   def resetWriteContext(self):
-    self.contextStack = []
-    self.context = JSONBaseContext(self)        
+    self.context = JSONBaseContext(self)
+    self.contextStack = [self.context]
 
   def resetReadContext(self):
     self.resetWriteContext()
@@ -152,6 +166,10 @@
 
   def popContext(self):
     self.contextStack.pop()
+    if self.contextStack:
+      self.context = self.contextStack[-1]
+    else:
+      self.context = JSONBaseContext(self)
 
   def writeJSONString(self, string):
     self.context.write()
@@ -210,7 +228,7 @@
           self.readJSONSyntaxChar(ZERO)
           character = json.JSONDecoder().decode('"\u00%s"' % self.trans.read(2))
         else:
-          off = ESCAPE_CHAR.find(char)
+          off = ESCAPE_CHAR.find(character)
           if off == -1:
             raise TProtocolException(TProtocolException.INVALID_DATA,
                                      "Expected control char")
@@ -251,7 +269,9 @@
       string  = self.readJSONString(True)
       try:
         double = float(string)
-        if self.context.escapeNum is False and double != inf and double != nan:
+        if (self.context.escapeNum is False and
+            not math.isinf(double) and
+            not math.isnan(double)):
           raise TProtocolException(TProtocolException.INVALID_DATA,
                                    "Numeric data unexpectedly quoted")
         return double
@@ -445,9 +465,86 @@
   def writeBinary(self, binary):
     self.writeJSONBase64(binary)
 
+
 class TJSONProtocolFactory:
-  def __init__(self):
-    pass
 
   def getProtocol(self, trans):
     return TJSONProtocol(trans)
+
+
+class TSimpleJSONProtocol(TJSONProtocolBase):
+    """Simple, readable, write-only JSON protocol.
+    
+    Useful for interacting with scripting languages.
+    """
+
+    def readMessageBegin(self):
+        raise NotImplementedError()
+    
+    def readMessageEnd(self):
+        raise NotImplementedError()
+    
+    def readStructBegin(self):
+        raise NotImplementedError()
+    
+    def readStructEnd(self):
+        raise NotImplementedError()
+    
+    def writeMessageBegin(self, name, request_type, seqid):
+        self.resetWriteContext()
+    
+    def writeMessageEnd(self):
+        pass
+    
+    def writeStructBegin(self, name):
+        self.writeJSONObjectStart()
+    
+    def writeStructEnd(self):
+        self.writeJSONObjectEnd()
+      
+    def writeFieldBegin(self, name, ttype, fid):
+        self.writeJSONString(name)
+    
+    def writeFieldEnd(self):
+        pass
+    
+    def writeMapBegin(self, ktype, vtype, size):
+        self.writeJSONObjectStart()
+    
+    def writeMapEnd(self):
+        self.writeJSONObjectEnd()
+    
+    def _writeCollectionBegin(self, etype, size):
+        self.writeJSONArrayStart()
+    
+    def _writeCollectionEnd(self):
+        self.writeJSONArrayEnd()
+    writeListBegin = _writeCollectionBegin
+    writeListEnd = _writeCollectionEnd
+    writeSetBegin = _writeCollectionBegin
+    writeSetEnd = _writeCollectionEnd
+
+    def writeInteger(self, integer):
+        self.writeJSONNumber(integer)
+    writeByte = writeInteger
+    writeI16 = writeInteger
+    writeI32 = writeInteger
+    writeI64 = writeInteger
+    
+    def writeBool(self, boolean):
+        self.writeJSONNumber(1 if boolean is True else 0)
+
+    def writeDouble(self, dbl):
+        self.writeJSONNumber(dbl)
+    
+    def writeString(self, string):
+        self.writeJSONString(string)
+      
+    def writeBinary(self, binary):
+        self.writeJSONBase64(binary)
+
+
+class TSimpleJSONProtocolFactory(object):
+
+    def getProtocol(self, trans):
+        return TSimpleJSONProtocol(trans)