THRIFT-4276:Add SSL support to the C++ Nonblocking Server
Client: C++ Lib
Patch: Divya Thaluru
Github Pull Request:
This closes #1251
diff --git a/lib/cpp/CMakeLists.txt b/lib/cpp/CMakeLists.txt
index 59da346..734af72 100755
--- a/lib/cpp/CMakeLists.txt
+++ b/lib/cpp/CMakeLists.txt
@@ -138,6 +138,8 @@
# Thrift non blocking server
set( thriftcppnb_SOURCES
src/thrift/server/TNonblockingServer.cpp
+ src/thrift/transport/TNonblockingServerSocket.cpp
+ src/thrift/transport/TNonblockingSSLServerSocket.cpp
src/thrift/async/TAsyncProtocolProcessor.cpp
src/thrift/async/TEvhttpServer.cpp
src/thrift/async/TEvhttpClientChannel.cpp
diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am
index 2a1cca8..64cf98e 100755
--- a/lib/cpp/Makefile.am
+++ b/lib/cpp/Makefile.am
@@ -94,6 +94,8 @@
src/thrift/transport/TSocketPool.cpp \
src/thrift/transport/TServerSocket.cpp \
src/thrift/transport/TSSLServerSocket.cpp \
+ src/thrift/transport/TNonblockingServerSocket.cpp \
+ src/thrift/transport/TNonblockingSSLServerSocket.cpp \
src/thrift/transport/TTransportUtils.cpp \
src/thrift/transport/TBufferTransports.cpp \
src/thrift/server/TConnectedClient.cpp \
@@ -212,6 +214,9 @@
src/thrift/transport/TServerSocket.h \
src/thrift/transport/TSSLServerSocket.h \
src/thrift/transport/TServerTransport.h \
+ src/thrift/transport/TNonblockingServerTransport.h \
+ src/thrift/transport/TNonblockingServerSocket.h \
+ src/thrift/transport/TNonblockingSSLServerSocket.h \
src/thrift/transport/THttpTransport.h \
src/thrift/transport/THttpClient.h \
src/thrift/transport/THttpServer.h \
diff --git a/lib/cpp/libthriftnb.vcxproj b/lib/cpp/libthriftnb.vcxproj
index 259bb20..9a6ffe6 100755
--- a/lib/cpp/libthriftnb.vcxproj
+++ b/lib/cpp/libthriftnb.vcxproj
@@ -35,16 +35,21 @@
</ProjectConfiguration>
</ItemGroup>
<ItemGroup>
- <ClCompile Include="src\thrift\async\TAsyncProtocolProcessor.cpp"/>
- <ClCompile Include="src\thrift\async\TEvhttpClientChannel.cpp"/>
- <ClCompile Include="src\thrift\async\TEvhttpServer.cpp"/>
- <ClCompile Include="src\thrift\server\TNonblockingServer.cpp"/>
+ <ClCompile Include="src\thrift\async\TAsyncProtocolProcessor.cpp" />
+ <ClCompile Include="src\thrift\async\TEvhttpClientChannel.cpp" />
+ <ClCompile Include="src\thrift\async\TEvhttpServer.cpp" />
+ <ClCompile Include="src\thrift\server\TNonblockingServer.cpp" />
+ <ClCompile Include="src\thrift\transport\TNonblockingServerSocket.cpp" />
+ <ClCompile Include="src\thrift\transport\TNonblockingSSLServerSocket.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="src\thrift\async\TAsyncProtocolProcessor.h" />
<ClInclude Include="src\thrift\async\TEvhttpClientChannel.h" />
<ClInclude Include="src\thrift\async\TEvhttpServer.h" />
<ClInclude Include="src\thrift\server\TNonblockingServer.h" />
+ <ClInclude Include="src\thrift\transport\TNonblockingServerSocket.h" />
+ <ClInclude Include="src\thrift\transport\TNonblockingServerTransport.h" />
+ <ClInclude Include="src\thrift\transport\TNonblockingSSLServerSocket.h" />
<ClInclude Include="src\thrift\windows\config.h" />
<ClInclude Include="src\thrift\windows\force_inc.h" />
<ClInclude Include="src\thrift\windows\TargetVersion.h" />
@@ -290,4 +295,4 @@
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
-</Project>
+</Project>
\ No newline at end of file
diff --git a/lib/cpp/libthriftnb.vcxproj.filters b/lib/cpp/libthriftnb.vcxproj.filters
index 5245544..85703dd 100644
--- a/lib/cpp/libthriftnb.vcxproj.filters
+++ b/lib/cpp/libthriftnb.vcxproj.filters
@@ -10,6 +10,9 @@
<Filter Include="windows">
<UniqueIdentifier>{60fc9e5e-0866-4aba-8662-439bb4a461d3}</UniqueIdentifier>
</Filter>
+ <Filter Include="transport">
+ <UniqueIdentifier>{23fe2fde-a7c9-43ec-a409-7f53df5eee64}</UniqueIdentifier>
+ </Filter>
</ItemGroup>
<ItemGroup>
<ClCompile Include="src\thrift\server\TNonblockingServer.cpp">
@@ -27,6 +30,12 @@
<ClCompile Include="src\thrift\windows\StdAfx.cpp">
<Filter>windows</Filter>
</ClCompile>
+ <ClCompile Include="src\thrift\transport\TNonblockingServerSocket.cpp">
+ <Filter>transport</Filter>
+ </ClCompile>
+ <ClCompile Include="src\thrift\transport\TNonblockingSSLServerSocket.cpp">
+ <Filter>transport</Filter>
+ </ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="src\thrift\server\TNonblockingServer.h">
@@ -53,5 +62,14 @@
<ClInclude Include="src\thrift\windows\force_inc.h">
<Filter>windows</Filter>
</ClInclude>
+ <ClInclude Include="src\thrift\transport\TNonblockingServerSocket.h">
+ <Filter>transport</Filter>
+ </ClInclude>
+ <ClInclude Include="src\thrift\transport\TNonblockingServerTransport.h">
+ <Filter>transport</Filter>
+ </ClInclude>
+ <ClInclude Include="src\thrift\transport\TNonblockingSSLServerSocket.h">
+ <Filter>transport</Filter>
+ </ClInclude>
</ItemGroup>
-</Project>
\ No newline at end of file
+</Project>
diff --git a/lib/cpp/src/thrift/server/TNonblockingServer.cpp b/lib/cpp/src/thrift/server/TNonblockingServer.cpp
index 97c4cd9..d5af12a 100644
--- a/lib/cpp/src/thrift/server/TNonblockingServer.cpp
+++ b/lib/cpp/src/thrift/server/TNonblockingServer.cpp
@@ -209,10 +209,8 @@
class Task;
/// Constructor
- TConnection(THRIFT_SOCKET socket,
- TNonblockingIOThread* ioThread,
- const sockaddr* addr,
- socklen_t addrLen) {
+ TConnection(boost::shared_ptr<TSocket> socket,
+ TNonblockingIOThread* ioThread) {
readBuffer_ = NULL;
readBufferSize_ = 0;
@@ -224,8 +222,10 @@
inputTransport_.reset(new TMemoryBuffer(readBuffer_, readBufferSize_));
outputTransport_.reset(
new TMemoryBuffer(static_cast<uint32_t>(server_->getWriteBufferDefaultSize())));
- tSocket_.reset(new TSocket());
- init(socket, ioThread, addr, addrLen);
+
+ tSocket_ = socket;
+
+ init(ioThread);
}
~TConnection() { std::free(readBuffer_); }
@@ -242,10 +242,10 @@
void checkIdleBufferMemLimit(size_t readLimit, size_t writeLimit);
/// Initialize
- void init(THRIFT_SOCKET socket,
- TNonblockingIOThread* ioThread,
- const sockaddr* addr,
- socklen_t addrLen);
+ void init(TNonblockingIOThread* ioThread);
+
+ /// set socket for connection
+ void setSocket(boost::shared_ptr<TSocket> socket);
/**
* This is called when the application transitions from one state into
@@ -367,13 +367,7 @@
void* connectionContext_;
};
-void TNonblockingServer::TConnection::init(THRIFT_SOCKET socket,
- TNonblockingIOThread* ioThread,
- const sockaddr* addr,
- socklen_t addrLen) {
- tSocket_->setSocketFD(socket);
- tSocket_->setCachedAddress(addr, addrLen);
-
+void TNonblockingServer::TConnection::init(TNonblockingIOThread* ioThread) {
ioThread_ = ioThread;
server_ = ioThread->getServer();
appState_ = APP_INIT;
@@ -416,6 +410,10 @@
processor_ = server_->getProcessor(inputProtocol_, outputProtocol_, tSocket_);
}
+void TNonblockingServer::TConnection::setSocket(boost::shared_ptr<TSocket> socket) {
+ tSocket_ = socket;
+}
+
void TNonblockingServer::TConnection::workSocket() {
int got = 0, left = 0, sent = 0;
uint32_t fetch = 0;
@@ -441,10 +439,14 @@
}
readBufferPos_ += fetch;
} catch (TTransportException& te) {
- GlobalOutput.printf("TConnection::workSocket(): %s", te.what());
- close();
+ //In Nonblocking SSLSocket some operations need to be retried again.
+ //Current approach is parsing exception message, but a better solution needs to be investigated.
+ if(!strstr(te.what(), "retry")) {
+ GlobalOutput.printf("TConnection::workSocket(): %s", te.what());
+ close();
- return;
+ return;
+ }
}
if (readBufferPos_ < sizeof(framing.size)) {
@@ -481,8 +483,12 @@
fetch = readWant_ - readBufferPos_;
got = tSocket_->read(readBuffer_ + readBufferPos_, fetch);
} catch (TTransportException& te) {
- GlobalOutput.printf("TConnection::workSocket(): %s", te.what());
- close();
+ //In Nonblocking SSLSocket some operations need to be retried again.
+ //Current approach is parsing exception message, but a better solution needs to be investigated.
+ if(!strstr(te.what(), "retry")) {
+ GlobalOutput.printf("TConnection::workSocket(): %s", te.what());
+ close();
+ }
return;
}
@@ -748,7 +754,7 @@
appState_ = APP_READ_REQUEST;
// Work the socket right away
- // workSocket();
+ workSocket();
return;
@@ -883,9 +889,7 @@
* Creates a new connection either by reusing an object off the stack or
* by allocating a new one entirely
*/
-TNonblockingServer::TConnection* TNonblockingServer::createConnection(THRIFT_SOCKET socket,
- const sockaddr* addr,
- socklen_t addrLen) {
+TNonblockingServer::TConnection* TNonblockingServer::createConnection(boost::shared_ptr<TSocket> socket) {
// Check the stack
Guard g(connMutex_);
@@ -899,12 +903,13 @@
// Check the connection stack to see if we can re-use
TConnection* result = NULL;
if (connectionStack_.empty()) {
- result = new TConnection(socket, ioThread, addr, addrLen);
+ result = new TConnection(socket, ioThread);
++numTConnections_;
} else {
result = connectionStack_.top();
connectionStack_.pop();
- result->init(socket, ioThread, addr, addrLen);
+ result->setSocket(socket);
+ result->init(ioThread);
}
activeConnections_.push_back(result);
return result;
@@ -939,53 +944,35 @@
// Make sure that libevent didn't mess up the socket handles
assert(fd == serverSocket_);
- // Server socket accepted a new connection
- socklen_t addrLen;
- sockaddr_storage addrStorage;
- sockaddr* addrp = (sockaddr*)&addrStorage;
- addrLen = sizeof(addrStorage);
-
// Going to accept a new client socket
- THRIFT_SOCKET clientSocket;
+ boost::shared_ptr<TSocket> clientSocket;
- // Accept as many new clients as possible, even though libevent signaled only
- // one, this helps us to avoid having to go back into the libevent engine so
- // many times
- while ((clientSocket = ::accept(fd, addrp, &addrLen)) != -1) {
+ clientSocket = serverTransport_->accept();
+ if (clientSocket) {
// If we're overloaded, take action here
if (overloadAction_ != T_OVERLOAD_NO_ACTION && serverOverloaded()) {
Guard g(connMutex_);
nConnectionsDropped_++;
nTotalConnectionsDropped_++;
if (overloadAction_ == T_OVERLOAD_CLOSE_ON_ACCEPT) {
- ::THRIFT_CLOSESOCKET(clientSocket);
+ clientSocket->close();
return;
} else if (overloadAction_ == T_OVERLOAD_DRAIN_TASK_QUEUE) {
if (!drainPendingTask()) {
// Nothing left to discard, so we drop connection instead.
- ::THRIFT_CLOSESOCKET(clientSocket);
+ clientSocket->close();
return;
}
}
}
- // Explicitly set this socket to NONBLOCK mode
- int flags;
- if ((flags = THRIFT_FCNTL(clientSocket, THRIFT_F_GETFL, 0)) < 0
- || THRIFT_FCNTL(clientSocket, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK) < 0) {
- GlobalOutput.perror("thriftServerEventHandler: set THRIFT_O_NONBLOCK (THRIFT_FCNTL) ",
- THRIFT_GET_SOCKET_ERROR);
- ::THRIFT_CLOSESOCKET(clientSocket);
- return;
- }
-
// Create a new TConnection for this client socket.
- TConnection* clientConnection = createConnection(clientSocket, addrp, addrLen);
+ TConnection* clientConnection = createConnection(clientSocket);
// Fail fast if we could not create a TConnection object
if (clientConnection == NULL) {
GlobalOutput.printf("thriftServerEventHandler: failed TConnection factory");
- ::THRIFT_CLOSESOCKET(clientSocket);
+ clientSocket->close();
return;
}
@@ -1009,15 +996,6 @@
clientConnection->close();
}
}
-
- // addrLen is written by the accept() call, so needs to be set before the next call.
- addrLen = sizeof(addrStorage);
- }
-
- // Done looping accept, now we have to make sure the error is due to
- // blocking. Any other error is a problem
- if (THRIFT_GET_SOCKET_ERROR != THRIFT_EAGAIN && THRIFT_GET_SOCKET_ERROR != THRIFT_EWOULDBLOCK) {
- GlobalOutput.perror("thriftServerEventHandler: accept() ", THRIFT_GET_SOCKET_ERROR);
}
}
@@ -1025,130 +1003,10 @@
* Creates a socket to listen on and binds it to the local port.
*/
void TNonblockingServer::createAndListenOnSocket() {
-#ifdef _WIN32
- TWinsockSingleton::create();
-#endif // _WIN32
-
- THRIFT_SOCKET s;
-
- struct addrinfo hints, *res, *res0;
- int error;
-
- char port[sizeof("65536") + 1];
- memset(&hints, 0, sizeof(hints));
- hints.ai_family = PF_UNSPEC;
- hints.ai_socktype = SOCK_STREAM;
- hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
- sprintf(port, "%d", port_);
-
- // Wildcard address
- error = getaddrinfo(NULL, port, &hints, &res0);
- if (error) {
- throw TException("TNonblockingServer::serve() getaddrinfo "
- + string(THRIFT_GAI_STRERROR(error)));
- }
-
- // Pick the ipv6 address first since ipv4 addresses can be mapped
- // into ipv6 space.
- for (res = res0; res; res = res->ai_next) {
- if (res->ai_family == AF_INET6 || res->ai_next == NULL)
- break;
- }
-
- // Create the server socket
- s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
- if (s == -1) {
- freeaddrinfo(res0);
- throw TException("TNonblockingServer::serve() socket() -1");
- }
-
-#ifdef IPV6_V6ONLY
- if (res->ai_family == AF_INET6) {
- int zero = 0;
- if (-1 == setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, const_cast_sockopt(&zero), sizeof(zero))) {
- GlobalOutput.perror("TServerSocket::listen() IPV6_V6ONLY", THRIFT_GET_SOCKET_ERROR);
- }
- }
-#endif // #ifdef IPV6_V6ONLY
-
- int one = 1;
-
- // Set THRIFT_NO_SOCKET_CACHING to avoid 2MSL delay on server restart
- setsockopt(s, SOL_SOCKET, THRIFT_NO_SOCKET_CACHING, const_cast_sockopt(&one), sizeof(one));
-
- if (::bind(s, res->ai_addr, static_cast<int>(res->ai_addrlen)) == -1) {
- ::THRIFT_CLOSESOCKET(s);
- freeaddrinfo(res0);
- throw TTransportException(TTransportException::NOT_OPEN,
- "TNonblockingServer::serve() bind",
- THRIFT_GET_SOCKET_ERROR);
- }
-
- // Done with the addr info
- freeaddrinfo(res0);
-
- // Set up this file descriptor for listening
- listenSocket(s);
+ serverTransport_->listen();
+ serverSocket_ = serverTransport_->getSocketFD();
}
-/**
- * Takes a socket created by listenSocket() and sets various options on it
- * to prepare for use in the server.
- */
-void TNonblockingServer::listenSocket(THRIFT_SOCKET s) {
- // Set socket to nonblocking mode
- int flags;
- if ((flags = THRIFT_FCNTL(s, THRIFT_F_GETFL, 0)) < 0
- || THRIFT_FCNTL(s, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK) < 0) {
- ::THRIFT_CLOSESOCKET(s);
- throw TException("TNonblockingServer::serve() THRIFT_O_NONBLOCK");
- }
-
- int one = 1;
- struct linger ling = {0, 0};
-
- // Keepalive to ensure full result flushing
- setsockopt(s, SOL_SOCKET, SO_KEEPALIVE, const_cast_sockopt(&one), sizeof(one));
-
- // Turn linger off to avoid hung sockets
- setsockopt(s, SOL_SOCKET, SO_LINGER, const_cast_sockopt(&ling), sizeof(ling));
-
-// Set TCP nodelay if available, MAC OS X Hack
-// See http://lists.danga.com/pipermail/memcached/2005-March/001240.html
-#ifndef TCP_NOPUSH
- setsockopt(s, IPPROTO_TCP, TCP_NODELAY, const_cast_sockopt(&one), sizeof(one));
-#endif
-
-#ifdef TCP_LOW_MIN_RTO
- if (TSocket::getUseLowMinRto()) {
- setsockopt(s, IPPROTO_TCP, TCP_LOW_MIN_RTO, const_cast_sockopt(&one), sizeof(one));
- }
-#endif
-
- if (listen(s, LISTEN_BACKLOG) == -1) {
- ::THRIFT_CLOSESOCKET(s);
- throw TTransportException(TTransportException::NOT_OPEN, "TNonblockingServer::serve() listen");
- }
-
- // Cool, this socket is good to go, set it as the serverSocket_
- serverSocket_ = s;
-
- if (!port_) {
- struct sockaddr_storage addr;
- socklen_t size = sizeof(addr);
- if (!getsockname(serverSocket_, reinterpret_cast<sockaddr*>(&addr), &size)) {
- if (addr.ss_family == AF_INET6) {
- const struct sockaddr_in6* sin = reinterpret_cast<const struct sockaddr_in6*>(&addr);
- listenPort_ = ntohs(sin->sin6_port);
- } else {
- const struct sockaddr_in* sin = reinterpret_cast<const struct sockaddr_in*>(&addr);
- listenPort_ = ntohs(sin->sin_port);
- }
- } else {
- GlobalOutput.perror("TNonblocking: failed to get listen port: ", THRIFT_GET_SOCKET_ERROR);
- }
- }
-}
void TNonblockingServer::setThreadManager(boost::shared_ptr<ThreadManager> threadManager) {
threadManager_ = threadManager;
@@ -1205,10 +1063,7 @@
connection->forceClose();
}
-void TNonblockingServer::stop() {
- if (!port_) {
- listenPort_ = 0;
- }
+void TNonblockingServer::stop() {
// Breaks the event loop in all threads so that they end ASAP.
for (uint32_t i = 0; i < ioThreads_.size(); ++i) {
ioThreads_[i]->stop();
@@ -1249,8 +1104,7 @@
assert(ioThreads_.size() == numIOThreads_);
assert(ioThreads_.size() > 0);
- GlobalOutput.printf("TNonblockingServer: Serving on port %d, %d io threads.",
- listenPort_,
+ GlobalOutput.printf("TNonblockingServer: Serving with %d io threads.",
ioThreads_.size());
// Launch all the secondary IO threads in separate threads
diff --git a/lib/cpp/src/thrift/server/TNonblockingServer.h b/lib/cpp/src/thrift/server/TNonblockingServer.h
index 82d40e9..1f60048 100644
--- a/lib/cpp/src/thrift/server/TNonblockingServer.h
+++ b/lib/cpp/src/thrift/server/TNonblockingServer.h
@@ -25,6 +25,7 @@
#include <thrift/transport/PlatformSocket.h>
#include <thrift/transport/TBufferTransports.h>
#include <thrift/transport/TSocket.h>
+#include <thrift/transport/TNonblockingServerTransport.h>
#include <thrift/concurrency/ThreadManager.h>
#include <climits>
#include <thrift/concurrency/Thread.h>
@@ -47,6 +48,7 @@
using apache::thrift::transport::TMemoryBuffer;
using apache::thrift::transport::TSocket;
+using apache::thrift::transport::TNonblockingServerTransport;
using apache::thrift::protocol::TProtocol;
using apache::thrift::concurrency::Runnable;
using apache::thrift::concurrency::ThreadManager;
@@ -96,10 +98,6 @@
* operates a set of IO threads (by default only one). It assumes that
* all incoming requests are framed with a 4 byte length indicator and
* writes out responses using the same framing.
- *
- * It does not use the TServerTransport framework, but rather has socket
- * operations hardcoded for use with select.
- *
*/
/// Overload condition actions.
@@ -157,12 +155,6 @@
/// Server socket file descriptor
THRIFT_SOCKET serverSocket_;
- /// Port server runs on. Zero when letting OS decide actual port
- int port_;
-
- /// Port server actually runs on
- int listenPort_;
-
/// The optional user-provided event-base (for single-thread servers)
event_base* userEventBase_;
@@ -269,23 +261,24 @@
*/
std::vector<TConnection*> activeConnections_;
+ /*
+ */
+ boost::shared_ptr<TNonblockingServerTransport> serverTransport_;
+
/**
* Called when server socket had something happen. We accept all waiting
* client connections on listen socket fd and assign TConnection objects
* to handle those requests.
*
- * @param fd the listen socket.
* @param which the event flag that triggered the handler.
*/
void handleEvent(THRIFT_SOCKET fd, short which);
- void init(int port) {
+ void init() {
serverSocket_ = THRIFT_INVALID_SOCKET;
numIOThreads_ = DEFAULT_IO_THREADS;
nextIOThread_ = 0;
useHighPriorityIOThreads_ = false;
- port_ = port;
- listenPort_ = port;
userEventBase_ = NULL;
threadPoolProcessing_ = false;
numTConnections_ = 0;
@@ -307,38 +300,42 @@
}
public:
- TNonblockingServer(const boost::shared_ptr<TProcessorFactory>& processorFactory, int port)
- : TServer(processorFactory) {
- init(port);
+ TNonblockingServer(const boost::shared_ptr<TProcessorFactory>& processorFactory,
+ const boost::shared_ptr<apache::thrift::transport::TNonblockingServerTransport>& serverTransport)
+ : TServer(processorFactory), serverTransport_(serverTransport) {
+ init();
}
- TNonblockingServer(const boost::shared_ptr<TProcessor>& processor, int port)
- : TServer(processor) {
- init(port);
+ TNonblockingServer(const boost::shared_ptr<TProcessor>& processor,
+ const boost::shared_ptr<apache::thrift::transport::TNonblockingServerTransport>& serverTransport)
+ : TServer(processor), serverTransport_(serverTransport) {
+ init();
}
+
TNonblockingServer(const boost::shared_ptr<TProcessorFactory>& processorFactory,
const boost::shared_ptr<TProtocolFactory>& protocolFactory,
- int port,
+ const boost::shared_ptr<apache::thrift::transport::TNonblockingServerTransport>& serverTransport,
const boost::shared_ptr<ThreadManager>& threadManager
= boost::shared_ptr<ThreadManager>())
- : TServer(processorFactory) {
+ : TServer(processorFactory), serverTransport_(serverTransport) {
- init(port);
+ init();
setInputProtocolFactory(protocolFactory);
setOutputProtocolFactory(protocolFactory);
setThreadManager(threadManager);
}
+
TNonblockingServer(const boost::shared_ptr<TProcessor>& processor,
const boost::shared_ptr<TProtocolFactory>& protocolFactory,
- int port,
+ const boost::shared_ptr<apache::thrift::transport::TNonblockingServerTransport>& serverTransport,
const boost::shared_ptr<ThreadManager>& threadManager
= boost::shared_ptr<ThreadManager>())
- : TServer(processor) {
+ : TServer(processor), serverTransport_(serverTransport) {
- init(port);
+ init();
setInputProtocolFactory(protocolFactory);
setOutputProtocolFactory(protocolFactory);
@@ -350,12 +347,12 @@
const boost::shared_ptr<TTransportFactory>& outputTransportFactory,
const boost::shared_ptr<TProtocolFactory>& inputProtocolFactory,
const boost::shared_ptr<TProtocolFactory>& outputProtocolFactory,
- int port,
+ const boost::shared_ptr<apache::thrift::transport::TNonblockingServerTransport>& serverTransport,
const boost::shared_ptr<ThreadManager>& threadManager
= boost::shared_ptr<ThreadManager>())
- : TServer(processorFactory) {
+ : TServer(processorFactory), serverTransport_(serverTransport) {
- init(port);
+ init();
setInputTransportFactory(inputTransportFactory);
setOutputTransportFactory(outputTransportFactory);
@@ -369,12 +366,12 @@
const boost::shared_ptr<TTransportFactory>& outputTransportFactory,
const boost::shared_ptr<TProtocolFactory>& inputProtocolFactory,
const boost::shared_ptr<TProtocolFactory>& outputProtocolFactory,
- int port,
+ const boost::shared_ptr<apache::thrift::transport::TNonblockingServerTransport>& serverTransport,
const boost::shared_ptr<ThreadManager>& threadManager
= boost::shared_ptr<ThreadManager>())
- : TServer(processor) {
+ : TServer(processor), serverTransport_(serverTransport) {
- init(port);
+ init();
setInputTransportFactory(inputTransportFactory);
setOutputTransportFactory(outputTransportFactory);
@@ -387,7 +384,7 @@
void setThreadManager(boost::shared_ptr<ThreadManager> threadManager);
- int getListenPort() { return listenPort_; }
+ int getListenPort() { return serverTransport_->getListenPort(); }
boost::shared_ptr<ThreadManager> getThreadManager() { return threadManager_; }
@@ -687,15 +684,7 @@
/// Creates a socket to listen on and binds it to the local port.
void createAndListenOnSocket();
-
- /**
- * Takes a socket created by createAndListenOnSocket() and sets various
- * options on it to prepare for use in the server.
- *
- * @param fd descriptor of socket to be initialized/
- */
- void listenSocket(THRIFT_SOCKET fd);
-
+
/**
* Register the optional user-provided event-base (for single-thread servers)
*
@@ -736,7 +725,7 @@
* @param addrLen the length of addr
* @return pointer to initialized TConnection object.
*/
- TConnection* createConnection(THRIFT_SOCKET socket, const sockaddr* addr, socklen_t addrLen);
+ TConnection* createConnection(boost::shared_ptr<TSocket> socket);
/**
* Returns a connection to pool or deletion. If the connection pool
diff --git a/lib/cpp/src/thrift/transport/TNonblockingSSLServerSocket.cpp b/lib/cpp/src/thrift/transport/TNonblockingSSLServerSocket.cpp
new file mode 100644
index 0000000..8e8b897
--- /dev/null
+++ b/lib/cpp/src/thrift/transport/TNonblockingSSLServerSocket.cpp
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <thrift/transport/TNonblockingSSLServerSocket.h>
+#include <thrift/transport/TSSLSocket.h>
+
+namespace apache {
+namespace thrift {
+namespace transport {
+
+/**
+ * Nonblocking SSL server socket implementation.
+ */
+TNonblockingSSLServerSocket::TNonblockingSSLServerSocket(int port, boost::shared_ptr<TSSLSocketFactory> factory)
+ : TNonblockingServerSocket(port), factory_(factory) {
+ factory_->server(true);
+}
+
+TNonblockingSSLServerSocket::TNonblockingSSLServerSocket(const std::string& address,
+ int port,
+ boost::shared_ptr<TSSLSocketFactory> factory)
+ : TNonblockingServerSocket(address, port), factory_(factory) {
+ factory_->server(true);
+}
+
+TNonblockingSSLServerSocket::TNonblockingSSLServerSocket(int port,
+ int sendTimeout,
+ int recvTimeout,
+ boost::shared_ptr<TSSLSocketFactory> factory)
+ : TNonblockingServerSocket(port, sendTimeout, recvTimeout), factory_(factory) {
+ factory_->server(true);
+}
+
+boost::shared_ptr<TSocket> TNonblockingSSLServerSocket::createSocket(THRIFT_SOCKET client) {
+ boost::shared_ptr<TSSLSocket> tSSLSocket;
+ tSSLSocket = factory_->createSocket(client);
+ tSSLSocket->setLibeventSafe();
+ return tSSLSocket;
+}
+}
+}
+}
diff --git a/lib/cpp/src/thrift/transport/TNonblockingSSLServerSocket.h b/lib/cpp/src/thrift/transport/TNonblockingSSLServerSocket.h
new file mode 100644
index 0000000..66a8a70
--- /dev/null
+++ b/lib/cpp/src/thrift/transport/TNonblockingSSLServerSocket.h
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TNONBLOCKINGSSLSERVERSOCKET_H_
+#define _THRIFT_TRANSPORT_TNONBLOCKINGSSLSERVERSOCKET_H_ 1
+
+#include <boost/shared_ptr.hpp>
+#include <thrift/transport/TNonblockingServerSocket.h>
+
+namespace apache {
+namespace thrift {
+namespace transport {
+
+class TSSLSocketFactory;
+
+/**
+ * Nonblocking Server socket that accepts SSL connections.
+ */
+class TNonblockingSSLServerSocket : public TNonblockingServerSocket {
+public:
+ /**
+ * Constructor. Binds to all interfaces.
+ *
+ * @param port Listening port
+ * @param factory SSL socket factory implementation
+ */
+ TNonblockingSSLServerSocket(int port, boost::shared_ptr<TSSLSocketFactory> factory);
+
+ /**
+ * Constructor. Binds to the specified address.
+ *
+ * @param address Address to bind to
+ * @param port Listening port
+ * @param factory SSL socket factory implementation
+ */
+ TNonblockingSSLServerSocket(const std::string& address,
+ int port,
+ boost::shared_ptr<TSSLSocketFactory> factory);
+
+ /**
+ * Constructor. Binds to all interfaces.
+ *
+ * @param port Listening port
+ * @param sendTimeout Socket send timeout
+ * @param recvTimeout Socket receive timeout
+ * @param factory SSL socket factory implementation
+ */
+ TNonblockingSSLServerSocket(int port,
+ int sendTimeout,
+ int recvTimeout,
+ boost::shared_ptr<TSSLSocketFactory> factory);
+
+protected:
+ boost::shared_ptr<TSocket> createSocket(THRIFT_SOCKET socket);
+ boost::shared_ptr<TSSLSocketFactory> factory_;
+};
+}
+}
+}
+
+#endif
diff --git a/lib/cpp/src/thrift/transport/TNonblockingServerSocket.cpp b/lib/cpp/src/thrift/transport/TNonblockingServerSocket.cpp
new file mode 100644
index 0000000..73a458b
--- /dev/null
+++ b/lib/cpp/src/thrift/transport/TNonblockingServerSocket.cpp
@@ -0,0 +1,549 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <thrift/thrift-config.h>
+
+#include <cstring>
+#include <stdexcept>
+#include <sys/types.h>
+#ifdef HAVE_SYS_SOCKET_H
+#include <sys/socket.h>
+#endif
+#ifdef HAVE_SYS_UN_H
+#include <sys/un.h>
+#endif
+#ifdef HAVE_SYS_POLL_H
+#include <sys/poll.h>
+#endif
+#ifdef HAVE_NETINET_IN_H
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#endif
+#ifdef HAVE_NETDB_H
+#include <netdb.h>
+#endif
+#include <fcntl.h>
+#ifdef HAVE_UNISTD_H
+#include <unistd.h>
+#endif
+
+#include <thrift/transport/TSocket.h>
+#include <thrift/transport/TNonblockingServerSocket.h>
+#include <thrift/transport/PlatformSocket.h>
+#include <boost/shared_ptr.hpp>
+
+#ifndef AF_LOCAL
+#define AF_LOCAL AF_UNIX
+#endif
+
+#ifndef SOCKOPT_CAST_T
+#ifndef _WIN32
+#define SOCKOPT_CAST_T void
+#else
+#define SOCKOPT_CAST_T char
+#endif // _WIN32
+#endif
+
+template <class T>
+inline const SOCKOPT_CAST_T* const_cast_sockopt(const T* v) {
+ return reinterpret_cast<const SOCKOPT_CAST_T*>(v);
+}
+
+template <class T>
+inline SOCKOPT_CAST_T* cast_sockopt(T* v) {
+ return reinterpret_cast<SOCKOPT_CAST_T*>(v);
+}
+
+namespace apache {
+namespace thrift {
+namespace transport {
+
+using namespace std;
+using boost::shared_ptr;
+
+TNonblockingServerSocket::TNonblockingServerSocket(int port)
+ : port_(port),
+ listenPort_(port),
+ serverSocket_(THRIFT_INVALID_SOCKET),
+ acceptBacklog_(DEFAULT_BACKLOG),
+ sendTimeout_(0),
+ recvTimeout_(0),
+ retryLimit_(0),
+ retryDelay_(0),
+ tcpSendBuffer_(0),
+ tcpRecvBuffer_(0),
+ keepAlive_(false),
+ listening_(false) {
+}
+
+TNonblockingServerSocket::TNonblockingServerSocket(int port, int sendTimeout, int recvTimeout)
+ : port_(port),
+ listenPort_(port),
+ serverSocket_(THRIFT_INVALID_SOCKET),
+ acceptBacklog_(DEFAULT_BACKLOG),
+ sendTimeout_(sendTimeout),
+ recvTimeout_(recvTimeout),
+ retryLimit_(0),
+ retryDelay_(0),
+ tcpSendBuffer_(0),
+ tcpRecvBuffer_(0),
+ keepAlive_(false),
+ listening_(false) {
+}
+
+TNonblockingServerSocket::TNonblockingServerSocket(const string& address, int port)
+ : port_(port),
+ listenPort_(port),
+ address_(address),
+ serverSocket_(THRIFT_INVALID_SOCKET),
+ acceptBacklog_(DEFAULT_BACKLOG),
+ sendTimeout_(0),
+ recvTimeout_(0),
+ retryLimit_(0),
+ retryDelay_(0),
+ tcpSendBuffer_(0),
+ tcpRecvBuffer_(0),
+ keepAlive_(false),
+ listening_(false) {
+}
+
+TNonblockingServerSocket::TNonblockingServerSocket(const string& path)
+ : port_(0),
+ listenPort_(0),
+ path_(path),
+ serverSocket_(THRIFT_INVALID_SOCKET),
+ acceptBacklog_(DEFAULT_BACKLOG),
+ sendTimeout_(0),
+ recvTimeout_(0),
+ retryLimit_(0),
+ retryDelay_(0),
+ tcpSendBuffer_(0),
+ tcpRecvBuffer_(0),
+ keepAlive_(false),
+ listening_(false) {
+}
+
+TNonblockingServerSocket::~TNonblockingServerSocket() {
+ close();
+}
+
+void TNonblockingServerSocket::setSendTimeout(int sendTimeout) {
+ sendTimeout_ = sendTimeout;
+}
+
+void TNonblockingServerSocket::setRecvTimeout(int recvTimeout) {
+ recvTimeout_ = recvTimeout;
+}
+
+void TNonblockingServerSocket::setAcceptBacklog(int accBacklog) {
+ acceptBacklog_ = accBacklog;
+}
+
+void TNonblockingServerSocket::setRetryLimit(int retryLimit) {
+ retryLimit_ = retryLimit;
+}
+
+void TNonblockingServerSocket::setRetryDelay(int retryDelay) {
+ retryDelay_ = retryDelay;
+}
+
+void TNonblockingServerSocket::setTcpSendBuffer(int tcpSendBuffer) {
+ tcpSendBuffer_ = tcpSendBuffer;
+}
+
+void TNonblockingServerSocket::setTcpRecvBuffer(int tcpRecvBuffer) {
+ tcpRecvBuffer_ = tcpRecvBuffer;
+}
+
+void TNonblockingServerSocket::listen() {
+ listening_ = true;
+#ifdef _WIN32
+ TWinsockSingleton::create();
+#endif // _WIN32
+
+ // Validate port number
+ if (port_ < 0 || port_ > 0xFFFF) {
+ throw TTransportException(TTransportException::BAD_ARGS, "Specified port is invalid");
+ }
+
+ const struct addrinfo *res;
+ int error;
+ char port[sizeof("65535")];
+ THRIFT_SNPRINTF(port, sizeof(port), "%d", port_);
+
+ struct addrinfo hints;
+ std::memset(&hints, 0, sizeof(hints));
+ hints.ai_family = PF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
+
+ // If address is not specified use wildcard address (NULL)
+ TGetAddrInfoWrapper info(address_.empty() ? NULL : &address_[0], port, &hints);
+
+ error = info.init();
+ if (error) {
+ GlobalOutput.printf("getaddrinfo %d: %s", error, THRIFT_GAI_STRERROR(error));
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN,
+ "Could not resolve host for server socket.");
+ }
+
+ // Pick the ipv6 address first since ipv4 addresses can be mapped
+ // into ipv6 space.
+ for (res = info.res(); res; res = res->ai_next) {
+ if (res->ai_family == AF_INET6 || res->ai_next == NULL)
+ break;
+ }
+
+ if (!path_.empty()) {
+ serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);
+ } else {
+ serverSocket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+ }
+
+ if (serverSocket_ == THRIFT_INVALID_SOCKET) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::listen() socket() ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN,
+ "Could not create server socket.",
+ errno_copy);
+ }
+
+ // Set THRIFT_NO_SOCKET_CACHING to prevent 2MSL delay on accept
+ int one = 1;
+ if (-1 == setsockopt(serverSocket_,
+ SOL_SOCKET,
+ THRIFT_NO_SOCKET_CACHING,
+ cast_sockopt(&one),
+ sizeof(one))) {
+// ignore errors coming out of this setsockopt on Windows. This is because
+// SO_EXCLUSIVEADDRUSE requires admin privileges on WinXP, but we don't
+// want to force servers to be an admin.
+#ifndef _WIN32
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::listen() setsockopt() THRIFT_NO_SOCKET_CACHING ",
+ errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN,
+ "Could not set THRIFT_NO_SOCKET_CACHING",
+ errno_copy);
+#endif
+ }
+
+ // Set TCP buffer sizes
+ if (tcpSendBuffer_ > 0) {
+ if (-1 == setsockopt(serverSocket_,
+ SOL_SOCKET,
+ SO_SNDBUF,
+ cast_sockopt(&tcpSendBuffer_),
+ sizeof(tcpSendBuffer_))) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::listen() setsockopt() SO_SNDBUF ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN,
+ "Could not set SO_SNDBUF",
+ errno_copy);
+ }
+ }
+
+ if (tcpRecvBuffer_ > 0) {
+ if (-1 == setsockopt(serverSocket_,
+ SOL_SOCKET,
+ SO_RCVBUF,
+ cast_sockopt(&tcpRecvBuffer_),
+ sizeof(tcpRecvBuffer_))) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::listen() setsockopt() SO_RCVBUF ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN,
+ "Could not set SO_RCVBUF",
+ errno_copy);
+ }
+ }
+
+#ifdef IPV6_V6ONLY
+ if (res->ai_family == AF_INET6 && path_.empty()) {
+ int zero = 0;
+ if (-1 == setsockopt(serverSocket_,
+ IPPROTO_IPV6,
+ IPV6_V6ONLY,
+ cast_sockopt(&zero),
+ sizeof(zero))) {
+ GlobalOutput.perror("TNonblockingServerSocket::listen() IPV6_V6ONLY ", THRIFT_GET_SOCKET_ERROR);
+ }
+ }
+#endif // #ifdef IPV6_V6ONLY
+
+ // Turn linger off, don't want to block on calls to close
+ struct linger ling = {0, 0};
+ if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_LINGER, cast_sockopt(&ling), sizeof(ling))) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::listen() setsockopt() SO_LINGER ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_LINGER", errno_copy);
+ }
+
+ // Keepalive to ensure full result flushing
+ if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_KEEPALIVE, const_cast_sockopt(&one), sizeof(one))) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::listen() setsockopt() SO_KEEPALIVE ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN,
+ "Could not set TCP_NODELAY",
+ errno_copy);
+ }
+
+ // Set TCP nodelay if available, MAC OS X Hack
+ // See http://lists.danga.com/pipermail/memcached/2005-March/001240.html
+#ifndef TCP_NOPUSH
+ // Unix Sockets do not need that
+ if (path_.empty()) {
+ // TCP Nodelay, speed over bandwidth
+ if (-1
+ == setsockopt(serverSocket_, IPPROTO_TCP, TCP_NODELAY, cast_sockopt(&one), sizeof(one))) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::listen() setsockopt() TCP_NODELAY ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN,
+ "Could not set TCP_NODELAY",
+ errno_copy);
+ }
+ }
+#endif
+
+ // Set NONBLOCK on the accept socket
+ int flags = THRIFT_FCNTL(serverSocket_, THRIFT_F_GETFL, 0);
+ if (flags == -1) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::listen() THRIFT_FCNTL() THRIFT_F_GETFL ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN,
+ "THRIFT_FCNTL() THRIFT_F_GETFL failed",
+ errno_copy);
+ }
+
+ if (-1 == THRIFT_FCNTL(serverSocket_, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK)) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::listen() THRIFT_FCNTL() THRIFT_O_NONBLOCK ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN,
+ "THRIFT_FCNTL() THRIFT_F_SETFL THRIFT_O_NONBLOCK failed",
+ errno_copy);
+ }
+
+#ifdef TCP_LOW_MIN_RTO
+ if (TSocket::getUseLowMinRto()) {
+ if (-1 == setsockopt(s, IPPROTO_TCP, TCP_LOW_MIN_RTO, const_cast_sockopt(&one), sizeof(one))) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::listen() setsockopt() TCP_LOW_MIN_RTO ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN,
+ "Could not set TCP_NODELAY",
+ errno_copy);
+ }
+ }
+#endif
+
+ // prepare the port information
+ // we may want to try to bind more than once, since THRIFT_NO_SOCKET_CACHING doesn't
+ // always seem to work. The client can configure the retry variables.
+ int retries = 0;
+ int errno_copy = 0;
+
+ if (!path_.empty()) {
+
+#ifndef _WIN32
+
+ // Unix Domain Socket
+ size_t len = path_.size() + 1;
+ if (len > sizeof(((sockaddr_un*)NULL)->sun_path)) {
+ errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TSocket::listen() Unix Domain socket path too long", errno_copy);
+ throw TTransportException(TTransportException::NOT_OPEN,
+ "Unix Domain socket path too long",
+ errno_copy);
+ }
+
+ struct sockaddr_un address;
+ address.sun_family = AF_UNIX;
+ memcpy(address.sun_path, path_.c_str(), len);
+
+ socklen_t structlen = static_cast<socklen_t>(sizeof(address));
+
+ if (!address.sun_path[0]) { // abstract namespace socket
+#ifdef __linux__
+ // sun_path is not null-terminated in this case and structlen determines its length
+ structlen -= sizeof(address.sun_path) - len;
+#else
+ GlobalOutput.perror("TSocket::open() Abstract Namespace Domain sockets only supported on linux: ", -99);
+ throw TTransportException(TTransportException::NOT_OPEN,
+ " Abstract Namespace Domain socket path not supported");
+#endif
+ }
+
+ do {
+ if (0 == ::bind(serverSocket_, (struct sockaddr*)&address, structlen)) {
+ break;
+ }
+ errno_copy = THRIFT_GET_SOCKET_ERROR;
+ // use short circuit evaluation here to only sleep if we need to
+ } while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0));
+#else
+ GlobalOutput.perror("TSocket::open() Unix Domain socket path not supported on windows", -99);
+ throw TTransportException(TTransportException::NOT_OPEN,
+ " Unix Domain socket path not supported");
+#endif
+ } else {
+ do {
+ if (0 == ::bind(serverSocket_, res->ai_addr, static_cast<int>(res->ai_addrlen))) {
+ break;
+ }
+ errno_copy = THRIFT_GET_SOCKET_ERROR;
+ // use short circuit evaluation here to only sleep if we need to
+ } while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0));
+
+ // retrieve bind info
+ if (port_ == 0 && retries <= retryLimit_) {
+ struct sockaddr_storage sa;
+ socklen_t len = sizeof(sa);
+ std::memset(&sa, 0, len);
+ if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr*>(&sa), &len) < 0) {
+ errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::getPort() getsockname() ", errno_copy);
+ } else {
+ if (sa.ss_family == AF_INET6) {
+ const struct sockaddr_in6* sin = reinterpret_cast<const struct sockaddr_in6*>(&sa);
+ listenPort_ = ntohs(sin->sin6_port);
+ } else {
+ const struct sockaddr_in* sin = reinterpret_cast<const struct sockaddr_in*>(&sa);
+ listenPort_ = ntohs(sin->sin_port);
+ }
+ }
+ }
+ }
+
+ // throw an error if we failed to bind properly
+ if (retries > retryLimit_) {
+ char errbuf[1024];
+ if (!path_.empty()) {
+ THRIFT_SNPRINTF(errbuf, sizeof(errbuf), "TNonblockingServerSocket::listen() PATH %s", path_.c_str());
+ } else {
+ THRIFT_SNPRINTF(errbuf, sizeof(errbuf), "TNonblockingServerSocket::listen() BIND %d", port_);
+ }
+ GlobalOutput(errbuf);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN,
+ "Could not bind",
+ errno_copy);
+ }
+
+ if (listenCallback_)
+ listenCallback_(serverSocket_);
+
+ // Call listen
+ if (-1 == ::listen(serverSocket_, acceptBacklog_)) {
+ errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::listen() listen() ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not listen", errno_copy);
+ }
+
+ // The socket is now listening!
+}
+
+int TNonblockingServerSocket::getPort() {
+ return port_;
+}
+
+int TNonblockingServerSocket::getListenPort() {
+ return listenPort_;
+}
+
+shared_ptr<TSocket> TNonblockingServerSocket::acceptImpl() {
+ if (serverSocket_ == THRIFT_INVALID_SOCKET) {
+ throw TTransportException(TTransportException::NOT_OPEN, "TNonblockingServerSocket not listening");
+ }
+
+ struct sockaddr_storage clientAddress;
+ int size = sizeof(clientAddress);
+ THRIFT_SOCKET clientSocket
+ = ::accept(serverSocket_, (struct sockaddr*)&clientAddress, (socklen_t*)&size);
+
+ if (clientSocket == THRIFT_INVALID_SOCKET) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TNonblockingServerSocket::acceptImpl() ::accept() ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "accept()", errno_copy);
+ }
+
+ // Explicitly set this socket to NONBLOCK mode
+ int flags = THRIFT_FCNTL(clientSocket, THRIFT_F_GETFL, 0);
+ if (flags == -1) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ ::THRIFT_CLOSESOCKET(clientSocket);
+ GlobalOutput.perror("TNonblockingServerSocket::acceptImpl() THRIFT_FCNTL() THRIFT_F_GETFL ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN,
+ "THRIFT_FCNTL(THRIFT_F_GETFL)",
+ errno_copy);
+ }
+
+ if (-1 == THRIFT_FCNTL(clientSocket, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK)) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ ::THRIFT_CLOSESOCKET(clientSocket);
+ GlobalOutput
+ .perror("TNonblockingServerSocket::acceptImpl() THRIFT_FCNTL() THRIFT_F_SETFL ~THRIFT_O_NONBLOCK ",
+ errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN,
+ "THRIFT_FCNTL(THRIFT_F_SETFL)",
+ errno_copy);
+ }
+
+ shared_ptr<TSocket> client = createSocket(clientSocket);
+ if (sendTimeout_ > 0) {
+ client->setSendTimeout(sendTimeout_);
+ }
+ if (recvTimeout_ > 0) {
+ client->setRecvTimeout(recvTimeout_);
+ }
+ if (keepAlive_) {
+ client->setKeepAlive(keepAlive_);
+ }
+ client->setCachedAddress((sockaddr*)&clientAddress, size);
+
+ if (acceptCallback_)
+ acceptCallback_(clientSocket);
+
+ return client;
+}
+
+shared_ptr<TSocket> TNonblockingServerSocket::createSocket(THRIFT_SOCKET clientSocket) {
+ return shared_ptr<TSocket>(new TSocket(clientSocket));
+}
+
+void TNonblockingServerSocket::close() {
+ if (serverSocket_ != THRIFT_INVALID_SOCKET) {
+ shutdown(serverSocket_, THRIFT_SHUT_RDWR);
+ ::THRIFT_CLOSESOCKET(serverSocket_);
+ }
+ serverSocket_ = THRIFT_INVALID_SOCKET;
+ listening_ = false;
+}
+}
+}
+} // apache::thrift::transport
diff --git a/lib/cpp/src/thrift/transport/TNonblockingServerSocket.h b/lib/cpp/src/thrift/transport/TNonblockingServerSocket.h
new file mode 100644
index 0000000..ff88ecb
--- /dev/null
+++ b/lib/cpp/src/thrift/transport/TNonblockingServerSocket.h
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TNONBLOCKINGSERVERSOCKET_H_
+#define _THRIFT_TRANSPORT_TNONBLOCKINGSERVERSOCKET_H_ 1
+
+#include <thrift/transport/TNonblockingServerTransport.h>
+#include <thrift/transport/PlatformSocket.h>
+#include <thrift/cxxfunctional.h>
+#include <boost/shared_ptr.hpp>
+
+namespace apache {
+namespace thrift {
+namespace transport {
+
+class TSocket;
+
+/**
+ * Nonblocking Server socket implementation of TNonblockingServerTransport. Wrapper around a unix
+ * socket listen and accept calls.
+ *
+ */
+class TNonblockingServerSocket : public TNonblockingServerTransport {
+public:
+ typedef apache::thrift::stdcxx::function<void(THRIFT_SOCKET fd)> socket_func_t;
+
+ const static int DEFAULT_BACKLOG = 1024;
+
+ /**
+ * Constructor.
+ *
+ * @param port Port number to bind to
+ */
+ TNonblockingServerSocket(int port);
+
+ /**
+ * Constructor.
+ *
+ * @param port Port number to bind to
+ * @param sendTimeout Socket send timeout
+ * @param recvTimeout Socket receive timeout
+ */
+ TNonblockingServerSocket(int port, int sendTimeout, int recvTimeout);
+
+ /**
+ * Constructor.
+ *
+ * @param address Address to bind to
+ * @param port Port number to bind to
+ */
+ TNonblockingServerSocket(const std::string& address, int port);
+
+ /**
+ * Constructor used for unix sockets.
+ *
+ * @param path Pathname for unix socket.
+ */
+ TNonblockingServerSocket(const std::string& path);
+
+ virtual ~TNonblockingServerSocket();
+
+ void setSendTimeout(int sendTimeout);
+ void setRecvTimeout(int recvTimeout);
+
+ void setAcceptBacklog(int accBacklog);
+
+ void setRetryLimit(int retryLimit);
+ void setRetryDelay(int retryDelay);
+
+ void setKeepAlive(bool keepAlive) { keepAlive_ = keepAlive; }
+
+ void setTcpSendBuffer(int tcpSendBuffer);
+ void setTcpRecvBuffer(int tcpRecvBuffer);
+
+ // listenCallback gets called just before listen, and after all Thrift
+ // setsockopt calls have been made. If you have custom setsockopt
+ // things that need to happen on the listening socket, this is the place to do it.
+ void setListenCallback(const socket_func_t& listenCallback) { listenCallback_ = listenCallback; }
+
+ // acceptCallback gets called after each accept call, on the newly created socket.
+ // It is called after all Thrift setsockopt calls have been made. If you have
+ // custom setsockopt things that need to happen on the accepted
+ // socket, this is the place to do it.
+ void setAcceptCallback(const socket_func_t& acceptCallback) { acceptCallback_ = acceptCallback; }
+
+ THRIFT_SOCKET getSocketFD() { return serverSocket_; }
+
+ int getPort();
+
+ int getListenPort();
+
+ void listen();
+ void close();
+
+protected:
+ boost::shared_ptr<TSocket> acceptImpl();
+ virtual boost::shared_ptr<TSocket> createSocket(THRIFT_SOCKET client);
+
+private:
+ int port_;
+ int listenPort_;
+ std::string address_;
+ std::string path_;
+ THRIFT_SOCKET serverSocket_;
+ int acceptBacklog_;
+ int sendTimeout_;
+ int recvTimeout_;
+ int retryLimit_;
+ int retryDelay_;
+ int tcpSendBuffer_;
+ int tcpRecvBuffer_;
+ bool keepAlive_;
+ bool listening_;
+
+ socket_func_t listenCallback_;
+ socket_func_t acceptCallback_;
+};
+}
+}
+} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_TNONBLOCKINGSERVERSOCKET_H_
diff --git a/lib/cpp/src/thrift/transport/TNonblockingServerTransport.h b/lib/cpp/src/thrift/transport/TNonblockingServerTransport.h
new file mode 100644
index 0000000..21b8262
--- /dev/null
+++ b/lib/cpp/src/thrift/transport/TNonblockingServerTransport.h
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TNONBLOCKINGSERVERTRANSPORT_H_
+#define _THRIFT_TRANSPORT_TNONBLOCKINGSERVERTRANSPORT_H_ 1
+
+#include <thrift/transport/TSocket.h>
+#include <thrift/transport/TTransportException.h>
+#include <boost/shared_ptr.hpp>
+
+namespace apache {
+namespace thrift {
+namespace transport {
+
+/**
+ * Server transport framework. A server needs to have some facility for
+ * creating base transports to read/write from. The server is expected
+ * to keep track of TTransport children that it creates for purposes of
+ * controlling their lifetime.
+ */
+class TNonblockingServerTransport {
+public:
+ virtual ~TNonblockingServerTransport() {}
+
+ /**
+ * Starts the server transport listening for new connections. Prior to this
+ * call most transports will not return anything when accept is called.
+ *
+ * @throws TTransportException if we were unable to listen
+ */
+ virtual void listen() {}
+
+ /**
+ * Gets a new dynamically allocated transport object and passes it to the
+ * caller. Note that it is the explicit duty of the caller to free the
+ * allocated object. The returned TTransport object must always be in the
+ * opened state. NULL should never be returned, instead an Exception should
+ * always be thrown.
+ *
+ * @return A new TTransport object
+ * @throws TTransportException if there is an error
+ */
+ boost::shared_ptr<TSocket> accept() {
+ boost::shared_ptr<TSocket> result = acceptImpl();
+ if (!result) {
+ throw TTransportException("accept() may not return NULL");
+ }
+ return result;
+ }
+
+ /**
+ * Utility method
+ *
+ * @return server socket file descriptor
+ * @throw TTransportException If an error occurs
+ */
+
+ virtual THRIFT_SOCKET getSocketFD() = 0;
+
+ virtual int getPort() = 0;
+
+ virtual int getListenPort() = 0;
+
+ /**
+ * Closes this transport such that future calls to accept will do nothing.
+ */
+ virtual void close() = 0;
+
+protected:
+ TNonblockingServerTransport() {}
+
+ /**
+ * Subclasses should implement this function for accept.
+ *
+ * @return A newly allocated TTransport object
+ * @throw TTransportException If an error occurs
+ */
+ virtual boost::shared_ptr<TSocket> acceptImpl() = 0;
+
+};
+}
+}
+} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_TNONBLOCKINGSERVERTRANSPORT_H_
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
index 926a58f..e8f38dd 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
@@ -214,27 +214,33 @@
// TSSLSocket implementation
TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx)
: TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
+ init();
}
TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, boost::shared_ptr<THRIFT_SOCKET> interruptListener)
: TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
+ init();
interruptListener_ = interruptListener;
}
TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket)
: TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) {
+ init();
}
TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener)
: TSocket(socket, interruptListener), server_(false), ssl_(NULL), ctx_(ctx) {
+ init();
}
TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, string host, int port)
: TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
+ init();
}
TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, string host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener)
: TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
+ init();
interruptListener_ = interruptListener;
}
@@ -242,6 +248,12 @@
close();
}
+void TSSLSocket::init() {
+ handshakeCompleted_ = false;
+ readRetryCount_ = 0;
+ eventSafe_ = false;
+}
+
bool TSSLSocket::isOpen() {
if (ssl_ == NULL || !TSocket::isOpen()) {
return false;
@@ -256,11 +268,16 @@
return true;
}
+/*
+ * Note: This method is not libevent safe.
+*/
bool TSSLSocket::peek() {
if (!isOpen()) {
return false;
}
- checkHandshake();
+ initializeHandshake();
+ if (!checkHandshake())
+ throw TSSLException("SSL_peek: Handshake is not completed");
int rc;
uint8_t byte;
do {
@@ -299,6 +316,9 @@
TSocket::open();
}
+/*
+ * Note: This method is not libevent safe.
+*/
void TSSLSocket::close() {
if (ssl_ != NULL) {
try {
@@ -339,37 +359,57 @@
}
SSL_free(ssl_);
ssl_ = NULL;
+ handshakeCompleted_ = false;
ERR_remove_state(0);
}
TSocket::close();
}
+/*
+ * Returns number of bytes read in SSL Socket.
+ * If eventSafe is set, and it may returns 0 bytes then read method
+ * needs to be called again until it is successfull or it throws
+ * exception incase of failure.
+*/
uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
- checkHandshake();
+ initializeHandshake();
+ if (!checkHandshake())
+ throw TTransportException(TTransportException::UNKNOWN, "retry again");
int32_t bytes = 0;
- for (int32_t retries = 0; retries < maxRecvRetries_; retries++) {
+ while (readRetryCount_ < maxRecvRetries_) {
ERR_clear_error();
bytes = SSL_read(ssl_, buf, len);
- if (bytes >= 0)
- break;
- int32_t errno_copy = THRIFT_GET_SOCKET_ERROR;
int32_t error = SSL_get_error(ssl_, bytes);
+ readRetryCount_++;
+ if (bytes >= 0 && error == 0) {
+ readRetryCount_ = 0;
+ break;
+ }
+ int32_t errno_copy = THRIFT_GET_SOCKET_ERROR;
switch (error) {
case SSL_ERROR_SYSCALL:
if ((errno_copy != THRIFT_EINTR)
&& (errno_copy != THRIFT_EAGAIN)) {
break;
}
- if (retries++ >= maxRecvRetries_) {
+ if (readRetryCount_ >= maxRecvRetries_) {
// THRIFT_EINTR needs to be handled manually and we can tolerate
// a certain number
break;
}
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
- if (waitForEvent(error == SSL_ERROR_WANT_READ) == TSSL_EINTR ) {
+ if (isLibeventSafe()) {
+ if (readRetryCount_ < maxRecvRetries_) {
+ // THRIFT_EINTR needs to be handled manually and we can tolerate
+ // a certain number
+ throw TTransportException(TTransportException::UNKNOWN, "retry again");
+ }
+ throw TTransportException(TTransportException::INTERNAL_ERROR, "too much recv retries");
+ }
+ else if (waitForEvent(error == SSL_ERROR_WANT_READ) == TSSL_EINTR ) {
// repeat operation
- if (retries++ < maxRecvRetries_) {
+ if (readRetryCount_ < maxRecvRetries_) {
// THRIFT_EINTR needs to be handled manually and we can tolerate
// a certain number
continue;
@@ -387,7 +427,9 @@
}
void TSSLSocket::write(const uint8_t* buf, uint32_t len) {
- checkHandshake();
+ initializeHandshake();
+ if (!checkHandshake())
+ return;
// loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
uint32_t written = 0;
while (written < len) {
@@ -404,8 +446,13 @@
}
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
- waitForEvent(error == SSL_ERROR_WANT_READ);
- continue;
+ if (isLibeventSafe()) {
+ return;
+ }
+ else {
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ continue;
+ }
default:;// do nothing
}
string errors;
@@ -416,12 +463,58 @@
}
}
+/*
+ * Returns number of bytes written in SSL Socket.
+ * If eventSafe is set, and it may returns 0 bytes then write method
+ * needs to be called again until it is successfull or it throws
+ * exception incase of failure.
+*/
+uint32_t TSSLSocket::write_partial(const uint8_t* buf, uint32_t len) {
+ initializeHandshake();
+ if (!checkHandshake())
+ return 0;
+ // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
+ uint32_t written = 0;
+ while (written < len) {
+ ERR_clear_error();
+ int32_t bytes = SSL_write(ssl_, &buf[written], len - written);
+ if (bytes <= 0) {
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ int error = SSL_get_error(ssl_, bytes);
+ switch (error) {
+ case SSL_ERROR_SYSCALL:
+ if ((errno_copy != THRIFT_EINTR)
+ && (errno_copy != THRIFT_EAGAIN)) {
+ break;
+ }
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ if (isLibeventSafe()) {
+ return 0;
+ }
+ else {
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ continue;
+ }
+ default:;// do nothing
+ }
+ string errors;
+ buildErrors(errors, errno_copy);
+ throw TSSLException("SSL_write: " + errors);
+ }
+ written += bytes;
+ }
+ return written;
+}
+
void TSSLSocket::flush() {
// Don't throw exception if not open. Thrift servers close socket twice.
if (ssl_ == NULL) {
return;
}
- checkHandshake();
+ initializeHandshake();
+ if (!checkHandshake())
+ throw TSSLException("BIO_flush: Handshake is not completed");
BIO* bio = SSL_get_wbio(ssl_);
if (bio == NULL) {
throw TSSLException("SSL_get_wbio returns NULL");
@@ -434,14 +527,7 @@
}
}
-void TSSLSocket::checkHandshake() {
- if (!TSocket::isOpen()) {
- throw TTransportException(TTransportException::NOT_OPEN);
- }
- if (ssl_ != NULL) {
- return;
- }
-
+void TSSLSocket::initializeHandshakeParams() {
// set underlying socket to non-blocking
int flags;
if ((flags = THRIFT_FCNTL(socket_, THRIFT_F_GETFL, 0)) < 0
@@ -451,10 +537,27 @@
::THRIFT_CLOSESOCKET(socket_);
return;
}
-
ssl_ = ctx_->createSSL();
SSL_set_fd(ssl_, static_cast<int>(socket_));
+}
+
+bool TSSLSocket::checkHandshake() {
+ return handshakeCompleted_;
+}
+
+void TSSLSocket::initializeHandshake() {
+ if (!TSocket::isOpen()) {
+ throw TTransportException(TTransportException::NOT_OPEN);
+ }
+ if (checkHandshake()) {
+ return;
+ }
+
+ if (ssl_ == NULL) {
+ initializeHandshakeParams();
+ }
+
int rc;
if (server()) {
do {
@@ -470,8 +573,14 @@
}
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
- waitForEvent(error == SSL_ERROR_WANT_READ);
- rc = 2;
+ if (isLibeventSafe()) {
+ return;
+ }
+ else {
+ // repeat operation
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ rc = 2;
+ }
default:;// do nothing
}
}
@@ -495,8 +604,14 @@
}
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
- waitForEvent(error == SSL_ERROR_WANT_READ);
- rc = 2;
+ if (isLibeventSafe()) {
+ return;
+ }
+ else {
+ // repeat operation
+ waitForEvent(error == SSL_ERROR_WANT_READ);
+ rc = 2;
+ }
default:;// do nothing
}
}
@@ -510,6 +625,7 @@
throw TSSLException(fname + ": " + errors);
}
authorize();
+ handshakeCompleted_ = true;
}
void TSSLSocket::authorize() {
@@ -618,6 +734,9 @@
}
}
+/*
+ * Note: This method is not libevent safe.
+*/
unsigned int TSSLSocket::waitForEvent(bool wantRead) {
int fdSocket;
BIO* bio;
@@ -801,12 +920,12 @@
}
}
-void TSSLSocketFactory::loadTrustedCertificates(const char* path) {
+void TSSLSocketFactory::loadTrustedCertificates(const char* path, const char* capath) {
if (path == NULL) {
throw TTransportException(TTransportException::BAD_ARGS,
"loadTrustedCertificates: <path> is NULL");
}
- if (SSL_CTX_load_verify_locations(ctx_->get(), path, NULL) == 0) {
+ if (SSL_CTX_load_verify_locations(ctx_->get(), path, capath) == 0) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
string errors;
buildErrors(errors, errno_copy);
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.h b/lib/cpp/src/thrift/transport/TSSLSocket.h
index 0462a20..023db94 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.h
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.h
@@ -79,6 +79,7 @@
void close();
uint32_t read(uint8_t* buf, uint32_t len);
void write(const uint8_t* buf, uint32_t len);
+ uint32_t write_partial(const uint8_t* buf, uint32_t len);
void flush();
/**
* Set whether to use client or server side SSL handshake protocol.
@@ -96,6 +97,14 @@
* @param manager Instance of AccessManager
*/
virtual void access(boost::shared_ptr<AccessManager> manager) { access_ = manager; }
+ /**
+ * Set eventSafe flag if libevent is used.
+ */
+ void setLibeventSafe() { eventSafe_ = true; }
+ /**
+ * Determines whether SSL Socket is libevent safe or not.
+ */
+ bool isLibeventSafe() const { return eventSafe_; }
protected:
/**
@@ -139,7 +148,15 @@
/**
* Initiate SSL handshake if not already initiated.
*/
- void checkHandshake();
+ void initializeHandshake();
+ /**
+ * Initiate SSL handshake params.
+ */
+ void initializeHandshakeParams();
+ /**
+ * Check if SSL handshake is completed or not.
+ */
+ bool checkHandshake();
/**
* Waits for an socket or shutdown event.
*
@@ -155,6 +172,13 @@
boost::shared_ptr<SSLContext> ctx_;
boost::shared_ptr<AccessManager> access_;
friend class TSSLSocketFactory;
+
+private:
+ bool handshakeCompleted_;
+ int readRetryCount_;
+ bool eventSafe_;
+
+ void init();
};
/**
@@ -248,7 +272,7 @@
*
* @param path Path to trusted certificate file
*/
- virtual void loadTrustedCertificates(const char* path);
+ virtual void loadTrustedCertificates(const char* path, const char* capath = NULL);
/**
* Default randomize method.
*/
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp
index dc698d5..da869e0 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp
@@ -75,21 +75,12 @@
delete ssock;
}
-class TGetAddrInfoWrapper {
-public:
- TGetAddrInfoWrapper(const char* node, const char* service, const struct addrinfo* hints);
+namespace apache {
+namespace thrift {
+namespace transport {
- virtual ~TGetAddrInfoWrapper();
-
- int init();
- const struct addrinfo* res();
-
-private:
- const char* node_;
- const char* service_;
- const struct addrinfo* hints_;
- struct addrinfo* res_;
-};
+using namespace std;
+using boost::shared_ptr;
TGetAddrInfoWrapper::TGetAddrInfoWrapper(const char* node,
const char* service,
@@ -111,13 +102,6 @@
return this->res_;
}
-namespace apache {
-namespace thrift {
-namespace transport {
-
-using namespace std;
-using boost::shared_ptr;
-
TServerSocket::TServerSocket(int port)
: interruptableChildren_(true),
port_(port),
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h
index 20a37e7..58254ee 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.h
+++ b/lib/cpp/src/thrift/transport/TServerSocket.h
@@ -25,12 +25,36 @@
#include <thrift/cxxfunctional.h>
#include <boost/shared_ptr.hpp>
+#include <sys/types.h>
+#ifdef HAVE_SYS_SOCKET_H
+#include <sys/socket.h>
+#endif
+#ifdef HAVE_NETDB_H
+#include <netdb.h>
+#endif
+
namespace apache {
namespace thrift {
namespace transport {
class TSocket;
+class TGetAddrInfoWrapper {
+public:
+ TGetAddrInfoWrapper(const char* node, const char* service, const struct addrinfo* hints);
+
+ virtual ~TGetAddrInfoWrapper();
+
+ int init();
+ const struct addrinfo* res();
+
+private:
+ const char* node_;
+ const char* service_;
+ const struct addrinfo* hints_;
+ struct addrinfo* res_;
+};
+
/**
* Server socket implementation of TServerTransport. Wrapper around a unix
* socket listen and accept calls.
@@ -113,6 +137,8 @@
// \throws std::logic_error if listen() has been called
void setInterruptableChildren(bool enable);
+ THRIFT_SOCKET getSocketFD() { return serverSocket_; }
+
int getPort();
void listen();
diff --git a/lib/cpp/src/thrift/transport/TServerTransport.h b/lib/cpp/src/thrift/transport/TServerTransport.h
index cd1d3da..51cb3e8 100644
--- a/lib/cpp/src/thrift/transport/TServerTransport.h
+++ b/lib/cpp/src/thrift/transport/TServerTransport.h
@@ -83,6 +83,15 @@
virtual void interruptChildren() {}
/**
+ * Utility method
+ *
+ * @return server socket file descriptor
+ * @throw TTransportException If an error occurs
+ */
+
+ virtual THRIFT_SOCKET getSocketFD() { return -1; }
+
+ /**
* Closes this transport such that future calls to accept will do nothing.
*/
virtual void close() = 0;
diff --git a/lib/cpp/src/thrift/transport/TSocket.h b/lib/cpp/src/thrift/transport/TSocket.h
index aa18c31..69d2533 100644
--- a/lib/cpp/src/thrift/transport/TSocket.h
+++ b/lib/cpp/src/thrift/transport/TSocket.h
@@ -120,7 +120,7 @@
/**
* Writes to the underlying socket. Does single send() and returns result.
*/
- uint32_t write_partial(const uint8_t* buf, uint32_t len);
+ virtual uint32_t write_partial(const uint8_t* buf, uint32_t len);
/**
* Get the host that the socket is connected to
diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am
index f61cff1..feff930 100755
--- a/lib/cpp/test/Makefile.am
+++ b/lib/cpp/test/Makefile.am
@@ -99,7 +99,8 @@
noinst_PROGRAMS += \
processor_test
check_PROGRAMS += \
- TNonblockingServerTest
+ TNonblockingServerTest \
+ TNonblockingSSLServerTest
endif
TESTS_ENVIRONMENT= \
@@ -272,6 +273,21 @@
$(BOOST_TEST_LDADD) \
$(BOOST_LDFLAGS) \
$(LIBEVENT_LIBS)
+#
+# TNonblockingSSLServerTest
+#
+TNonblockingSSLServerTest_SOURCES = TNonblockingSSLServerTest.cpp
+
+TNonblockingSSLServerTest_LDADD = libprocessortest.la \
+ $(top_builddir)/lib/cpp/libthrift.la \
+ $(top_builddir)/lib/cpp/libthriftnb.la \
+ $(BOOST_TEST_LDADD) \
+ $(BOOST_LDFLAGS) \
+ $(BOOST_FILESYSTEM_LDADD) \
+ $(BOOST_CHRONO_LDADD) \
+ $(BOOST_SYSTEM_LDADD) \
+ $(BOOST_THREAD_LDADD) \
+ $(LIBEVENT_LIBS)
#
# OptionalRequiredTest
diff --git a/lib/cpp/test/TNonblockingSSLServerTest.cpp b/lib/cpp/test/TNonblockingSSLServerTest.cpp
new file mode 100644
index 0000000..f21dd18
--- /dev/null
+++ b/lib/cpp/test/TNonblockingSSLServerTest.cpp
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#define BOOST_TEST_MODULE TNonblockingSSLServerTest
+#include <boost/test/unit_test.hpp>
+#include <boost/smart_ptr.hpp>
+#include <boost/shared_ptr.hpp>
+#include <boost/filesystem.hpp>
+#include <boost/format.hpp>
+
+#include "thrift/server/TNonblockingServer.h"
+#include "thrift/transport/TSSLSocket.h"
+#include "thrift/transport/TNonblockingSSLServerSocket.h"
+
+#include "gen-cpp/ParentService.h"
+
+#include <event.h>
+
+using namespace apache::thrift;
+using apache::thrift::concurrency::Guard;
+using apache::thrift::concurrency::Monitor;
+using apache::thrift::concurrency::Mutex;
+using apache::thrift::server::TServerEventHandler;
+using apache::thrift::transport::TSSLSocketFactory;
+using apache::thrift::transport::TSSLSocket;
+
+struct Handler : public test::ParentServiceIf {
+ void addString(const std::string& s) { strings_.push_back(s); }
+ void getStrings(std::vector<std::string>& _return) { _return = strings_; }
+ std::vector<std::string> strings_;
+
+ // dummy overrides not used in this test
+ int32_t incrementGeneration() { return 0; }
+ int32_t getGeneration() { return 0; }
+ void getDataWait(std::string&, const int32_t) {}
+ void onewayWait() {}
+ void exceptionWait(const std::string&) {}
+ void unexpectedExceptionWait(const std::string&) {}
+};
+
+boost::filesystem::path keyDir;
+boost::filesystem::path certFile(const std::string& filename)
+{
+ return keyDir / filename;
+}
+
+struct GlobalFixtureSSL
+{
+ GlobalFixtureSSL()
+ {
+ using namespace boost::unit_test::framework;
+ for (int i = 0; i < master_test_suite().argc; ++i)
+ {
+ BOOST_TEST_MESSAGE(boost::format("argv[%1%] = \"%2%\"") % i % master_test_suite().argv[i]);
+ }
+
+#ifdef __linux__
+ // OpenSSL calls send() without MSG_NOSIGPIPE so writing to a socket that has
+ // disconnected can cause a SIGPIPE signal...
+ signal(SIGPIPE, SIG_IGN);
+#endif
+
+ TSSLSocketFactory::setManualOpenSSLInitialization(true);
+ apache::thrift::transport::initializeOpenSSL();
+
+ keyDir = boost::filesystem::current_path().parent_path().parent_path().parent_path() / "test" / "keys";
+ if (!boost::filesystem::exists(certFile("server.crt")))
+ {
+ keyDir = boost::filesystem::path(master_test_suite().argv[master_test_suite().argc - 1]);
+ if (!boost::filesystem::exists(certFile("server.crt")))
+ {
+ throw std::invalid_argument("The last argument to this test must be the directory containing the test certificate(s).");
+ }
+ }
+ }
+
+ virtual ~GlobalFixtureSSL()
+ {
+ apache::thrift::transport::cleanupOpenSSL();
+#ifdef __linux__
+ signal(SIGPIPE, SIG_DFL);
+#endif
+ }
+};
+
+#if (BOOST_VERSION >= 105900)
+BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL);
+#else
+BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL)
+#endif
+
+boost::shared_ptr<TSSLSocketFactory> createServerSocketFactory() {
+ boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory;
+
+ pServerSocketFactory.reset(new TSSLSocketFactory());
+ pServerSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+ pServerSocketFactory->loadCertificate(certFile("server.crt").string().c_str());
+ pServerSocketFactory->loadPrivateKey(certFile("server.key").string().c_str());
+ pServerSocketFactory->server(true);
+ return pServerSocketFactory;
+}
+
+boost::shared_ptr<TSSLSocketFactory> createClientSocketFactory() {
+ boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory;
+
+ pClientSocketFactory.reset(new TSSLSocketFactory());
+ pClientSocketFactory->authenticate(true);
+ pClientSocketFactory->loadCertificate(certFile("client.crt").string().c_str());
+ pClientSocketFactory->loadPrivateKey(certFile("client.key").string().c_str());
+ pClientSocketFactory->loadTrustedCertificates(certFile("CA.pem").string().c_str());
+ return pClientSocketFactory;
+}
+
+class Fixture {
+private:
+ struct ListenEventHandler : public TServerEventHandler {
+ public:
+ ListenEventHandler(Mutex* mutex) : listenMonitor_(mutex), ready_(false) {}
+
+ void preServe() /* override */ {
+ Guard g(listenMonitor_.mutex());
+ ready_ = true;
+ listenMonitor_.notify();
+ }
+
+ Monitor listenMonitor_;
+ bool ready_;
+ };
+
+ struct Runner : public apache::thrift::concurrency::Runnable {
+ int port;
+ boost::shared_ptr<event_base> userEventBase;
+ boost::shared_ptr<TProcessor> processor;
+ boost::shared_ptr<server::TNonblockingServer> server;
+ boost::shared_ptr<ListenEventHandler> listenHandler;
+ boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory;
+ boost::shared_ptr<transport::TNonblockingSSLServerSocket> socket;
+ Mutex mutex_;
+
+ Runner() {
+ listenHandler.reset(new ListenEventHandler(&mutex_));
+ }
+
+ virtual void run() {
+ // When binding to explicit port, allow retrying to workaround bind failures on ports in use
+ int retryCount = port ? 10 : 0;
+ pServerSocketFactory = createServerSocketFactory();
+ startServer(retryCount);
+ }
+
+ void readyBarrier() {
+ // block until server is listening and ready to accept connections
+ Guard g(mutex_);
+ while (!listenHandler->ready_) {
+ listenHandler->listenMonitor_.wait();
+ }
+ }
+ private:
+ void startServer(int retry_count) {
+ try {
+ socket.reset(new transport::TNonblockingSSLServerSocket(port, pServerSocketFactory));
+ server.reset(new server::TNonblockingServer(processor, socket));
+ server->setServerEventHandler(listenHandler);
+ server->setNumIOThreads(1);
+ if (userEventBase) {
+ server->registerEvents(userEventBase.get());
+ }
+ server->serve();
+ } catch (const transport::TTransportException&) {
+ if (retry_count > 0) {
+ ++port;
+ startServer(retry_count - 1);
+ } else {
+ throw;
+ }
+ }
+ }
+ };
+
+ struct EventDeleter {
+ void operator()(event_base* p) { event_base_free(p); }
+ };
+
+protected:
+ Fixture() : processor(new test::ParentServiceProcessor(boost::make_shared<Handler>())) {}
+
+ ~Fixture() {
+ if (server) {
+ server->stop();
+ }
+ if (thread) {
+ thread->join();
+ }
+ }
+
+ void setEventBase(event_base* user_event_base) {
+ userEventBase_.reset(user_event_base, EventDeleter());
+ }
+
+ int startServer(int port) {
+ boost::shared_ptr<Runner> runner(new Runner);
+ runner->port = port;
+ runner->processor = processor;
+ runner->userEventBase = userEventBase_;
+
+ boost::scoped_ptr<apache::thrift::concurrency::ThreadFactory> threadFactory(
+ new apache::thrift::concurrency::PlatformThreadFactory(
+#if !USE_BOOST_THREAD && !USE_STD_THREAD
+ concurrency::PlatformThreadFactory::OTHER, concurrency::PlatformThreadFactory::NORMAL,
+ 1,
+#endif
+ false));
+ thread = threadFactory->newThread(runner);
+ thread->start();
+ runner->readyBarrier();
+
+ server = runner->server;
+ return runner->port;
+ }
+
+ bool canCommunicate(int serverPort) {
+ boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
+ boost::shared_ptr<TSSLSocket> socket = pClientSocketFactory->createSocket("localhost", serverPort);
+ socket->open();
+ test::ParentServiceClient client(boost::make_shared<protocol::TBinaryProtocol>(
+ boost::make_shared<transport::TFramedTransport>(socket)));
+ client.addString("foo");
+ std::vector<std::string> strings;
+ client.getStrings(strings);
+ return strings.size() == 1 && !(strings[0].compare("foo"));
+ }
+
+private:
+ boost::shared_ptr<event_base> userEventBase_;
+ boost::shared_ptr<test::ParentServiceProcessor> processor;
+protected:
+ boost::shared_ptr<server::TNonblockingServer> server;
+private:
+ boost::shared_ptr<apache::thrift::concurrency::Thread> thread;
+
+};
+
+BOOST_AUTO_TEST_SUITE(TNonblockingSSLServerTest)
+
+BOOST_FIXTURE_TEST_CASE(get_specified_port, Fixture) {
+ int specified_port = startServer(12345);
+ BOOST_REQUIRE_GE(specified_port, 12345);
+ BOOST_REQUIRE_EQUAL(server->getListenPort(), specified_port);
+ BOOST_CHECK(canCommunicate(specified_port));
+
+ server->stop();
+}
+
+BOOST_FIXTURE_TEST_CASE(get_assigned_port, Fixture) {
+ int specified_port = startServer(0);
+ BOOST_REQUIRE_EQUAL(specified_port, 0);
+ int assigned_port = server->getListenPort();
+ BOOST_REQUIRE_NE(assigned_port, 0);
+ BOOST_CHECK(canCommunicate(assigned_port));
+
+ server->stop();
+}
+
+BOOST_FIXTURE_TEST_CASE(provide_event_base, Fixture) {
+ event_base* eb = event_base_new();
+ setEventBase(eb);
+ startServer(0);
+
+ // assert that the server works
+ BOOST_CHECK(canCommunicate(server->getListenPort()));
+#if LIBEVENT_VERSION_NUMBER > 0x02010400
+ // also assert that the event_base is actually used when it's easy
+ BOOST_CHECK_GT(event_base_get_num_events(eb, EVENT_BASE_COUNT_ADDED), 0);
+#endif
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/lib/cpp/test/TNonblockingServerTest.cpp b/lib/cpp/test/TNonblockingServerTest.cpp
index e933d6b..36c64b1 100644
--- a/lib/cpp/test/TNonblockingServerTest.cpp
+++ b/lib/cpp/test/TNonblockingServerTest.cpp
@@ -24,6 +24,7 @@
#include "thrift/concurrency/Monitor.h"
#include "thrift/concurrency/Thread.h"
#include "thrift/server/TNonblockingServer.h"
+#include "thrift/transport/TNonblockingServerSocket.h"
#include "gen-cpp/ParentService.h"
@@ -71,6 +72,7 @@
boost::shared_ptr<TProcessor> processor;
boost::shared_ptr<server::TNonblockingServer> server;
boost::shared_ptr<ListenEventHandler> listenHandler;
+ boost::shared_ptr<transport::TNonblockingServerSocket> socket;
Mutex mutex_;
Runner() {
@@ -93,7 +95,8 @@
private:
void startServer(int retry_count) {
try {
- server.reset(new server::TNonblockingServer(processor, port));
+ socket.reset(new transport::TNonblockingServerSocket(port));
+ server.reset(new server::TNonblockingServer(processor, socket));
server->setServerEventHandler(listenHandler);
if (userEventBase) {
server->registerEvents(userEventBase.get());
@@ -181,7 +184,6 @@
BOOST_CHECK(canCommunicate(specified_port));
server->stop();
- BOOST_CHECK_EQUAL(server->getListenPort(), specified_port);
}
BOOST_FIXTURE_TEST_CASE(get_assigned_port, Fixture) {
@@ -192,7 +194,6 @@
BOOST_CHECK(canCommunicate(assigned_port));
server->stop();
- BOOST_CHECK_EQUAL(server->getListenPort(), 0);
}
BOOST_FIXTURE_TEST_CASE(provide_event_base, Fixture) {
diff --git a/lib/cpp/test/processor/ProcessorTest.cpp b/lib/cpp/test/processor/ProcessorTest.cpp
index a4e984c..486b8cf 100644
--- a/lib/cpp/test/processor/ProcessorTest.cpp
+++ b/lib/cpp/test/processor/ProcessorTest.cpp
@@ -33,6 +33,7 @@
#include <thrift/server/TNonblockingServer.h>
#include <thrift/server/TSimpleServer.h>
#include <thrift/transport/TSocket.h>
+#include <thrift/transport/TNonblockingServerSocket.h>
#include "EventLog.h"
#include "ServerThread.h"
@@ -121,14 +122,14 @@
if (framedFactory == NULL) {
throw TException("TNonblockingServer must use TFramedTransport");
}
-
+ boost::shared_ptr<TNonblockingServerSocket> socket(new TNonblockingServerSocket(port));
boost::shared_ptr<PlatformThreadFactory> threadFactory(new PlatformThreadFactory);
boost::shared_ptr<ThreadManager> threadManager = ThreadManager::newSimpleThreadManager(8);
threadManager->threadFactory(threadFactory);
threadManager->start();
return boost::shared_ptr<TNonblockingServer>(
- new TNonblockingServer(processor, protocolFactory, port, threadManager));
+ new TNonblockingServer(processor, protocolFactory, socket, threadManager));
}
};
@@ -150,10 +151,11 @@
throw TException("TNonblockingServer must use TFramedTransport");
}
+ boost::shared_ptr<TNonblockingServerSocket> socket(new TNonblockingServerSocket(port));
// Use a NULL ThreadManager
boost::shared_ptr<ThreadManager> threadManager;
return boost::shared_ptr<TNonblockingServer>(
- new TNonblockingServer(processor, protocolFactory, port, threadManager));
+ new TNonblockingServer(processor, protocolFactory, socket, threadManager));
}
};