Refactor TNonblockingServer to use event_base construct

Summary: This allows the event loop to be shared across different components of a program of for a separate thread in a TNonblockingServer to safely use its own libevent code without conflicts.

Reviewed By: mcslee

Test Plan: Updated test/ committed here

Other Notes: submitted by Ben Maurer, patched in by mcslee with slight modifications


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665364 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/configure.ac b/lib/cpp/configure.ac
index d9a66da..25b6152 100644
--- a/lib/cpp/configure.ac
+++ b/lib/cpp/configure.ac
@@ -14,6 +14,8 @@
 
 AC_FUNC_MALLOC
 
+AC_FUNC_MEMCMP
+
 AC_FUNC_REALLOC
 
 AC_FUNC_SELECT_ARGTYPES
diff --git a/lib/cpp/src/server/TNonblockingServer.cpp b/lib/cpp/src/server/TNonblockingServer.cpp
index b7ed761..ad1fb65 100644
--- a/lib/cpp/src/server/TNonblockingServer.cpp
+++ b/lib/cpp/src/server/TNonblockingServer.cpp
@@ -15,7 +15,7 @@
 #include <errno.h>
 #include <assert.h>
 
-namespace facebook { namespace thrift { namespace server { 
+namespace facebook { namespace thrift { namespace server {
 
 using namespace facebook::thrift::protocol;
 using namespace facebook::thrift::transport;
@@ -46,7 +46,7 @@
     } catch (...) {
       cerr << "TThreadedServer uncaught exception." << endl;
     }
-    
+
     // Signal completion back to the libevent thread via a socketpair
     int8_t b = 0;
     if (-1 == send(taskHandle_, &b, sizeof(int8_t), 0)) {
@@ -79,7 +79,7 @@
 
   socketState_ = SOCKET_RECV;
   appState_ = APP_INIT;
-  
+
   taskHandle_ = -1;
 
   // Set flags, which also registers the event
@@ -119,14 +119,14 @@
     // Read from the socket
     fetch = readWant_ - readBufferPos_;
     got = recv(socket_, readBuffer_ + readBufferPos_, fetch, 0);
-   
+
     if (got > 0) {
       // Move along in the buffer
       readBufferPos_ += got;
 
       // Check that we did not overdo it
       assert(readBufferPos_ <= readWant_);
-    
+
       // We are done reading, move onto the next state
       if (readBufferPos_ == readWant_) {
         transition();
@@ -145,7 +145,7 @@
 
     // Whenever we get down here it means a remote disconnect
     close();
-    
+
     return;
 
   case SOCKET_SEND:
@@ -154,7 +154,7 @@
 
     // If there is no data to send, then let us move on
     if (writeBufferPos_ == writeBufferSize_) {
-      fprintf(stderr, "WARNING: Send state with no data to send\n");
+      GlobalOutput("WARNING: Send state with no data to send\n");
       transition();
       return;
     }
@@ -186,7 +186,7 @@
     // Did we overdo it?
     assert(writeBufferPos_ <= writeBufferSize_);
 
-    // We are  done!
+    // We are done!
     if (writeBufferPos_ == writeBufferSize_) {
       transition();
     }
@@ -216,7 +216,7 @@
     // and get back some data from the dispatch function
     inputTransport_->resetBuffer(readBuffer_, readBufferPos_);
     outputTransport_->resetBuffer();
-    
+
     if (server_->isThreadPoolProcessing()) {
       // We are setting up a Task to do this work and we will wait on it
       int sv[2];
@@ -230,13 +230,19 @@
                                                inputProtocol_,
                                                outputProtocol_,
                                                sv[1]));
+        // The application is now waiting on the task to finish
         appState_ = APP_WAIT_TASK;
+
+        // Create an event to be notified when the task finishes
         event_set(&taskEvent_,
                   taskHandle_ = sv[0],
                   EV_READ,
                   TConnection::taskHandler,
                   this);
 
+        // Attach to the base
+        event_base_set(server_->getEventBase(), &taskEvent_);
+
         // Add the event and start up the server
         if (-1 == event_add(&taskEvent_, 0)) {
           GlobalOutput("TNonblockingServer::serve(): coult not event_add");
@@ -260,7 +266,7 @@
         return;
       } catch (TException &x) {
         fprintf(stderr, "TException: Server::process() %s\n", x.what());
-        close();     
+        close();
         return;
       } catch (...) {
         fprintf(stderr, "Server::process() unknown exception\n");
@@ -434,6 +440,7 @@
    * its own ev.
    */
   event_set(&event_, socket_, eventFlags_, TConnection::eventHandler, this);
+  event_base_set(server_->getEventBase(), &event_);
 
   // Add the event
   if (event_add(&event_, 0) == -1) {
@@ -493,15 +500,15 @@
 void TNonblockingServer::handleEvent(int fd, short which) {
   // Make sure that libevent didn't fuck up the socket handles
   assert(fd == serverSocket_);
-  
+
   // Server socket accepted a new connection
   socklen_t addrLen;
   struct sockaddr addr;
-  addrLen = sizeof(addr);   
-  
+  addrLen = sizeof(addr);
+
   // Going to accept a new client socket
   int clientSocket;
-  
+
   // Accept as many new clients as possible, even though libevent signaled only
   // one, this helps us to avoid having to go back into the libevent engine so
   // many times
@@ -530,7 +537,7 @@
     // Put this client connection into the proper state
     clientConnection->transition();
   }
-  
+
   // Done looping accept, now we have to make sure the error is due to
   // blocking. Any other error is a problem
   if (errno != EAGAIN && errno != EWOULDBLOCK) {
@@ -539,21 +546,13 @@
 }
 
 /**
- * Main workhorse function, starts up the server listening on a port and
- * loops over the libevent handler.
+ * Creates a socket to listen on and binds it to the local port.
  */
-void TNonblockingServer::serve() {
-  // Initialize libevent
-  event_init();
-
-  // Print some libevent stats
-  fprintf(stderr,
-          "libevent %s method %s\n",
-          event_get_version(),
-          event_get_method());
-
+void TNonblockingServer::listenSocket() {
+  int s;
   struct addrinfo hints, *res, *res0;
   int error;
+
   char port[sizeof("65536") + 1];
   memset(&hints, 0, sizeof(hints));
   hints.ai_family = PF_UNSPEC;
@@ -576,64 +575,105 @@
   }
 
   // Create the server socket
-  serverSocket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
-  if (serverSocket_ == -1) {
-    GlobalOutput("TNonblockingServer::serve() socket() -1");
-    return;
+  s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+  if (s == -1) {
+    freeaddrinfo(res0);
+    throw TException("TNonblockingServer::serve() socket() -1");
   }
 
+  int one = 1;
+
+  // Set reuseaddr to avoid 2MSL delay on server restart
+  setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
+
+  if (bind(s, res->ai_addr, res->ai_addrlen) == -1) {
+    close(s);
+    freeaddrinfo(res0);
+    throw TException("TNonblockingServer::serve() bind");
+  }
+
+  // Done with the addr info
+  freeaddrinfo(res0);
+
+  // Set up this file descriptor for listening
+  listenSocket(s);
+}
+
+/**
+ * Takes a socket created by listenSocket() and sets various options on it
+ * to prepare for use in the server.
+ */
+void TNonblockingServer::listenSocket(int s) {
   // Set socket to nonblocking mode
   int flags;
-  if ((flags = fcntl(serverSocket_, F_GETFL, 0)) < 0 ||
-      fcntl(serverSocket_, F_SETFL, flags | O_NONBLOCK) < 0) {
-    GlobalOutput("TNonblockingServer::serve() O_NONBLOCK");
-    ::close(serverSocket_);
-    return;
+  if ((flags = fcntl(s, F_GETFL, 0)) < 0 ||
+      fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
+    close(s);
+    throw TException("TNonblockingServer::serve() O_NONBLOCK");
   }
 
   int one = 1;
   struct linger ling = {0, 0};
-  
-  // Set reuseaddr to avoid 2MSL delay on server restart
-  setsockopt(serverSocket_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
 
   // Keepalive to ensure full result flushing
-  setsockopt(serverSocket_, SOL_SOCKET, SO_KEEPALIVE, &one, sizeof(one));
+  setsockopt(s, SOL_SOCKET, SO_KEEPALIVE, &one, sizeof(one));
 
   // Turn linger off to avoid hung sockets
-  setsockopt(serverSocket_, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling));
+  setsockopt(s, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling));
 
   // Set TCP nodelay if available, MAC OS X Hack
   // See http://lists.danga.com/pipermail/memcached/2005-March/001240.html
   #ifndef TCP_NOPUSH
-  setsockopt(serverSocket_, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one));
+  setsockopt(s, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one));
   #endif
 
-  if (bind(serverSocket_, res->ai_addr, res->ai_addrlen) == -1) {
-    GlobalOutput("TNonblockingServer::serve() bind");
-    close(serverSocket_);
-    return;
+  if (listen(s, LISTEN_BACKLOG) == -1) {
+    close(s);
+    throw TException("TNonblockingServer::serve() listen");
   }
 
-  if (listen(serverSocket_, LISTEN_BACKLOG) == -1) {
-    GlobalOutput("TNonblockingServer::serve() listen");
-    close(serverSocket_);
-    return;
-  }
+  // Cool, this socket is good to go, set it as the serverSocket_
+  serverSocket_ = s;
+}
+
+/**
+ * Register the core libevent events onto the proper base.
+ */
+void TNonblockingServer::registerEvents(event_base* base) {
+  assert(serverSocket_ != -1);
+  assert(!eventBase_);
+  eventBase_ = base;
+
+  // Print some libevent stats
+  fprintf(stderr,
+          "libevent %s method %s\n",
+          event_get_version(),
+          event_get_method());
 
   // Register the server event
-  struct event serverEvent;
-  event_set(&serverEvent,
+  event_set(&serverEvent_,
             serverSocket_,
             EV_READ | EV_PERSIST,
             TNonblockingServer::eventHandler,
             this);
+  event_base_set(eventBase_, &serverEvent_);
 
   // Add the event and start up the server
-  if (-1 == event_add(&serverEvent, 0)) {
-    GlobalOutput("TNonblockingServer::serve(): coult not event_add");
-    return;
+  if (-1 == event_add(&serverEvent_, 0)) {
+    throw TException("TNonblockingServer::serve(): coult not event_add");
   }
+}
+
+/**
+ * Main workhorse function, starts up the server listening on a port and
+ * loops over the libevent handler.
+ */
+void TNonblockingServer::serve() {
+  // Init socket
+  listenSocket();
+
+  // Initialize libevent core
+  registerEvents(static_cast<event_base*>(event_init()));
 
   // Run pre-serve callback function if we have one
   if (preServeCallback_) {
@@ -641,7 +681,7 @@
   }
 
   // Run libevent engine, never returns, invokes calls to eventHandler
-  event_loop(0);
+  event_base_loop(eventBase_, 0);
 }
 
 }}} // facebook::thrift::server
diff --git a/lib/cpp/src/server/TNonblockingServer.h b/lib/cpp/src/server/TNonblockingServer.h
index 5470ad4..2cdd897 100644
--- a/lib/cpp/src/server/TNonblockingServer.h
+++ b/lib/cpp/src/server/TNonblockingServer.h
@@ -14,7 +14,7 @@
 #include <stack>
 #include <event.h>
 
-namespace facebook { namespace thrift { namespace server { 
+namespace facebook { namespace thrift { namespace server {
 
 using facebook::thrift::transport::TMemoryBuffer;
 using facebook::thrift::protocol::TProtocol;
@@ -55,6 +55,12 @@
   // Is thread pool processing?
   bool threadPoolProcessing_;
 
+  // The event base for libevent
+  event_base* eventBase_;
+
+  // Event struct, for use with eventBase_
+  struct event serverEvent_;
+
   /**
    * This is a stack of all the objects that have been created but that
    * are NOT currently in use. When we close a connection, we place it on this
@@ -64,7 +70,7 @@
   std::stack<TConnection*> connectionStack_;
 
   // Pointer to optional function called after opening the listen socket and
-  // before running the event loop, along with its argument data
+  // before running the event loop, along with its argument data.
   void (*preServeCallback_)(void*);
   void* preServeCallbackArg_;
 
@@ -74,22 +80,24 @@
   TNonblockingServer(boost::shared_ptr<TProcessor> processor,
                      int port) :
     TServer(processor),
-    serverSocket_(0),
+    serverSocket_(-1),
     port_(port),
     frameResponses_(true),
     threadPoolProcessing_(false),
+    eventBase_(NULL),
     preServeCallback_(NULL),
     preServeCallbackArg_(NULL) {}
 
-  TNonblockingServer(boost::shared_ptr<TProcessor> processor, 
+  TNonblockingServer(boost::shared_ptr<TProcessor> processor,
                      boost::shared_ptr<TProtocolFactory> protocolFactory,
                      int port,
                      boost::shared_ptr<ThreadManager> threadManager = boost::shared_ptr<ThreadManager>()) :
     TServer(processor),
-    serverSocket_(0),
+    serverSocket_(-1),
     port_(port),
     frameResponses_(true),
-    threadManager_(threadManager), 
+    threadManager_(threadManager),
+    eventBase_(NULL),
     preServeCallback_(NULL),
     preServeCallbackArg_(NULL) {
     setInputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory()));
@@ -119,7 +127,7 @@
     setOutputProtocolFactory(outputProtocolFactory);
     setThreadManager(threadManager);
   }
-        
+
   ~TNonblockingServer() {}
 
   void setThreadManager(boost::shared_ptr<ThreadManager> threadManager) {
@@ -127,7 +135,7 @@
     threadPoolProcessing_ = (threadManager != NULL);
   }
 
-  bool isThreadPoolProcessing() {
+  bool isThreadPoolProcessing() const {
     return threadPoolProcessing_;
   }
 
@@ -139,10 +147,14 @@
     frameResponses_ = frameResponses;
   }
 
-  bool getFrameResponses() {
+  bool getFrameResponses() const {
     return frameResponses_;
   }
 
+  event_base* getEventBase() const {
+    return eventBase_;
+  }
+
   TConnection* createConnection(int socket, short flags);
 
   void returnConnection(TConnection* connection);
@@ -151,6 +163,12 @@
     ((TNonblockingServer*)v)->handleEvent(fd, which);
   }
 
+  void listenSocket();
+
+  void listenSocket(int fd);
+
+  void registerEvents(event_base* base);
+
   void serve();
 
   void setPreServeCallback(void(*fn_ptr)(void*), void* arg = NULL) {
@@ -225,7 +243,7 @@
 
   // Write buffer
   uint8_t* writeBuffer_;
-  
+
   // Write buffer size
   uint32_t writeBufferSize_;
 
@@ -256,7 +274,7 @@
 
   // Protocol encoder
   boost::shared_ptr<TProtocol> outputProtocol_;
-  
+
   // Go into read mode
   void setRead() {
     setFlags(EV_READ | EV_PERSIST);
@@ -290,13 +308,13 @@
       throw new facebook::thrift::TException("Out of memory.");
     }
     readBufferSize_ = 1024;
-    
+
     // Allocate input and output tranpsorts
-    // these only need to be allocated once per TConnection (they don't need to be 
+    // these only need to be allocated once per TConnection (they don't need to be
     // reallocated on init() call)
     inputTransport_ = boost::shared_ptr<TMemoryBuffer>(new TMemoryBuffer(readBuffer_, readBufferSize_));
     outputTransport_ = boost::shared_ptr<TMemoryBuffer>(new TMemoryBuffer());
-        
+
     init(socket, eventFlags, s);
   }
 
@@ -311,7 +329,7 @@
     assert(fd == ((TConnection*)v)->socket_);
     ((TConnection*)v)->workSocket();
   }
-  
+
   // Handler wrapper for task block
   static void taskHandler(int fd, short which, void* v) {
     assert(fd == ((TConnection*)v)->taskHandle_);
diff --git a/lib/cpp/src/transport/TTransportUtils.h b/lib/cpp/src/transport/TTransportUtils.h
index 3a97448..1ebbfbd 100644
--- a/lib/cpp/src/transport/TTransportUtils.h
+++ b/lib/cpp/src/transport/TTransportUtils.h
@@ -12,14 +12,14 @@
 #include <transport/TTransport.h>
 #include <transport/TFileTransport.h>
 
-namespace facebook { namespace thrift { namespace transport { 
+namespace facebook { namespace thrift { namespace transport {
 
 /**
  * The null transport is a dummy transport that doesn't actually do anything.
  * It's sort of an analogy to /dev/null, you can never read anything from it
  * and it will let you write anything you want to it, though it won't actually
  * go anywhere.
- * 
+ *
  * @author Mark Slee <mcslee@facebook.com>
  */
 class TNullTransport : public TTransport {
@@ -82,8 +82,8 @@
   bool isOpen() {
     return transport_->isOpen();
   }
-  
-  bool peek() {    
+
+  bool peek() {
     if (rPos_ >= rLen_) {
       rLen_ = transport_->read(rBuf_, rBufSize_);
       rPos_ = 0;
@@ -101,7 +101,7 @@
   }
 
   uint32_t read(uint8_t* buf, uint32_t len);
-  
+
   void write(const uint8_t* buf, uint32_t len);
 
   void flush();
@@ -192,7 +192,7 @@
   void setWrite(bool write) {
     write_ = write;
   }
- 
+
   void open() {
     transport_->open();
   }
@@ -209,10 +209,12 @@
   }
 
   void close() {
-    flush();
+    if (wLen_ > 0) {
+      flush();
+    }
     transport_->close();
   }
- 
+
   uint32_t read(uint8_t* buf, uint32_t len);
 
   void write(const uint8_t* buf, uint32_t len);
@@ -397,16 +399,16 @@
 
   void consume(uint32_t len);
 
- private: 
+ private:
   // Data buffer
   uint8_t* buffer_;
-  
+
   // Allocated buffer size
   uint32_t bufferSize_;
 
   // Where the write is at
   uint32_t wPos_;
-  
+
   // Where the reader is at
   uint32_t rPos_;
 
@@ -416,17 +418,17 @@
 };
 
 /**
- * TPipedTransport. This transport allows piping of a request from one 
+ * TPipedTransport. This transport allows piping of a request from one
  * transport to another either when readEnd() or writeEnd(). The typical
  * use case for this is to log a request or a reply to disk.
- * The underlying buffer expands to a keep a copy of the entire 
+ * The underlying buffer expands to a keep a copy of the entire
  * request/response.
  *
  * @author Aditya Agarwal <aditya@facebook.com>
  */
 class TPipedTransport : virtual public TTransport {
  public:
-  TPipedTransport(boost::shared_ptr<TTransport> srcTrans, 
+  TPipedTransport(boost::shared_ptr<TTransport> srcTrans,
                   boost::shared_ptr<TTransport> dstTrans) :
     srcTrans_(srcTrans),
     dstTrans_(dstTrans),
@@ -435,14 +437,14 @@
 
     // default is to to pipe the request when readEnd() is called
     pipeOnRead_ = true;
-    pipeOnWrite_ = false; 
+    pipeOnWrite_ = false;
 
     rBuf_ = (uint8_t*) malloc(sizeof(uint8_t) * rBufSize_);
     wBuf_ = (uint8_t*) malloc(sizeof(uint8_t) * wBufSize_);
   }
-    
-  TPipedTransport(boost::shared_ptr<TTransport> srcTrans, 
-                  boost::shared_ptr<TTransport> dstTrans, 
+
+  TPipedTransport(boost::shared_ptr<TTransport> srcTrans,
+                  boost::shared_ptr<TTransport> dstTrans,
                   uint32_t sz) :
     srcTrans_(srcTrans),
     dstTrans_(dstTrans),
@@ -461,15 +463,15 @@
   bool isOpen() {
     return srcTrans_->isOpen();
   }
-  
-  bool peek() {    
+
+  bool peek() {
     if (rPos_ >= rLen_) {
       // Double the size of the underlying buffer if it is full
       if (rLen_ == rBufSize_) {
         rBufSize_ *=2;
         rBuf_ = (uint8_t *)realloc(rBuf_, sizeof(uint8_t) * rBufSize_);
       }
-    
+
       // try to fill up the buffer
       rLen_ += srcTrans_->read(rBuf_+rPos_, rBufSize_ - rPos_);
     }
@@ -492,7 +494,7 @@
   void setPipeOnWrite(bool pipeVal) {
     pipeOnWrite_ = pipeVal;
   }
-  
+
   uint32_t read(uint8_t* buf, uint32_t len);
 
   void readEnd() {
@@ -522,7 +524,7 @@
 
   boost::shared_ptr<TTransport> getTargetTransport() {
     return dstTrans_;
-  } 
+  }
 
  protected:
   boost::shared_ptr<TTransport> srcTrans_;
@@ -576,7 +578,7 @@
 
 /**
  * TPipedFileTransport. This is just like a TTransport, except that
- * it is a templatized class, so that clients who rely on a specific 
+ * it is a templatized class, so that clients who rely on a specific
  * TTransport can still access the original transport.
  *
  * @author James Wang <jwang@facebook.com>
@@ -622,7 +624,7 @@
 class TPipedFileReaderTransportFactory : public TPipedTransportFactory {
  public:
   TPipedFileReaderTransportFactory() {}
-  TPipedFileReaderTransportFactory(boost::shared_ptr<TTransport> dstTrans) 
+  TPipedFileReaderTransportFactory(boost::shared_ptr<TTransport> dstTrans)
     : TPipedTransportFactory(dstTrans)
   {}
   virtual ~TPipedFileReaderTransportFactory() {}
diff --git a/test/cpp/src/nb-main.cpp b/test/cpp/src/nb-main.cpp
index 9fc1beb..164febb 100644
--- a/test/cpp/src/nb-main.cpp
+++ b/test/cpp/src/nb-main.cpp
@@ -61,7 +61,7 @@
   Server() {}
 
   void count(const char* method) {
-    MutexMonitor m(lock_);
+    Guard m(lock_);
     int ct = counts_[method];
     counts_[method] = ++ct;
   }
@@ -74,7 +74,7 @@
   }
 
   count_map getCount() {
-    MutexMonitor m(lock_);
+    Guard m(lock_);
     return counts_;
   }
 
@@ -354,10 +354,12 @@
     }
 
     shared_ptr<Thread> serverThread;
+    shared_ptr<Thread> serverThread2;
 
     if (serverType == "simple") {
 
       serverThread = threadFactory->newThread(shared_ptr<TServer>(new TNonblockingServer(serviceProcessor, protocolFactory, port)));
+      serverThread2 = threadFactory->newThread(shared_ptr<TServer>(new TNonblockingServer(serviceProcessor, protocolFactory, port+1)));
 
     } else if (serverType == "thread-pool") {
 
@@ -366,17 +368,21 @@
       threadManager->threadFactory(threadFactory);
       threadManager->start();
       serverThread = threadFactory->newThread(shared_ptr<TServer>(new TNonblockingServer(serviceProcessor, protocolFactory, port, threadManager)));
+      serverThread2 = threadFactory->newThread(shared_ptr<TServer>(new TNonblockingServer(serviceProcessor, protocolFactory, port+1, threadManager)));
     }
 
-    cerr << "Starting the server on port " << port << endl;
+    cerr << "Starting the server on port " << port << " and " << (port + 1) << endl;
     serverThread->start();
+    serverThread2->start();
 
     // If we aren't running clients, just wait forever for external clients
 
     if (clientCount == 0) {
       serverThread->join();
+      serverThread2->join();
     }
   }
+  sleep(1);
 
   if (clientCount > 0) {
 
@@ -395,7 +401,7 @@
 
     for (size_t ix = 0; ix < clientCount; ix++) {
 
-      shared_ptr<TSocket> socket(new TSocket("127.0.0.1", port));
+      shared_ptr<TSocket> socket(new TSocket("127.0.0.1", port + (ix % 2)));
       shared_ptr<TFramedTransport> framedSocket(new TFramedTransport(socket));
       shared_ptr<TProtocol> protocol(new TBinaryProtocol(framedSocket));
       shared_ptr<ServiceClient> serviceClient(new ServiceClient(protocol));