THRIFT-4977: Allow loading OpenSSL certificates from memory

Client: cpp

This closes #1860.
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
index b413002..64f08dd 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
@@ -957,6 +957,28 @@
   }
 }
 
+void TSSLSocketFactory::loadCertificateFromBuffer(const char* aCertificate, const char* format) {
+  if (aCertificate == nullptr || format == nullptr) {
+    throw TTransportException(TTransportException::BAD_ARGS,
+                              "loadCertificate: either <path> or <format> is NULL");
+  }
+  if (strcmp(format, "PEM") == 0) {
+    BIO* mem = BIO_new(BIO_s_mem());
+    BIO_puts(mem, aCertificate);
+    X509* cert = PEM_read_bio_X509(mem, NULL, 0, NULL);
+    BIO_free(mem);
+
+    if (SSL_CTX_use_certificate(ctx_->get(), cert) == 0) {
+      int errno_copy = THRIFT_GET_SOCKET_ERROR;
+      string errors;
+      buildErrors(errors, errno_copy);
+      throw TSSLException("SSL_CTX_use_certificate: " + errors);
+    }
+  } else {
+    throw TSSLException("Unsupported certificate format: " + string(format));
+  }
+}
+
 void TSSLSocketFactory::loadPrivateKey(const char* path, const char* format) {
   if (path == nullptr || format == nullptr) {
     throw TTransportException(TTransportException::BAD_ARGS,
@@ -972,6 +994,28 @@
   }
 }
 
+void TSSLSocketFactory::loadPrivateKeyFromBuffer(const char* aPrivateKey, const char* format) {
+  if (aPrivateKey == nullptr || format == nullptr) {
+    throw TTransportException(TTransportException::BAD_ARGS,
+                              "loadPrivateKey: either <path> or <format> is NULL");
+  }
+  if (strcmp(format, "PEM") == 0) {
+    BIO* mem = BIO_new(BIO_s_mem());
+    BIO_puts(mem, aPrivateKey);
+    EVP_PKEY* cert = PEM_read_bio_PrivateKey(mem, nullptr, nullptr, nullptr);
+
+    BIO_free(mem);
+    if (SSL_CTX_use_PrivateKey(ctx_->get(), cert) == 0) {
+      int errno_copy = THRIFT_GET_SOCKET_ERROR;
+      string errors;
+      buildErrors(errors, errno_copy);
+      throw TSSLException("SSL_CTX_use_PrivateKey: " + errors);
+    }
+  } else {
+    throw TSSLException("Unsupported certificate format: " + string(format));
+  }
+}
+
 void TSSLSocketFactory::loadTrustedCertificates(const char* path, const char* capath) {
   if (path == nullptr) {
     throw TTransportException(TTransportException::BAD_ARGS,
@@ -985,6 +1029,39 @@
   }
 }
 
+void TSSLSocketFactory::loadTrustedCertificatesFromBuffer(const char* aCertificate, const char* aChain) {
+  if (aCertificate == nullptr) {
+    throw TTransportException(TTransportException::BAD_ARGS,
+                              "loadTrustedCertificates: aCertificate is empty");
+  }
+  X509_STORE* vX509Store = SSL_CTX_get_cert_store(ctx_->get());
+  BIO* mem = BIO_new(BIO_s_mem());
+  BIO_puts(mem, aCertificate);
+  X509* cert = PEM_read_bio_X509(mem, NULL, 0, NULL);
+  BIO_free(mem);
+
+  if (X509_STORE_add_cert(vX509Store, cert) == 0) {
+    int errno_copy = THRIFT_GET_SOCKET_ERROR;
+    string errors;
+    buildErrors(errors, errno_copy);
+    throw TSSLException("X509_STORE_add_cert: " + errors);
+  }
+
+  if (aChain) {
+    mem = BIO_new(BIO_s_mem());
+    BIO_puts(mem, aChain);
+    cert = PEM_read_bio_X509(mem, NULL, 0, NULL);
+    BIO_free(mem);
+
+    if (SSL_CTX_add_extra_chain_cert(ctx_->get(), cert) == 0) {
+      int errno_copy = THRIFT_GET_SOCKET_ERROR;
+      string errors;
+      buildErrors(errors, errno_copy);
+      throw TSSLException("X509_STORE_add_cert: " + errors);
+    }
+  }
+}
+
 void TSSLSocketFactory::randomize() {
   RAND_poll();
 }
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.h b/lib/cpp/src/thrift/transport/TSSLSocket.h
index 87a9601..a78112c 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.h
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.h
@@ -261,6 +261,7 @@
    * @param format Certificate file format
    */
   virtual void loadCertificate(const char* path, const char* format = "PEM");
+  virtual void loadCertificateFromBuffer(const char* aCertificate, const char* format = "PEM");
   /**
    * Load private key.
    *
@@ -268,12 +269,14 @@
    * @param format Private key file format
    */
   virtual void loadPrivateKey(const char* path, const char* format = "PEM");
+  virtual void loadPrivateKeyFromBuffer(const char* aPrivateKey, const char* format = "PEM");
   /**
    * Load trusted certificates from specified file.
    *
    * @param path Path to trusted certificate file
    */
   virtual void loadTrustedCertificates(const char* path, const char* capath = nullptr);
+  virtual void loadTrustedCertificatesFromBuffer(const char* aCertificate, const char* aChain = nullptr);
   /**
    * Default randomize method.
    */
diff --git a/lib/cpp/test/CMakeLists.txt b/lib/cpp/test/CMakeLists.txt
index ef08dbc..fba15f6 100644
--- a/lib/cpp/test/CMakeLists.txt
+++ b/lib/cpp/test/CMakeLists.txt
@@ -334,6 +334,17 @@
 endif ()
 add_test(NAME SecurityTest COMMAND SecurityTest -- "${CMAKE_CURRENT_SOURCE_DIR}/../../../test/keys")
 
+add_executable(SecurityFromBufferTest SecurityFromBufferTest.cpp)
+target_link_libraries(SecurityFromBufferTest
+    testgencpp
+    ${Boost_LIBRARIES}
+)
+LINK_AGAINST_THRIFT_LIBRARY(SecurityFromBufferTest thrift)
+if (NOT MSVC AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin" AND NOT MINGW)
+target_link_libraries(SecurityFromBufferTest -lrt)
+endif ()
+add_test(NAME SecurityFromBufferTest COMMAND SecurityFromBufferTest -- "${CMAKE_CURRENT_SOURCE_DIR}/../../../test/keys")
+
 endif()
 
 if(WITH_QT5)
diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am
index 2a0b9e6..8399d9e 100755
--- a/lib/cpp/test/Makefile.am
+++ b/lib/cpp/test/Makefile.am
@@ -99,6 +99,7 @@
 	TInterruptTest \
 	TServerIntegrationTest \
 	SecurityTest \
+	SecurityFromBufferTest \
 	ZlibTest \
 	TFileTransportTest \
 	link_test \
@@ -174,6 +175,17 @@
   $(BOOST_SYSTEM_LDADD) \
   $(BOOST_THREAD_LDADD)
 
+SecurityFromBufferTest_SOURCES = \
+	SecurityFromBufferTest.cpp
+
+SecurityFromBufferTest_LDADD = \
+  libtestgencpp.la \
+  libprocessortest.la \
+  $(BOOST_TEST_LDADD) \
+  $(BOOST_FILESYSTEM_LDADD) \
+  $(BOOST_SYSTEM_LDADD) \
+  $(BOOST_THREAD_LDADD)
+
 TransportTest_SOURCES = \
 	TransportTest.cpp
 
diff --git a/lib/cpp/test/SecurityFromBufferTest.cpp b/lib/cpp/test/SecurityFromBufferTest.cpp
new file mode 100644
index 0000000..72a4c2a
--- /dev/null
+++ b/lib/cpp/test/SecurityFromBufferTest.cpp
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#define BOOST_TEST_MODULE SecurityFromBufferTest
+#include <boost/filesystem.hpp>
+#include <boost/foreach.hpp>
+#include <boost/format.hpp>
+#include <boost/test/unit_test.hpp>
+#include <boost/thread.hpp>
+#include <stdexcept>
+#include <fstream>
+#include <memory>
+#include <thrift/transport/TSSLServerSocket.h>
+#include <thrift/transport/TSSLSocket.h>
+#include <thrift/transport/TTransport.h>
+#include <vector>
+#ifdef __linux__
+#include <signal.h>
+#endif
+
+using apache::thrift::transport::TServerTransport;
+using apache::thrift::transport::TSSLServerSocket;
+using apache::thrift::transport::TSSLSocket;
+using apache::thrift::transport::TSSLSocketFactory;
+using apache::thrift::transport::TTransport;
+using apache::thrift::transport::TTransportException;
+using apache::thrift::transport::TTransportFactory;
+
+using std::bind;
+using std::shared_ptr;
+
+boost::filesystem::path keyDir;
+boost::filesystem::path certFile(const std::string& filename) {
+  return keyDir / filename;
+}
+std::string certString(const std::string& filename) {
+  std::ifstream ifs(certFile(filename).string());
+  if(!ifs.is_open() || !ifs.good()) {
+    throw(std::runtime_error("Failed to open key file " + filename + " for reading"));
+  }
+  std::stringstream buffer;
+  buffer << ifs.rdbuf();
+  return buffer.str();
+}
+boost::mutex gMutex;
+
+struct GlobalFixture {
+  GlobalFixture() {
+    using namespace boost::unit_test::framework;
+    for (int i = 0; i < master_test_suite().argc; ++i) {
+      BOOST_TEST_MESSAGE(boost::format("argv[%1%] = \"%2%\"") % i % master_test_suite().argv[i]);
+    }
+
+#ifdef __linux__
+    // OpenSSL calls send() without MSG_NOSIGPIPE so writing to a socket that has
+    // disconnected can cause a SIGPIPE signal...
+    signal(SIGPIPE, SIG_IGN);
+#endif
+
+    TSSLSocketFactory::setManualOpenSSLInitialization(true);
+    apache::thrift::transport::initializeOpenSSL();
+
+    keyDir = boost::filesystem::current_path().parent_path().parent_path().parent_path() / "test" / "keys";
+    if (!boost::filesystem::exists(certFile("server.crt"))) {
+      keyDir = boost::filesystem::path(master_test_suite().argv[master_test_suite().argc - 1]);
+      if (!boost::filesystem::exists(certFile("server.crt"))) {
+        throw std::invalid_argument("The last argument to this test must be the directory containing the test certificate(s).");
+      }
+    }
+  }
+
+  virtual ~GlobalFixture() {
+    apache::thrift::transport::cleanupOpenSSL();
+#ifdef __linux__
+    signal(SIGPIPE, SIG_DFL);
+#endif
+  }
+};
+
+#if (BOOST_VERSION >= 105900)
+BOOST_GLOBAL_FIXTURE(GlobalFixture);
+#else
+BOOST_GLOBAL_FIXTURE(GlobalFixture)
+#endif
+
+struct SecurityFromBufferFixture {
+  void server(apache::thrift::transport::SSLProtocol protocol) {
+    try {
+      boost::mutex::scoped_lock lock(mMutex);
+
+      shared_ptr<TSSLSocketFactory> pServerSocketFactory;
+      shared_ptr<TSSLServerSocket> pServerSocket;
+
+      pServerSocketFactory.reset(new TSSLSocketFactory(static_cast<apache::thrift::transport::SSLProtocol>(protocol)));
+      pServerSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+      pServerSocketFactory->loadCertificateFromBuffer(certString("server.crt").c_str());
+      pServerSocketFactory->loadPrivateKeyFromBuffer(certString("server.key").c_str());
+      pServerSocketFactory->server(true);
+      pServerSocket.reset(new TSSLServerSocket("localhost", 0, pServerSocketFactory));
+      shared_ptr<TTransport> connectedClient;
+
+      try {
+        pServerSocket->listen();
+        mPort = pServerSocket->getPort();
+        mCVar.notify_one();
+        lock.unlock();
+
+        connectedClient = pServerSocket->accept();
+        uint8_t buf[2];
+        buf[0] = 'O';
+        buf[1] = 'K';
+        connectedClient->write(&buf[0], 2);
+        connectedClient->flush();
+      }
+
+      catch (apache::thrift::transport::TTransportException& ex) {
+        boost::mutex::scoped_lock lock(gMutex);
+        BOOST_TEST_MESSAGE(boost::format("SRV %1% Exception: %2%") % boost::this_thread::get_id() % ex.what());
+      }
+
+      if (connectedClient) {
+        connectedClient->close();
+        connectedClient.reset();
+      }
+
+      pServerSocket->close();
+      pServerSocket.reset();
+    } catch (std::exception& ex) {
+      BOOST_FAIL(boost::format("%1%: %2%") % typeid(ex).name() % ex.what());
+    }
+  }
+
+  void client(apache::thrift::transport::SSLProtocol protocol) {
+    try {
+      shared_ptr<TSSLSocketFactory> pClientSocketFactory;
+      shared_ptr<TSSLSocket> pClientSocket;
+
+      try {
+        pClientSocketFactory.reset(new TSSLSocketFactory(static_cast<apache::thrift::transport::SSLProtocol>(protocol)));
+        pClientSocketFactory->authenticate(true);
+        pClientSocketFactory->loadCertificateFromBuffer(certString("client.crt").c_str());
+        pClientSocketFactory->loadPrivateKeyFromBuffer(certString("client.key").c_str());
+        pClientSocketFactory->loadTrustedCertificatesFromBuffer(certString("CA.pem").c_str());
+        pClientSocket = pClientSocketFactory->createSocket("localhost", mPort);
+        pClientSocket->open();
+
+        uint8_t buf[3];
+        buf[0] = 0;
+        buf[1] = 0;
+        BOOST_CHECK_EQUAL(2, pClientSocket->read(&buf[0], 2));
+        BOOST_CHECK_EQUAL(0, memcmp(&buf[0], "OK", 2));
+        mConnected = true;
+      } catch (apache::thrift::transport::TTransportException& ex) {
+        boost::mutex::scoped_lock lock(gMutex);
+        BOOST_TEST_MESSAGE(boost::format("CLI %1% Exception: %2%") % boost::this_thread::get_id() % ex.what());
+      }
+
+      if (pClientSocket) {
+        pClientSocket->close();
+        pClientSocket.reset();
+      }
+    } catch (std::exception& ex) {
+      BOOST_FAIL(boost::format("%1%: %2%") % typeid(ex).name() % ex.what());
+    }
+  }
+
+  static const char* protocol2str(size_t protocol) {
+    static const char* strings[apache::thrift::transport::LATEST + 1]
+        = {"SSLTLS", "SSLv2", "SSLv3", "TLSv1_0", "TLSv1_1", "TLSv1_2"};
+    return strings[protocol];
+  }
+
+  boost::mutex mMutex;
+  boost::condition_variable mCVar;
+  int mPort;
+  bool mConnected;
+};
+
+BOOST_FIXTURE_TEST_SUITE(BOOST_TEST_MODULE, SecurityFromBufferFixture)
+
+BOOST_AUTO_TEST_CASE(ssl_security_matrix) {
+  try {
+    // matrix of connection success between client and server with different SSLProtocol selections
+        static_assert(apache::thrift::transport::LATEST == 5, "Mismatch in assumed number of ssl protocols");
+        bool matrix[apache::thrift::transport::LATEST + 1][apache::thrift::transport::LATEST + 1] =
+        {
+    //   server    = SSLTLS   SSLv2    SSLv3    TLSv1_0  TLSv1_1  TLSv1_2
+    // client
+    /* SSLTLS  */  { true,    false,   false,   true,    true,    true    },
+    /* SSLv2   */  { false,   false,   false,   false,   false,   false   },
+    /* SSLv3   */  { false,   false,   true,    false,   false,   false   },
+    /* TLSv1_0 */  { true,    false,   false,   true,    false,   false   },
+    /* TLSv1_1 */  { true,    false,   false,   false,   true,    false   },
+    /* TLSv1_2 */  { true,    false,   false,   false,   false,   true    }
+        };
+
+    for (size_t si = 0; si <= apache::thrift::transport::LATEST; ++si) {
+      for (size_t ci = 0; ci <= apache::thrift::transport::LATEST; ++ci) {
+        if (si == 1 || ci == 1) {
+          // Skip all SSLv2 cases - protocol not supported
+          continue;
+        }
+
+#ifdef OPENSSL_NO_SSL3
+        if (si == 2 || ci == 2) {
+          // Skip all SSLv3 cases - protocol not supported
+          continue;
+        }
+#endif
+
+        boost::mutex::scoped_lock lock(mMutex);
+
+        BOOST_TEST_MESSAGE(boost::format("TEST: Server = %1%, Client = %2%") % protocol2str(si)
+                           % protocol2str(ci));
+
+        mConnected = false;
+        // thread_group manages the thread lifetime - ignore the return value of create_thread
+        boost::thread_group threads;
+        (void)threads.create_thread(bind(&SecurityFromBufferFixture::server, this,
+                                         static_cast<apache::thrift::transport::SSLProtocol>(si)));
+        mCVar.wait(lock); // wait for listen() to succeed
+        lock.unlock();
+        (void)threads.create_thread(bind(&SecurityFromBufferFixture::client, this,
+                                         static_cast<apache::thrift::transport::SSLProtocol>(ci)));
+        threads.join_all();
+
+        BOOST_CHECK_MESSAGE(mConnected == matrix[ci][si],
+            boost::format("      Server = %1%, Client = %2% expected mConnected == %3% but was %4%")
+                % protocol2str(si) % protocol2str(ci) % matrix[ci][si] % mConnected);
+      }
+    }
+  } catch (std::exception& ex) {
+    BOOST_FAIL(boost::format("%1%: %2%") % typeid(ex).name() % ex.what());
+  }
+}
+
+BOOST_AUTO_TEST_SUITE_END()