Server robustness fixes in Thrift C++ libs

Summary: ServerSockets can be interrupt() ed

Reviewed By: marc, karl


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665039 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/src/server/TSimpleServer.cpp b/lib/cpp/src/server/TSimpleServer.cpp
index 8657fab..e7d7e9b 100644
--- a/lib/cpp/src/server/TSimpleServer.cpp
+++ b/lib/cpp/src/server/TSimpleServer.cpp
@@ -57,21 +57,21 @@
       outputTransport->close();
       client->close();    
     } catch (TTransportException& ttx) {
-      inputTransport->close();
-      outputTransport->close();
-      client->close();
+      if (inputTransport.get() != NULL) { inputTransport->close(); }
+      if (outputTransport.get() != NULL) { outputTransport->close(); }
+      if (client.get() != NULL) { client->close(); }
       cerr << "TServerTransport died on accept: " << ttx.what() << endl;
       continue;
     } catch (TException& tx) {
-      inputTransport->close();
-      outputTransport->close();
-      client->close();
+      if (inputTransport.get() != NULL) { inputTransport->close(); }
+      if (outputTransport.get() != NULL) { outputTransport->close(); }
+      if (client.get() != NULL) { client->close(); }
       cerr << "Some kind of accept exception: " << tx.what() << endl;
       continue;
     } catch (string s) {
-      inputTransport->close();
-      outputTransport->close();
-      client->close();
+      if (inputTransport.get() != NULL) { inputTransport->close(); }
+      if (outputTransport.get() != NULL) { outputTransport->close(); }
+      if (client.get() != NULL) { client->close(); }
       cerr << "TThreadPoolServer: Unknown exception: " << s << endl;
       break;
     }
diff --git a/lib/cpp/src/server/TThreadPoolServer.cpp b/lib/cpp/src/server/TThreadPoolServer.cpp
index 32a0223..dcbf2f2 100644
--- a/lib/cpp/src/server/TThreadPoolServer.cpp
+++ b/lib/cpp/src/server/TThreadPoolServer.cpp
@@ -102,6 +102,7 @@
     try {
       // Fetch client from server
       client = serverTransport_->accept();
+
       // Make IO transports
       inputTransport = inputTransportFactory_->getTransport(client);
       outputTransport = outputTransportFactory_->getTransport(client);
@@ -109,25 +110,26 @@
       outputProtocol = outputProtocolFactory_->getProtocol(outputTransport);
 
       // Add to threadmanager pool
-      threadManager_->add(shared_ptr<TThreadPoolServer::Task>(new TThreadPoolServer::Task(processor_, 
-                                                                                          inputProtocol, 
-                                                                                          outputProtocol)));
+      threadManager_->add(shared_ptr<TThreadPoolServer::Task>(new TThreadPoolServer::Task(processor_, inputProtocol, outputProtocol)));
+
     } catch (TTransportException& ttx) {
-      inputTransport->close();
-      outputTransport->close();
-      client->close();
-      cerr << "TThreadPoolServer: TServerTransport died on accept: " << ttx.what() << endl;
+      if (inputTransport.get() != NULL) { inputTransport->close(); }
+      if (outputTransport.get() != NULL) { outputTransport->close(); }
+      if (client.get() != NULL) { client->close(); }
+      if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) {
+        cerr << "TThreadPoolServer: TServerTransport died on accept: " << ttx.what() << endl;
+      }
       continue;
     } catch (TException& tx) {
-      inputTransport->close();
-      outputTransport->close();
-      client->close();
+      if (inputTransport.get() != NULL) { inputTransport->close(); }
+      if (outputTransport.get() != NULL) { outputTransport->close(); }
+      if (client.get() != NULL) { client->close(); }
       cerr << "TThreadPoolServer: Caught TException: " << tx.what() << endl;
       continue;
     } catch (string s) {
-      inputTransport->close();
-      outputTransport->close();
-      client->close();
+      if (inputTransport.get() != NULL) { inputTransport->close(); }
+      if (outputTransport.get() != NULL) { outputTransport->close(); }
+      if (client.get() != NULL) { client->close(); }
       cerr << "TThreadPoolServer: Unknown exception: " << s << endl;
       break;
     }
