Support socket activation by fd passing
Client: cpp
Patch: Federico Giovanardi

This closes #3211
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp
index ffe9ed3..b0c9aeb 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp
@@ -117,7 +117,8 @@
     listening_(false),
     interruptSockWriter_(THRIFT_INVALID_SOCKET),
     interruptSockReader_(THRIFT_INVALID_SOCKET),
-    childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
+    childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
+    boundSocketType_(SocketType::NONE) {
 }
 
 TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
@@ -136,7 +137,8 @@
     listening_(false),
     interruptSockWriter_(THRIFT_INVALID_SOCKET),
     interruptSockReader_(THRIFT_INVALID_SOCKET),
-    childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
+    childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
+    boundSocketType_(SocketType::NONE) {
 }
 
 TServerSocket::TServerSocket(const string& address, int port)
@@ -156,7 +158,8 @@
     listening_(false),
     interruptSockWriter_(THRIFT_INVALID_SOCKET),
     interruptSockReader_(THRIFT_INVALID_SOCKET),
-    childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
+    childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
+    boundSocketType_(SocketType::NONE) {
 }
 
 TServerSocket::TServerSocket(const string& path)
@@ -176,7 +179,28 @@
     listening_(false),
     interruptSockWriter_(THRIFT_INVALID_SOCKET),
     interruptSockReader_(THRIFT_INVALID_SOCKET),
-    childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
+    childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
+    boundSocketType_(SocketType::NONE) {
+}
+TServerSocket::TServerSocket(THRIFT_SOCKET sock,SocketType socketType)
+  : interruptableChildren_(true),
+    port_(0),
+    path_(),
+    serverSocket_(sock),
+    acceptBacklog_(DEFAULT_BACKLOG),
+    sendTimeout_(0),
+    recvTimeout_(0),
+    accTimeout_(-1),
+    retryLimit_(0),
+    retryDelay_(0),
+    tcpSendBuffer_(0),
+    tcpRecvBuffer_(0),
+    keepAlive_(false),
+    listening_(false),
+    interruptSockWriter_(THRIFT_INVALID_SOCKET),
+    interruptSockReader_(THRIFT_INVALID_SOCKET),
+    childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
+    boundSocketType_(socketType) {
 }
 
 TServerSocket::~TServerSocket() {
@@ -439,7 +463,8 @@
   if (isUnixDomainSocket()) {
     // -- Unix Domain Socket -- //
 
-    serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);
+    if (serverSocket_ == THRIFT_INVALID_SOCKET)
+      serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);
 
     if (serverSocket_ == THRIFT_INVALID_SOCKET) {
       int errno_copy = THRIFT_GET_SOCKET_ERROR;
@@ -471,6 +496,8 @@
     throw TTransportException(TTransportException::NOT_OPEN,
                               " Unix Domain socket path not supported");
 #endif
+  } else if( boundSocketType_ != SocketType::NONE){
+    // -- Socket is already bound
   } else {
     // -- TCP socket -- //
 
@@ -516,25 +543,31 @@
       // use short circuit evaluation here to only sleep if we need to
     } while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0));
 
-    // retrieve bind info
-    if (port_ == 0 && retries <= retryLimit_) {
-      struct sockaddr_storage sa;
-      socklen_t len = sizeof(sa);
-      std::memset(&sa, 0, len);
-      if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr*>(&sa), &len) < 0) {
-        errno_copy = THRIFT_GET_SOCKET_ERROR;
-        GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
+  } // TCP socket //
+
+  // retrieve bind info
+  if ((port_ == 0 || path_.empty() ) && retries <= retryLimit_) {
+    struct sockaddr_storage sa;
+    socklen_t len = sizeof(sa);
+    std::memset(&sa, 0, len);
+    if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr*>(&sa), &len) < 0) {
+      errno_copy = THRIFT_GET_SOCKET_ERROR;
+      GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
+    } else {
+      if (sa.ss_family == AF_INET6) {
+        const auto* sin = reinterpret_cast<const struct sockaddr_in6*>(&sa);
+        port_ = ntohs(sin->sin6_port);
+      } else if (sa.ss_family == AF_INET) {
+        const auto* sin = reinterpret_cast<const struct sockaddr_in*>(&sa);
+        port_ = ntohs(sin->sin_port);
+      } else if (sa.ss_family == AF_UNIX) {
+        const auto* sin = reinterpret_cast<const struct sockaddr_un*>(&sa);
+        path_ = sin->sun_path;
       } else {
-        if (sa.ss_family == AF_INET6) {
-          const auto* sin = reinterpret_cast<const struct sockaddr_in6*>(&sa);
-          port_ = ntohs(sin->sin6_port);
-        } else {
-          const auto* sin = reinterpret_cast<const struct sockaddr_in*>(&sa);
-          port_ = ntohs(sin->sin_port);
-        }
+        GlobalOutput.perror("TServerSocket::getPort() getsockname() unhandled socket type",EINVAL);
       }
     }
