blob: c19dbcda6b20b03af341cf555d866b6df7abc7c0 [file] [log] [blame]
Nobuaki Sukegawaad835862015-12-23 23:32:09 +09001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
Nobuaki Sukegawa355116e2016-02-11 18:01:20 +090020import logging
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090021import os
22import platform
23import select
24import ssl
25import sys
26import threading
27import time
28import unittest
29import warnings
30
31import _import_local_thrift
32from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket
33
34SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
35ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
36SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
37SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
38SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
Nobuaki Sukegawaf39f7db2016-02-04 15:09:41 +090039CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
40CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
41CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt')
42CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key')
43CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090044
45TEST_PORT = 23458
46TEST_ADDR = '/tmp/.thrift.domain.sock.%d' % TEST_PORT
Nobuaki Sukegawaf07b4a12016-02-01 23:44:02 +090047CONNECT_DELAY = 0.5
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090048CONNECT_TIMEOUT = 20.0
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090049TEST_CIPHERS = 'DES-CBC3-SHA'
50
51
52class ServerAcceptor(threading.Thread):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090053 def __init__(self, server):
54 super(ServerAcceptor, self).__init__()
55 self._server = server
56 self.client = None
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090057
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090058 def run(self):
59 self._server.listen()
60 self.client = self._server.accept()
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090061
62
63# Python 2.6 compat
64class AssertRaises(object):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090065 def __init__(self, expected):
66 self._expected = expected
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090067
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090068 def __enter__(self):
69 pass
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090070
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090071 def __exit__(self, exc_type, exc_value, traceback):
72 if not exc_type or not issubclass(exc_type, self._expected):
73 raise Exception('fail')
74 return True
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090075
76
77class TSSLSocketTest(unittest.TestCase):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090078 def _assert_connection_failure(self, server, client):
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090079 acc = ServerAcceptor(server)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090080 try:
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090081 acc.start()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090082 time.sleep(CONNECT_DELAY / 2)
83 client.setTimeout(CONNECT_TIMEOUT / 2)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090084 with self._assert_raises(Exception):
Nobuaki Sukegawa355116e2016-02-11 18:01:20 +090085 logging.disable(logging.CRITICAL)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090086 client.open()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090087 select.select([], [client.handle], [], CONNECT_TIMEOUT / 2)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090088 # self.assertIsNone(acc.client)
89 self.assertTrue(acc.client is None)
90 finally:
Nobuaki Sukegawa355116e2016-02-11 18:01:20 +090091 logging.disable(logging.NOTSET)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090092 client.close()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090093 if acc.client:
94 acc.client.close()
95 server.close()
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090096
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090097 def _assert_raises(self, exc):
98 if sys.hexversion >= 0x020700F0:
99 return self.assertRaises(exc)
100 else:
101 return AssertRaises(exc)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900102
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900103 def _assert_connection_success(self, server, client):
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +0900104 acc = ServerAcceptor(server)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900105 try:
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900106 acc.start()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +0900107 time.sleep(CONNECT_DELAY)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900108 client.setTimeout(CONNECT_TIMEOUT)
109 client.open()
110 select.select([], [client.handle], [], CONNECT_TIMEOUT)
111 # self.assertIsNotNone(acc.client)
112 self.assertTrue(acc.client is not None)
113 finally:
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900114 client.close()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +0900115 if acc.client:
116 acc.client.close()
117 server.close()
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900118
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900119 # deprecated feature
120 def test_deprecation(self):
121 with warnings.catch_warnings(record=True) as w:
122 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
123 TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
124 self.assertEqual(len(w), 1)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900125
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900126 with warnings.catch_warnings(record=True) as w:
127 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
128 # Deprecated signature
129 # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
130 client = TSSLSocket('localhost', TEST_PORT, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
131 self.assertEqual(len(w), 7)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900132
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900133 with warnings.catch_warnings(record=True) as w:
134 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
135 # Deprecated signature
136 # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
137 server = TSSLServerSocket(None, TEST_PORT, SERVER_PEM, None, TEST_CIPHERS)
138 self.assertEqual(len(w), 3)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900139
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900140 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900141
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900142 # deprecated feature
143 def test_set_cert_reqs_by_validate(self):
144 c1 = TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
145 self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900146
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900147 c1 = TSSLSocket('localhost', TEST_PORT, validate=False)
148 self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900149
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900150 # deprecated feature
151 def test_set_validate_by_cert_reqs(self):
152 c1 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
153 self.assertFalse(c1.validate)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900154
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900155 c2 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
156 self.assertTrue(c2.validate)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900157
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900158 c3 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
159 self.assertTrue(c3.validate)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900160
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900161 def test_unix_domain_socket(self):
162 if platform.system() == 'Windows':
163 print('skipping test_unix_domain_socket')
164 return
165 server = TSSLServerSocket(unix_socket=TEST_ADDR, keyfile=SERVER_KEY, certfile=SERVER_CERT)
166 client = TSSLSocket(None, None, TEST_ADDR, cert_reqs=ssl.CERT_NONE)
167 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900168
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900169 def test_server_cert(self):
170 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
171 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
172 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900173
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900174 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
175 # server cert on in ca_certs
176 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
177 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900178
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900179 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
180 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
181 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900182
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900183 def test_set_server_cert(self):
184 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=CLIENT_CERT)
185 with self._assert_raises(Exception):
186 server.certfile = 'foo'
187 with self._assert_raises(Exception):
188 server.certfile = None
189 server.certfile = SERVER_CERT
190 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
191 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900192
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900193 def test_client_cert(self):
194 server = TSSLServerSocket(
195 port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
196 certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
Nobuaki Sukegawaf39f7db2016-02-04 15:09:41 +0900197 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
198 self._assert_connection_failure(server, client)
199
200 server = TSSLServerSocket(
201 port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
202 certfile=SERVER_CERT, ca_certs=CLIENT_CA)
203 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
204 self._assert_connection_failure(server, client)
205
206 server = TSSLServerSocket(
207 port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
208 certfile=SERVER_CERT, ca_certs=CLIENT_CA)
209 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
210 self._assert_connection_success(server, client)
211
212 server = TSSLServerSocket(
213 port=TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY,
214 certfile=SERVER_CERT, ca_certs=CLIENT_CA)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900215 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
216 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900217
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900218 def test_ciphers(self):
219 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
220 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
221 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900222
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900223 if not TSSLSocket._has_ciphers:
224 # unittest.skip is not available for Python 2.6
225 print('skipping test_ciphers')
226 return
227 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
228 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
229 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900230
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900231 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
232 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
233 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900234
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900235 def test_ssl2_and_ssl3_disabled(self):
236 if not hasattr(ssl, 'PROTOCOL_SSLv3'):
237 print('PROTOCOL_SSLv3 is not available')
238 else:
239 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
240 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
241 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900242
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900243 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
244 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
245 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900246
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900247 if not hasattr(ssl, 'PROTOCOL_SSLv2'):
248 print('PROTOCOL_SSLv2 is not available')
249 else:
250 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
251 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
252 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900253
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900254 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
255 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
256 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900257
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900258 def test_newer_tls(self):
259 if not TSSLSocket._has_ssl_context:
260 # unittest.skip is not available for Python 2.6
261 print('skipping test_newer_tls')
262 return
263 if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
264 print('PROTOCOL_TLSv1_2 is not available')
265 else:
266 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
267 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
268 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900269
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900270 if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
271 print('PROTOCOL_TLSv1_1 is not available')
272 else:
273 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
274 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
275 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900276
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900277 if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
278 print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
279 else:
280 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
281 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
282 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900283
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900284 def test_ssl_context(self):
285 if not TSSLSocket._has_ssl_context:
286 # unittest.skip is not available for Python 2.6
287 print('skipping test_ssl_context')
288 return
289 server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
290 server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
Nobuaki Sukegawaf39f7db2016-02-04 15:09:41 +0900291 server_context.load_verify_locations(CLIENT_CA)
292 server_context.verify_mode = ssl.CERT_REQUIRED
293 server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900294
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900295 client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
296 client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
297 client_context.load_verify_locations(SERVER_CERT)
Nobuaki Sukegawaf39f7db2016-02-04 15:09:41 +0900298 client_context.verify_mode = ssl.CERT_REQUIRED
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900299 client = TSSLSocket('localhost', TEST_PORT, ssl_context=client_context)
Nobuaki Sukegawaf39f7db2016-02-04 15:09:41 +0900300
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900301 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900302
303if __name__ == '__main__':
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900304 # import logging
305 # logging.basicConfig(level=logging.DEBUG)
306 unittest.main()