@@ -141,8 +143,8 @@
     } catch (TException &tx) {
       cerr << "TThreadPoolServer: Exception shutting down: " << tx.what() << endl;
     }
+    stop_ = false;
   }
-  stop_ = false;
 
 }
 
diff --git a/lib/cpp/src/server/TThreadPoolServer.h b/lib/cpp/src/server/TThreadPoolServer.h
index f6809fc..bdb0e47 100644
--- a/lib/cpp/src/server/TThreadPoolServer.h
+++ b/lib/cpp/src/server/TThreadPoolServer.h
@@ -41,7 +41,10 @@
 
   virtual void serve();
   
-  virtual void stop() { stop_ = true; }
+  virtual void stop() {
+    stop_ = true;
+    serverTransport_->interrupt();
+  }
 
  protected:
 
diff --git a/lib/cpp/src/transport/TServerSocket.cpp b/lib/cpp/src/transport/TServerSocket.cpp
index e4d7c02..448898a 100644
--- a/lib/cpp/src/transport/TServerSocket.cpp
+++ b/lib/cpp/src/transport/TServerSocket.cpp
@@ -5,8 +5,10 @@
 // http://developers.facebook.com/thrift/
 
 #include <sys/socket.h>
+#include <sys/select.h>
 #include <netinet/in.h>
 #include <netinet/tcp.h>
+#include <fcntl.h>
 #include <errno.h>
 
 #include "TSocket.h"
@@ -22,14 +24,16 @@
   serverSocket_(-1),
   acceptBacklog_(1024),
   sendTimeout_(0),
-  recvTimeout_(0) {}
+  recvTimeout_(0),
+  interrupt_(false) {}
 
 TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout) :
   port_(port),
   serverSocket_(-1),
   acceptBacklog_(1024),
   sendTimeout_(sendTimeout),
-  recvTimeout_(recvTimeout) {}
+  recvTimeout_(recvTimeout),
+  interrupt_(false) {}
 
 TServerSocket::~TServerSocket() {
   close();
@@ -87,6 +91,15 @@
     throw TTransportException(TTransportException::NOT_OPEN, "Could not set TCP_NODELAY");
   }
 
+  // Set NONBLOCK on the accept socket
+  int flags = fcntl(serverSocket_, F_GETFL, 0);
+  if (flags == -1) {
+    throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed");
+  }
+  if (-1 == fcntl(serverSocket_, F_SETFL, flags | O_NONBLOCK)) {
+    throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed");
+  }
+
   // Bind to a port
   struct sockaddr_in addr;
   memset(&addr, 0, sizeof(addr));
@@ -116,6 +129,37 @@
     throw TTransportException(TTransportException::NOT_OPEN, "TServerSocket not listening");
   }
 
+  // 200ms timeout on accept
+  struct timeval c = {0, 200000};
+  fd_set fds;
+
+  while (true) {
+    FD_ZERO(&fds);
+    FD_SET(serverSocket_, &fds);
+    int ret = select(serverSocket_+1, &fds, NULL, NULL, &c);
+
+    // Check for interrupt case
+    if (ret == 0 && interrupt_) {
+      interrupt_ = false;
+      throw TTransportException(TTransportException::INTERRUPTED);
+    }
+
+    // Reset interrupt flag no matter what
+    interrupt_ = false;
+
+    if (ret > 0) {
+      break;
+    } else if (ret == 0) {
+      if (errno != EINTR && errno != EAGAIN) {
+        perror("TServerSocket::select() errcode");
+        throw TTransportException(TTransportException::UNKNOWN);
+      }
+    } else {
+      perror("TServerSocket::select() negret");
+      throw TTransportException(TTransportException::UNKNOWN);
+    }
+  }
+
   struct sockaddr_in clientAddress;
   int size = sizeof(clientAddress);
   int clientSocket = ::accept(serverSocket_,
@@ -126,6 +170,17 @@
     perror("TServerSocket::accept()");
     throw TTransportException(TTransportException::UNKNOWN, "ERROR:" + errno);
   }
