blob: bd2a70aef0ef6908dc0f4095f6db79945d740949 [file] [log] [blame]
Nobuaki Sukegawaad835862015-12-23 23:32:09 +09001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import os
21import platform
22import select
23import ssl
24import sys
25import threading
26import time
27import unittest
28import warnings
29
30import _import_local_thrift
31from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket
32
33SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
34ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
35SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
36SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
37SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
38CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
39CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
40
41TEST_PORT = 23458
42TEST_ADDR = '/tmp/.thrift.domain.sock.%d' % TEST_PORT
43CONNECT_TIMEOUT = 10.0
44TEST_CIPHERS = 'DES-CBC3-SHA'
45
46
47class ServerAcceptor(threading.Thread):
48 def __init__(self, server):
49 super(ServerAcceptor, self).__init__()
50 self._server = server
51 self.client = None
52
53 def run(self):
54 self._server.listen()
55 self.client = self._server.accept()
56
57
58# Python 2.6 compat
59class AssertRaises(object):
60 def __init__(self, expected):
61 self._expected = expected
62
63 def __enter__(self):
64 pass
65
66 def __exit__(self, exc_type, exc_value, traceback):
67 if not exc_type or not issubclass(exc_type, self._expected):
68 raise Exception('fail')
69 return True
70
71
72class TSSLSocketTest(unittest.TestCase):
73 def _assert_connection_failure(self, server, client):
74 try:
75 acc = ServerAcceptor(server)
76 acc.start()
77 time.sleep(0.15)
78 client.setTimeout(CONNECT_TIMEOUT)
79 with self._assert_raises(Exception):
80 client.open()
81 select.select([], [client.handle], [], CONNECT_TIMEOUT)
82 # self.assertIsNone(acc.client)
83 self.assertTrue(acc.client is None)
84 finally:
85 server.close()
86 client.close()
87
88 def _assert_raises(self, exc):
89 if sys.hexversion >= 0x020700F0:
90 return self.assertRaises(exc)
91 else:
92 return AssertRaises(exc)
93
94 def _assert_connection_success(self, server, client):
95 try:
96 acc = ServerAcceptor(server)
97 acc.start()
98 time.sleep(0.15)
99 client.setTimeout(CONNECT_TIMEOUT)
100 client.open()
101 select.select([], [client.handle], [], CONNECT_TIMEOUT)
102 # self.assertIsNotNone(acc.client)
103 self.assertTrue(acc.client is not None)
104 finally:
105 server.close()
106 client.close()
107
108 # deprecated feature
109 def test_deprecation(self):
110 with warnings.catch_warnings(record=True) as w:
111 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
112 TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
113 self.assertEqual(len(w), 1)
114
115 with warnings.catch_warnings(record=True) as w:
116 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
117 # Deprecated signature
118 # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
119 client = TSSLSocket('localhost', TEST_PORT, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
120 self.assertEqual(len(w), 7)
121
122 with warnings.catch_warnings(record=True) as w:
123 warnings.filterwarnings('always', category=DeprecationWarning, module='thrift.*SSL.*')
124 # Deprecated signature
125 # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
126 server = TSSLServerSocket(None, TEST_PORT, SERVER_PEM, None, TEST_CIPHERS)
127 self.assertEqual(len(w), 3)
128
129 self._assert_connection_success(server, client)
130
131 # deprecated feature
132 def test_set_cert_reqs_by_validate(self):
133 c1 = TSSLSocket('localhost', TEST_PORT, validate=True, ca_certs=SERVER_CERT)
134 self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
135
136 c1 = TSSLSocket('localhost', TEST_PORT, validate=False)
137 self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
138
139 # deprecated feature
140 def test_set_validate_by_cert_reqs(self):
141 c1 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
142 self.assertFalse(c1.validate)
143
144 c2 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
145 self.assertTrue(c2.validate)
146
147 c3 = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
148 self.assertTrue(c3.validate)
149
150 def test_unix_domain_socket(self):
151 if platform.system() == 'Windows':
152 print('skipping test_unix_domain_socket')
153 return
154 server = TSSLServerSocket(unix_socket=TEST_ADDR, keyfile=SERVER_KEY, certfile=SERVER_CERT)
155 client = TSSLSocket(None, None, TEST_ADDR, cert_reqs=ssl.CERT_NONE)
156 self._assert_connection_success(server, client)
157
158 def test_server_cert(self):
159 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
160 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
161 self._assert_connection_success(server, client)
162
163 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
164 # server cert on in ca_certs
165 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
166 self._assert_connection_failure(server, client)
167
168 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
169 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE)
170 self._assert_connection_success(server, client)
171
172 def test_set_server_cert(self):
173 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=CLIENT_CERT)
174 with self._assert_raises(Exception):
175 server.certfile = 'foo'
176 with self._assert_raises(Exception):
177 server.certfile = None
178 server.certfile = SERVER_CERT
179 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
180 self._assert_connection_success(server, client)
181
182 def test_client_cert(self):
183 server = TSSLServerSocket(
184 port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
185 certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
186 client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
187 self._assert_connection_success(server, client)
188
189 def test_ciphers(self):
190 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
191 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
192 self._assert_connection_success(server, client)
193
194 if not TSSLSocket._has_ciphers:
195 # unittest.skip is not available for Python 2.6
196 print('skipping test_ciphers')
197 return
198 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
199 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
200 self._assert_connection_failure(server, client)
201
202 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
203 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ciphers='NULL')
204 self._assert_connection_failure(server, client)
205
206 def test_ssl2_and_ssl3_disabled(self):
207 if not hasattr(ssl, 'PROTOCOL_SSLv3'):
208 print('PROTOCOL_SSLv3 is not available')
209 else:
210 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
211 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
212 self._assert_connection_failure(server, client)
213
214 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
215 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
216 self._assert_connection_failure(server, client)
217
218 if not hasattr(ssl, 'PROTOCOL_SSLv2'):
219 print('PROTOCOL_SSLv2 is not available')
220 else:
221 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT)
222 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
223 self._assert_connection_failure(server, client)
224
225 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
226 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT)
227 self._assert_connection_failure(server, client)
228
229 def test_newer_tls(self):
230 if not TSSLSocket._has_ssl_context:
231 # unittest.skip is not available for Python 2.6
232 print('skipping test_newer_tls')
233 return
234 if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
235 print('PROTOCOL_TLSv1_2 is not available')
236 else:
237 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
238 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
239 self._assert_connection_success(server, client)
240
241 if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
242 print('PROTOCOL_TLSv1_1 is not available')
243 else:
244 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
245 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
246 self._assert_connection_success(server, client)
247
248 if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
249 print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
250 else:
251 server = TSSLServerSocket(port=TEST_PORT, keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
252 client = TSSLSocket('localhost', TEST_PORT, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
253 self._assert_connection_failure(server, client)
254
255 def test_ssl_context(self):
256 if not TSSLSocket._has_ssl_context:
257 # unittest.skip is not available for Python 2.6
258 print('skipping test_ssl_context')
259 return
260 server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
261 server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
262 server_context.load_verify_locations(CLIENT_CERT)
263
264 client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
265 client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
266 client_context.load_verify_locations(SERVER_CERT)
267
268 server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context)
269 client = TSSLSocket('localhost', TEST_PORT, ssl_context=client_context)
270 self._assert_connection_success(server, client)
271
272if __name__ == '__main__':
273 # import logging
274 # logging.basicConfig(level=logging.DEBUG)
275 unittest.main()