Support socket activation by fd passing
Client: cpp
Patch: Federico Giovanardi
This closes #3211
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp
index ffe9ed3..b0c9aeb 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp
@@ -117,7 +117,8 @@
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
- childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
+ childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
+ boundSocketType_(SocketType::NONE) {
}
TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
@@ -136,7 +137,8 @@
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
- childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
+ childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
+ boundSocketType_(SocketType::NONE) {
}
TServerSocket::TServerSocket(const string& address, int port)
@@ -156,7 +158,8 @@
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
- childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
+ childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
+ boundSocketType_(SocketType::NONE) {
}
TServerSocket::TServerSocket(const string& path)
@@ -176,7 +179,28 @@
listening_(false),
interruptSockWriter_(THRIFT_INVALID_SOCKET),
interruptSockReader_(THRIFT_INVALID_SOCKET),
- childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
+ childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
+ boundSocketType_(SocketType::NONE) {
+}
+TServerSocket::TServerSocket(THRIFT_SOCKET sock,SocketType socketType)
+ : interruptableChildren_(true),
+ port_(0),
+ path_(),
+ serverSocket_(sock),
+ acceptBacklog_(DEFAULT_BACKLOG),
+ sendTimeout_(0),
+ recvTimeout_(0),
+ accTimeout_(-1),
+ retryLimit_(0),
+ retryDelay_(0),
+ tcpSendBuffer_(0),
+ tcpRecvBuffer_(0),
+ keepAlive_(false),
+ listening_(false),
+ interruptSockWriter_(THRIFT_INVALID_SOCKET),
+ interruptSockReader_(THRIFT_INVALID_SOCKET),
+ childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
+ boundSocketType_(socketType) {
}
TServerSocket::~TServerSocket() {
@@ -439,7 +463,8 @@
if (isUnixDomainSocket()) {
// -- Unix Domain Socket -- //
- serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);
+ if (serverSocket_ == THRIFT_INVALID_SOCKET)
+ serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);
if (serverSocket_ == THRIFT_INVALID_SOCKET) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
@@ -471,6 +496,8 @@
throw TTransportException(TTransportException::NOT_OPEN,
" Unix Domain socket path not supported");
#endif
+ } else if( boundSocketType_ != SocketType::NONE){
+ // -- Socket is already bound
} else {
// -- TCP socket -- //
@@ -516,25 +543,31 @@
// use short circuit evaluation here to only sleep if we need to
} while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0));
- // retrieve bind info
- if (port_ == 0 && retries <= retryLimit_) {
- struct sockaddr_storage sa;
- socklen_t len = sizeof(sa);
- std::memset(&sa, 0, len);
- if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr*>(&sa), &len) < 0) {
- errno_copy = THRIFT_GET_SOCKET_ERROR;
- GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
+ } // TCP socket //
+
+ // retrieve bind info
+ if ((port_ == 0 || path_.empty() ) && retries <= retryLimit_) {
+ struct sockaddr_storage sa;
+ socklen_t len = sizeof(sa);
+ std::memset(&sa, 0, len);
+ if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr*>(&sa), &len) < 0) {
+ errno_copy = THRIFT_GET_SOCKET_ERROR;
+ GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
+ } else {
+ if (sa.ss_family == AF_INET6) {
+ const auto* sin = reinterpret_cast<const struct sockaddr_in6*>(&sa);
+ port_ = ntohs(sin->sin6_port);
+ } else if (sa.ss_family == AF_INET) {
+ const auto* sin = reinterpret_cast<const struct sockaddr_in*>(&sa);
+ port_ = ntohs(sin->sin_port);
+ } else if (sa.ss_family == AF_UNIX) {
+ const auto* sin = reinterpret_cast<const struct sockaddr_un*>(&sa);
+ path_ = sin->sun_path;
} else {
- if (sa.ss_family == AF_INET6) {
- const auto* sin = reinterpret_cast<const struct sockaddr_in6*>(&sa);
- port_ = ntohs(sin->sin6_port);
- } else {
- const auto* sin = reinterpret_cast<const struct sockaddr_in*>(&sa);
- port_ = ntohs(sin->sin_port);
- }
+ GlobalOutput.perror("TServerSocket::getPort() getsockname() unhandled socket type",EINVAL);
}
}
- } // TCP socket //
+ }
// throw error if socket still wasn't created successfully
if (serverSocket_ == THRIFT_INVALID_SOCKET) {
@@ -569,7 +602,7 @@
listenCallback_(serverSocket_);
// Call listen
- if (-1 == ::listen(serverSocket_, acceptBacklog_)) {
+ if (boundSocketType_ == SocketType::NONE && -1 == ::listen(serverSocket_, acceptBacklog_)) {
errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy);
close();
@@ -734,7 +767,8 @@
concurrency::Guard g(rwMutex_);
if (serverSocket_ != THRIFT_INVALID_SOCKET) {
shutdown(serverSocket_, THRIFT_SHUT_RDWR);
- ::THRIFT_CLOSESOCKET(serverSocket_);
+ if( boundSocketType_ == SocketType::NONE) //Do not close the server socket if it owned by systemd
+ ::THRIFT_CLOSESOCKET(serverSocket_);
}
if (interruptSockWriter_ != THRIFT_INVALID_SOCKET) {
::THRIFT_CLOSESOCKET(interruptSockWriter_);
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h
index e826707..4d43ccf 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.h
+++ b/lib/cpp/src/thrift/transport/TServerSocket.h
@@ -40,6 +40,13 @@
class TSocket;
+enum class SocketType {
+ NONE,
+ INET,
+ INET6,
+ UNIX
+};
+
/**
* Server socket implementation of TServerTransport. Wrapper around a unix
* socket listen and accept calls.
@@ -82,6 +89,14 @@
*/
TServerSocket(const std::string& path);
+ /**
+ * Constructor used for to initialize from an already bound unix socket.
+ * Useful for socket activation on systemd.
+ *
+ * @param fd
+ */
+ TServerSocket(THRIFT_SOCKET sock,SocketType socketType);
+
~TServerSocket() override;
@@ -172,6 +187,7 @@
socket_func_t listenCallback_;
socket_func_t acceptCallback_;
+ SocketType boundSocketType_;
};
}
}
diff --git a/test/cpp/src/TestServer.cpp b/test/cpp/src/TestServer.cpp
index 858fffa..dc95af6 100644
--- a/test/cpp/src/TestServer.cpp
+++ b/test/cpp/src/TestServer.cpp
@@ -31,6 +31,7 @@
#include <thrift/server/TSimpleServer.h>
#include <thrift/server/TThreadPoolServer.h>
#include <thrift/server/TThreadedServer.h>
+#include <thrift/transport/PlatformSocket.h>
#include <thrift/transport/THttpServer.h>
#include <thrift/transport/THttpTransport.h>
#include <thrift/transport/TNonblockingSSLServerSocket.h>
@@ -54,14 +55,21 @@
#ifdef HAVE_SIGNAL_H
#include <signal.h>
#endif
+#ifdef HAVE_SYS_SOCKET_H
+#include <sys/socket.h>
+#endif
+#ifdef HAVE_SYS_UN_H
+#include <sys/un.h>
+#endif
#include <iostream>
-#include <stdexcept>
+#include <memory>
#include <sstream>
+#include <stdexcept>
#include <boost/algorithm/string.hpp>
-#include <boost/program_options.hpp>
#include <boost/filesystem.hpp>
+#include <boost/program_options.hpp>
#if _WIN32
#include <thrift/windows/TWinsockSingleton.h>
@@ -570,6 +578,47 @@
std::shared_ptr<TestHandler> _delegate;
};
+struct DomainSocketFd {
+ THRIFT_SOCKET socket_fd;
+ std::string path;
+ DomainSocketFd(const std::string& path) : path(path) {
+#ifdef HAVE_SYS_UN_H
+ unlink(path.c_str());
+ socket_fd = socket(AF_UNIX, SOCK_STREAM, IPPROTO_IP);
+ if (socket_fd == -1) {
+ std::ostringstream os;
+ os << "Cannot create domain socket: " << strerror(errno);
+ throw std::runtime_error(os.str());
+ }
+ if (path.size() > sizeof(sockaddr_un::sun_path) - 1)
+ throw std::runtime_error("Path size on domain socket too big");
+ struct sockaddr_un sa;
+ memset(&sa, 0, sizeof(sa));
+ sa.sun_family = AF_UNIX;
+ strcpy(sa.sun_path, path.c_str());
+ int rv = bind(socket_fd, (struct sockaddr*)&sa, sizeof(sa));
+ if (rv == -1) {
+ std::ostringstream os;
+ os << "Cannot bind domain socket: " << strerror(errno);
+ throw std::runtime_error(os.str());
+ }
+
+ rv = ::listen(socket_fd, 16);
+ if (rv == -1) {
+ std::ostringstream os;
+ os << "Cannot listen on domain socket: " << strerror(errno);
+ throw std::runtime_error(os.str());
+ }
+#else
+ throw std::runtime_error("Cannot create a domain socket without AF_UNIX");
+#endif
+ }
+ ~DomainSocketFd() {
+ ::THRIFT_CLOSESOCKET(socket_fd);
+ unlink(path.c_str());
+ }
+};
+
namespace po = boost::program_options;
int main(int argc, char** argv) {
@@ -589,6 +638,8 @@
string server_type = "simple";
string domain_socket = "";
bool abstract_namespace = false;
+ bool emulate_socketactivation = false;
+ std::unique_ptr<DomainSocketFd> domain_socket_fd;
size_t workers = 4;
int string_limit = 0;
int container_limit = 0;
@@ -599,6 +650,7 @@
("port", po::value<int>(&port)->default_value(port), "Port number to listen")
("domain-socket", po::value<string>(&domain_socket) ->default_value(domain_socket), "Unix Domain Socket (e.g. /tmp/ThriftTest.thrift)")
("abstract-namespace", "Create the domain socket in the Abstract Namespace (no connection with filesystem pathnames)")
+ ("emulate-socketactivation","Open the socket from the tester program and pass the library an already open fd")
("server-type", po::value<string>(&server_type)->default_value(server_type), "type of server, \"simple\", \"thread-pool\", \"threaded\", or \"nonblocking\"")
("transport", po::value<string>(&transport_type)->default_value(transport_type), "transport: buffered, framed, http, websocket, zlib")
("protocol", po::value<string>(&protocol_type)->default_value(protocol_type), "protocol: binary, compact, header, json, multi, multic, multih, multij")
@@ -678,6 +730,9 @@
if (vm.count("abstract-namespace")) {
abstract_namespace = true;
}
+ if (vm.count("emulate-socketactivation")) {
+ emulate_socketactivation = true;
+ }
// Dispatcher
std::shared_ptr<TProtocolFactory> protocolFactory;
@@ -727,8 +782,16 @@
abstract_socket += domain_socket;
serverSocket = std::shared_ptr<TServerSocket>(new TServerSocket(abstract_socket));
} else {
- unlink(domain_socket.c_str());
- serverSocket = std::shared_ptr<TServerSocket>(new TServerSocket(domain_socket));
+ if (emulate_socketactivation) {
+ unlink(domain_socket.c_str());
+ // open and bind the socket
+ domain_socket_fd.reset(new DomainSocketFd(domain_socket));
+ serverSocket = std::shared_ptr<TServerSocket>(
+ new TServerSocket(domain_socket_fd->socket_fd, SocketType::UNIX));
+ } else {
+ unlink(domain_socket.c_str());
+ serverSocket = std::shared_ptr<TServerSocket>(new TServerSocket(domain_socket));
+ }
}
port = 0;
} else {
diff --git a/test/crossrunner/run.py b/test/crossrunner/run.py
index 3ccc6e3..e532417 100644
--- a/test/crossrunner/run.py
+++ b/test/crossrunner/run.py
@@ -306,7 +306,7 @@
return port if ok else self._get_domain_port()
def alloc_port(self, socket_type):
- if socket_type in ('domain', 'abstract'):
+ if socket_type in ('domain', 'abstract','domain-socketactivated'):
return self._get_domain_port()
else:
return self._get_tcp_port()
@@ -323,7 +323,7 @@
self._log.debug('free_port')
self._lock.acquire()
try:
- if socket_type == 'domain':
+ if socket_type in ['domain','domain-socketactivated']:
self._dom_ports.remove(port)
path = domain_socket_path(port)
if os.path.exists(path):
diff --git a/test/crossrunner/test.py b/test/crossrunner/test.py
index 2a1a4da..3da38f4 100644
--- a/test/crossrunner/test.py
+++ b/test/crossrunner/test.py
@@ -59,9 +59,11 @@
return cmd
def _socket_args(self, socket, port):
+ support_socket_activation = self.kind == 'server' and sys.platform != "win32"
return {
'ip-ssl': ['--ssl'],
'domain': ['--domain-socket=%s' % domain_socket_path(port)],
+ 'domain-socketactivated': (['--emulate-socketactivation'] if support_socket_activation else []) + ['--domain-socket=%s' % domain_socket_path(port)],
'abstract': ['--abstract-namespace', '--domain-socket=%s' % domain_socket_path(port)],
}.get(socket, None)
diff --git a/test/tests.json b/test/tests.json
index 16b47ac..9731a88 100644
--- a/test/tests.json
+++ b/test/tests.json
@@ -404,13 +404,13 @@
"buffered",
"http",
"framed",
- "zlib",
- "websocket"
+ "zlib"
],
"sockets": [
"ip",
"ip-ssl",
- "domain"
+ "domain",
+ "domain-socketactivated"
],
"protocols": [
"compact",