blob: 98d47ae310f52fc9be6acee0b61e6e1f84be359c [file] [log] [blame]
Nobuaki Sukegawaad835862015-12-23 23:32:09 +09001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import os
21import platform
22import select
23import ssl
24import sys
25import threading
26import time
27import unittest
28import warnings
29
30import _import_local_thrift
31from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket
32
33SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
34ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
35SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
36SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
37SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
Nobuaki Sukegawaf39f7db2016-02-04 15:09:41 +090038CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
39CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
40CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt')
41CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key')
42CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090043
44TEST_PORT = 23458
45TEST_ADDR = '/tmp/.thrift.domain.sock.%d' % TEST_PORT
Nobuaki Sukegawaf07b4a12016-02-01 23:44:02 +090046CONNECT_DELAY = 0.5
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090047CONNECT_TIMEOUT = 20.0
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090048TEST_CIPHERS = 'DES-CBC3-SHA'
49
50
51class ServerAcceptor(threading.Thread):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090052 def __init__(self, server):
53 super(ServerAcceptor, self).__init__()
54 self._server = server
55 self.client = None
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090056
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090057 def run(self):
58 self._server.listen()
59 self.client = self._server.accept()
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090060
61
62# Python 2.6 compat
63class AssertRaises(object):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090064 def __init__(self, expected):
65 self._expected = expected
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090066
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090067 def __enter__(self):
68 pass
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090069
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090070 def __exit__(self, exc_type, exc_value, traceback):
71 if not exc_type or not issubclass(exc_type, self._expected):
72 raise Exception('fail')
73 return True
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090074
75
76class TSSLSocketTest(unittest.TestCase):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090077 def _assert_connection_failure(self, server, client):
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090078 acc = ServerAcceptor(server)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090079 try:
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090080 acc.start()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090081 time.sleep(CONNECT_DELAY / 2)
82 client.setTimeout(CONNECT_TIMEOUT / 2)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090083 with self._assert_raises(Exception):
84 client.open()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090085 select.select([], [client.handle], [], CONNECT_TIMEOUT / 2)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090086 # self.assertIsNone(acc.client)
87 self.assertTrue(acc.client is None)
88 finally:
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090089 client.close()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090090 if acc.client:
91 acc.client.close()
92 server.close()
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090093
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090094 def _assert_raises(self, exc):
95 if sys.hexversion >= 0x020700F0:
96 return self.assertRaises(exc)
97 else:
98 return AssertRaises(exc)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090099
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900100 def _assert_connection_success(self, server, client):
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +0900101 acc = ServerAcceptor(server)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900102 try:
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900103 acc.start()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +0900104 time.sleep(CONNECT_DELAY)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900105 client.setTimeout(CONNECT_TIMEOUT)
106 client.open()
107 select.select([], [client.handle], [], CONNECT_TIMEOUT)
108 # self.assertIsNotNone(acc.client)
109 self.assertTrue(acc.client is not None)
110 finally:
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900111 client.close()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +0900112 if acc.client:
113 acc.client.close()
114 server.close()
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900115
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900116 # deprecated feature
117 def test_deprecation(self):
118 with warnings.catch_warnings(record=True) as w:
119 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
120 TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
121 self.assertEqual(len(w), 1)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900122
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900123 with warnings.catch_warnings(record=True) as w:
124 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
125 # Deprecated signature
126 # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
127 client = TSSLSocket('localhost', TEST_PORT, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
128 self.assertEqual(len(w), 7)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900129
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900130 with warnings.catch_warnings(record=True) as w:
131 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
132 # Deprecated signature
133 # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
134 server = TSSLServerSocket(None, TEST_PORT, SERVER_PEM, None, TEST_CIPHERS)
135 self.assertEqual(len(w), 3)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900136
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900137 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900138
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900139 # deprecated feature
140 def test_set_cert_reqs_by_validate(self):
141 c1 = TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
142 self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900143
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900144 c1 = TSSLSocket('localhost', TEST_PORT, validate=False)
145 self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900146
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900147 # deprecated feature
148 def test_set_validate_by_cert_reqs(self):
149 c1 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
150 self.assertFalse(c1.validate)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900151
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900152 c2 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
153 self.assertTrue(c2.validate)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900154
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900155 c3 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
156 self.assertTrue(c3.validate)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900157
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900158 def test_unix_domain_socket(self):
159 if platform.system() == 'Windows':
160 print('skipping test_unix_domain_socket')
161 return
162 server = TSSLServerSocket(unix_socket=TEST_ADDR, keyfile=SERVER_KEY, certfile=SERVER_CERT)
163 client = TSSLSocket(None, None, TEST_ADDR, cert_reqs=ssl.CERT_NONE)
164 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900165
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900166 def test_server_cert(self):
167 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
168 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
169 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900170
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900171 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
172 # server cert on in ca_certs
173 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
174 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900175
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900176 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
177 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
178 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900179
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900180 def test_set_server_cert(self):
181 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=CLIENT_CERT)
182 with self._assert_raises(Exception):
183 server.certfile = 'foo'
184 with self._assert_raises(Exception):
185 server.certfile = None
186 server.certfile = SERVER_CERT
187 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
188 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900189
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900190 def test_client_cert(self):
191 server = TSSLServerSocket(
192 port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
193 certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
Nobuaki Sukegawaf39f7db2016-02-04 15:09:41 +0900194 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
195 self._assert_connection_failure(server, client)
196
197 server = TSSLServerSocket(
198 port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
199 certfile=SERVER_CERT, ca_certs=CLIENT_CA)
200 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
201 self._assert_connection_failure(server, client)
202
203 server = TSSLServerSocket(
204 port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
205 certfile=SERVER_CERT, ca_certs=CLIENT_CA)
206 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
207 self._assert_connection_success(server, client)
208
209 server = TSSLServerSocket(
210 port=TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY,
211 certfile=SERVER_CERT, ca_certs=CLIENT_CA)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900212 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
213 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900214
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900215 def test_ciphers(self):
216 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
217 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
218 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900219
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900220 if not TSSLSocket._has_ciphers:
221 # unittest.skip is not available for Python 2.6
222 print('skipping test_ciphers')
223 return
224 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
225 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
226 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900227
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900228 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
229 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
230 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900231
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900232 def test_ssl2_and_ssl3_disabled(self):
233 if not hasattr(ssl, 'PROTOCOL_SSLv3'):
234 print('PROTOCOL_SSLv3 is not available')
235 else:
236 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
237 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
238 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900239
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900240 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
241 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
242 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900243
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900244 if not hasattr(ssl, 'PROTOCOL_SSLv2'):
245 print('PROTOCOL_SSLv2 is not available')
246 else:
247 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
248 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
249 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900250
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900251 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
252 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
253 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900254
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900255 def test_newer_tls(self):
256 if not TSSLSocket._has_ssl_context:
257 # unittest.skip is not available for Python 2.6
258 print('skipping test_newer_tls')
259 return
260 if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
261 print('PROTOCOL_TLSv1_2 is not available')
262 else:
263 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
264 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
265 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900266
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900267 if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
268 print('PROTOCOL_TLSv1_1 is not available')
269 else:
270 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
271 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
272 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900273
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900274 if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
275 print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
276 else:
277 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
278 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
279 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900280
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900281 def test_ssl_context(self):
282 if not TSSLSocket._has_ssl_context:
283 # unittest.skip is not available for Python 2.6
284 print('skipping test_ssl_context')
285 return
286 server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
287 server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
Nobuaki Sukegawaf39f7db2016-02-04 15:09:41 +0900288 server_context.load_verify_locations(CLIENT_CA)
289 server_context.verify_mode = ssl.CERT_REQUIRED
290 server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900291
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900292 client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
293 client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
294 client_context.load_verify_locations(SERVER_CERT)
Nobuaki Sukegawaf39f7db2016-02-04 15:09:41 +0900295 client_context.verify_mode = ssl.CERT_REQUIRED
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900296 client = TSSLSocket('localhost', TEST_PORT, ssl_context=client_context)
Nobuaki Sukegawaf39f7db2016-02-04 15:09:41 +0900297
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900298 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900299
300if __name__ == '__main__':
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900301 # import logging
302 # logging.basicConfig(level=logging.DEBUG)
303 unittest.main()