blob: b7c3802fe4c5cb1d49808846f03e76cfa61446e2 [file] [log] [blame]
Nobuaki Sukegawaad835862015-12-23 23:32:09 +09001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import os
21import platform
22import select
23import ssl
24import sys
25import threading
26import time
27import unittest
28import warnings
29
30import _import_local_thrift
31from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket
32
33SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
34ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
35SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
36SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
37SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
38CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
39CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
40
41TEST_PORT = 23458
42TEST_ADDR = '/tmp/.thrift.domain.sock.%d' % TEST_PORT
Nobuaki Sukegawaf07b4a12016-02-01 23:44:02 +090043CONNECT_DELAY = 0.5
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090044CONNECT_TIMEOUT = 10.0
45TEST_CIPHERS = 'DES-CBC3-SHA'
46
47
48class ServerAcceptor(threading.Thread):
49 def __init__(self, server):
50 super(ServerAcceptor, self).__init__()
51 self._server = server
52 self.client = None
53
54 def run(self):
55 self._server.listen()
56 self.client = self._server.accept()
57
58
59# Python 2.6 compat
60class AssertRaises(object):
61 def __init__(self, expected):
62 self._expected = expected
63
64 def __enter__(self):
65 pass
66
67 def __exit__(self, exc_type, exc_value, traceback):
68 if not exc_type or not issubclass(exc_type, self._expected):
69 raise Exception('fail')
70 return True
71
72
73class TSSLSocketTest(unittest.TestCase):
74 def _assert_connection_failure(self, server, client):
75 try:
76 acc = ServerAcceptor(server)
77 acc.start()
Nobuaki Sukegawaf07b4a12016-02-01 23:44:02 +090078 time.sleep(CONNECT_DELAY)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090079 client.setTimeout(CONNECT_TIMEOUT)
80 with self._assert_raises(Exception):
81 client.open()
82 select.select([], [client.handle], [], CONNECT_TIMEOUT)
83 # self.assertIsNone(acc.client)
84 self.assertTrue(acc.client is None)
85 finally:
86 server.close()
87 client.close()
88
89 def _assert_raises(self, exc):
90 if sys.hexversion >= 0x020700F0:
91 return self.assertRaises(exc)
92 else:
93 return AssertRaises(exc)
94
95 def _assert_connection_success(self, server, client):
96 try:
97 acc = ServerAcceptor(server)
98 acc.start()
99 time.sleep(0.15)
100 client.setTimeout(CONNECT_TIMEOUT)
101 client.open()
102 select.select([], [client.handle], [], CONNECT_TIMEOUT)
103 # self.assertIsNotNone(acc.client)
104 self.assertTrue(acc.client is not None)
105 finally:
106 server.close()
107 client.close()
108
109 # deprecated feature
110 def test_deprecation(self):
111 with warnings.catch_warnings(record=True) as w:
112 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
113 TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
114 self.assertEqual(len(w), 1)
115
116 with warnings.catch_warnings(record=True) as w:
117 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
118 # Deprecated signature
119 # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
120 client = TSSLSocket('localhost', TEST_PORT, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
121 self.assertEqual(len(w), 7)
122
123 with warnings.catch_warnings(record=True) as w:
124 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
125 # Deprecated signature
126 # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
127 server = TSSLServerSocket(None, TEST_PORT, SERVER_PEM, None, TEST_CIPHERS)
128 self.assertEqual(len(w), 3)
129
130 self._assert_connection_success(server, client)
131
132 # deprecated feature
133 def test_set_cert_reqs_by_validate(self):
134 c1 = TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
135 self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
136
137 c1 = TSSLSocket('localhost', TEST_PORT, validate=False)
138 self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
139
140 # deprecated feature
141 def test_set_validate_by_cert_reqs(self):
142 c1 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
143 self.assertFalse(c1.validate)
144
145 c2 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
146 self.assertTrue(c2.validate)
147
148 c3 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
149 self.assertTrue(c3.validate)
150
151 def test_unix_domain_socket(self):
152 if platform.system() == 'Windows':
153 print('skipping test_unix_domain_socket')
154 return
155 server = TSSLServerSocket(unix_socket=TEST_ADDR, keyfile=SERVER_KEY, certfile=SERVER_CERT)
156 client = TSSLSocket(None, None, TEST_ADDR, cert_reqs=ssl.CERT_NONE)
157 self._assert_connection_success(server, client)
158
159 def test_server_cert(self):
160 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
161 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
162 self._assert_connection_success(server, client)
163
164 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
165 # server cert on in ca_certs
166 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
167 self._assert_connection_failure(server, client)
168
169 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
170 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
171 self._assert_connection_success(server, client)
172
173 def test_set_server_cert(self):
174 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=CLIENT_CERT)
175 with self._assert_raises(Exception):
176 server.certfile = 'foo'
177 with self._assert_raises(Exception):
178 server.certfile = None
179 server.certfile = SERVER_CERT
180 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
181 self._assert_connection_success(server, client)
182
183 def test_client_cert(self):
184 server = TSSLServerSocket(
185 port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
186 certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
187 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
188 self._assert_connection_success(server, client)
189
190 def test_ciphers(self):
191 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
192 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
193 self._assert_connection_success(server, client)
194
195 if not TSSLSocket._has_ciphers:
196 # unittest.skip is not available for Python 2.6
197 print('skipping test_ciphers')
198 return
199 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
200 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
201 self._assert_connection_failure(server, client)
202
203 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
204 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
205 self._assert_connection_failure(server, client)
206
207 def test_ssl2_and_ssl3_disabled(self):
208 if not hasattr(ssl, 'PROTOCOL_SSLv3'):
209 print('PROTOCOL_SSLv3 is not available')
210 else:
211 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
212 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
213 self._assert_connection_failure(server, client)
214
215 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
216 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
217 self._assert_connection_failure(server, client)
218
219 if not hasattr(ssl, 'PROTOCOL_SSLv2'):
220 print('PROTOCOL_SSLv2 is not available')
221 else:
222 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
223 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
224 self._assert_connection_failure(server, client)
225
226 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
227 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
228 self._assert_connection_failure(server, client)
229
230 def test_newer_tls(self):
231 if not TSSLSocket._has_ssl_context:
232 # unittest.skip is not available for Python 2.6
233 print('skipping test_newer_tls')
234 return
235 if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
236 print('PROTOCOL_TLSv1_2 is not available')
237 else:
238 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
239 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
240 self._assert_connection_success(server, client)
241
242 if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
243 print('PROTOCOL_TLSv1_1 is not available')
244 else:
245 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
246 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
247 self._assert_connection_success(server, client)
248
249 if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
250 print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
251 else:
252 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
253 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
254 self._assert_connection_failure(server, client)
255
256 def test_ssl_context(self):
257 if not TSSLSocket._has_ssl_context:
258 # unittest.skip is not available for Python 2.6
259 print('skipping test_ssl_context')
260 return
261 server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
262 server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
263 server_context.load_verify_locations(CLIENT_CERT)
264
265 client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
266 client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
267 client_context.load_verify_locations(SERVER_CERT)
268
269 server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context)
270 client = TSSLSocket('localhost', TEST_PORT, ssl_context=client_context)
271 self._assert_connection_success(server, client)
272
273if __name__ == '__main__':
274 # import logging
275 # logging.basicConfig(level=logging.DEBUG)
276 unittest.main()