THRIFT-2441 Cannot shutdown TThreadedServer when clients are still connected
Author: James E. King, III <Jim.King@simplivity.com>
diff --git a/lib/cpp/src/thrift/server/TSimpleServer.cpp b/lib/cpp/src/thrift/server/TSimpleServer.cpp
index fa6bff5..19f44ac 100644
--- a/lib/cpp/src/thrift/server/TSimpleServer.cpp
+++ b/lib/cpp/src/thrift/server/TSimpleServer.cpp
@@ -70,11 +70,11 @@
if (client) {
client->close();
}
- if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) {
+ if (ttx.getType() != TTransportException::INTERRUPTED) {
string errStr = string("TServerTransport died on accept: ") + ttx.what();
GlobalOutput(errStr.c_str());
}
- continue;
+ if (stop_) break; else continue;
} catch (TException& tx) {
if (inputTransport) {
inputTransport->close();
@@ -88,7 +88,7 @@
string errStr = string("Some kind of accept exception: ") + tx.what();
GlobalOutput(errStr.c_str());
continue;
- } catch (string s) {
+ } catch (const string& s) {
if (inputTransport) {
inputTransport->close();
}
@@ -122,8 +122,12 @@
}
}
} catch (const TTransportException& ttx) {
- string errStr = string("TSimpleServer client died: ") + ttx.what();
- GlobalOutput(errStr.c_str());
+ if (ttx.getType() != TTransportException::END_OF_FILE &&
+ ttx.getType() != TTransportException::INTERRUPTED)
+ {
+ string errStr = string("TSimpleServer client died: ") + ttx.what();
+ GlobalOutput(errStr.c_str());
+ }
} catch (const std::exception& x) {
GlobalOutput.printf("TSimpleServer exception: %s: %s", typeid(x).name(), x.what());
} catch (...) {
@@ -163,6 +167,15 @@
stop_ = false;
}
}
+
+void TSimpleServer::stop() {
+ if (!stop_) {
+ stop_ = true;
+ serverTransport_->interrupt();
+ serverTransport_->interruptChildren();
+ }
+}
+
}
}
} // apache::thrift::server
diff --git a/lib/cpp/src/thrift/server/TSimpleServer.h b/lib/cpp/src/thrift/server/TSimpleServer.h
index 967f834..941f12b 100644
--- a/lib/cpp/src/thrift/server/TSimpleServer.h
+++ b/lib/cpp/src/thrift/server/TSimpleServer.h
@@ -32,7 +32,6 @@
* continuous loop of accepting a single connection, processing requests on
* that connection until it closes, and then repeating. It is a good example
* of how to extend the TServer interface.
- *
*/
class TSimpleServer : public TServer {
public:
@@ -84,14 +83,20 @@
outputProtocolFactory),
stop_(false) {}
- ~TSimpleServer() {}
-
+ /**
+ * Process one connection at a time using the caller's thread.
+ * Call stop() on another thread to interrupt processing and
+ * return control to the caller.
+ * Post-conditions (return guarantees):
+ * The serverTransport will be closed.
+ * There will be no connected client.
+ */
void serve();
- void stop() {
- stop_ = true;
- serverTransport_->interrupt();
- }
+ /**
+ * Interrupt serve() so that it meets post-conditions.
+ */
+ void stop();
protected:
bool stop_;
diff --git a/lib/cpp/src/thrift/server/TThreadPoolServer.cpp b/lib/cpp/src/thrift/server/TThreadPoolServer.cpp
index 0530d8d..58cfe3e 100644
--- a/lib/cpp/src/thrift/server/TThreadPoolServer.cpp
+++ b/lib/cpp/src/thrift/server/TThreadPoolServer.cpp
@@ -69,11 +69,12 @@
break;
}
}
- } catch (const TTransportException&) {
- // This is reasonably expected, client didn't send a full request so just
- // ignore him
- // string errStr = string("TThreadPoolServer client died: ") + ttx.what();
- // GlobalOutput(errStr.c_str());
+ } catch (const TTransportException& ttx) {
+ if (ttx.getType() != TTransportException::END_OF_FILE &&
+ ttx.getType() != TTransportException::INTERRUPTED) {
+ string errStr = string("TThreadPoolServer::Task client died: ") + ttx.what();
+ GlobalOutput(errStr.c_str());
+ }
} catch (const std::exception& x) {
GlobalOutput.printf("TThreadPoolServer exception %s: %s", typeid(x).name(), x.what());
} catch (...) {
@@ -108,8 +109,7 @@
shared_ptr<TTransport> transport_;
};
-TThreadPoolServer::~TThreadPoolServer() {
-}
+TThreadPoolServer::~TThreadPoolServer() {}
void TThreadPoolServer::serve() {
shared_ptr<TTransport> client;
@@ -160,11 +160,11 @@
if (client) {
client->close();
}
- if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) {
+ if (ttx.getType() != TTransportException::INTERRUPTED) {
string errStr = string("TThreadPoolServer: TServerTransport died on accept: ") + ttx.what();
GlobalOutput(errStr.c_str());
}
- continue;
+ if (stop_) break; else continue;
} catch (TException& tx) {
if (inputTransport) {
inputTransport->close();
@@ -178,7 +178,7 @@
string errStr = string("TThreadPoolServer: Caught TException: ") + tx.what();
GlobalOutput(errStr.c_str());
continue;
- } catch (string s) {
+ } catch (const string& s) {
if (inputTransport) {
inputTransport->close();
}
@@ -207,6 +207,14 @@
}
}
+void TThreadPoolServer::stop() {
+ if (!stop_) {
+ stop_ = true;
+ serverTransport_->interrupt();
+ serverTransport_->interruptChildren();
+ }
+}
+
int64_t TThreadPoolServer::getTimeout() const {
return timeout_;
}
diff --git a/lib/cpp/src/thrift/server/TThreadPoolServer.h b/lib/cpp/src/thrift/server/TThreadPoolServer.h
index ad7e7ef..1696700 100644
--- a/lib/cpp/src/thrift/server/TThreadPoolServer.h
+++ b/lib/cpp/src/thrift/server/TThreadPoolServer.h
@@ -109,15 +109,12 @@
virtual void serve();
+ virtual void stop();
+
virtual int64_t getTimeout() const;
virtual void setTimeout(int64_t value);
- virtual void stop() {
- stop_ = true;
- serverTransport_->interrupt();
- }
-
virtual int64_t getTaskExpiration() const;
virtual void setTaskExpiration(int64_t value);
diff --git a/lib/cpp/src/thrift/server/TThreadedServer.cpp b/lib/cpp/src/thrift/server/TThreadedServer.cpp
index 380f69c..118c9cb 100644
--- a/lib/cpp/src/thrift/server/TThreadedServer.cpp
+++ b/lib/cpp/src/thrift/server/TThreadedServer.cpp
@@ -55,10 +55,6 @@
~Task() {}
- void stop() {
- input_->getTransport()->close();
- }
-
void run() {
boost::shared_ptr<TServerEventHandler> eventHandler = server_.getEventHandler();
void* connectionContext = NULL;
@@ -76,7 +72,8 @@
}
}
} catch (const TTransportException& ttx) {
- if (ttx.getType() != TTransportException::END_OF_FILE) {
+ if (ttx.getType() != TTransportException::END_OF_FILE &&
+ ttx.getType() != TTransportException::INTERRUPTED) {
string errStr = string("TThreadedServer client died: ") + ttx.what();
GlobalOutput(errStr.c_str());
}
@@ -130,8 +127,7 @@
}
}
-TThreadedServer::~TThreadedServer() {
-}
+TThreadedServer::~TThreadedServer() {}
void TThreadedServer::serve() {
@@ -196,11 +192,11 @@
if (client) {
client->close();
}
- if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) {
+ if (ttx.getType() != TTransportException::INTERRUPTED) {
string errStr = string("TThreadedServer: TServerTransport died on accept: ") + ttx.what();
GlobalOutput(errStr.c_str());
}
- continue;
+ if (stop_) break; else continue;
} catch (TException& tx) {
if (inputTransport) {
inputTransport->close();
@@ -214,7 +210,7 @@
string errStr = string("TThreadedServer: Caught TException: ") + tx.what();
GlobalOutput(errStr.c_str());
continue;
- } catch (string s) {
+ } catch (const string& s) {
if (inputTransport) {
inputTransport->close();
}
@@ -240,8 +236,6 @@
}
try {
Synchronized s(tasksMonitor_);
- for ( std::set<Task*>::iterator tIt = tasks_.begin(); tIt != tasks_.end(); ++tIt )
- (*tIt)->stop();
while (!tasks_.empty()) {
tasksMonitor_.wait();
}
@@ -252,6 +246,14 @@
stop_ = false;
}
}
+
+void TThreadedServer::stop() {
+ if (!stop_) {
+ stop_ = true;
+ serverTransport_->interrupt();
+ serverTransport_->interruptChildren();
+ }
+}
}
}
} // apache::thrift::server
diff --git a/lib/cpp/src/thrift/server/TThreadedServer.h b/lib/cpp/src/thrift/server/TThreadedServer.h
index 2b1f757..b9b24fe 100644
--- a/lib/cpp/src/thrift/server/TThreadedServer.h
+++ b/lib/cpp/src/thrift/server/TThreadedServer.h
@@ -75,11 +75,7 @@
virtual ~TThreadedServer();
virtual void serve();
-
- void stop() {
- stop_ = true;
- serverTransport_->interrupt();
- }
+ void stop();
protected:
void init();
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp
index 8c8fd73..f5c3ea5 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp
@@ -20,6 +20,7 @@
#include <thrift/thrift-config.h>
#include <cstring>
+#include <stdexcept>
#include <sys/types.h>
#ifdef HAVE_SYS_SOCKET_H
#include <sys/socket.h>
@@ -69,6 +70,12 @@
return reinterpret_cast<SOCKOPT_CAST_T*>(v);
}
+void destroyer_of_fine_sockets(THRIFT_SOCKET *ssock)
+{
+ ::THRIFT_CLOSESOCKET(*ssock);
+ delete ssock;
+}
+
namespace apache {
namespace thrift {
namespace transport {
@@ -88,9 +95,12 @@
tcpSendBuffer_(0),
tcpRecvBuffer_(0),
keepAlive_(false),
- intSock1_(THRIFT_INVALID_SOCKET),
- intSock2_(THRIFT_INVALID_SOCKET) {
-}
+ interruptableChildren_(true),
+ listening_(false),
+ interruptSockWriter_(THRIFT_INVALID_SOCKET),
+ interruptSockReader_(THRIFT_INVALID_SOCKET),
+ childInterruptSockWriter_(THRIFT_INVALID_SOCKET)
+{}
TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
: port_(port),
@@ -104,9 +114,12 @@
tcpSendBuffer_(0),
tcpRecvBuffer_(0),
keepAlive_(false),
- intSock1_(THRIFT_INVALID_SOCKET),
- intSock2_(THRIFT_INVALID_SOCKET) {
-}
+ interruptableChildren_(true),
+ listening_(false),
+ interruptSockWriter_(THRIFT_INVALID_SOCKET),
+ interruptSockReader_(THRIFT_INVALID_SOCKET),
+ childInterruptSockWriter_(THRIFT_INVALID_SOCKET)
+{}
TServerSocket::TServerSocket(const string& address, int port)
: port_(port),
@@ -121,9 +134,12 @@
tcpSendBuffer_(0),
tcpRecvBuffer_(0),
keepAlive_(false),
- intSock1_(THRIFT_INVALID_SOCKET),
- intSock2_(THRIFT_INVALID_SOCKET) {
-}
+ interruptableChildren_(true),
+ listening_(false),
+ interruptSockWriter_(THRIFT_INVALID_SOCKET),
+ interruptSockReader_(THRIFT_INVALID_SOCKET),
+ childInterruptSockWriter_(THRIFT_INVALID_SOCKET)
+{}
TServerSocket::TServerSocket(const string& path)
: port_(0),
@@ -138,9 +154,12 @@
tcpSendBuffer_(0),
tcpRecvBuffer_(0),
keepAlive_(false),
- intSock1_(THRIFT_INVALID_SOCKET),
- intSock2_(THRIFT_INVALID_SOCKET) {
-}
+ interruptableChildren_(true),
+ listening_(false),
+ interruptSockWriter_(THRIFT_INVALID_SOCKET),
+ interruptSockReader_(THRIFT_INVALID_SOCKET),
+ childInterruptSockWriter_(THRIFT_INVALID_SOCKET)
+{}
TServerSocket::~TServerSocket() {
close();
@@ -178,18 +197,41 @@
tcpRecvBuffer_ = tcpRecvBuffer;
}
+void TServerSocket::setInterruptableChildren(bool enable) {
+ if (listening_) {
+ throw std::logic_error("setInterruptableChildren cannot be called after listen()");
+ }
+ interruptableChildren_ = enable;
+}
+
void TServerSocket::listen() {
+ listening_ = true;
#ifdef _WIN32
TWinsockSingleton::create();
#endif // _WIN32
THRIFT_SOCKET sv[2];
+ // Create the socket pair used to interrupt
if (-1 == THRIFT_SOCKETPAIR(AF_LOCAL, SOCK_STREAM, 0, sv)) {
- GlobalOutput.perror("TServerSocket::listen() socketpair() ", THRIFT_GET_SOCKET_ERROR);
- intSock1_ = THRIFT_INVALID_SOCKET;
- intSock2_ = THRIFT_INVALID_SOCKET;
+ GlobalOutput.perror("TServerSocket::listen() socketpair() interrupt",
+ THRIFT_GET_SOCKET_ERROR);
+ interruptSockWriter_ = THRIFT_INVALID_SOCKET;
+ interruptSockReader_ = THRIFT_INVALID_SOCKET;
} else {
- intSock1_ = sv[1];
- intSock2_ = sv[0];
+ interruptSockWriter_ = sv[1];
+ interruptSockReader_ = sv[0];
+ }
+
+ // Create the socket pair used to interrupt all clients
+ if (-1 == THRIFT_SOCKETPAIR(AF_LOCAL, SOCK_STREAM, 0, sv)) {
+ GlobalOutput.perror("TServerSocket::listen() socketpair() childInterrupt",
+ THRIFT_GET_SOCKET_ERROR);
+ childInterruptSockWriter_ = THRIFT_INVALID_SOCKET;
+ pChildInterruptSockReader_.reset();
+ } else {
+ childInterruptSockWriter_ = sv[1];
+ pChildInterruptSockReader_ =
+ boost::shared_ptr<THRIFT_SOCKET>(new THRIFT_SOCKET(sv[0]),
+ destroyer_of_fine_sockets);
}
// Validate port number
@@ -469,8 +511,8 @@
std::memset(fds, 0, sizeof(fds));
fds[0].fd = serverSocket_;
fds[0].events = THRIFT_POLLIN;
- if (intSock2_ != THRIFT_INVALID_SOCKET) {
- fds[1].fd = intSock2_;
+ if (interruptSockReader_ != THRIFT_INVALID_SOCKET) {
+ fds[1].fd = interruptSockReader_;
fds[1].events = THRIFT_POLLIN;
}
/*
@@ -491,9 +533,9 @@
throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
} else if (ret > 0) {
// Check for an interrupt signal
- if (intSock2_ != THRIFT_INVALID_SOCKET && (fds[1].revents & THRIFT_POLLIN)) {
+ if (interruptSockReader_ != THRIFT_INVALID_SOCKET && (fds[1].revents & THRIFT_POLLIN)) {
int8_t buf;
- if (-1 == recv(intSock2_, cast_sockopt(&buf), sizeof(int8_t), 0)) {
+ if (-1 == recv(interruptSockReader_, cast_sockopt(&buf), sizeof(int8_t), 0)) {
GlobalOutput.perror("TServerSocket::acceptImpl() recv() interrupt ",
THRIFT_GET_SOCKET_ERROR);
}
@@ -562,16 +604,28 @@
}
shared_ptr<TSocket> TServerSocket::createSocket(THRIFT_SOCKET clientSocket) {
- return shared_ptr<TSocket>(new TSocket(clientSocket));
+ if (interruptableChildren_) {
+ return shared_ptr<TSocket>(new TSocket(clientSocket, pChildInterruptSockReader_));
+ } else {
+ return shared_ptr<TSocket>(new TSocket(clientSocket));
+ }
+}
+
+void TServerSocket::notify(THRIFT_SOCKET notifySocket) {
+ if (notifySocket != THRIFT_INVALID_SOCKET) {
+ int8_t byte = 0;
+ if (-1 == send(notifySocket, cast_sockopt(&byte), sizeof(int8_t), 0)) {
+ GlobalOutput.perror("TServerSocket::notify() send() ", THRIFT_GET_SOCKET_ERROR);
+ }
+ }
}
void TServerSocket::interrupt() {
- if (intSock1_ != THRIFT_INVALID_SOCKET) {
- int8_t byte = 0;
- if (-1 == send(intSock1_, cast_sockopt(&byte), sizeof(int8_t), 0)) {
- GlobalOutput.perror("TServerSocket::interrupt() send() ", THRIFT_GET_SOCKET_ERROR);
- }
- }
+ notify(interruptSockWriter_);
+}
+
+void TServerSocket::interruptChildren() {
+ notify(childInterruptSockWriter_);
}
void TServerSocket::close() {
@@ -579,16 +633,23 @@
shutdown(serverSocket_, THRIFT_SHUT_RDWR);
::THRIFT_CLOSESOCKET(serverSocket_);
}
- if (intSock1_ != THRIFT_INVALID_SOCKET) {
- ::THRIFT_CLOSESOCKET(intSock1_);
+ if (interruptSockWriter_ != THRIFT_INVALID_SOCKET) {
+ ::THRIFT_CLOSESOCKET(interruptSockWriter_);
}
- if (intSock2_ != THRIFT_INVALID_SOCKET) {
- ::THRIFT_CLOSESOCKET(intSock2_);
+ if (interruptSockReader_ != THRIFT_INVALID_SOCKET) {
+ ::THRIFT_CLOSESOCKET(interruptSockReader_);
+ }
+ if (childInterruptSockWriter_ != THRIFT_INVALID_SOCKET) {
+ ::THRIFT_CLOSESOCKET(childInterruptSockWriter_);
}
serverSocket_ = THRIFT_INVALID_SOCKET;
- intSock1_ = THRIFT_INVALID_SOCKET;
- intSock2_ = THRIFT_INVALID_SOCKET;
+ interruptSockWriter_ = THRIFT_INVALID_SOCKET;
+ interruptSockReader_ = THRIFT_INVALID_SOCKET;
+ childInterruptSockWriter_ = THRIFT_INVALID_SOCKET;
+ pChildInterruptSockReader_.reset();
+ listening_ = false;
}
+
}
}
} // apache::thrift::transport
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h
index 49711e8..58e4afd 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.h
+++ b/lib/cpp/src/thrift/transport/TServerSocket.h
@@ -100,17 +100,33 @@
// socket, this is the place to do it.
void setAcceptCallback(const socket_func_t& acceptCallback) { acceptCallback_ = acceptCallback; }
- void listen();
- void close();
+ // When enabled (the default), new children TSockets will be constructed so
+ // they can be interrupted by TServerTransport::interruptChildren().
+ // This is more expensive in terms of system calls (poll + recv) however
+ // ensures a connected client cannot interfere with TServer::stop().
+ //
+ // When disabled, TSocket children do not incur an additional poll() call.
+ // Server-side reads are more efficient, however a client can interfere with
+ // the server's ability to shutdown properly by staying connected.
+ //
+ // Must be called before listen(); mode cannot be switched after that.
+ // \throws std::logic_error if listen() has been called
+ void setInterruptableChildren(bool enable);
- void interrupt();
int getPort();
+ void listen();
+ void interrupt();
+ void interruptChildren();
+ void close();
+
protected:
boost::shared_ptr<TTransport> acceptImpl();
virtual boost::shared_ptr<TSocket> createSocket(THRIFT_SOCKET client);
private:
+ void notify(THRIFT_SOCKET notifySock);
+
int port_;
std::string address_;
std::string path_;
@@ -124,9 +140,13 @@
int tcpSendBuffer_;
int tcpRecvBuffer_;
bool keepAlive_;
+ bool interruptableChildren_;
+ bool listening_;
- THRIFT_SOCKET intSock1_;
- THRIFT_SOCKET intSock2_;
+ THRIFT_SOCKET interruptSockWriter_; // is notified on interrupt()
+ THRIFT_SOCKET interruptSockReader_; // is used in select/poll with serverSocket_ for interruptability
+ THRIFT_SOCKET childInterruptSockWriter_; // is notified on interruptChildren()
+ boost::shared_ptr<THRIFT_SOCKET> pChildInterruptSockReader_; // if interruptableChildren_ this is shared with child TSockets
socket_func_t listenCallback_;
socket_func_t acceptCallback_;
diff --git a/lib/cpp/src/thrift/transport/TServerTransport.h b/lib/cpp/src/thrift/transport/TServerTransport.h
index 7c4a7c3..cd1d3da 100644
--- a/lib/cpp/src/thrift/transport/TServerTransport.h
+++ b/lib/cpp/src/thrift/transport/TServerTransport.h
@@ -30,8 +30,9 @@
/**
* Server transport framework. A server needs to have some facility for
- * creating base transports to read/write from.
- *
+ * creating base transports to read/write from. The server is expected
+ * to keep track of TTransport children that it creates for purposes of
+ * controlling their lifetime.
*/
class TServerTransport {
public:
@@ -67,11 +68,21 @@
* For "smart" TServerTransport implementations that work in a multi
* threaded context this can be used to break out of an accept() call.
* It is expected that the transport will throw a TTransportException
- * with the interrupted error code.
+ * with the INTERRUPTED error code.
+ *
+ * This will not make an attempt to interrupt any TTransport children.
*/
virtual void interrupt() {}
/**
+ * This will interrupt the children created by the server transport.
+ * allowing them to break out of any blocking data reception call.
+ * It is expected that the children will throw a TTransportException
+ * with the INTERRUPTED error code.
+ */
+ virtual void interruptChildren() {}
+
+ /**
* Closes this transport such that future calls to accept will do nothing.
*/
virtual void close() = 0;
diff --git a/lib/cpp/src/thrift/transport/TSocket.cpp b/lib/cpp/src/thrift/transport/TSocket.cpp
index bcea291..cc4dce0 100644
--- a/lib/cpp/src/thrift/transport/TSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSocket.cpp
@@ -74,7 +74,7 @@
*
*/
-TSocket::TSocket(string host, int port)
+TSocket::TSocket(const string& host, int port)
: host_(host),
port_(port),
path_(""),
@@ -89,7 +89,7 @@
maxRecvRetries_(5) {
}
-TSocket::TSocket(string path)
+TSocket::TSocket(const string& path)
: host_(""),
port_(0),
path_(path),
@@ -143,6 +143,30 @@
#endif
}
+TSocket::TSocket(THRIFT_SOCKET socket,
+ boost::shared_ptr<THRIFT_SOCKET> interruptListener) :
+ host_(""),
+ port_(0),
+ path_(""),
+ socket_(socket),
+ interruptListener_(interruptListener),
+ connTimeout_(0),
+ sendTimeout_(0),
+ recvTimeout_(0),
+ keepAlive_(false),
+ lingerOn_(1),
+ lingerVal_(0),
+ noDelay_(1),
+ maxRecvRetries_(5) {
+ cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC;
+#ifdef SO_NOSIGPIPE
+ {
+ int one = 1;
+ setsockopt(socket_, SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(one));
+ }
+#endif
+}
+
TSocket::~TSocket() {
close();
}
@@ -153,8 +177,41 @@
bool TSocket::peek() {
if (!isOpen()) {
- return false;
+ return false;
}
+ if (interruptListener_)
+ {
+ for (int retries = 0; ; ) {
+ struct THRIFT_POLLFD fds[2];
+ std::memset(fds, 0, sizeof(fds));
+ fds[0].fd = socket_;
+ fds[0].events = THRIFT_POLLIN;
+ fds[1].fd = *(interruptListener_.get());
+ fds[1].events = THRIFT_POLLIN;
+ int ret = THRIFT_POLL(fds, 2, (recvTimeout_ == 0) ? -1 : recvTimeout_);
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ if (ret < 0) {
+ // error cases
+ if (errno_copy == THRIFT_EINTR && (retries++ < maxRecvRetries_)) {
+ continue;
+ }
+ GlobalOutput.perror("TSocket::peek() THRIFT_POLL() ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
+ } else if (ret > 0) {
+ // Check the interruptListener
+ if (fds[1].revents & THRIFT_POLLIN) {
+ return false;
+ }
+ // There must be data or a disconnection, fall through to the PEEK
+ break;
+ } else {
+ // timeout
+ return false;
+ }
+ }
+ }
+
+ // Check to see if data is available or if the remote side closed
uint8_t buf;
int r = static_cast<int>(recv(socket_, cast_sockopt(&buf), 1, MSG_PEEK));
if (r == -1) {
@@ -455,9 +512,44 @@
// an THRIFT_EAGAIN is due to a timeout or an out-of-resource condition.
begin.tv_sec = begin.tv_usec = 0;
}
- int got = static_cast<int>(recv(socket_, cast_sockopt(buf), len, 0));
- int errno_copy = THRIFT_GET_SOCKET_ERROR; // THRIFT_GETTIMEOFDAY can change
- // THRIFT_GET_SOCKET_ERROR
+
+ int got = 0;
+
+ if (interruptListener_)
+ {
+ struct THRIFT_POLLFD fds[2];
+ std::memset(fds, 0, sizeof(fds));
+ fds[0].fd = socket_;
+ fds[0].events = THRIFT_POLLIN;
+ fds[1].fd = *(interruptListener_.get());
+ fds[1].events = THRIFT_POLLIN;
+
+ int ret = THRIFT_POLL(fds, 2, (recvTimeout_ == 0) ? -1 : recvTimeout_);
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
+ if (ret < 0) {
+ // error cases
+ if (errno_copy == THRIFT_EINTR && (retries++ < maxRecvRetries_)) {
+ goto try_again;
+ }
+ GlobalOutput.perror("TSocket::read() THRIFT_POLL() ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
+ } else if (ret > 0) {
+ // Check the interruptListener
+ if (fds[1].revents & THRIFT_POLLIN) {
+ throw TTransportException(TTransportException::INTERRUPTED,
+ "Interrupted");
+ }
+ } else /* ret == 0 */ {
+ throw TTransportException(TTransportException::TIMED_OUT,
+ "THRIFT_EAGAIN (timed out)");
+ }
+
+ // falling through means there is something to recv and it cannot block
+ }
+
+ got = static_cast<int>(recv(socket_, cast_sockopt(buf), len, 0));
+ // THRIFT_GETTIMEOFDAY can change THRIFT_GET_SOCKET_ERROR
+ int errno_copy = THRIFT_GET_SOCKET_ERROR;
// Check for error on read
if (got < 0) {
@@ -493,29 +585,9 @@
goto try_again;
}
-#if defined __FreeBSD__ || defined __MACH__
if (errno_copy == THRIFT_ECONNRESET) {
- /* shigin: freebsd doesn't follow POSIX semantic of recv and fails with
- * THRIFT_ECONNRESET if peer performed shutdown
- * edhall: eliminated close() since we do that in the destructor.
- */
return 0;
}
-#endif
-
-#ifdef _WIN32
- if (errno_copy == WSAECONNRESET) {
- return 0; // EOF
- }
-#endif
-
- // Now it's not a try again case, but a real probblez
- GlobalOutput.perror("TSocket::read() recv() " + getSocketInfo(), errno_copy);
-
- // If we disconnect with no linger time
- if (errno_copy == THRIFT_ECONNRESET) {
- throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_ECONNRESET");
- }
// This ish isn't open
if (errno_copy == THRIFT_ENOTCONN) {
@@ -527,18 +599,13 @@
throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_ETIMEDOUT");
}
+ // Now it's not a try again case, but a real probblez
+ GlobalOutput.perror("TSocket::read() recv() " + getSocketInfo(), errno_copy);
+
// Some other error, whatevz
throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
}
- // The remote host has closed the socket
- if (got == 0) {
- // edhall: we used to call close() here, but our caller may want to deal
- // with the socket fd and we'll close() in our destructor in any case.
- return 0;
- }
-
- // Pack data into string
return got;
}
diff --git a/lib/cpp/src/thrift/transport/TSocket.h b/lib/cpp/src/thrift/transport/TSocket.h
index 427b91f..09a64e3 100644
--- a/lib/cpp/src/thrift/transport/TSocket.h
+++ b/lib/cpp/src/thrift/transport/TSocket.h
@@ -61,7 +61,7 @@
* @param host An IP address or hostname to connect to
* @param port The port to connect on
*/
- TSocket(std::string host, int port);
+ TSocket(const std::string& host, int port);
/**
* Constructs a new Unix domain socket.
@@ -69,7 +69,7 @@
*
* @param path The Unix domain socket e.g. "/tmp/ThriftTest.binary.thrift"
*/
- TSocket(std::string path);
+ TSocket(const std::string& path);
/**
* Destroyes the socket object, closing it if necessary.
@@ -102,6 +102,13 @@
/**
* Reads from the underlying socket.
+ * \returns the number of bytes read or 0 indicates EOF
+ * \throws TTransportException of types:
+ * INTERRUPTED means the socket was interrupted
+ * out of a blocking call
+ * NOT_OPEN means the socket has been closed
+ * TIMED_OUT means the receive timeout expired
+ * UNKNOWN means something unexpected happened
*/
virtual uint32_t read(uint8_t* buf, uint32_t len);
@@ -242,11 +249,18 @@
virtual const std::string getOrigin();
/**
- * Constructor to create socket from raw UNIX handle.
+ * Constructor to create socket from file descriptor.
*/
TSocket(THRIFT_SOCKET socket);
/**
+ * Constructor to create socket from file descriptor that
+ * can be interrupted safely.
+ */
+ TSocket(THRIFT_SOCKET socket,
+ boost::shared_ptr<THRIFT_SOCKET> interruptListener);
+
+ /**
* Set a cache of the peer address (used when trivially available: e.g.
* accept() or connect()). Only caches IPV4 and IPV6; unset for others.
*/
@@ -274,9 +288,15 @@
/** UNIX domain socket path */
std::string path_;
- /** Underlying UNIX socket handle */
+ /** Underlying socket handle */
THRIFT_SOCKET socket_;
+ /**
+ * A shared socket pointer that will interrupt a blocking read if data
+ * becomes available on it
+ */
+ boost::shared_ptr<THRIFT_SOCKET> interruptListener_;
+
/** Connect timeout in ms */
int connTimeout_;