-  } // TCP socket //
+  }
 
   // throw error if socket still wasn't created successfully
   if (serverSocket_ == THRIFT_INVALID_SOCKET) {
@@ -569,7 +602,7 @@
     listenCallback_(serverSocket_);
 
   // Call listen
-  if (-1 == ::listen(serverSocket_, acceptBacklog_)) {
+  if (boundSocketType_ == SocketType::NONE && -1 == ::listen(serverSocket_, acceptBacklog_)) {
     errno_copy = THRIFT_GET_SOCKET_ERROR;
     GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy);
     close();
@@ -734,7 +767,8 @@
   concurrency::Guard g(rwMutex_);
   if (serverSocket_ != THRIFT_INVALID_SOCKET) {
     shutdown(serverSocket_, THRIFT_SHUT_RDWR);
-    ::THRIFT_CLOSESOCKET(serverSocket_);
+    if( boundSocketType_ == SocketType::NONE) //Do not close the server socket if it owned by systemd
+      ::THRIFT_CLOSESOCKET(serverSocket_);
   }
   if (interruptSockWriter_ != THRIFT_INVALID_SOCKET) {
     ::THRIFT_CLOSESOCKET(interruptSockWriter_);
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h
index e826707..4d43ccf 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.h
+++ b/lib/cpp/src/thrift/transport/TServerSocket.h
@@ -40,6 +40,13 @@
 
 class TSocket;
 
+enum class SocketType {
+    NONE,
+    INET,
+    INET6,
+    UNIX
+};
+
 /**
  * Server socket implementation of TServerTransport. Wrapper around a unix
  * socket listen and accept calls.
@@ -82,6 +89,14 @@
    */
   TServerSocket(const std::string& path);
 
+  /**
+   * Constructor used for to initialize from an already bound unix socket.
+   * Useful for socket activation on systemd.
+   *
+   * @param fd
+   */
+  TServerSocket(THRIFT_SOCKET sock,SocketType socketType);
+
   ~TServerSocket() override;
 
 
@@ -172,6 +187,7 @@
 
   socket_func_t listenCallback_;
   socket_func_t acceptCallback_;
+  SocketType boundSocketType_;
 };
 }
 }
diff --git a/test/cpp/src/TestServer.cpp b/test/cpp/src/TestServer.cpp
index 858fffa..dc95af6 100644
--- a/test/cpp/src/TestServer.cpp
+++ b/test/cpp/src/TestServer.cpp
@@ -31,6 +31,7 @@
 #include <thrift/server/TSimpleServer.h>
 #include <thrift/server/TThreadPoolServer.h>
 #include <thrift/server/TThreadedServer.h>
+#include <thrift/transport/PlatformSocket.h>
 #include <thrift/transport/THttpServer.h>
 #include <thrift/transport/THttpTransport.h>
 #include <thrift/transport/TNonblockingSSLServerSocket.h>
@@ -54,14 +55,21 @@
 #ifdef HAVE_SIGNAL_H
 #include <signal.h>
 #endif
+#ifdef HAVE_SYS_SOCKET_H
+#include <sys/socket.h>
+#endif
+#ifdef HAVE_SYS_UN_H
+#include <sys/un.h>
+#endif
 
 #include <iostream>
-#include <stdexcept>
+#include <memory>
 #include <sstream>
+#include <stdexcept>
 
 #include <boost/algorithm/string.hpp>
-#include <boost/program_options.hpp>
 #include <boost/filesystem.hpp>
+#include <boost/program_options.hpp>
 
 #if _WIN32
 #include <thrift/windows/TWinsockSingleton.h>
@@ -570,6 +578,47 @@
   std::shared_ptr<TestHandler> _delegate;
 };
 