+
+  // Make sure client socket is blocking
+  int flags = fcntl(clientSocket, F_GETFL, 0);
+  if (flags == -1) {
+    perror("TServerSocket::select() fcntl GETFL");
+    throw TTransportException(TTransportException::UNKNOWN, "ERROR:" + errno);
+  }
+  if (-1 == fcntl(clientSocket, F_SETFL, flags & ~O_NONBLOCK)) {
+    perror("TServerSocket::select() fcntl SETFL");
+    throw TTransportException(TTransportException::UNKNOWN, "ERROR:" + errno);
+  }
   
   shared_ptr<TSocket> client(new TSocket(clientSocket));
   if (sendTimeout_ > 0) {
@@ -133,7 +188,8 @@
   }
   if (recvTimeout_ > 0) {
     client->setRecvTimeout(recvTimeout_);
-  }                          
+  }
+  
   return client;
 }
 
diff --git a/lib/cpp/src/transport/TServerSocket.h b/lib/cpp/src/transport/TServerSocket.h
index e801f84..b482e02 100644
--- a/lib/cpp/src/transport/TServerSocket.h
+++ b/lib/cpp/src/transport/TServerSocket.h
@@ -33,6 +33,10 @@
   void listen();
   void close();
 
+  void interrupt() {
+    interrupt_ = true;
+  }
+
  protected:
   shared_ptr<TTransport> acceptImpl();
 
@@ -42,6 +46,7 @@
   int acceptBacklog_;
   int sendTimeout_;
   int recvTimeout_;
+  volatile bool interrupt_;
 };
 
 }}} // facebook::thrift::transport
diff --git a/lib/cpp/src/transport/TServerTransport.h b/lib/cpp/src/transport/TServerTransport.h
index 5a1748e..bb47756 100644
--- a/lib/cpp/src/transport/TServerTransport.h
+++ b/lib/cpp/src/transport/TServerTransport.h
@@ -52,6 +52,14 @@
   }
 
   /**
+   * For "smart" TServerTransport implementations that work in a multi
+   * threaded context this can be used to break out of an accept() call.
+   * It is expected that the transport will throw a TTransportException
+   * with the interrupted error code.
+   */
+  virtual void interrupt() {}
+
+  /**
    * Closes this transport such that future calls to accept will do nothing.
    */
   virtual void close() = 0;
diff --git a/lib/cpp/src/transport/TSocket.cpp b/lib/cpp/src/transport/TSocket.cpp
index 5d653c4..95781f0 100644
--- a/lib/cpp/src/transport/TSocket.cpp
+++ b/lib/cpp/src/transport/TSocket.cpp
@@ -157,9 +157,13 @@
   // Set the socket to be non blocking for connect if a timeout exists
   int flags = fcntl(socket_, F_GETFL, 0); 
   if (connTimeout_ > 0) {
-    fcntl(socket_, F_SETFL, flags | O_NONBLOCK);
+    if (-1 == fcntl(socket_, F_SETFL, flags | O_NONBLOCK)) {
+      throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed");
+    }
   } else {
-    fcntl(socket_, F_SETFL, flags | ~O_NONBLOCK);
+    if (-1 == fcntl(socket_, F_SETFL, flags & ~O_NONBLOCK)) {
+      throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed");
+    }
   }
 
   // Conn timeout
diff --git a/lib/cpp/src/transport/TTransportException.h b/lib/cpp/src/transport/TTransportException.h
index 52ff97e..5083766 100644
--- a/lib/cpp/src/transport/TTransportException.h
+++ b/lib/cpp/src/transport/TTransportException.h
@@ -32,6 +32,7 @@
     ALREADY_OPEN = 2,
     TIMED_OUT = 3,
     END_OF_FILE = 4,
+    INTERRUPTED = 5
   };
   
   TTransportException() :