THRIFT-3420 C++: TSSLSockets are not interruptable
Client: C++
Patch: Martin Haimberger
This closes #690
diff --git a/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp b/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
index 7e1484d..89423b4 100644
--- a/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
@@ -48,7 +48,12 @@
}
boost::shared_ptr<TSocket> TSSLServerSocket::createSocket(THRIFT_SOCKET client) {
- return factory_->createSocket(client);
+ if (interruptableChildren_) {
+ return factory_->createSocket(client, pChildInterruptSockReader_);
+
+ } else {
+ return factory_->createSocket(client);
+ }
}
}
}
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
index 98c5326..6e9a4de 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
@@ -28,6 +28,14 @@
#ifdef HAVE_SYS_SOCKET_H
#include <sys/socket.h>
#endif
+#ifdef HAVE_SYS_POLL_H
+#include <sys/poll.h>
+#endif
+#ifdef HAVE_FCNTL_H
+#include <fcntl.h>
+#endif
+
+
#include <boost/lexical_cast.hpp>
#include <boost/shared_array.hpp>
#include <openssl/err.h>
@@ -189,14 +197,28 @@
: TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
}
+TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, boost::shared_ptr<THRIFT_SOCKET> interruptListener)
+ : TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
+ interruptListener_ = interruptListener;
+}
+
TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket)
: TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) {
}
+TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener)
+ : TSocket(socket, interruptListener), server_(false), ssl_(NULL), ctx_(ctx) {
+}
+
TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, string host, int port)
: TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
}
+TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, string host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener)
+ : TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
+ interruptListener_ = interruptListener;
+}
+
TSSLSocket::~TSSLSocket() {
close();
}
@@ -222,16 +244,32 @@
checkHandshake();
int rc;
uint8_t byte;
- rc = SSL_peek(ssl_, &byte, 1);
- if (rc < 0) {
- int errno_copy = THRIFT_GET_SOCKET_ERROR;
- string errors;
- buildErrors(errors, errno_copy);
- throw TSSLException("SSL_peek: " + errors);
- }
- if (rc == 0) {
- ERR_clear_error();
- }
+ do {
+ rc = SSL_peek(ssl_, &byte, 1);
+ if (rc < 0) {
+
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int error = SSL_get_error(ssl_, rc);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ || (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ continue;
+ default:;// do nothing
+ }
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("SSL_peek: " + errors);
+ } else if (rc == 0) {
+ ERR_clear_error();
+ break;
+ }
+ } while (true);
return (rc > 0);
}
@@ -244,7 +282,28 @@
void TSSLSocket::close() {
if (ssl_ != NULL) {
- int rc = SSL_shutdown(ssl_);
+ int rc;
+
+ do {
+ rc = SSL_shutdown(ssl_);
+ if (rc <= 0) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int error = SSL_get_error(ssl_, rc);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ || (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ rc = 2;
+ default:;// do nothing
+ }
+ }
+ } while (rc == 2);
+
if (rc < 0) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
string errors;
@@ -262,14 +321,36 @@
checkHandshake();
int32_t bytes = 0;
for (int32_t retries = 0; retries < maxRecvRetries_; retries++) {
+ ERR_clear_error();
bytes = SSL_read(ssl_, buf, len);
if (bytes >= 0)
break;
- int errno_copy = THRIFT_GET_SOCKET_ERROR;
- if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
- if (ERR_get_error() == 0 && errno_copy == THRIFT_EINTR) {
+ int32_t errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int32_t error = SSL_get_error(ssl_, bytes);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ || (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ if (retries++ >= maxRecvRetries_) {
+ // THRIFT_EINTR needs to be handled manually and we can tolerate
+ // a certain number
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ if (waitForEvent(error == SSL_ERROR_WANT_READ) == TSSL_EINTR ) {
+ // repeat operation
+ if (retries++ < maxRecvRetries_) {
+ // THRIFT_EINTR needs to be handled manually and we can tolerate
+ // a certain number
+ continue;
+ }
+ throw TTransportException(TTransportException::INTERNAL_ERROR, "too much recv retries");
+ }
continue;
- }
+ default:;// do nothing
}
string errors;
buildErrors(errors, errno_copy);
@@ -283,9 +364,23 @@
// loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
uint32_t written = 0;
while (written < len) {
+ ERR_clear_error();
int32_t bytes = SSL_write(ssl_, &buf[written], len - written);
if (bytes <= 0) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int error = SSL_get_error(ssl_, bytes);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ || (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ continue;
+ default:;// do nothing
+ }
string errors;
buildErrors(errors, errno_copy);
throw TSSLException("SSL_write: " + errors);
@@ -319,13 +414,76 @@
if (ssl_ != NULL) {
return;
}
+
+ // set underlying socket to non-blocking
+ int flags;
+ if ((flags = THRIFT_FCNTL(socket_, THRIFT_F_GETFL, 0)) < 0
+ || THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK) < 0) {
+ GlobalOutput.perror("thriftServerEventHandler: set THRIFT_O_NONBLOCK (THRIFT_FCNTL) ",
+ THRIFT_GET_SOCKET_ERROR);
+ ::THRIFT_CLOSESOCKET(socket_);
+ return;
+ }
+
ssl_ = ctx_->createSSL();
+
+ //set read and write bios to non-blocking
+ BIO* wbio = BIO_new(BIO_s_mem());
+ if (wbio == NULL) {
+ throw TSSLException("SSL_get_wbio returns NULL");
+ }
+ BIO_set_nbio(wbio, 1);
+
+ BIO* rbio = BIO_new(BIO_s_mem());
+ if (rbio == NULL) {
+ throw TSSLException("SSL_get_rbio returns NULL");
+ }
+ BIO_set_nbio(rbio, 1);
+
+ SSL_set_bio(ssl_, rbio, wbio);
+
SSL_set_fd(ssl_, static_cast<int>(socket_));
int rc;
if (server()) {
- rc = SSL_accept(ssl_);
+ do {
+ rc = SSL_accept(ssl_);
+ if (rc <= 0) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int error = SSL_get_error(ssl_, rc);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ || (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ rc = 2;
+ default:;// do nothing
+ }
+ }
+ } while (rc == 2);
} else {
- rc = SSL_connect(ssl_);
+ do {
+ rc = SSL_connect(ssl_);
+ if (rc <= 0) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int error = SSL_get_error(ssl_, rc);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ || (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ rc = 2;
+ default:;// do nothing
+ }
+ }
+ } while (rc == 2);
}
if (rc <= 0) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
@@ -443,6 +601,54 @@
}
}
+unsigned int TSSLSocket::waitForEvent(bool wantRead) {
+ int fdSocket;
+ BIO* bio;
+
+ if (wantRead) {
+ bio = SSL_get_rbio(ssl_);
+ } else {
+ bio = SSL_get_wbio(ssl_);
+ }
+
+ if (bio == NULL) {
+ throw TSSLException("SSL_get_?bio returned NULL");
+ }
+
+ if (BIO_get_fd(bio, &fdSocket) <= 0) {
+ throw TSSLException("BIO_get_fd failed");
+ }
+
+ struct THRIFT_POLLFD fds[2];
+ std::memset(fds, 0, sizeof(fds));
+ fds[0].fd = fdSocket;
+ fds[0].events = wantRead ? THRIFT_POLLIN : THRIFT_POLLOUT;
+
+ if (interruptListener_) {
+ fds[1].fd = *(interruptListener_.get());
+ fds[1].events = THRIFT_POLLIN;
+ }
+
+ int ret = THRIFT_POLL(fds, interruptListener_ ? 2 : 1, -1);
+
+ if (ret < 0) {
+ // error cases
+ if (THRIFT_GET_SOCKET_ERROR == THRIFT_EINTR) {
+ return TSSL_EINTR; // repeat operation
+ }
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TSSLSocket::read THRIFT_POLL() ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
+ } else if (ret > 0){
+ if (fds[1].revents & THRIFT_POLLIN) {
+ throw TTransportException(TTransportException::INTERRUPTED, "Interrupted");
+ }
+ return TSSL_DATA;
+ } else {
+ throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_POLL (timed out)");
+ }
+}
+
// TSSLSocketFactory implementation
uint64_t TSSLSocketFactory::count_ = 0;
Mutex TSSLSocketFactory::mutex_;
@@ -475,18 +681,37 @@
return ssl;
}
+boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(boost::shared_ptr<THRIFT_SOCKET> interruptListener) {
+ boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, interruptListener));
+ setup(ssl);
+ return ssl;
+}
+
boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket) {
boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket));
setup(ssl);
return ssl;
}
+boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener) {
+ boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket, interruptListener));
+ setup(ssl);
+ return ssl;
+}
+
boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host, int port) {
boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
setup(ssl);
return ssl;
}
+boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener) {
+ boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port, interruptListener));
+ setup(ssl);
+ return ssl;
+}
+
+
void TSSLSocketFactory::setup(boost::shared_ptr<TSSLSocket> ssl) {
ssl->server(server());
if (access_ == NULL && !server()) {
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.h b/lib/cpp/src/thrift/transport/TSSLSocket.h
index 02d414b..ba8abf4 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.h
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.h
@@ -43,6 +43,9 @@
LATEST = TLSv1_2
};
+#define TSSL_EINTR 0
+#define TSSL_DATA 1
+
/**
* Initialize OpenSSL library. This function, or some other
* equivalent function to initialize OpenSSL, must be called before
@@ -99,18 +102,35 @@
*/
TSSLSocket(boost::shared_ptr<SSLContext> ctx);
/**
+ * Constructor with an interrupt signal.
+ */
+ TSSLSocket(boost::shared_ptr<SSLContext> ctx, boost::shared_ptr<THRIFT_SOCKET> interruptListener);
+ /**
* Constructor, create an instance of TSSLSocket given an existing socket.
*
* @param socket An existing socket
*/
TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket);
/**
+ * Constructor, create an instance of TSSLSocket given an existing socket that can be interrupted.
+ *
+ * @param socket An existing socket
+ */
+ TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener);
+ /**
* Constructor.
*
* @param host Remote host name
* @param port Remote port number
*/
TSSLSocket(boost::shared_ptr<SSLContext> ctx, std::string host, int port);
+ /**
+ * Constructor with an interrupt signal.
+ *
+ * @param host Remote host name
+ * @param port Remote port number
+ */
+ TSSLSocket(boost::shared_ptr<SSLContext> ctx, std::string host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener);
/**
* Authorize peer access after SSL handshake completes.
*/
@@ -119,6 +139,15 @@
* Initiate SSL handshake if not already initiated.
*/
void checkHandshake();
+ /**
+ * Waits for an socket or shutdown event.
+ *
+ * @throw TTransportException::INTERRUPTED if interrupted is signaled.
+ *
+ * @return TSSL_EINTR if EINTR happened on the underlying socket
+ * TSSL_DATA if data is available on the socket.
+ */
+ unsigned int waitForEvent(bool wantRead);
bool server_;
SSL* ssl_;
@@ -144,12 +173,22 @@
*/
virtual boost::shared_ptr<TSSLSocket> createSocket();
/**
+ * Create an instance of TSSLSocket with a fresh new socket, which is interruptable.
+ */
+ virtual boost::shared_ptr<TSSLSocket> createSocket(boost::shared_ptr<THRIFT_SOCKET> interruptListener);
+ /**
* Create an instance of TSSLSocket with the given socket.
*
* @param socket An existing socket.
*/
virtual boost::shared_ptr<TSSLSocket> createSocket(THRIFT_SOCKET socket);
/**
+ * Create an instance of TSSLSocket with the given socket which is interruptable.
+ *
+ * @param socket An existing socket.
+ */
+ virtual boost::shared_ptr<TSSLSocket> createSocket(THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener);
+ /**
* Create an instance of TSSLSocket.
*
* @param host Remote host to be connected to
@@ -157,6 +196,13 @@
*/
virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host, int port);
/**
+ * Create an instance of TSSLSocket.
+ *
+ * @param host Remote host to be connected to
+ * @param port Remote port to be connected to
+ */
+ virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener);
+ /**
* Set ciphers to be used in SSL handshake process.
*
* @param ciphers A list of ciphers
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp
index 215cda6..137dc32 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp
@@ -119,7 +119,8 @@
using boost::shared_ptr;
TServerSocket::TServerSocket(int port)
- : port_(port),
+ : interruptableChildren_(true),
+ port_(port),
serverSocket_(THRIFT_INVALID_SOCKET),
acceptBacklog_(DEFAULT_BACKLOG),
sendTimeout_(0),
@@ -130,7 +131,6 @@
tcpSendBuffer_(0),
tcpRecvBuffer_(0),
keepAlive_(false),
- interruptableChildren_(true),
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
@@ -138,7 +138,8 @@
}
TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
- : port_(port),
+ : interruptableChildren_(true),
+ port_(port),
serverSocket_(THRIFT_INVALID_SOCKET),
acceptBacklog_(DEFAULT_BACKLOG),
sendTimeout_(sendTimeout),
@@ -149,7 +150,6 @@
tcpSendBuffer_(0),
tcpRecvBuffer_(0),
keepAlive_(false),
- interruptableChildren_(true),
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
@@ -157,7 +157,8 @@
}
TServerSocket::TServerSocket(const string& address, int port)
- : port_(port),
+ : interruptableChildren_(true),
+ port_(port),
address_(address),
serverSocket_(THRIFT_INVALID_SOCKET),
acceptBacklog_(DEFAULT_BACKLOG),
@@ -169,7 +170,6 @@
tcpSendBuffer_(0),
tcpRecvBuffer_(0),
keepAlive_(false),
- interruptableChildren_(true),
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
@@ -177,7 +177,8 @@
}
TServerSocket::TServerSocket(const string& path)
- : port_(0),
+ : interruptableChildren_(true),
+ port_(0),
path_(path),
serverSocket_(THRIFT_INVALID_SOCKET),
acceptBacklog_(DEFAULT_BACKLOG),
@@ -189,7 +190,6 @@
tcpSendBuffer_(0),
tcpRecvBuffer_(0),
keepAlive_(false),
- interruptableChildren_(true),
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h
index 58e4afd..20a37e7 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.h
+++ b/lib/cpp/src/thrift/transport/TServerSocket.h
@@ -123,6 +123,8 @@
protected:
boost::shared_ptr<TTransport> acceptImpl();
virtual boost::shared_ptr<TSocket> createSocket(THRIFT_SOCKET client);
+ bool interruptableChildren_;
+ boost::shared_ptr<THRIFT_SOCKET> pChildInterruptSockReader_; // if interruptableChildren_ this is shared with child TSockets
private:
void notify(THRIFT_SOCKET notifySock);
@@ -140,13 +142,11 @@
int tcpSendBuffer_;
int tcpRecvBuffer_;
bool keepAlive_;
- bool interruptableChildren_;
bool listening_;
THRIFT_SOCKET interruptSockWriter_; // is notified on interrupt()
THRIFT_SOCKET interruptSockReader_; // is used in select/poll with serverSocket_ for interruptability
THRIFT_SOCKET childInterruptSockWriter_; // is notified on interruptChildren()
- boost::shared_ptr<THRIFT_SOCKET> pChildInterruptSockReader_; // if interruptableChildren_ this is shared with child TSockets
socket_func_t listenCallback_;
socket_func_t acceptCallback_;
diff --git a/lib/cpp/test/CMakeLists.txt b/lib/cpp/test/CMakeLists.txt
index 02932cb..491d343 100644
--- a/lib/cpp/test/CMakeLists.txt
+++ b/lib/cpp/test/CMakeLists.txt
@@ -92,7 +92,10 @@
endif ( MSVC )
-set( TInterruptTest_SOURCES TSocketInterruptTest.cpp )
+set( TInterruptTest_SOURCES
+ TSocketInterruptTest.cpp
+ TSSLSocketInterruptTest.cpp
+)
if (WIN32)
list(APPEND TInterruptTest_SOURCES
TPipeInterruptTest.cpp
@@ -108,7 +111,7 @@
if (NOT MSVC AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
target_link_libraries(TInterruptTest -lrt)
endif ()
-add_test(NAME TInterruptTest COMMAND TInterruptTest)
+add_test(NAME TInterruptTest COMMAND TInterruptTest "${CMAKE_CURRENT_SOURCE_DIR}/../../../test/keys")
add_executable(TServerIntegrationTest TServerIntegrationTest.cpp)
target_link_libraries(TServerIntegrationTest
diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am
index 4b7a99f..1895afc 100755
--- a/lib/cpp/test/Makefile.am
+++ b/lib/cpp/test/Makefile.am
@@ -128,11 +128,13 @@
$(BOOST_TEST_LDADD)
TInterruptTest_SOURCES = \
- TSocketInterruptTest.cpp
+ TSocketInterruptTest.cpp \
+ TSSLSocketInterruptTest.cpp
TInterruptTest_LDADD = \
libtestgencpp.la \
$(BOOST_TEST_LDADD) \
+ $(BOOST_FILESYSTEM_LDADD) \
$(BOOST_CHRONO_LDADD) \
$(BOOST_SYSTEM_LDADD) \
$(BOOST_THREAD_LDADD)
diff --git a/lib/cpp/test/TSSLSocketInterruptTest.cpp b/lib/cpp/test/TSSLSocketInterruptTest.cpp
new file mode 100644
index 0000000..c723d0e
--- /dev/null
+++ b/lib/cpp/test/TSSLSocketInterruptTest.cpp
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <boost/test/auto_unit_test.hpp>
+#include <boost/test/unit_test_suite.hpp>
+#include <boost/bind.hpp>
+#include <boost/chrono/duration.hpp>
+#include <boost/date_time/posix_time/posix_time_duration.hpp>
+#include <boost/thread/thread.hpp>
+#include <boost/filesystem.hpp>
+#include <boost/format.hpp>
+#include <boost/shared_ptr.hpp>
+#include <thrift/transport/TSSLSocket.h>
+#include <thrift/transport/TSSLServerSocket.h>
+#include "TestPortFixture.h"
+#ifdef __linux__
+#include <signal.h>
+#endif
+
+using apache::thrift::transport::TSSLServerSocket;
+using apache::thrift::transport::TSSLSocket;
+using apache::thrift::transport::TTransport;
+using apache::thrift::transport::TTransportException;
+using apache::thrift::transport::TSSLSocketFactory;
+
+boost::filesystem::path keyDir;
+boost::filesystem::path certFile(const std::string& filename)
+{
+ return keyDir / filename;
+}
+boost::mutex gMutex;
+
+struct GlobalFixtureSSL
+{
+ GlobalFixtureSSL()
+ {
+ using namespace boost::unit_test::framework;
+ for (int i = 0; i < master_test_suite().argc; ++i)
+ {
+ BOOST_TEST_MESSAGE(boost::format("argv[%1%] = \"%2%\"") % i % master_test_suite().argv[i]);
+ }
+
+#ifdef __linux__
+ // OpenSSL calls send() without MSG_NOSIGPIPE so writing to a socket that has
+ // disconnected can cause a SIGPIPE signal...
+ signal(SIGPIPE, SIG_IGN);
+#endif
+
+ TSSLSocketFactory::setManualOpenSSLInitialization(true);
+ apache::thrift::transport::initializeOpenSSL();
+
+ keyDir = boost::filesystem::current_path().parent_path().parent_path().parent_path() / "test" / "keys";
+ if (!boost::filesystem::exists(certFile("server.crt")))
+ {
+ keyDir = boost::filesystem::path(master_test_suite().argv[master_test_suite().argc - 1]);
+ if (!boost::filesystem::exists(certFile("server.crt")))
+ {
+ throw std::invalid_argument("The last argument to this test must be the directory containing the test certificate(s).");
+ }
+ }
+ }
+
+ virtual ~GlobalFixtureSSL()
+ {
+ apache::thrift::transport::cleanupOpenSSL();
+#ifdef __linux__
+ signal(SIGPIPE, SIG_DFL);
+#endif
+ }
+};
+
+#if (BOOST_VERSION >= 105900)
+BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL);
+#else
+BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL)
+#endif
+
+BOOST_FIXTURE_TEST_SUITE(TSSLSocketInterruptTest, TestPortFixture)
+
+void readerWorker(boost::shared_ptr<TTransport> tt, uint32_t expectedResult) {
+ uint8_t buf[4];
+ try {
+ tt->read(buf, 1);
+ BOOST_CHECK_EQUAL(expectedResult, tt->read(buf, 4));
+ } catch (const TTransportException& tx) {
+ BOOST_CHECK_EQUAL(TTransportException::INTERNAL_ERROR, tx.getType());
+ }
+}
+
+void readerWorkerMustThrow(boost::shared_ptr<TTransport> tt) {
+ try {
+ uint8_t buf[400];
+ tt->read(buf, 1);
+ tt->read(buf, 400);
+ BOOST_ERROR("should not have gotten here");
+ } catch (const TTransportException& tx) {
+ BOOST_CHECK_EQUAL(TTransportException::INTERRUPTED, tx.getType());
+ }
+}
+
+boost::shared_ptr<TSSLSocketFactory> createServerSocketFactory() {
+ boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory;
+
+ pServerSocketFactory.reset(new TSSLSocketFactory());
+ pServerSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+ pServerSocketFactory->loadCertificate(certFile("server.crt").native().c_str());
+ pServerSocketFactory->loadPrivateKey(certFile("server.key").native().c_str());
+ pServerSocketFactory->server(true);
+ return pServerSocketFactory;
+}
+
+boost::shared_ptr<TSSLSocketFactory> createClientSocketFactory() {
+ boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory;
+
+ pClientSocketFactory.reset(new TSSLSocketFactory());
+ pClientSocketFactory->authenticate(true);
+ pClientSocketFactory->loadCertificate(certFile("client.crt").native().c_str());
+ pClientSocketFactory->loadPrivateKey(certFile("client.key").native().c_str());
+ pClientSocketFactory->loadTrustedCertificates(certFile("CA.pem").native().c_str());
+ return pClientSocketFactory;
+}
+
+BOOST_AUTO_TEST_CASE(test_ssl_interruptable_child_read_while_handshaking) {
+ boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory();
+ TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory);
+ sock1.listen();
+ boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
+ boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort);
+ clientSock->open();
+ boost::shared_ptr<TTransport> accepted = sock1.accept();
+ boost::thread readThread(boost::bind(readerWorkerMustThrow, accepted));
+ boost::this_thread::sleep(boost::posix_time::milliseconds(50));
+ // readThread is practically guaranteed to be blocking now
+ sock1.interruptChildren();
+ BOOST_CHECK_MESSAGE(readThread.try_join_for(boost::chrono::milliseconds(20)),
+ "server socket interruptChildren did not interrupt child read");
+ clientSock->close();
+ accepted->close();
+ sock1.close();
+}
+
+BOOST_AUTO_TEST_CASE(test_ssl_interruptable_child_read) {
+ boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory();
+ TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory);
+ sock1.listen();
+ boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
+ boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort);
+ clientSock->open();
+ boost::shared_ptr<TTransport> accepted = sock1.accept();
+ boost::thread readThread(boost::bind(readerWorkerMustThrow, accepted));
+ clientSock->write((const uint8_t*)"0", 1);
+ boost::this_thread::sleep(boost::posix_time::milliseconds(50));
+ // readThread is practically guaranteed to be blocking now
+ sock1.interruptChildren();
+ BOOST_CHECK_MESSAGE(readThread.try_join_for(boost::chrono::milliseconds(20)),
+ "server socket interruptChildren did not interrupt child read");
+ accepted->close();
+ clientSock->close();
+ sock1.close();
+}
+
+BOOST_AUTO_TEST_CASE(test_ssl_non_interruptable_child_read) {
+ std::cout << "An error message from SSL_Shutdown on the console is expected:" << std::endl;
+ boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory();
+ TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory);
+ sock1.setInterruptableChildren(false); // returns to pre-THRIFT-2441 behavior
+ sock1.listen();
+ boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
+ boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort);
+ clientSock->open();
+ boost::shared_ptr<TTransport> accepted = sock1.accept();
+ boost::thread readThread(boost::bind(readerWorker, accepted, 0));
+ clientSock->write((const uint8_t*)"0", 1);
+ boost::this_thread::sleep(boost::posix_time::milliseconds(50));
+ // readThread is practically guaranteed to be blocking here
+ sock1.interruptChildren();
+ BOOST_CHECK_MESSAGE(!readThread.try_join_for(boost::chrono::milliseconds(200)),
+ "server socket interruptChildren interrupted child read");
+
+ // only way to proceed is to have the client disconnect
+ clientSock->close();
+ readThread.join();
+ accepted->close();
+ sock1.close();
+}
+
+BOOST_AUTO_TEST_CASE(test_ssl_cannot_change_after_listen) {
+ boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory();
+ TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory);
+ sock1.listen();
+ BOOST_CHECK_THROW(sock1.setInterruptableChildren(false), std::logic_error);
+ sock1.close();
+}
+
+void peekerWorker(boost::shared_ptr<TTransport> tt, bool expectedResult) {
+ uint8_t buf[400];
+
+ tt->read(buf, 1);
+ BOOST_CHECK_EQUAL(expectedResult, tt->peek());
+}
+
+void peekerWorkerInterrupt(boost::shared_ptr<TTransport> tt) {
+ uint8_t buf[400];
+ try {
+ tt->read(buf, 1);
+ tt->peek();
+ } catch (const TTransportException& tx) {
+ BOOST_CHECK_EQUAL(TTransportException::INTERRUPTED, tx.getType());
+ }
+}
+
+BOOST_AUTO_TEST_CASE(test_ssl_interruptable_child_peek) {
+ std::cout << "An error message from SSL_Shutdown on the console is expected:" << std::endl;
+ boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory();
+ TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory);
+ sock1.listen();
+ boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
+ boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort);
+ clientSock->open();
+ boost::shared_ptr<TTransport> accepted = sock1.accept();
+ // peek() will return false if child is interrupted
+ boost::thread peekThread(boost::bind(peekerWorkerInterrupt, accepted));
+ clientSock->write((const uint8_t*)"0", 1);
+ boost::this_thread::sleep(boost::posix_time::milliseconds(50));
+ // peekThread is practically guaranteed to be blocking now
+ sock1.interruptChildren();
+ BOOST_CHECK_MESSAGE(peekThread.try_join_for(boost::chrono::milliseconds(200)),
+ "server socket interruptChildren did not interrupt child peek");
+#ifdef __linux__
+ signal(SIGPIPE, SIG_IGN);
+#endif
+ clientSock->close();
+ accepted->close();
+ sock1.close();
+}
+
+BOOST_AUTO_TEST_CASE(test_ssl_non_interruptable_child_peek) {
+ std::cout << "An error message from SSL_Shutdown on the console is expected:" << std::endl;
+ boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory();
+ TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory);
+ sock1.setInterruptableChildren(false); // returns to pre-THRIFT-2441 behavior
+ sock1.listen();
+ boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
+ boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort);
+ clientSock->open();
+ boost::shared_ptr<TTransport> accepted = sock1.accept();
+ // peek() will return false when remote side is closed
+ boost::thread peekThread(boost::bind(peekerWorker, accepted, false));
+ //boost::thread peekThread(boost::bind(peekerWorkerRead, clientSock, false));
+ clientSock->write((const uint8_t*)"0", 1);
+ boost::this_thread::sleep(boost::posix_time::milliseconds(50));
+ // peekThread is practically guaranteed to be blocking now
+ sock1.interruptChildren();
+ BOOST_CHECK_MESSAGE(!peekThread.try_join_for(boost::chrono::milliseconds(200)),
+ "server socket interruptChildren interrupted child peek");
+
+ // only way to proceed is to have the client disconnect
+#ifdef __linux__
+ signal(SIGPIPE, SIG_IGN);
+#endif
+ clientSock->close();
+ peekThread.join();
+ accepted->close();
+ sock1.close();
+}
+
+BOOST_AUTO_TEST_SUITE_END()