THRIFT-1797 Python implementation of TSimpleJSONProtocol
Patch: Avi Flamholz
diff --git a/.gitignore b/.gitignore
old mode 100755
new mode 100644
index 9aa31aa..12a6b0b
--- a/.gitignore
+++ b/.gitignore
@@ -245,3 +245,5 @@
/tutorial/py/Makefile
/tutorial/py/Makefile.in
/ylwrap
+.project
+.pydevproject
diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/protocol/TJSONProtocol.py
index 5fb3ec7..3048197 100644
--- a/lib/py/src/protocol/TJSONProtocol.py
+++ b/lib/py/src/protocol/TJSONProtocol.py
@@ -17,10 +17,15 @@
# under the License.
#
-from TProtocol import *
-import json, base64, sys
+from TProtocol import TType, TProtocolBase, TProtocolException
+import base64
+import json
+import math
-__all__ = ['TJSONProtocol', 'TJSONProtocolFactory']
+__all__ = ['TJSONProtocol',
+ 'TJSONProtocolFactory',
+ 'TSimpleJSONProtocol',
+ 'TSimpleJSONProtocolFactory']
VERSION = 1
@@ -74,6 +79,9 @@
def escapeNum(self):
return False
+ def __str__(self):
+ return self.__class__.__name__
+
class JSONListContext(JSONBaseContext):
@@ -91,14 +99,17 @@
class JSONPairContext(JSONBaseContext):
- colon = True
+
+ def __init__(self, protocol):
+ super(JSONPairContext, self).__init__(protocol)
+ self.colon = True
def doIO(self, function):
- if self.first is True:
+ if self.first:
self.first = False
self.colon = True
else:
- function(COLON if self.colon == True else COMMA)
+ function(COLON if self.colon else COMMA)
self.colon = not self.colon
def write(self):
@@ -110,6 +121,9 @@
def escapeNum(self):
return self.colon
+ def __str__(self):
+ return '%s, colon=%s' % (self.__class__.__name__, self.colon)
+
class LookaheadReader():
hasData = False
@@ -139,8 +153,8 @@
self.resetReadContext()
def resetWriteContext(self):
- self.contextStack = []
- self.context = JSONBaseContext(self)
+ self.context = JSONBaseContext(self)
+ self.contextStack = [self.context]
def resetReadContext(self):
self.resetWriteContext()
@@ -152,6 +166,10 @@
def popContext(self):
self.contextStack.pop()
+ if self.contextStack:
+ self.context = self.contextStack[-1]
+ else:
+ self.context = JSONBaseContext(self)
def writeJSONString(self, string):
self.context.write()
@@ -210,7 +228,7 @@
self.readJSONSyntaxChar(ZERO)
character = json.JSONDecoder().decode('"\u00%s"' % self.trans.read(2))
else:
- off = ESCAPE_CHAR.find(char)
+ off = ESCAPE_CHAR.find(character)
if off == -1:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Expected control char")
@@ -251,7 +269,9 @@
string = self.readJSONString(True)
try:
double = float(string)
- if self.context.escapeNum is False and double != inf and double != nan:
+ if (self.context.escapeNum is False and
+ not math.isinf(double) and
+ not math.isnan(double)):
raise TProtocolException(TProtocolException.INVALID_DATA,
"Numeric data unexpectedly quoted")
return double
@@ -445,9 +465,86 @@
def writeBinary(self, binary):
self.writeJSONBase64(binary)
+
class TJSONProtocolFactory:
- def __init__(self):
- pass
def getProtocol(self, trans):
return TJSONProtocol(trans)
+
+
+class TSimpleJSONProtocol(TJSONProtocolBase):
+ """Simple, readable, write-only JSON protocol.
+
+ Useful for interacting with scripting languages.
+ """
+
+ def readMessageBegin(self):
+ raise NotImplementedError()
+
+ def readMessageEnd(self):
+ raise NotImplementedError()
+
+ def readStructBegin(self):
+ raise NotImplementedError()
+
+ def readStructEnd(self):
+ raise NotImplementedError()
+
+ def writeMessageBegin(self, name, request_type, seqid):
+ self.resetWriteContext()
+
+ def writeMessageEnd(self):
+ pass
+
+ def writeStructBegin(self, name):
+ self.writeJSONObjectStart()
+
+ def writeStructEnd(self):
+ self.writeJSONObjectEnd()
+
+ def writeFieldBegin(self, name, ttype, fid):
+ self.writeJSONString(name)
+
+ def writeFieldEnd(self):
+ pass
+
+ def writeMapBegin(self, ktype, vtype, size):
+ self.writeJSONObjectStart()
+
+ def writeMapEnd(self):
+ self.writeJSONObjectEnd()
+
+ def _writeCollectionBegin(self, etype, size):
+ self.writeJSONArrayStart()
+
+ def _writeCollectionEnd(self):
+ self.writeJSONArrayEnd()
+ writeListBegin = _writeCollectionBegin
+ writeListEnd = _writeCollectionEnd
+ writeSetBegin = _writeCollectionBegin
+ writeSetEnd = _writeCollectionEnd
+
+ def writeInteger(self, integer):
+ self.writeJSONNumber(integer)
+ writeByte = writeInteger
+ writeI16 = writeInteger
+ writeI32 = writeInteger
+ writeI64 = writeInteger
+
+ def writeBool(self, boolean):
+ self.writeJSONNumber(1 if boolean is True else 0)
+
+ def writeDouble(self, dbl):
+ self.writeJSONNumber(dbl)
+
+ def writeString(self, string):
+ self.writeJSONString(string)
+
+ def writeBinary(self, binary):
+ self.writeJSONBase64(binary)
+
+
+class TSimpleJSONProtocolFactory(object):
+
+ def getProtocol(self, trans):
+ return TSimpleJSONProtocol(trans)
diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py
index f9121c8..db0bfa4 100755
--- a/test/py/RunClientServer.py
+++ b/test/py/RunClientServer.py
@@ -46,7 +46,11 @@
for gp_dir in options.genpydirs.split(','):
generated_dirs.append('gen-py-%s' % (gp_dir))
-SCRIPTS = ['SerializationTest.py', 'TestEof.py', 'TestSyntax.py', 'TestSocket.py']
+SCRIPTS = ['TSimpleJSONProtocolTest.py',
+ 'SerializationTest.py',
+ 'TestEof.py',
+ 'TestSyntax.py',
+ 'TestSocket.py']
FRAMED = ["TNonblockingServer"]
SKIP_ZLIB = ['TNonblockingServer', 'THttpServer']
SKIP_SSL = ['TNonblockingServer', 'THttpServer']
diff --git a/test/py/TSimpleJSONProtocolTest.py b/test/py/TSimpleJSONProtocolTest.py
new file mode 100644
index 0000000..080293a
--- /dev/null
+++ b/test/py/TSimpleJSONProtocolTest.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import sys
+import glob
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
+sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
+
+from ThriftTest.ttypes import *
+from thrift.protocol import TJSONProtocol
+from thrift.transport import TTransport
+
+import json
+import unittest
+
+
+class SimpleJSONProtocolTest(unittest.TestCase):
+ protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory()
+
+ def _assertDictEqual(self, a ,b, msg=None):
+ if hasattr(self, 'assertDictEqual'):
+ # assertDictEqual only in Python 2.7. Depends on your machine.
+ self.assertDictEqual(a, b, msg)
+ return
+
+ # Substitute implementation not as good as unittest library's
+ self.assertEquals(len(a), len(b), msg)
+ for k, v in a.iteritems():
+ self.assertTrue(k in b, msg)
+ self.assertEquals(b.get(k), v, msg)
+
+ def _serialize(self, obj):
+ trans = TTransport.TMemoryBuffer()
+ prot = self.protocol_factory.getProtocol(trans)
+ obj.write(prot)
+ return trans.getvalue()
+
+ def _deserialize(self, objtype, data):
+ prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
+ ret = objtype()
+ ret.read(prot)
+ return ret
+
+ def testWriteOnly(self):
+ self.assertRaises(NotImplementedError,
+ self._deserialize, VersioningTestV1, '{}')
+
+ def testSimpleMessage(self):
+ v1obj = VersioningTestV1(
+ begin_in_both=12345,
+ old_string='aaa',
+ end_in_both=54321)
+ expected = dict(begin_in_both=v1obj.begin_in_both,
+ old_string=v1obj.old_string,
+ end_in_both=v1obj.end_in_both)
+ actual = json.loads(self._serialize(v1obj))
+
+ self._assertDictEqual(expected, actual)
+
+ def testComplicated(self):
+ v2obj = VersioningTestV2(
+ begin_in_both=12345,
+ newint=1,
+ newbyte=2,
+ newshort=3,
+ newlong=4,
+ newdouble=5.0,
+ newstruct=Bonk(message="Hello!", type=123),
+ newlist=[7,8,9],
+ newset=set([42,1,8]),
+ newmap={1:2,2:3},
+ newstring="Hola!",
+ end_in_both=54321)
+ expected = dict(begin_in_both=v2obj.begin_in_both,
+ newint=v2obj.newint,
+ newbyte=v2obj.newbyte,
+ newshort=v2obj.newshort,
+ newlong=v2obj.newlong,
+ newdouble=v2obj.newdouble,
+ newstruct=dict(message=v2obj.newstruct.message,
+ type=v2obj.newstruct.type),
+ newlist=v2obj.newlist,
+ newset=list(v2obj.newset),
+ newmap=v2obj.newmap,
+ newstring=v2obj.newstring,
+ end_in_both=v2obj.end_in_both)
+
+ # Need to load/dump because map keys get escaped.
+ expected = json.loads(json.dumps(expected))
+ actual = json.loads(self._serialize(v2obj))
+ self._assertDictEqual(expected, actual)
+
+
+if __name__ == '__main__':
+ unittest.main()
+