blob: ded34b03d9da3fd8a0433ea5f25ef72472123bff [file] [log] [blame]
zeshuai007bbc5e062020-05-07 17:10:32 +08001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import unittest
21import random
22import string
23
24import _import_local_thrift # noqa
25from thrift.transport import TTransport
26from thrift.transport import TZlibTransport
27
28
29def generate_random_buff():
30 data = []
31 buf_len = 1024 * 32
32 index = 0
33
34 while index < buf_len:
35 run_len = random.randint(1, 64)
36 if index + run_len > buf_len:
37 run_len = buf_len - index
38 for i in range(run_len):
39 data.extend(random.sample(string.printable, 1))
40 index += 1
41
42 new_data = ''.join(data)
43 return new_data
44
45
46class TestTZlibTransport(unittest.TestCase):
47
48 def test_write_then_read(self):
49 buff = TTransport.TMemoryBuffer()
50 trans = TTransport.TBufferedTransportFactory().getTransport(buff)
51 zlib_trans = TZlibTransport.TZlibTransport(trans)
52 data_w = generate_random_buff()
53 zlib_trans.write(data_w.encode('utf-8'))
54 zlib_trans.flush()
55
56 value = buff.getvalue()
57 zlib_trans.close()
58
59 buff = TTransport.TMemoryBuffer(value)
60 trans = TTransport.TBufferedTransportFactory().getTransport(buff)
61 zlib_trans = TZlibTransport.TZlibTransport(trans)
62 data_r = zlib_trans.read(len(data_w))
63 zlib_trans.close()
64
65 try:
66 self.assertEqual(data_w, data_r.decode('utf-8'))
67 self.assertEqual(len(data_w), len(data_r.decode('utf-8')))
68 except AssertionError:
69 raise
70
71 def test_after_flushd_write_then_read(self):
72 buff = TTransport.TMemoryBuffer()
73 trans = TTransport.TBufferedTransportFactory().getTransport(buff)
74 zlib_trans = TZlibTransport.TZlibTransport(trans)
75 data_w_1 = "hello thrift !@#" * 50
76 zlib_trans.write(data_w_1.encode('utf-8'))
77 zlib_trans.flush()
78 data_w_2 = "{'name': 'thrift', 1: ['abcd' , 233, ('a','c')]}" * 20
79 zlib_trans.write(data_w_2.encode('utf-8'))
80 zlib_trans.flush()
81
82 value = buff.getvalue()
83 zlib_trans.close()
84
85 buff = TTransport.TMemoryBuffer(value)
86 trans = TTransport.TBufferedTransportFactory().getTransport(buff)
87 zlib_trans = TZlibTransport.TZlibTransport(trans)
88 data_r = zlib_trans.read(len(data_w_1) + len(data_w_2))
89 zlib_trans.close()
90
91 try:
92 self.assertEqual(data_w_1 + data_w_2, data_r.decode('utf-8'))
93 self.assertEqual(len(data_w_1) + len(data_w_2), len(data_r.decode('utf-8')))
94 except AssertionError:
95 raise
96
97
98if __name__ == '__main__':
99 unittest.main()