blob: 1ed8c15748ec9a27a479daedd4ee873feb01ed89 [file] [log] [blame]
#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from ThriftTest.ttypes import Bonk, VersioningTestV1, VersioningTestV2
from thrift.protocol import TJSONProtocol
from thrift.transport import TTransport
import json
import unittest
class SimpleJSONProtocolTest(unittest.TestCase):
protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory()
def _assertDictEqual(self, a, b, msg=None):
if hasattr(self, 'assertDictEqual'):
# assertDictEqual only in Python 2.7. Depends on your machine.
self.assertDictEqual(a, b, msg)
return
# Substitute implementation not as good as unittest library's
self.assertEquals(len(a), len(b), msg)
for k, v in a.iteritems():
self.assertTrue(k in b, msg)
self.assertEquals(b.get(k), v, msg)
def _serialize(self, obj):
trans = TTransport.TMemoryBuffer()
prot = self.protocol_factory.getProtocol(trans)
obj.write(prot)
return trans.getvalue()
def _deserialize(self, objtype, data):
prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
ret = objtype()
ret.read(prot)
return ret
def testWriteOnly(self):
self.assertRaises(NotImplementedError,
self._deserialize, VersioningTestV1, b'{}')
def testSimpleMessage(self):
v1obj = VersioningTestV1(
begin_in_both=12345,
old_string='aaa',
end_in_both=54321)
expected = dict(begin_in_both=v1obj.begin_in_both,
old_string=v1obj.old_string,
end_in_both=v1obj.end_in_both)
actual = json.loads(self._serialize(v1obj).decode('ascii'))
self._assertDictEqual(expected, actual)
def testComplicated(self):
v2obj = VersioningTestV2(
begin_in_both=12345,
newint=1,
newbyte=2,
newshort=3,
newlong=4,
newdouble=5.0,
newstruct=Bonk(message="Hello!", type=123),
newlist=[7, 8, 9],
newset=set([42, 1, 8]),
newmap={1: 2, 2: 3},
newstring="Hola!",
end_in_both=54321)
expected = dict(begin_in_both=v2obj.begin_in_both,
newint=v2obj.newint,
newbyte=v2obj.newbyte,
newshort=v2obj.newshort,
newlong=v2obj.newlong,
newdouble=v2obj.newdouble,
newstruct=dict(message=v2obj.newstruct.message,
type=v2obj.newstruct.type),
newlist=v2obj.newlist,
newset=list(v2obj.newset),
newmap=v2obj.newmap,
newstring=v2obj.newstring,
end_in_both=v2obj.end_in_both)
# Need to load/dump because map keys get escaped.
expected = json.loads(json.dumps(expected))
actual = json.loads(self._serialize(v2obj).decode('ascii'))
self._assertDictEqual(expected, actual)
if __name__ == '__main__':
unittest.main()