+struct DomainSocketFd {
+  THRIFT_SOCKET socket_fd;
+  std::string path;
+  DomainSocketFd(const std::string& path) : path(path) {
+#ifdef HAVE_SYS_UN_H
+    unlink(path.c_str());
+    socket_fd = socket(AF_UNIX, SOCK_STREAM, IPPROTO_IP);
+    if (socket_fd == -1) {
+      std::ostringstream os;
+      os << "Cannot create domain socket: " << strerror(errno);
+      throw std::runtime_error(os.str());
+    }
+    if (path.size() > sizeof(sockaddr_un::sun_path) - 1)
+      throw std::runtime_error("Path size on domain socket too big");
+    struct sockaddr_un sa;
+    memset(&sa, 0, sizeof(sa));
+    sa.sun_family = AF_UNIX;
+    strcpy(sa.sun_path, path.c_str());
+    int rv = bind(socket_fd, (struct sockaddr*)&sa, sizeof(sa));
+    if (rv == -1) {
+      std::ostringstream os;
+      os << "Cannot bind domain socket: " << strerror(errno);
+      throw std::runtime_error(os.str());
+    }
+
+    rv = ::listen(socket_fd, 16);
+    if (rv == -1) {
+      std::ostringstream os;
+      os << "Cannot listen on domain socket: " << strerror(errno);
+      throw std::runtime_error(os.str());
+    }
+#else
+    throw std::runtime_error("Cannot create a domain socket without AF_UNIX");
+#endif
+  }
+  ~DomainSocketFd() {
+    ::THRIFT_CLOSESOCKET(socket_fd);
+    unlink(path.c_str());
+  }
+};
+
 namespace po = boost::program_options;
 
 int main(int argc, char** argv) {
@@ -589,6 +638,8 @@
   string server_type = "simple";
   string domain_socket = "";
   bool abstract_namespace = false;
+  bool emulate_socketactivation = false;
+  std::unique_ptr<DomainSocketFd> domain_socket_fd;
   size_t workers = 4;
   int string_limit = 0;
   int container_limit = 0;
@@ -599,6 +650,7 @@
     ("port", po::value<int>(&port)->default_value(port), "Port number to listen")
     ("domain-socket", po::value<string>(&domain_socket) ->default_value(domain_socket), "Unix Domain Socket (e.g. /tmp/ThriftTest.thrift)")
     ("abstract-namespace", "Create the domain socket in the Abstract Namespace (no connection with filesystem pathnames)")
+    ("emulate-socketactivation","Open the socket from the tester program and pass the library an already open fd")
     ("server-type", po::value<string>(&server_type)->default_value(server_type), "type of server, \"simple\", \"thread-pool\", \"threaded\", or \"nonblocking\"")
     ("transport", po::value<string>(&transport_type)->default_value(transport_type), "transport: buffered, framed, http, websocket, zlib")
     ("protocol", po::value<string>(&protocol_type)->default_value(protocol_type), "protocol: binary, compact, header, json, multi, multic, multih, multij")
@@ -678,6 +730,9 @@
   if (vm.count("abstract-namespace")) {
     abstract_namespace = true;
   }
+  if (vm.count("emulate-socketactivation")) {
+    emulate_socketactivation = true;
+  }
 
   // Dispatcher
   std::shared_ptr<TProtocolFactory> protocolFactory;
@@ -727,8 +782,16 @@
         abstract_socket += domain_socket;
         serverSocket = std::shared_ptr<TServerSocket>(new TServerSocket(abstract_socket));
       } else {
-        unlink(domain_socket.c_str());
-        serverSocket = std::shared_ptr<TServerSocket>(new TServerSocket(domain_socket));
+        if (emulate_socketactivation) {
+          unlink(domain_socket.c_str());
+          // open and bind the socket
+          domain_socket_fd.reset(new DomainSocketFd(domain_socket));
+          serverSocket = std::shared_ptr<TServerSocket>(
+              new TServerSocket(domain_socket_fd->socket_fd, SocketType::UNIX));
+        } else {
+          unlink(domain_socket.c_str());
+          serverSocket = std::shared_ptr<TServerSocket>(new TServerSocket(domain_socket));
+        }
       }
       port = 0;
     } else {
diff --git a/test/crossrunner/run.py b/test/crossrunner/run.py
index 3ccc6e3..e532417 100644
--- a/test/crossrunner/run.py
+++ b/test/crossrunner/run.py
@@ -306,7 +306,7 @@
         return port if ok else self._get_domain_port()
 
     def alloc_port(self, socket_type):
-        if socket_type in ('domain', 'abstract'):
+        if socket_type in ('domain', 'abstract','domain-socketactivated'):
             return self._get_domain_port()
         else:
             return self._get_tcp_port()
@@ -323,7 +323,7 @@
         self._log.debug('free_port')
         self._lock.acquire()
         try:
-            if socket_type == 'domain':
+            if socket_type in ['domain','domain-socketactivated']:
                 self._dom_ports.remove(port)
                 path = domain_socket_path(port)
                 if os.path.exists(path):
diff --git a/test/crossrunner/test.py b/test/crossrunner/test.py
index 2a1a4da..3da38f4 100644
--- a/test/crossrunner/test.py
+++ b/test/crossrunner/test.py
@@ -59,9 +59,11 @@
         return cmd
 
     def _socket_args(self, socket, port):
+        support_socket_activation = self.kind == 'server' and sys.platform != "win32"
         return {
             'ip-ssl': ['--ssl'],
             'domain': ['--domain-socket=%s' % domain_socket_path(port)],
+            'domain-socketactivated': (['--emulate-socketactivation'] if support_socket_activation else []) + ['--domain-socket=%s' % domain_socket_path(port)],
             'abstract': ['--abstract-namespace', '--domain-socket=%s' % domain_socket_path(port)],
         }.get(socket, None)
 
diff --git a/test/tests.json b/test/tests.json
index 16b47ac..9731a88 100644
--- a/test/tests.json
+++ b/test/tests.json
@@ -404,13 +404,13 @@
       "buffered",
       "http",
       "framed",
-      "zlib",
-      "websocket"
+      "zlib"
     ],
     "sockets": [
       "ip",
       "ip-ssl",
-      "domain"
+      "domain",
+      "domain-socketactivated"
     ],
     "protocols": [
       "compact",