THRIFT-3599 Validate client IP address against cert's SubjectAltName
diff --git a/lib/py/test/test_sslsocket.py b/lib/py/test/test_sslsocket.py
index fe03961..98d47ae 100644
--- a/lib/py/test/test_sslsocket.py
+++ b/lib/py/test/test_sslsocket.py
@@ -35,8 +35,11 @@
SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
-CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
-CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
+CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
+CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
+CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt')
+CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key')
+CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
TEST_PORT = 23458
TEST_ADDR = '/tmp/.thrift.domain.sock.%d' % TEST_PORT
@@ -188,6 +191,24 @@
server = TSSLServerSocket(
port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
+ client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
+ self._assert_connection_failure(server, client)
+
+ server = TSSLServerSocket(
+ port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
+ certfile=SERVER_CERT, ca_certs=CLIENT_CA)
+ client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
+ self._assert_connection_failure(server, client)
+
+ server = TSSLServerSocket(
+ port=TEST_PORT, cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
+ certfile=SERVER_CERT, ca_certs=CLIENT_CA)
+ client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
+ self._assert_connection_success(server, client)
+
+ server = TSSLServerSocket(
+ port=TEST_PORT, cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY,
+ certfile=SERVER_CERT, ca_certs=CLIENT_CA)
client = TSSLSocket('localhost', TEST_PORT, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
self._assert_connection_success(server, client)
@@ -264,14 +285,16 @@
return
server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
- server_context.load_verify_locations(CLIENT_CERT)
+ server_context.load_verify_locations(CLIENT_CA)
+ server_context.verify_mode = ssl.CERT_REQUIRED
+ server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context)
client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
client_context.load_verify_locations(SERVER_CERT)
-
- server = TSSLServerSocket(port=TEST_PORT, ssl_context=server_context)
+ client_context.verify_mode = ssl.CERT_REQUIRED
client = TSSLSocket('localhost', TEST_PORT, ssl_context=client_context)
+
self._assert_connection_success(server, client)
if __name__ == '__main__':