THRIFT-4977: Allow loading OpenSSL certificates from memory
Client: cpp
This closes #1860.
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
index b413002..64f08dd 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
@@ -957,6 +957,28 @@
}
}
+void TSSLSocketFactory::loadCertificateFromBuffer(const char* aCertificate, const char* format) {
+ if (aCertificate == nullptr || format == nullptr) {
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "loadCertificate: either <path> or <format> is NULL");
+ }
+ if (strcmp(format, "PEM") == 0) {
+ BIO* mem = BIO_new(BIO_s_mem());
+ BIO_puts(mem, aCertificate);
+ X509* cert = PEM_read_bio_X509(mem, NULL, 0, NULL);
+ BIO_free(mem);
+
+ if (SSL_CTX_use_certificate(ctx_->get(), cert) == 0) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("SSL_CTX_use_certificate: " + errors);
+ }
+ } else {
+ throw TSSLException("Unsupported certificate format: " + string(format));
+ }
+}
+
void TSSLSocketFactory::loadPrivateKey(const char* path, const char* format) {
if (path == nullptr || format == nullptr) {
throw TTransportException(TTransportException::BAD_ARGS,
@@ -972,6 +994,28 @@
}
}
+void TSSLSocketFactory::loadPrivateKeyFromBuffer(const char* aPrivateKey, const char* format) {
+ if (aPrivateKey == nullptr || format == nullptr) {
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "loadPrivateKey: either <path> or <format> is NULL");
+ }
+ if (strcmp(format, "PEM") == 0) {
+ BIO* mem = BIO_new(BIO_s_mem());
+ BIO_puts(mem, aPrivateKey);
+ EVP_PKEY* cert = PEM_read_bio_PrivateKey(mem, nullptr, nullptr, nullptr);
+
+ BIO_free(mem);
+ if (SSL_CTX_use_PrivateKey(ctx_->get(), cert) == 0) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("SSL_CTX_use_PrivateKey: " + errors);
+ }
+ } else {
+ throw TSSLException("Unsupported certificate format: " + string(format));
+ }
+}
+
void TSSLSocketFactory::loadTrustedCertificates(const char* path, const char* capath) {
if (path == nullptr) {
throw TTransportException(TTransportException::BAD_ARGS,
@@ -985,6 +1029,39 @@
}
}
+void TSSLSocketFactory::loadTrustedCertificatesFromBuffer(const char* aCertificate, const char* aChain) {
+ if (aCertificate == nullptr) {
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "loadTrustedCertificates: aCertificate is empty");
+ }
+ X509_STORE* vX509Store = SSL_CTX_get_cert_store(ctx_->get());
+ BIO* mem = BIO_new(BIO_s_mem());
+ BIO_puts(mem, aCertificate);
+ X509* cert = PEM_read_bio_X509(mem, NULL, 0, NULL);
+ BIO_free(mem);
+
+ if (X509_STORE_add_cert(vX509Store, cert) == 0) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("X509_STORE_add_cert: " + errors);
+ }
+
+ if (aChain) {
+ mem = BIO_new(BIO_s_mem());
+ BIO_puts(mem, aChain);
+ cert = PEM_read_bio_X509(mem, NULL, 0, NULL);
+ BIO_free(mem);
+
+ if (SSL_CTX_add_extra_chain_cert(ctx_->get(), cert) == 0) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("X509_STORE_add_cert: " + errors);
+ }
+ }
+}
+
void TSSLSocketFactory::randomize() {
RAND_poll();
}
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.h b/lib/cpp/src/thrift/transport/TSSLSocket.h
index 87a9601..a78112c 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.h
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.h
@@ -261,6 +261,7 @@
* @param format Certificate file format
*/
virtual void loadCertificate(const char* path, const char* format = "PEM");
+ virtual void loadCertificateFromBuffer(const char* aCertificate, const char* format = "PEM");
/**
* Load private key.
*
@@ -268,12 +269,14 @@
* @param format Private key file format
*/
virtual void loadPrivateKey(const char* path, const char* format = "PEM");
+ virtual void loadPrivateKeyFromBuffer(const char* aPrivateKey, const char* format = "PEM");
/**
* Load trusted certificates from specified file.
*
* @param path Path to trusted certificate file
*/
virtual void loadTrustedCertificates(const char* path, const char* capath = nullptr);
+ virtual void loadTrustedCertificatesFromBuffer(const char* aCertificate, const char* aChain = nullptr);
/**
* Default randomize method.
*/
diff --git a/lib/cpp/test/CMakeLists.txt b/lib/cpp/test/CMakeLists.txt
index ef08dbc..fba15f6 100644
--- a/lib/cpp/test/CMakeLists.txt
+++ b/lib/cpp/test/CMakeLists.txt
@@ -334,6 +334,17 @@
endif ()
add_test(NAME SecurityTest COMMAND SecurityTest -- "${CMAKE_CURRENT_SOURCE_DIR}/../../../test/keys")
+add_executable(SecurityFromBufferTest SecurityFromBufferTest.cpp)
+target_link_libraries(SecurityFromBufferTest
+ testgencpp
+ ${Boost_LIBRARIES}
+)
+LINK_AGAINST_THRIFT_LIBRARY(SecurityFromBufferTest thrift)
+if (NOT MSVC AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin" AND NOT MINGW)
+target_link_libraries(SecurityFromBufferTest -lrt)
+endif ()
+add_test(NAME SecurityFromBufferTest COMMAND SecurityFromBufferTest -- "${CMAKE_CURRENT_SOURCE_DIR}/../../../test/keys")
+
endif()
if(WITH_QT5)
diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am
index 2a0b9e6..8399d9e 100755
--- a/lib/cpp/test/Makefile.am
+++ b/lib/cpp/test/Makefile.am
@@ -99,6 +99,7 @@
TInterruptTest \
TServerIntegrationTest \
SecurityTest \
+ SecurityFromBufferTest \
ZlibTest \
TFileTransportTest \
link_test \
@@ -174,6 +175,17 @@
$(BOOST_SYSTEM_LDADD) \
$(BOOST_THREAD_LDADD)
+SecurityFromBufferTest_SOURCES = \
+ SecurityFromBufferTest.cpp
+
+SecurityFromBufferTest_LDADD = \
+ libtestgencpp.la \
+ libprocessortest.la \
+ $(BOOST_TEST_LDADD) \
+ $(BOOST_FILESYSTEM_LDADD) \
+ $(BOOST_SYSTEM_LDADD) \
+ $(BOOST_THREAD_LDADD)
+
TransportTest_SOURCES = \
TransportTest.cpp
diff --git a/lib/cpp/test/SecurityFromBufferTest.cpp b/lib/cpp/test/SecurityFromBufferTest.cpp
new file mode 100644
index 0000000..72a4c2a
--- /dev/null
+++ b/lib/cpp/test/SecurityFromBufferTest.cpp
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#define BOOST_TEST_MODULE SecurityFromBufferTest
+#include <boost/filesystem.hpp>
+#include <boost/foreach.hpp>
+#include <boost/format.hpp>
+#include <boost/test/unit_test.hpp>
+#include <boost/thread.hpp>
+#include <stdexcept>
+#include <fstream>
+#include <memory>
+#include <thrift/transport/TSSLServerSocket.h>
+#include <thrift/transport/TSSLSocket.h>
+#include <thrift/transport/TTransport.h>
+#include <vector>
+#ifdef __linux__
+#include <signal.h>
+#endif
+
+using apache::thrift::transport::TServerTransport;
+using apache::thrift::transport::TSSLServerSocket;
+using apache::thrift::transport::TSSLSocket;
+using apache::thrift::transport::TSSLSocketFactory;
+using apache::thrift::transport::TTransport;
+using apache::thrift::transport::TTransportException;
+using apache::thrift::transport::TTransportFactory;
+
+using std::bind;
+using std::shared_ptr;
+
+boost::filesystem::path keyDir;
+boost::filesystem::path certFile(const std::string& filename) {
+ return keyDir / filename;
+}
+std::string certString(const std::string& filename) {
+ std::ifstream ifs(certFile(filename).string());
+ if(!ifs.is_open() || !ifs.good()) {
+ throw(std::runtime_error("Failed to open key file " + filename + " for reading"));
+ }
+ std::stringstream buffer;
+ buffer << ifs.rdbuf();
+ return buffer.str();
+}
+boost::mutex gMutex;
+
+struct GlobalFixture {
+ GlobalFixture() {
+ using namespace boost::unit_test::framework;
+ for (int i = 0; i < master_test_suite().argc; ++i) {
+ BOOST_TEST_MESSAGE(boost::format("argv[%1%] = \"%2%\"") % i % master_test_suite().argv[i]);
+ }
+
+#ifdef __linux__
+ // OpenSSL calls send() without MSG_NOSIGPIPE so writing to a socket that has
+ // disconnected can cause a SIGPIPE signal...
+ signal(SIGPIPE, SIG_IGN);
+#endif
+
+ TSSLSocketFactory::setManualOpenSSLInitialization(true);
+ apache::thrift::transport::initializeOpenSSL();
+
+ keyDir = boost::filesystem::current_path().parent_path().parent_path().parent_path() / "test" / "keys";
+ if (!boost::filesystem::exists(certFile("server.crt"))) {
+ keyDir = boost::filesystem::path(master_test_suite().argv[master_test_suite().argc - 1]);
+ if (!boost::filesystem::exists(certFile("server.crt"))) {
+ throw std::invalid_argument("The last argument to this test must be the directory containing the test certificate(s).");
+ }
+ }
+ }
+
+ virtual ~GlobalFixture() {
+ apache::thrift::transport::cleanupOpenSSL();
+#ifdef __linux__
+ signal(SIGPIPE, SIG_DFL);
+#endif
+ }
+};
+
+#if (BOOST_VERSION >= 105900)
+BOOST_GLOBAL_FIXTURE(GlobalFixture);
+#else
+BOOST_GLOBAL_FIXTURE(GlobalFixture)
+#endif
+
+struct SecurityFromBufferFixture {
+ void server(apache::thrift::transport::SSLProtocol protocol) {
+ try {
+ boost::mutex::scoped_lock lock(mMutex);
+
+ shared_ptr<TSSLSocketFactory> pServerSocketFactory;
+ shared_ptr<TSSLServerSocket> pServerSocket;
+
+ pServerSocketFactory.reset(new TSSLSocketFactory(static_cast<apache::thrift::transport::SSLProtocol>(protocol)));
+ pServerSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+ pServerSocketFactory->loadCertificateFromBuffer(certString("server.crt").c_str());
+ pServerSocketFactory->loadPrivateKeyFromBuffer(certString("server.key").c_str());
+ pServerSocketFactory->server(true);
+ pServerSocket.reset(new TSSLServerSocket("localhost", 0, pServerSocketFactory));
+ shared_ptr<TTransport> connectedClient;
+
+ try {
+ pServerSocket->listen();
+ mPort = pServerSocket->getPort();
+ mCVar.notify_one();
+ lock.unlock();
+
+ connectedClient = pServerSocket->accept();
+ uint8_t buf[2];
+ buf[0] = 'O';
+ buf[1] = 'K';
+ connectedClient->write(&buf[0], 2);
+ connectedClient->flush();
+ }
+
+ catch (apache::thrift::transport::TTransportException& ex) {
+ boost::mutex::scoped_lock lock(gMutex);
+ BOOST_TEST_MESSAGE(boost::format("SRV %1% Exception: %2%") % boost::this_thread::get_id() % ex.what());
+ }
+
+ if (connectedClient) {
+ connectedClient->close();
+ connectedClient.reset();
+ }
+
+ pServerSocket->close();
+ pServerSocket.reset();
+ } catch (std::exception& ex) {
+ BOOST_FAIL(boost::format("%1%: %2%") % typeid(ex).name() % ex.what());
+ }
+ }
+
+ void client(apache::thrift::transport::SSLProtocol protocol) {
+ try {
+ shared_ptr<TSSLSocketFactory> pClientSocketFactory;
+ shared_ptr<TSSLSocket> pClientSocket;
+
+ try {
+ pClientSocketFactory.reset(new TSSLSocketFactory(static_cast<apache::thrift::transport::SSLProtocol>(protocol)));
+ pClientSocketFactory->authenticate(true);
+ pClientSocketFactory->loadCertificateFromBuffer(certString("client.crt").c_str());
+ pClientSocketFactory->loadPrivateKeyFromBuffer(certString("client.key").c_str());
+ pClientSocketFactory->loadTrustedCertificatesFromBuffer(certString("CA.pem").c_str());
+ pClientSocket = pClientSocketFactory->createSocket("localhost", mPort);
+ pClientSocket->open();
+
+ uint8_t buf[3];
+ buf[0] = 0;
+ buf[1] = 0;
+ BOOST_CHECK_EQUAL(2, pClientSocket->read(&buf[0], 2));
+ BOOST_CHECK_EQUAL(0, memcmp(&buf[0], "OK", 2));
+ mConnected = true;
+ } catch (apache::thrift::transport::TTransportException& ex) {
+ boost::mutex::scoped_lock lock(gMutex);
+ BOOST_TEST_MESSAGE(boost::format("CLI %1% Exception: %2%") % boost::this_thread::get_id() % ex.what());
+ }
+
+ if (pClientSocket) {
+ pClientSocket->close();
+ pClientSocket.reset();
+ }
+ } catch (std::exception& ex) {
+ BOOST_FAIL(boost::format("%1%: %2%") % typeid(ex).name() % ex.what());
+ }
+ }
+
+ static const char* protocol2str(size_t protocol) {
+ static const char* strings[apache::thrift::transport::LATEST + 1]
+ = {"SSLTLS", "SSLv2", "SSLv3", "TLSv1_0", "TLSv1_1", "TLSv1_2"};
+ return strings[protocol];
+ }
+
+ boost::mutex mMutex;
+ boost::condition_variable mCVar;
+ int mPort;
+ bool mConnected;
+};
+
+BOOST_FIXTURE_TEST_SUITE(BOOST_TEST_MODULE, SecurityFromBufferFixture)
+
+BOOST_AUTO_TEST_CASE(ssl_security_matrix) {
+ try {
+ // matrix of connection success between client and server with different SSLProtocol selections
+ static_assert(apache::thrift::transport::LATEST == 5, "Mismatch in assumed number of ssl protocols");
+ bool matrix[apache::thrift::transport::LATEST + 1][apache::thrift::transport::LATEST + 1] =
+ {
+ // server = SSLTLS SSLv2 SSLv3 TLSv1_0 TLSv1_1 TLSv1_2
+ // client
+ /* SSLTLS */ { true, false, false, true, true, true },
+ /* SSLv2 */ { false, false, false, false, false, false },
+ /* SSLv3 */ { false, false, true, false, false, false },
+ /* TLSv1_0 */ { true, false, false, true, false, false },
+ /* TLSv1_1 */ { true, false, false, false, true, false },
+ /* TLSv1_2 */ { true, false, false, false, false, true }
+ };
+
+ for (size_t si = 0; si <= apache::thrift::transport::LATEST; ++si) {
+ for (size_t ci = 0; ci <= apache::thrift::transport::LATEST; ++ci) {
+ if (si == 1 || ci == 1) {
+ // Skip all SSLv2 cases - protocol not supported
+ continue;
+ }
+
+#ifdef OPENSSL_NO_SSL3
+ if (si == 2 || ci == 2) {
+ // Skip all SSLv3 cases - protocol not supported
+ continue;
+ }
+#endif
+
+ boost::mutex::scoped_lock lock(mMutex);
+
+ BOOST_TEST_MESSAGE(boost::format("TEST: Server = %1%, Client = %2%") % protocol2str(si)
+ % protocol2str(ci));
+
+ mConnected = false;
+ // thread_group manages the thread lifetime - ignore the return value of create_thread
+ boost::thread_group threads;
+ (void)threads.create_thread(bind(&SecurityFromBufferFixture::server, this,
+ static_cast<apache::thrift::transport::SSLProtocol>(si)));
+ mCVar.wait(lock); // wait for listen() to succeed
+ lock.unlock();
+ (void)threads.create_thread(bind(&SecurityFromBufferFixture::client, this,
+ static_cast<apache::thrift::transport::SSLProtocol>(ci)));
+ threads.join_all();
+
+ BOOST_CHECK_MESSAGE(mConnected == matrix[ci][si],
+ boost::format(" Server = %1%, Client = %2% expected mConnected == %3% but was %4%")
+ % protocol2str(si) % protocol2str(ci) % matrix[ci][si] % mConnected);
+ }
+ }
+ } catch (std::exception& ex) {
+ BOOST_FAIL(boost::format("%1%: %2%") % typeid(ex).name() % ex.what());
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END()