THRIFT-3118: add http (for non-ssl and for ssl) to the python cross tests
diff --git a/lib/py/src/server/THttpServer.py b/lib/py/src/server/THttpServer.py
index 1b501a7..85cf400 100644
--- a/lib/py/src/server/THttpServer.py
+++ b/lib/py/src/server/THttpServer.py
@@ -17,6 +17,8 @@
# under the License.
#
+import ssl
+
from six.moves import BaseHTTPServer
from thrift.server import TServer
@@ -47,11 +49,17 @@
server_address,
inputProtocolFactory,
outputProtocolFactory=None,
- server_class=BaseHTTPServer.HTTPServer):
- """Set up protocol factories and HTTP server.
+ server_class=BaseHTTPServer.HTTPServer,
+ **kwargs):
+ """Set up protocol factories and HTTP (or HTTPS) server.
See BaseHTTPServer for server_address.
See TServer for protocol factories.
+
+ To make a secure server, provide the named arguments:
+ * cafile - to validate clients [optional]
+ * cert_file - the server cert
+ * key_file - the server's key
"""
if outputProtocolFactory is None:
outputProtocolFactory = inputProtocolFactory
@@ -83,5 +91,16 @@
self.httpd = server_class(server_address, RequestHander)
+ if (kwargs.get('cafile') or kwargs.get('cert_file') or kwargs.get('key_file')):
+ context = ssl.create_default_context(cafile=kwargs.get('cafile'))
+ context.check_hostname = False
+ context.load_cert_chain(kwargs.get('cert_file'), kwargs.get('key_file'))
+ context.verify_mode = ssl.CERT_REQUIRED if kwargs.get('cafile') else ssl.CERT_NONE
+ self.httpd.socket = context.wrap_socket(self.httpd.socket, server_side=True)
+
def serve(self):
self.httpd.serve_forever()
+
+ def shutdown(self):
+ self.httpd.socket.close()
+ # self.httpd.shutdown() # hangs forever, python doesn't handle POLLNVAL properly!
diff --git a/lib/py/src/transport/THttpClient.py b/lib/py/src/transport/THttpClient.py
index fb33421..60ff226 100644
--- a/lib/py/src/transport/THttpClient.py
+++ b/lib/py/src/transport/THttpClient.py
@@ -20,6 +20,7 @@
from io import BytesIO
import os
import socket
+import ssl
import sys
import warnings
import base64
@@ -34,17 +35,20 @@
class THttpClient(TTransportBase):
"""Http implementation of TTransport base."""
- def __init__(self, uri_or_host, port=None, path=None):
- """THttpClient supports two different types constructor parameters.
+ def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, ssl_context=None):
+ """THttpClient supports two different types of construction:
THttpClient(host, port, path) - deprecated
- THttpClient(uri)
+ THttpClient(uri, [port=<n>, path=<s>, cafile=<filename>, cert_file=<filename>, key_file=<filename>, ssl_context=<context>])
- Only the second supports https.
+ Only the second supports https. To properly authenticate against the server,
+ provide the client's identity by specifying cert_file and key_file. To properly
+ authenticate the server, specify either cafile or ssl_context with a CA defined.
+ NOTE: if both cafile and ssl_context are defined, ssl_context will override cafile.
"""
if port is not None:
warnings.warn(
- "Please use the THttpClient('http://host:port/path') syntax",
+ "Please use the THttpClient('http{s}://host:port/path') constructor",
DeprecationWarning,
stacklevel=2)
self.host = uri_or_host
@@ -60,6 +64,9 @@
self.port = parsed.port or http_client.HTTP_PORT
elif self.scheme == 'https':
self.port = parsed.port or http_client.HTTPS_PORT
+ self.certfile = cert_file
+ self.keyfile = key_file
+ self.context = ssl.create_default_context(cafile=cafile) if (cafile and not ssl_context) else ssl_context
self.host = parsed.hostname
self.path = parsed.path
if parsed.query:
@@ -100,12 +107,17 @@
def open(self):
if self.scheme == 'http':
- self.__http = http_client.HTTPConnection(self.host, self.port)
+ self.__http = http_client.HTTPConnection(self.host, self.port,
+ timeout=self.__timeout)
elif self.scheme == 'https':
- self.__http = http_client.HTTPSConnection(self.host, self.port)
- if self.using_proxy():
- self.__http.set_tunnel(self.realhost, self.realport,
- {"Proxy-Authorization": self.proxy_auth})
+ self.__http = http_client.HTTPSConnection(self.host, self.port,
+ key_file=self.keyfile,
+ cert_file=self.certfile,
+ timeout=self.__timeout,
+ context=self.context)
+ if self.using_proxy():
+ self.__http.set_tunnel(self.realhost, self.realport,
+ {"Proxy-Authorization": self.proxy_auth})
def close(self):
self.__http.close()