blob: 080293a08012dc6cf5e77ba14c1436884d86897e [file] [log] [blame]
Roger Meier0895dfe2012-12-26 22:09:55 +01001#!/usr/bin/env python
2
3#
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements. See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership. The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License. You may obtain a copy of the License at
11#
12# http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing,
15# software distributed under the License is distributed on an
16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17# KIND, either express or implied. See the License for the
18# specific language governing permissions and limitations
19# under the License.
20#
21
22import sys
23import glob
24from optparse import OptionParser
25parser = OptionParser()
26parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
27options, args = parser.parse_args()
28del sys.argv[1:] # clean up hack so unittest doesn't complain
29sys.path.insert(0, options.genpydir)
30sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
31
32from ThriftTest.ttypes import *
33from thrift.protocol import TJSONProtocol
34from thrift.transport import TTransport
35
36import json
37import unittest
38
39
40class SimpleJSONProtocolTest(unittest.TestCase):
41 protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory()
42
43 def _assertDictEqual(self, a ,b, msg=None):
44 if hasattr(self, 'assertDictEqual'):
45 # assertDictEqual only in Python 2.7. Depends on your machine.
46 self.assertDictEqual(a, b, msg)
47 return
48
49 # Substitute implementation not as good as unittest library's
50 self.assertEquals(len(a), len(b), msg)
51 for k, v in a.iteritems():
52 self.assertTrue(k in b, msg)
53 self.assertEquals(b.get(k), v, msg)
54
55 def _serialize(self, obj):
56 trans = TTransport.TMemoryBuffer()
57 prot = self.protocol_factory.getProtocol(trans)
58 obj.write(prot)
59 return trans.getvalue()
60
61 def _deserialize(self, objtype, data):
62 prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
63 ret = objtype()
64 ret.read(prot)
65 return ret
66
67 def testWriteOnly(self):
68 self.assertRaises(NotImplementedError,
69 self._deserialize, VersioningTestV1, '{}')
70
71 def testSimpleMessage(self):
72 v1obj = VersioningTestV1(
73 begin_in_both=12345,
74 old_string='aaa',
75 end_in_both=54321)
76 expected = dict(begin_in_both=v1obj.begin_in_both,
77 old_string=v1obj.old_string,
78 end_in_both=v1obj.end_in_both)
79 actual = json.loads(self._serialize(v1obj))
80
81 self._assertDictEqual(expected, actual)
82
83 def testComplicated(self):
84 v2obj = VersioningTestV2(
85 begin_in_both=12345,
86 newint=1,
87 newbyte=2,
88 newshort=3,
89 newlong=4,
90 newdouble=5.0,
91 newstruct=Bonk(message="Hello!", type=123),
92 newlist=[7,8,9],
93 newset=set([42,1,8]),
94 newmap={1:2,2:3},
95 newstring="Hola!",
96 end_in_both=54321)
97 expected = dict(begin_in_both=v2obj.begin_in_both,
98 newint=v2obj.newint,
99 newbyte=v2obj.newbyte,
100 newshort=v2obj.newshort,
101 newlong=v2obj.newlong,
102 newdouble=v2obj.newdouble,
103 newstruct=dict(message=v2obj.newstruct.message,
104 type=v2obj.newstruct.type),
105 newlist=v2obj.newlist,
106 newset=list(v2obj.newset),
107 newmap=v2obj.newmap,
108 newstring=v2obj.newstring,
109 end_in_both=v2obj.end_in_both)
110
111 # Need to load/dump because map keys get escaped.
112 expected = json.loads(json.dumps(expected))
113 actual = json.loads(self._serialize(v2obj))
114 self._assertDictEqual(expected, actual)
115
116
117if __name__ == '__main__':
118 unittest.main()
119