blob: fe03961a2dce5f05cf2454d80fad8b8da20f5e1a [file] [log] [blame]
Nobuaki Sukegawaad835862015-12-23 23:32:09 +09001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import os
21import platform
22import select
23import ssl
24import sys
25import threading
26import time
27import unittest
28import warnings
29
30import _import_local_thrift
31from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket
32
33SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
34ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
35SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
36SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
37SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
38CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
39CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
40
41TEST_PORT = 23458
42TEST_ADDR = '/tmp/.thrift.domain.sock.%d' % TEST_PORT
Nobuaki Sukegawaf07b4a12016-02-01 23:44:02 +090043CONNECT_DELAY = 0.5
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090044CONNECT_TIMEOUT = 20.0
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090045TEST_CIPHERS = 'DES-CBC3-SHA'
46
47
48class ServerAcceptor(threading.Thread):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090049 def __init__(self, server):
50 super(ServerAcceptor, self).__init__()
51 self._server = server
52 self.client = None
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090053
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090054 def run(self):
55 self._server.listen()
56 self.client = self._server.accept()
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090057
58
59# Python 2.6 compat
60class AssertRaises(object):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090061 def __init__(self, expected):
62 self._expected = expected
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090063
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090064 def __enter__(self):
65 pass
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090066
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090067 def __exit__(self, exc_type, exc_value, traceback):
68 if not exc_type or not issubclass(exc_type, self._expected):
69 raise Exception('fail')
70 return True
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090071
72
73class TSSLSocketTest(unittest.TestCase):
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090074 def _assert_connection_failure(self, server, client):
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090075 acc = ServerAcceptor(server)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090076 try:
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090077 acc.start()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090078 time.sleep(CONNECT_DELAY / 2)
79 client.setTimeout(CONNECT_TIMEOUT / 2)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090080 with self._assert_raises(Exception):
81 client.open()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090082 select.select([], [client.handle], [], CONNECT_TIMEOUT / 2)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090083 # self.assertIsNone(acc.client)
84 self.assertTrue(acc.client is None)
85 finally:
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090086 client.close()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090087 if acc.client:
88 acc.client.close()
89 server.close()
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090090
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090091 def _assert_raises(self, exc):
92 if sys.hexversion >= 0x020700F0:
93 return self.assertRaises(exc)
94 else:
95 return AssertRaises(exc)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +090096
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090097 def _assert_connection_success(self, server, client):
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +090098 acc = ServerAcceptor(server)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +090099 try:
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900100 acc.start()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +0900101 time.sleep(CONNECT_DELAY)
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900102 client.setTimeout(CONNECT_TIMEOUT)
103 client.open()
104 select.select([], [client.handle], [], CONNECT_TIMEOUT)
105 # self.assertIsNotNone(acc.client)
106 self.assertTrue(acc.client is not None)
107 finally:
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900108 client.close()
Nobuaki Sukegawa25536ad2016-02-04 15:08:55 +0900109 if acc.client:
110 acc.client.close()
111 server.close()
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900112
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900113 # deprecated feature
114 def test_deprecation(self):
115 with warnings.catch_warnings(record=True) as w:
116 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
117 TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
118 self.assertEqual(len(w), 1)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900119
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900120 with warnings.catch_warnings(record=True) as w:
121 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
122 # Deprecated signature
123 # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
124 client = TSSLSocket('localhost', TEST_PORT, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
125 self.assertEqual(len(w), 7)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900126
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900127 with warnings.catch_warnings(record=True) as w:
128 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
129 # Deprecated signature
130 # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
131 server = TSSLServerSocket(None, TEST_PORT, SERVER_PEM, None, TEST_CIPHERS)
132 self.assertEqual(len(w), 3)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900133
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900134 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900135
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900136 # deprecated feature
137 def test_set_cert_reqs_by_validate(self):
138 c1 = TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
139 self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900140
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900141 c1 = TSSLSocket('localhost', TEST_PORT, validate=False)
142 self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900143
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900144 # deprecated feature
145 def test_set_validate_by_cert_reqs(self):
146 c1 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
147 self.assertFalse(c1.validate)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900148
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900149 c2 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
150 self.assertTrue(c2.validate)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900151
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900152 c3 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
153 self.assertTrue(c3.validate)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900154
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900155 def test_unix_domain_socket(self):
156 if platform.system() == 'Windows':
157 print('skipping test_unix_domain_socket')
158 return
159 server = TSSLServerSocket(unix_socket=TEST_ADDR, keyfile=SERVER_KEY, certfile=SERVER_CERT)
160 client = TSSLSocket(None, None, TEST_ADDR, cert_reqs=ssl.CERT_NONE)
161 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900162
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900163 def test_server_cert(self):
164 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
165 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
166 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900167
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900168 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
169 # server cert on in ca_certs
170 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
171 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900172
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900173 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
174 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
175 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900176
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900177 def test_set_server_cert(self):
178 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=CLIENT_CERT)
179 with self._assert_raises(Exception):
180 server.certfile = 'foo'
181 with self._assert_raises(Exception):
182 server.certfile = None
183 server.certfile = SERVER_CERT
184 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
185 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900186
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900187 def test_client_cert(self):
188 server = TSSLServerSocket(
189 port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
190 certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
191 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
192 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900193
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900194 def test_ciphers(self):
195 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
196 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
197 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900198
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900199 if not TSSLSocket._has_ciphers:
200 # unittest.skip is not available for Python 2.6
201 print('skipping test_ciphers')
202 return
203 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
204 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
205 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900206
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900207 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
208 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
209 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900210
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900211 def test_ssl2_and_ssl3_disabled(self):
212 if not hasattr(ssl, 'PROTOCOL_SSLv3'):
213 print('PROTOCOL_SSLv3 is not available')
214 else:
215 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
216 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
217 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900218
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900219 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
220 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
221 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900222
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900223 if not hasattr(ssl, 'PROTOCOL_SSLv2'):
224 print('PROTOCOL_SSLv2 is not available')
225 else:
226 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
227 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
228 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900229
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900230 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
231 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
232 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900233
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900234 def test_newer_tls(self):
235 if not TSSLSocket._has_ssl_context:
236 # unittest.skip is not available for Python 2.6
237 print('skipping test_newer_tls')
238 return
239 if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
240 print('PROTOCOL_TLSv1_2 is not available')
241 else:
242 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
243 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
244 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900245
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900246 if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
247 print('PROTOCOL_TLSv1_1 is not available')
248 else:
249 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
250 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
251 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900252
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900253 if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
254 print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
255 else:
256 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
257 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
258 self._assert_connection_failure(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900259
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900260 def test_ssl_context(self):
261 if not TSSLSocket._has_ssl_context:
262 # unittest.skip is not available for Python 2.6
263 print('skipping test_ssl_context')
264 return
265 server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
266 server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
267 server_context.load_verify_locations(CLIENT_CERT)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900268
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900269 client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
270 client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
271 client_context.load_verify_locations(SERVER_CERT)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900272
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900273 server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context)
274 client = TSSLSocket('localhost', TEST_PORT, ssl_context=client_context)
275 self._assert_connection_success(server, client)
Nobuaki Sukegawaad835862015-12-23 23:32:09 +0900276
277if __name__ == '__main__':
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900278 # import logging
279 # logging.basicConfig(level=logging.DEBUG)
280 unittest.main()