THRIFT-151. cpp: TSSLServerSocket and TSSLSocket implementation
This patch adds an implementation of the above ssl sockets.
Patch: Ping Li, Kevin Worth, Rowan Kerr
git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1073441 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am
index 402a294..b085ac4 100644
--- a/lib/cpp/Makefile.am
+++ b/lib/cpp/Makefile.am
@@ -60,8 +60,10 @@
src/transport/THttpClient.cpp \
src/transport/THttpServer.cpp \
src/transport/TSocket.cpp \
+ src/transport/TSSLSocket.cpp \
src/transport/TSocketPool.cpp \
src/transport/TServerSocket.cpp \
+ src/transport/TSSLServerSocket.cpp \
src/transport/TTransportUtils.cpp \
src/transport/TBufferTransports.cpp \
src/server/TServer.cpp \
@@ -125,11 +127,13 @@
src/transport/TFileTransport.h \
src/transport/TSimpleFileTransport.h \
src/transport/TServerSocket.h \
+ src/transport/TSSLServerSocket.h \
src/transport/TServerTransport.h \
src/transport/THttpTransport.h \
src/transport/THttpClient.h \
src/transport/THttpServer.h \
src/transport/TSocket.h \
+ src/transport/TSSLSocket.h \
src/transport/TSocketPool.h \
src/transport/TVirtualTransport.h \
src/transport/TTransport.h \
diff --git a/lib/cpp/src/transport/TSSLServerSocket.cpp b/lib/cpp/src/transport/TSSLServerSocket.cpp
new file mode 100644
index 0000000..ed4b648
--- /dev/null
+++ b/lib/cpp/src/transport/TSSLServerSocket.cpp
@@ -0,0 +1,36 @@
+// Copyright (c) 2009- Facebook
+// Distributed under the Thrift Software License
+//
+// See accompanying file LICENSE or visit the Thrift site at:
+// http://developers.facebook.com/thrift/
+
+#include "TSSLServerSocket.h"
+#include "TSSLSocket.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+using namespace boost;
+
+/**
+ * SSL server socket implementation.
+ *
+ * @author Ping Li <pingli@facebook.com>
+ */
+TSSLServerSocket::TSSLServerSocket(int port,
+ shared_ptr<TSSLSocketFactory> factory):
+ TServerSocket(port), factory_(factory) {
+ factory_->server(true);
+}
+
+TSSLServerSocket::TSSLServerSocket(int port, int sendTimeout, int recvTimeout,
+ shared_ptr<TSSLSocketFactory> factory):
+ TServerSocket(port, sendTimeout, recvTimeout),
+ factory_(factory) {
+ factory_->server(true);
+}
+
+shared_ptr<TSocket> TSSLServerSocket::createSocket(int client) {
+ return factory_->createSocket(client);
+}
+
+}}}
diff --git a/lib/cpp/src/transport/TSSLServerSocket.h b/lib/cpp/src/transport/TSSLServerSocket.h
new file mode 100644
index 0000000..36f895c
--- /dev/null
+++ b/lib/cpp/src/transport/TSSLServerSocket.h
@@ -0,0 +1,48 @@
+// Copyright (c) 2009- Facebook
+// Distributed under the Thrift Software License
+//
+// See accompanying file LICENSE or visit the Thrift site at:
+// http://developers.facebook.com/thrift/
+
+#ifndef _THRIFT_TRANSPORT_TSSLSERVERSOCKET_H_
+#define _THRIFT_TRANSPORT_TSSLSERVERSOCKET_H_ 1
+
+#include <boost/shared_ptr.hpp>
+#include "TServerSocket.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+class TSSLSocketFactory;
+
+/**
+ * Server socket that accepts SSL connections.
+ *
+ * @author Ping Li <pingli@facebook.com>
+ */
+class TSSLServerSocket: public TServerSocket {
+ public:
+ /**
+ * Constructor.
+ *
+ * @param port Listening port
+ * @param factory SSL socket factory implementation
+ */
+ TSSLServerSocket(int port, boost::shared_ptr<TSSLSocketFactory> factory);
+ /**
+ * Constructor.
+ *
+ * @param port Listening port
+ * @param sendTimeout Socket send timeout
+ * @param recvTimeout Socket receive timeout
+ * @param factory SSL socket factory implementation
+ */
+ TSSLServerSocket(int port, int sendTimeout, int recvTimeout,
+ boost::shared_ptr<TSSLSocketFactory> factory);
+ protected:
+ boost::shared_ptr<TSocket> createSocket(int socket);
+ boost::shared_ptr<TSSLSocketFactory> factory_;
+};
+
+}}}
+
+#endif
diff --git a/lib/cpp/src/transport/TSSLSocket.cpp b/lib/cpp/src/transport/TSSLSocket.cpp
new file mode 100644
index 0000000..f84f806
--- /dev/null
+++ b/lib/cpp/src/transport/TSSLSocket.cpp
@@ -0,0 +1,645 @@
+// Copyright (c) 2009- Facebook
+// Distributed under the Thrift Software License
+//
+// See accompanying file LICENSE or visit the Thrift site at:
+// http://developers.facebook.com/thrift/
+
+#include <errno.h>
+#include <string>
+#include <arpa/inet.h>
+#include <boost/lexical_cast.hpp>
+#include <boost/shared_array.hpp>
+#include <openssl/err.h>
+#include <openssl/rand.h>
+#include <openssl/ssl.h>
+#include <openssl/x509v3.h>
+#include "concurrency/Mutex.h"
+#include "TSSLSocket.h"
+
+#define OPENSSL_VERSION_NO_THREAD_ID 0x10000000L
+
+using namespace std;
+using namespace boost;
+using namespace apache::thrift::concurrency;
+
+struct CRYPTO_dynlock_value {
+ Mutex mutex;
+};
+
+namespace apache { namespace thrift { namespace transport {
+
+
+static void buildErrors(string& message, int error = 0);
+static bool matchName(const char* host, const char* pattern, int size);
+static char uppercase(char c);
+
+// SSLContext implementation
+SSLContext::SSLContext() {
+ ctx_ = SSL_CTX_new(TLSv1_method());
+ if (ctx_ == NULL) {
+ string errors;
+ buildErrors(errors);
+ throw TSSLException("SSL_CTX_new: " + errors);
+ }
+ SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
+}
+
+SSLContext::~SSLContext() {
+ if (ctx_ != NULL) {
+ SSL_CTX_free(ctx_);
+ ctx_ = NULL;
+ }
+}
+
+SSL* SSLContext::createSSL() {
+ SSL* ssl = SSL_new(ctx_);
+ if (ssl == NULL) {
+ string errors;
+ buildErrors(errors);
+ throw TSSLException("SSL_new: " + errors);
+ }
+ return ssl;
+}
+
+// TSSLSocket implementation
+TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx):
+ TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
+}
+
+TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx, int socket):
+ TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) {
+}
+
+TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx, string host, int port):
+ TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
+}
+
+TSSLSocket::~TSSLSocket() {
+ close();
+}
+
+bool TSSLSocket::isOpen() {
+ if (ssl_ == NULL || !TSocket::isOpen()) {
+ return false;
+ }
+ int shutdown = SSL_get_shutdown(ssl_);
+ bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN);
+ bool shutdownSent = (shutdown & SSL_SENT_SHUTDOWN);
+ if (shutdownReceived && shutdownSent) {
+ return false;
+ }
+ return true;
+}
+
+bool TSSLSocket::peek() {
+ if (!isOpen()) {
+ return false;
+ }
+ checkHandshake();
+ int rc;
+ uint8_t byte;
+ rc = SSL_peek(ssl_, &byte, 1);
+ if (rc < 0) {
+ int errno_copy = errno;
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("SSL_peek: " + errors);
+ }
+ if (rc == 0) {
+ ERR_clear_error();
+ }
+ return (rc > 0);
+}
+
+void TSSLSocket::open() {
+ if (isOpen() || server()) {
+ throw TTransportException(TTransportException::BAD_ARGS);
+ }
+ TSocket::open();
+}
+
+void TSSLSocket::close() {
+ if (ssl_ != NULL) {
+ int rc = SSL_shutdown(ssl_);
+ if (rc == 0) {
+ rc = SSL_shutdown(ssl_);
+ }
+ if (rc < 0) {
+ int errno_copy = errno;
+ string errors;
+ buildErrors(errors, errno_copy);
+ GlobalOutput(("SSL_shutdown: " + errors).c_str());
+ }
+ SSL_free(ssl_);
+ ssl_ = NULL;
+ ERR_remove_state(0);
+ }
+ TSocket::close();
+}
+
+uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
+ checkHandshake();
+ int32_t bytes = 0;
+ for (int32_t retries = 0; retries < maxRecvRetries_; retries++){
+ bytes = SSL_read(ssl_, buf, len);
+ if (bytes >= 0)
+ break;
+ int errno_copy = errno;
+ if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
+ if (ERR_get_error() == 0 && errno_copy == EINTR) {
+ continue;
+ }
+ }
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("SSL_read: " + errors);
+ }
+ return bytes;
+}
+
+void TSSLSocket::write(const uint8_t* buf, uint32_t len) {
+ checkHandshake();
+ // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
+ uint32_t written = 0;
+ while (written < len) {
+ int32_t bytes = SSL_write(ssl_, &buf[written], len - written);
+ if (bytes <= 0) {
+ int errno_copy = errno;
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("SSL_write: " + errors);
+ }
+ written += bytes;
+ }
+}
+
+void TSSLSocket::flush() {
+ // Don't throw exception if not open. Thrift servers close socket twice.
+ if (ssl_ == NULL) {
+ return;
+ }
+ checkHandshake();
+ BIO* bio = SSL_get_wbio(ssl_);
+ if (bio == NULL) {
+ throw TSSLException("SSL_get_wbio returns NULL");
+ }
+ if (BIO_flush(bio) != 1) {
+ int errno_copy = errno;
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("BIO_flush: " + errors);
+ }
+}
+
+void TSSLSocket::checkHandshake() {
+ if (!TSocket::isOpen()) {
+ throw TTransportException(TTransportException::NOT_OPEN);
+ }
+ if (ssl_ != NULL) {
+ return;
+ }
+ ssl_ = ctx_->createSSL();
+ SSL_set_fd(ssl_, socket_);
+ int rc;
+ if (server()) {
+ rc = SSL_accept(ssl_);
+ } else {
+ rc = SSL_connect(ssl_);
+ }
+ if (rc <= 0) {
+ int errno_copy = errno;
+ string fname(server() ? "SSL_accept" : "SSL_connect");
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException(fname + ": " + errors);
+ }
+ authorize();
+}
+
+void TSSLSocket::authorize() {
+ int rc = SSL_get_verify_result(ssl_);
+ if (rc != X509_V_OK) { // verify authentication result
+ throw TSSLException(string("SSL_get_verify_result(), ") +
+ X509_verify_cert_error_string(rc));
+ }
+
+ X509* cert = SSL_get_peer_certificate(ssl_);
+ if (cert == NULL) {
+ // certificate is not present
+ if (SSL_get_verify_mode(ssl_) & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) {
+ throw TSSLException("authorize: required certificate not present");
+ }
+ // certificate was optional: didn't intend to authorize remote
+ if (server() && access_ != NULL) {
+ throw TSSLException("authorize: certificate required for authorization");
+ }
+ return;
+ }
+ // certificate is present
+ if (access_ == NULL) {
+ X509_free(cert);
+ return;
+ }
+ // both certificate and access manager are present
+
+ string host;
+ sockaddr_storage sa = {};
+ socklen_t saLength = sizeof(sa);
+
+ if (getpeername(socket_, (sockaddr*)&sa, &saLength) != 0) {
+ sa.ss_family = AF_UNSPEC;
+ }
+
+ AccessManager::Decision decision = access_->verify(sa);
+
+ if (decision != AccessManager::SKIP) {
+ X509_free(cert);
+ if (decision != AccessManager::ALLOW) {
+ throw TSSLException("authorize: access denied based on remote IP");
+ }
+ return;
+ }
+
+ // extract subjectAlternativeName
+ STACK_OF(GENERAL_NAME)* alternatives = (STACK_OF(GENERAL_NAME)*)
+ X509_get_ext_d2i(cert, NID_subject_alt_name, NULL, NULL);
+ if (alternatives != NULL) {
+ const int count = sk_GENERAL_NAME_num(alternatives);
+ for (int i = 0; decision == AccessManager::SKIP && i < count; i++) {
+ const GENERAL_NAME* name = sk_GENERAL_NAME_value(alternatives, i);
+ if (name == NULL) {
+ continue;
+ }
+ char* data = (char*)ASN1_STRING_data(name->d.ia5);
+ int length = ASN1_STRING_length(name->d.ia5);
+ switch (name->type) {
+ case GEN_DNS:
+ if (host.empty()) {
+ host = (server() ? getPeerHost() : getHost());
+ }
+ decision = access_->verify(host, data, length);
+ break;
+ case GEN_IPADD:
+ decision = access_->verify(sa, data, length);
+ break;
+ }
+ }
+ sk_GENERAL_NAME_pop_free(alternatives, GENERAL_NAME_free);
+ }
+
+ if (decision != AccessManager::SKIP) {
+ X509_free(cert);
+ if (decision != AccessManager::ALLOW) {
+ throw TSSLException("authorize: access denied");
+ }
+ return;
+ }
+
+ // extract commonName
+ X509_NAME* name = X509_get_subject_name(cert);
+ if (name != NULL) {
+ X509_NAME_ENTRY* entry;
+ unsigned char* utf8;
+ int last = -1;
+ while (decision == AccessManager::SKIP) {
+ last = X509_NAME_get_index_by_NID(name, NID_commonName, last);
+ if (last == -1)
+ break;
+ entry = X509_NAME_get_entry(name, last);
+ if (entry == NULL)
+ continue;
+ ASN1_STRING* common = X509_NAME_ENTRY_get_data(entry);
+ int size = ASN1_STRING_to_UTF8(&utf8, common);
+ if (host.empty()) {
+ host = (server() ? getHost() : getHost());
+ }
+ decision = access_->verify(host, (char*)utf8, size);
+ OPENSSL_free(utf8);
+ }
+ }
+ X509_free(cert);
+ if (decision != AccessManager::ALLOW) {
+ throw TSSLException("authorize: cannot authorize peer");
+ }
+}
+
+// TSSLSocketFactory implementation
+bool TSSLSocketFactory::initialized = false;
+uint64_t TSSLSocketFactory::count_ = 0;
+Mutex TSSLSocketFactory::mutex_;
+
+TSSLSocketFactory::TSSLSocketFactory(): server_(false) {
+ Guard guard(mutex_);
+ if (count_ == 0) {
+ initializeOpenSSL();
+ randomize();
+ }
+ count_++;
+ ctx_ = shared_ptr<SSLContext>(new SSLContext);
+}
+
+TSSLSocketFactory::~TSSLSocketFactory() {
+ Guard guard(mutex_);
+ count_--;
+ if (count_ == 0) {
+ cleanupOpenSSL();
+ }
+}
+
+shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket() {
+ shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_));
+ setup(ssl);
+ return ssl;
+}
+
+shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(int socket) {
+ shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket));
+ setup(ssl);
+ return ssl;
+}
+
+shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host,
+ int port) {
+ shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
+ setup(ssl);
+ return ssl;
+}
+
+void TSSLSocketFactory::setup(shared_ptr<TSSLSocket> ssl) {
+ ssl->server(server());
+ if (access_ == NULL && !server()) {
+ access_ = shared_ptr<AccessManager>(new DefaultClientAccessManager);
+ }
+ if (access_ != NULL) {
+ ssl->access(access_);
+ }
+}
+
+void TSSLSocketFactory::ciphers(const string& enable) {
+ int rc = SSL_CTX_set_cipher_list(ctx_->get(), enable.c_str());
+ if (ERR_peek_error() != 0) {
+ string errors;
+ buildErrors(errors);
+ throw TSSLException("SSL_CTX_set_cipher_list: " + errors);
+ }
+ if (rc == 0) {
+ throw TSSLException("None of specified ciphers are supported");
+ }
+}
+
+void TSSLSocketFactory::authenticate(bool required) {
+ int mode;
+ if (required) {
+ mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
+ } else {
+ mode = SSL_VERIFY_NONE;
+ }
+ SSL_CTX_set_verify(ctx_->get(), mode, NULL);
+}
+
+void TSSLSocketFactory::loadCertificate(const char* path, const char* format) {
+ if (path == NULL || format == NULL) {
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "loadCertificateChain: either <path> or <format> is NULL");
+ }
+ if (strcmp(format, "PEM") == 0) {
+ if (SSL_CTX_use_certificate_chain_file(ctx_->get(), path) == 0) {
+ int errno_copy = errno;
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("SSL_CTX_use_certificate_chain_file: " + errors);
+ }
+ } else {
+ throw TSSLException("Unsupported certificate format: " + string(format));
+ }
+}
+
+void TSSLSocketFactory::loadPrivateKey(const char* path, const char* format) {
+ if (path == NULL || format == NULL) {
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "loadPrivateKey: either <path> or <format> is NULL");
+ }
+ if (strcmp(format, "PEM") == 0) {
+ if (SSL_CTX_use_PrivateKey_file(ctx_->get(), path, SSL_FILETYPE_PEM) == 0) {
+ int errno_copy = errno;
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("SSL_CTX_use_PrivateKey_file: " + errors);
+ }
+ }
+}
+
+void TSSLSocketFactory::loadTrustedCertificates(const char* path) {
+ if (path == NULL) {
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "loadTrustedCertificates: <path> is NULL");
+ }
+ if (SSL_CTX_load_verify_locations(ctx_->get(), path, NULL) == 0) {
+ int errno_copy = errno;
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("SSL_CTX_load_verify_locations: " + errors);
+ }
+}
+
+void TSSLSocketFactory::randomize() {
+ RAND_poll();
+}
+
+void TSSLSocketFactory::overrideDefaultPasswordCallback() {
+ SSL_CTX_set_default_passwd_cb(ctx_->get(), passwordCallback);
+ SSL_CTX_set_default_passwd_cb_userdata(ctx_->get(), this);
+}
+
+int TSSLSocketFactory::passwordCallback(char* password,
+ int size,
+ int,
+ void* data) {
+ TSSLSocketFactory* factory = (TSSLSocketFactory*)data;
+ string userPassword;
+ factory->getPassword(userPassword, size);
+ int length = userPassword.size();
+ if (length > size) {
+ length = size;
+ }
+ strncpy(password, userPassword.c_str(), length);
+ return length;
+}
+
+static shared_array<Mutex> mutexes;
+
+static void callbackLocking(int mode, int n, const char*, int) {
+ if (mode & CRYPTO_LOCK) {
+ mutexes[n].lock();
+ } else {
+ mutexes[n].unlock();
+ }
+}
+
+#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
+static unsigned long callbackThreadID() {
+ return reinterpret_cast<unsigned long>(pthread_self());
+}
+#endif
+
+static CRYPTO_dynlock_value* dyn_create(const char*, int) {
+ return new CRYPTO_dynlock_value;
+}
+
+static void dyn_lock(int mode,
+ struct CRYPTO_dynlock_value* lock,
+ const char*, int) {
+ if (lock != NULL) {
+ if (mode & CRYPTO_LOCK) {
+ lock->mutex.lock();
+ } else {
+ lock->mutex.unlock();
+ }
+ }
+}
+
+static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
+ delete lock;
+}
+
+void TSSLSocketFactory::initializeOpenSSL() {
+ if (initialized) {
+ return;
+ }
+ initialized = true;
+ SSL_library_init();
+ SSL_load_error_strings();
+ // static locking
+ mutexes = shared_array<Mutex>(new Mutex[::CRYPTO_num_locks()]);
+ if (mutexes == NULL) {
+ throw TTransportException(TTransportException::INTERNAL_ERROR,
+ "initializeOpenSSL() failed, "
+ "out of memory while creating mutex array");
+ }
+#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
+ CRYPTO_set_id_callback(callbackThreadID);
+#endif
+ CRYPTO_set_locking_callback(callbackLocking);
+ // dynamic locking
+ CRYPTO_set_dynlock_create_callback(dyn_create);
+ CRYPTO_set_dynlock_lock_callback(dyn_lock);
+ CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
+}
+
+void TSSLSocketFactory::cleanupOpenSSL() {
+ if (!initialized) {
+ return;
+ }
+ initialized = false;
+#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
+ CRYPTO_set_id_callback(NULL);
+#endif
+ CRYPTO_set_locking_callback(NULL);
+ CRYPTO_set_dynlock_create_callback(NULL);
+ CRYPTO_set_dynlock_lock_callback(NULL);
+ CRYPTO_set_dynlock_destroy_callback(NULL);
+ CRYPTO_cleanup_all_ex_data();
+ ERR_free_strings();
+ EVP_cleanup();
+ ERR_remove_state(0);
+ mutexes.reset();
+}
+
+// extract error messages from error queue
+void buildErrors(string& errors, int errno_copy) {
+ unsigned long errorCode;
+ char message[256];
+
+ errors.reserve(512);
+ while ((errorCode = ERR_get_error()) != 0) {
+ if (!errors.empty()) {
+ errors += "; ";
+ }
+ const char* reason = ERR_reason_error_string(errorCode);
+ if (reason == NULL) {
+ snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
+ reason = message;
+ }
+ errors += reason;
+ }
+ if (errors.empty()) {
+ if (errno_copy != 0) {
+ errors += TOutput::strerror_s(errno_copy);
+ }
+ }
+ if (errors.empty()) {
+ errors = "error code: " + lexical_cast<string>(errno_copy);
+ }
+}
+
+/**
+ * Default implementation of AccessManager
+ */
+Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa)
+ throw() { return SKIP; }
+
+Decision DefaultClientAccessManager::verify(const string& host,
+ const char* name,
+ int size) throw() {
+ if (host.empty() || name == NULL || size <= 0) {
+ return SKIP;
+ }
+ return (matchName(host.c_str(), name, size) ? ALLOW : SKIP);
+}
+
+Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa,
+ const char* data,
+ int size) throw() {
+ bool match = false;
+ if (sa.ss_family == AF_INET && size == sizeof(in_addr)) {
+ match = (memcmp(&((sockaddr_in*)&sa)->sin_addr, data, size) == 0);
+ } else if (sa.ss_family == AF_INET6 && size == sizeof(in6_addr)) {
+ match = (memcmp(&((sockaddr_in6*)&sa)->sin6_addr, data, size) == 0);
+ }
+ return (match ? ALLOW : SKIP);
+}
+
+/**
+ * Match a name with a pattern. The pattern may include wildcard. A single
+ * wildcard "*" can match up to one component in the domain name.
+ *
+ * @param host Host name, typically the name of the remote host
+ * @param pattern Name retrieved from certificate
+ * @param size Size of "pattern"
+ * @return True, if "host" matches "pattern". False otherwise.
+ */
+bool matchName(const char* host, const char* pattern, int size) {
+ bool match = false;
+ int i = 0, j = 0;
+ while (i < size && host[j] != '\0') {
+ if (uppercase(pattern[i]) == uppercase(host[j])) {
+ i++;
+ j++;
+ continue;
+ }
+ if (pattern[i] == '*') {
+ while (host[j] != '.' && host[j] != '\0') {
+ j++;
+ }
+ i++;
+ continue;
+ }
+ break;
+ }
+ if (i == size && host[j] == '\0') {
+ match = true;
+ }
+ return match;
+
+}
+
+// This is to work around the Turkish locale issue, i.e.,
+// toupper('i') != toupper('I') if locale is "tr_TR"
+char uppercase (char c) {
+ if ('a' <= c && c <= 'z') {
+ return c + ('A' - 'a');
+ }
+ return c;
+}
+
+}}}
diff --git a/lib/cpp/src/transport/TSSLSocket.h b/lib/cpp/src/transport/TSSLSocket.h
new file mode 100644
index 0000000..58e3934
--- /dev/null
+++ b/lib/cpp/src/transport/TSSLSocket.h
@@ -0,0 +1,304 @@
+// Copyright (c) 2009- Facebook
+// Distributed under the Thrift Software License
+//
+// See accompanying file LICENSE or visit the Thrift site at:
+// http://developers.facebook.com/thrift/
+
+#ifndef _THRIFT_TRANSPORT_TSSLSOCKET_H_
+#define _THRIFT_TRANSPORT_TSSLSOCKET_H_ 1
+
+#include <string>
+#include <boost/shared_ptr.hpp>
+#include <openssl/ssl.h>
+#include "concurrency/Mutex.h"
+#include "TSocket.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+class AccessManager;
+class SSLContext;
+
+/**
+ * OpenSSL implementation for SSL socket interface.
+ *
+ * @author Ping Li <pingli@facebook.com>
+ */
+class TSSLSocket: public TSocket {
+ public:
+ ~TSSLSocket();
+ /**
+ * TTransport interface.
+ */
+ bool isOpen();
+ bool peek();
+ void open();
+ void close();
+ uint32_t read(uint8_t* buf, uint32_t len);
+ void write(const uint8_t* buf, uint32_t len);
+ void flush();
+ /**
+ * Set whether to use client or server side SSL handshake protocol.
+ *
+ * @param flag Use server side handshake protocol if true.
+ */
+ void server(bool flag) { server_ = flag; }
+ /**
+ * Determine whether the SSL socket is server or client mode.
+ */
+ bool server() const { return server_; }
+ /**
+ * Set AccessManager.
+ *
+ * @param manager Instance of AccessManager
+ */
+ virtual void access(boost::shared_ptr<AccessManager> manager) {
+ access_ = manager;
+ }
+protected:
+ /**
+ * Constructor.
+ */
+ TSSLSocket(boost::shared_ptr<SSLContext> ctx);
+ /**
+ * Constructor, create an instance of TSSLSocket given an existing socket.
+ *
+ * @param socket An existing socket
+ */
+ TSSLSocket(boost::shared_ptr<SSLContext> ctx, int socket);
+ /**
+ * Constructor.
+ *
+ * @param host Remote host name
+ * @param port Remote port number
+ */
+ TSSLSocket(boost::shared_ptr<SSLContext> ctx,
+ std::string host,
+ int port);
+ /**
+ * Authorize peer access after SSL handshake completes.
+ */
+ virtual void authorize();
+ /**
+ * Initiate SSL handshake if not already initiated.
+ */
+ void checkHandshake();
+
+ bool server_;
+ SSL* ssl_;
+ boost::shared_ptr<SSLContext> ctx_;
+ boost::shared_ptr<AccessManager> access_;
+ friend class TSSLSocketFactory;
+};
+
+/**
+ * SSL socket factory. SSL sockets should be created via SSL factory.
+ */
+class TSSLSocketFactory {
+ public:
+ /**
+ * Constructor/Destructor
+ */
+ TSSLSocketFactory();
+ virtual ~TSSLSocketFactory();
+ /**
+ * Create an instance of TSSLSocket with a fresh new socket.
+ */
+ virtual boost::shared_ptr<TSSLSocket> createSocket();
+ /**
+ * Create an instance of TSSLSocket with the given socket.
+ *
+ * @param socket An existing socket.
+ */
+ virtual boost::shared_ptr<TSSLSocket> createSocket(int socket);
+ /**
+ * Create an instance of TSSLSocket.
+ *
+ * @param host Remote host to be connected to
+ * @param port Remote port to be connected to
+ */
+ virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host,
+ int port);
+ /**
+ * Set ciphers to be used in SSL handshake process.
+ *
+ * @param ciphers A list of ciphers
+ */
+ virtual void ciphers(const std::string& enable);
+ /**
+ * Enable/Disable authentication.
+ *
+ * @param required Require peer to present valid certificate if true
+ */
+ virtual void authenticate(bool required);
+ /**
+ * Load server certificate.
+ *
+ * @param path Path to the certificate file
+ * @param format Certificate file format
+ */
+ virtual void loadCertificate(const char* path, const char* format = "PEM");
+ /**
+ * Load private key.
+ *
+ * @param path Path to the private key file
+ * @param format Private key file format
+ */
+ virtual void loadPrivateKey(const char* path, const char* format = "PEM");
+ /**
+ * Load trusted certificates from specified file.
+ *
+ * @param path Path to trusted certificate file
+ */
+ virtual void loadTrustedCertificates(const char* path);
+ /**
+ * Default randomize method.
+ */
+ virtual void randomize();
+ /**
+ * Override default OpenSSL password callback with getPassword().
+ */
+ void overrideDefaultPasswordCallback();
+ /**
+ * Set/Unset server mode.
+ *
+ * @param flag Server mode if true
+ */
+ virtual void server(bool flag) { server_ = flag; }
+ /**
+ * Determine whether the socket is in server or client mode.
+ *
+ * @return true, if server mode, or, false, if client mode
+ */
+ virtual bool server() const { return server_; }
+ /**
+ * Set AccessManager.
+ *
+ * @param manager The AccessManager instance
+ */
+ virtual void access(boost::shared_ptr<AccessManager> manager) {
+ access_ = manager;
+ }
+ protected:
+ boost::shared_ptr<SSLContext> ctx_;
+
+ static void initializeOpenSSL();
+ static void cleanupOpenSSL();
+ /**
+ * Override this method for custom password callback. It may be called
+ * multiple times at any time during a session as necessary.
+ *
+ * @param password Pass collected password to OpenSSL
+ * @param size Maximum length of password including NULL character
+ */
+ virtual void getPassword(std::string& password, int size) { }
+ private:
+ bool server_;
+ boost::shared_ptr<AccessManager> access_;
+ static bool initialized;
+ static concurrency::Mutex mutex_;
+ static uint64_t count_;
+ void setup(boost::shared_ptr<TSSLSocket> ssl);
+ static int passwordCallback(char* password, int size, int, void* data);
+};
+
+/**
+ * SSL exception.
+ */
+class TSSLException: public TTransportException {
+ public:
+ TSSLException(const std::string& message):
+ TTransportException(TTransportException::INTERNAL_ERROR, message) {}
+
+ virtual const char* what() const throw() {
+ if (message_.empty()) {
+ return "TSSLException";
+ } else {
+ return message_.c_str();
+ }
+ }
+};
+
+/**
+ * Wrap OpenSSL SSL_CTX into a class.
+ */
+class SSLContext {
+ public:
+ SSLContext();
+ virtual ~SSLContext();
+ SSL* createSSL();
+ SSL_CTX* get() { return ctx_; }
+ private:
+ SSL_CTX* ctx_;
+};
+
+/**
+ * Callback interface for access control. It's meant to verify the remote host.
+ * It's constructed when application starts and set to TSSLSocketFactory
+ * instance. It's passed onto all TSSLSocket instances created by this factory
+ * object.
+ */
+class AccessManager {
+ public:
+ enum Decision {
+ DENY = -1, // deny access
+ SKIP = 0, // cannot make decision, move on to next (if any)
+ ALLOW = 1, // allow access
+ };
+ /**
+ * Destructor
+ */
+ virtual ~AccessManager() {}
+ /**
+ * Determine whether the peer should be granted access or not. It's called
+ * once after the SSL handshake completes successfully, before peer certificate
+ * is examined.
+ *
+ * If a valid decision (ALLOW or DENY) is returned, the peer certificate is
+ * not to be verified.
+ *
+ * @param sa Peer IP address
+ * @return True if the peer is trusted, false otherwise
+ */
+ virtual Decision verify(const sockaddr_storage& sa) throw() { return DENY; }
+ /**
+ * Determine whether the peer should be granted access or not. It's called
+ * every time a DNS subjectAltName/common name is extracted from peer's
+ * certificate.
+ *
+ * @param host Client mode: host name returned by TSocket::getHost()
+ * Server mode: host name returned by TSocket::getPeerHost()
+ * @param name SubjectAltName or common name extracted from peer certificate
+ * @param size Length of name
+ * @return True if the peer is trusted, false otherwise
+ *
+ * Note: The "name" parameter may be UTF8 encoded.
+ */
+ virtual Decision verify(const std::string& host, const char* name, int size)
+ throw() { return DENY; }
+ /**
+ * Determine whether the peer should be granted access or not. It's called
+ * every time an IP subjectAltName is extracted from peer's certificate.
+ *
+ * @param sa Peer IP address retrieved from the underlying socket
+ * @param data IP address extracted from certificate
+ * @param size Length of the IP address
+ * @return True if the peer is trusted, false otherwise
+ */
+ virtual Decision verify(const sockaddr_storage& sa, const char* data, int size)
+ throw() { return DENY; }
+};
+
+typedef AccessManager::Decision Decision;
+
+class DefaultClientAccessManager: public AccessManager {
+ public:
+ // AccessManager interface
+ Decision verify(const sockaddr_storage& sa) throw();
+ Decision verify(const std::string& host, const char* name, int size) throw();
+ Decision verify(const sockaddr_storage& sa, const char* data, int size) throw();
+};
+
+
+}}}
+
+#endif
diff --git a/lib/cpp/src/transport/TServerSocket.cpp b/lib/cpp/src/transport/TServerSocket.cpp
index 8608898..276b060 100644
--- a/lib/cpp/src/transport/TServerSocket.cpp
+++ b/lib/cpp/src/transport/TServerSocket.cpp
@@ -386,7 +386,7 @@
throw TTransportException(TTransportException::UNKNOWN, "fcntl(F_SETFL)", errno_copy);
}
- shared_ptr<TSocket> client(new TSocket(clientSocket));
+ shared_ptr<TSocket> client = createSocket(clientSocket);
if (sendTimeout_ > 0) {
client->setSendTimeout(sendTimeout_);
}
@@ -398,6 +398,10 @@
return client;
}
+shared_ptr<TSocket> TServerSocket::createSocket(int clientSocket) {
+ return shared_ptr<TSocket>(new TSocket(clientSocket));
+}
+
void TServerSocket::interrupt() {
if (intSock1_ >= 0) {
int8_t byte = 0;
diff --git a/lib/cpp/src/transport/TServerSocket.h b/lib/cpp/src/transport/TServerSocket.h
index 8cd521f..40a1148 100644
--- a/lib/cpp/src/transport/TServerSocket.h
+++ b/lib/cpp/src/transport/TServerSocket.h
@@ -56,6 +56,7 @@
protected:
boost::shared_ptr<TTransport> acceptImpl();
+ virtual boost::shared_ptr<TSocket> createSocket(int client);
private:
int port_;
diff --git a/lib/cpp/src/transport/TSocket.h b/lib/cpp/src/transport/TSocket.h
index e89059f..5521491 100644
--- a/lib/cpp/src/transport/TSocket.h
+++ b/lib/cpp/src/transport/TSocket.h
@@ -70,12 +70,12 @@
*
* @return Is the socket alive?
*/
- bool isOpen();
+ virtual bool isOpen();
/**
* Calls select on the socket to see if there is more data available.
*/
- bool peek();
+ virtual bool peek();
/**
* Creates and opens the UNIX socket.
@@ -92,12 +92,12 @@
/**
* Reads from the underlying socket.
*/
- uint32_t read(uint8_t* buf, uint32_t len);
+ virtual uint32_t read(uint8_t* buf, uint32_t len);
/**
* Writes to the underlying socket. Loops until done or fail.
*/
- void write(const uint8_t* buf, uint32_t len);
+ virtual void write(const uint8_t* buf, uint32_t len);
/**
* Writes to the underlying socket. Does single send() and returns result.