THRIFT-928. cpp: TNonblockingServer: use TSocket and support TClientInfo

Modify TNonblockingServer to use TSocket for I/O and support server
event handlers; this enables TClientInfo to function with a minor change
to the processing loop.

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@1005145 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/src/server/TNonblockingServer.cpp b/lib/cpp/src/server/TNonblockingServer.cpp
index 4245d5e..6297030 100644
--- a/lib/cpp/src/server/TNonblockingServer.cpp
+++ b/lib/cpp/src/server/TNonblockingServer.cpp
@@ -52,12 +52,18 @@
     processor_(processor),
     input_(input),
     output_(output),
-    connection_(connection) {}
+    connection_(connection),
+    serverEventHandler_(connection_->getServerEventHandler()),
+    connectionContext_(connection_->getConnectionContext()) {}
 
   void run() {
     try {
-      while (processor_->process(input_, output_, NULL)) {
-        if (!input_->getTransport()->peek()) {
+      for (;;) {
+        if (serverEventHandler_ != NULL) {
+          serverEventHandler_->processContext(connectionContext_, connection_->getTSocket());
+        }
+        if (!processor_->process(input_, output_, connectionContext_) ||
+            !input_->getTransport()->peek()) {
           break;
         }
       }
@@ -87,10 +93,15 @@
   boost::shared_ptr<TProtocol> input_;
   boost::shared_ptr<TProtocol> output_;
   TConnection* connection_;
+  boost::shared_ptr<TServerEventHandler> serverEventHandler_;
+  void* connectionContext_;
 };
 
-void TConnection::init(int socket, short eventFlags, TNonblockingServer* s) {
-  socket_ = socket;
+void TConnection::init(int socket, short eventFlags, TNonblockingServer* s,
+                       const sockaddr* addr, socklen_t addrLen) {
+  tSocket_->setSocketFD(socket);
+  tSocket_->setCachedAddress(addr, addrLen);
+
   server_ = s;
   appState_ = APP_INIT;
   eventFlags_ = 0;
@@ -115,10 +126,18 @@
   // Create protocol
   inputProtocol_ = s->getInputProtocolFactory()->getProtocol(factoryInputTransport_);
   outputProtocol_ = s->getOutputProtocolFactory()->getProtocol(factoryOutputTransport_);
+
+  // Set up for any server event handler
+  serverEventHandler_ = server_->getEventHandler();
+  if (serverEventHandler_ != NULL) {
+    connectionContext_ = serverEventHandler_->createContext(inputProtocol_, outputProtocol_);
+  } else {
+    connectionContext_ = NULL;
+  }
 }
 
 void TConnection::workSocket() {
-  int flags=0, got=0, left=0, sent=0;
+  int got=0, left=0, sent=0;
   uint32_t fetch = 0;
 
   switch (socketState_) {
@@ -142,10 +161,18 @@
       readBufferSize_ = newSize;
     }
 
-    // Read from the socket
-    fetch = readWant_ - readBufferPos_;
-    got = recv(socket_, readBuffer_ + readBufferPos_, fetch, 0);
+    try {
+      // Read from the socket
+      fetch = readWant_ - readBufferPos_;
+      got = tSocket_->read(readBuffer_ + readBufferPos_, fetch);
+    }
+    catch (TTransportException& te) {
+      GlobalOutput.printf("TConnection::workSocket(): %s", te.what());
+      close();
 
+      return;
+    }
+        
     if (got > 0) {
       // Move along in the buffer
       readBufferPos_ += got;
@@ -158,15 +185,6 @@
         transition();
       }
       return;
-    } else if (got == -1) {
-      // Blocking errors are okay, just move on
-      if (errno == EAGAIN || errno == EWOULDBLOCK) {
-        return;
-      }
-
-      if (errno != ECONNRESET) {
-        GlobalOutput.perror("TConnection::workSocket() recv -1 ", errno);
-      }
     }
 
     // Whenever we get down here it means a remote disconnect
@@ -185,24 +203,12 @@
       return;
     }
 
-    flags = 0;
-    #ifdef MSG_NOSIGNAL
-    // Note the use of MSG_NOSIGNAL to suppress SIGPIPE errors, instead we
-    // check for the EPIPE return condition and close the socket in that case
-    flags |= MSG_NOSIGNAL;
-    #endif // ifdef MSG_NOSIGNAL
-
-    left = writeBufferSize_ - writeBufferPos_;
-    sent = send(socket_, writeBuffer_ + writeBufferPos_, left, flags);
-
-    if (sent <= 0) {
-      // Blocking errors are okay, just move on
-      if (errno == EAGAIN || errno == EWOULDBLOCK) {
-        return;
-      }
-      if (errno != EPIPE) {
-        GlobalOutput.perror("TConnection::workSocket() send -1 ", errno);
-      }
+    try {
+      left = writeBufferSize_ - writeBufferPos_;
+      sent = tSocket_->write_partial(writeBuffer_ + writeBufferPos_, left);
+    }
+    catch (TTransportException& te) {
+      GlobalOutput.printf("TConnection::workSocket(): %s ", te.what());
       close();
       return;
     }
@@ -478,7 +484,8 @@
    * ev structure for multiple monitored descriptors; each descriptor needs
    * its own ev.
    */
-  event_set(&event_, socket_, eventFlags_, TConnection::eventHandler, this);
+  event_set(&event_, tSocket_->getSocketFD(), eventFlags_,
+            TConnection::eventHandler, this);
   event_base_set(server_->getEventBase(), &event_);
 
   // Add the event
@@ -493,14 +500,15 @@
 void TConnection::close() {
   // Delete the registered libevent
   if (event_del(&event_) == -1) {
-    GlobalOutput("TConnection::close() event_del");
+    GlobalOutput.perror("TConnection::close() event_del", errno);
+  }
+
+  if (serverEventHandler_ != NULL) {
+    serverEventHandler_->deleteContext(connectionContext_, inputProtocol_, outputProtocol_);
   }
 
   // Close the socket
-  if (socket_ >= 0) {
-    ::close(socket_);
-  }
-  socket_ = -1;
+  tSocket_->close();
 
   // close any factory produced transports
   factoryInputTransport_->close();
@@ -548,14 +556,16 @@
  * Creates a new connection either by reusing an object off the stack or
  * by allocating a new one entirely
  */
-TConnection* TNonblockingServer::createConnection(int socket, short flags) {
+TConnection* TNonblockingServer::createConnection(int socket, short flags,
+                                                  const sockaddr* addr,
+                                                  socklen_t addrLen) {
   // Check the stack
   if (connectionStack_.empty()) {
-    return new TConnection(socket, flags, this);
+    return new TConnection(socket, flags, this, addr, addrLen);
   } else {
     TConnection* result = connectionStack_.top();
     connectionStack_.pop();
-    result->init(socket, flags, this);
+    result->init(socket, flags, this, addr, addrLen);
     return result;
   }
 }
@@ -583,8 +593,9 @@
 
   // Server socket accepted a new connection
   socklen_t addrLen;
-  struct sockaddr addr;
-  addrLen = sizeof(addr);
+  sockaddr_storage addrStorage;
+  sockaddr* addrp = (sockaddr*)&addrStorage;
+  addrLen = sizeof(addrStorage);
 
   // Going to accept a new client socket
   int clientSocket;
@@ -592,7 +603,7 @@
   // Accept as many new clients as possible, even though libevent signaled only
   // one, this helps us to avoid having to go back into the libevent engine so
   // many times
-  while ((clientSocket = accept(fd, &addr, &addrLen)) != -1) {
+  while ((clientSocket = ::accept(fd, addrp, &addrLen)) != -1) {
     // If we're overloaded, take action here
     if (overloadAction_ != T_OVERLOAD_NO_ACTION && serverOverloaded()) {
       nConnectionsDropped_++;
@@ -619,7 +630,7 @@
 
     // Create a new TConnection for this client socket.
     TConnection* clientConnection =
-      createConnection(clientSocket, EV_READ | EV_PERSIST);
+      createConnection(clientSocket, EV_READ | EV_PERSIST, addrp, addrLen);
 
     // Fail fast if we could not create a TConnection object
     if (clientConnection == NULL) {
@@ -632,7 +643,7 @@
     clientConnection->transition();
 
     // addrLen is written by the accept() call, so needs to be set before the next call.
-    addrLen = sizeof(addr);
+    addrLen = sizeof(addrStorage);
   }
 
   // Done looping accept, now we have to make sure the error is due to
diff --git a/lib/cpp/src/server/TNonblockingServer.h b/lib/cpp/src/server/TNonblockingServer.h
index 2dd5362..ac0e345 100644
--- a/lib/cpp/src/server/TNonblockingServer.h
+++ b/lib/cpp/src/server/TNonblockingServer.h
@@ -23,6 +23,7 @@
 #include <Thrift.h>
 #include <server/TServer.h>
 #include <transport/TBufferTransports.h>
+#include <transport/TSocket.h>
 #include <concurrency/ThreadManager.h>
 #include <climits>
 #include <stack>
@@ -35,6 +36,7 @@
 namespace apache { namespace thrift { namespace server {
 
 using apache::thrift::transport::TMemoryBuffer;
+using apache::thrift::transport::TSocket;
 using apache::thrift::protocol::TProtocol;
 using apache::thrift::concurrency::Runnable;
 using apache::thrift::concurrency::ThreadManager;
@@ -470,9 +472,12 @@
    *
    * @param socket FD of socket associated with this connection.
    * @param flags initial lib_event flags for this connection.
+   * @param addr the sockaddr of the client
+   * @param addrLen the length of addr
    * @return pointer to initialized TConnection object.
    */
-  TConnection* createConnection(int socket, short flags);
+  TConnection* createConnection(int socket, short flags,
+                                const sockaddr* addr, socklen_t addrLen);
 
   /**
    * Returns a connection to pool or deletion.  If the connection pool
@@ -576,7 +581,7 @@
  * Represents a connection that is handled via libevent. This connection
  * essentially encapsulates a socket that has some associated libevent state.
  */
-class TConnection {
+  class TConnection {
  private:
 
   /// Starting size for new connection buffer
@@ -585,8 +590,8 @@
   /// Server handle
   TNonblockingServer* server_;
 
-  /// Socket handle
-  int socket_;
+  /// Object wrapping network socket
+  boost::shared_ptr<TSocket> tSocket_;
 
   /// Libevent object
   struct event event_;
@@ -649,6 +654,12 @@
   /// Protocol encoder
   boost::shared_ptr<TProtocol> outputProtocol_;
 
+  /// Server event handler, if any
+  boost::shared_ptr<TServerEventHandler> serverEventHandler_;
+
+  /// Thrift call context, if any
+  void *connectionContext_;
+
   /// Go into read mode
   void setRead() {
     setFlags(EV_READ | EV_PERSIST);
@@ -687,7 +698,8 @@
   class Task;
 
   /// Constructor
-  TConnection(int socket, short eventFlags, TNonblockingServer *s) {
+  TConnection(int socket, short eventFlags, TNonblockingServer *s,
+              const sockaddr* addr, socklen_t addrLen) {
     readBuffer_ = (uint8_t*)std::malloc(STARTING_CONNECTION_BUFFER_SIZE);
     if (readBuffer_ == NULL) {
       throw new apache::thrift::TException("Out of memory.");
@@ -702,8 +714,9 @@
     // reallocated on init() call)
     inputTransport_ = boost::shared_ptr<TMemoryBuffer>(new TMemoryBuffer(readBuffer_, readBufferSize_));
     outputTransport_ = boost::shared_ptr<TMemoryBuffer>(new TMemoryBuffer());
+    tSocket_.reset(new TSocket());
 
-    init(socket, eventFlags, s);
+    init(socket, eventFlags, s, addr, addrLen);
     server_->incrementNumConnections();
   }
 
@@ -720,7 +733,8 @@
   void checkIdleBufferMemLimit(size_t limit);
 
   /// Initialize
-  void init(int socket, short eventFlags, TNonblockingServer *s);
+  void init(int socket, short eventFlags, TNonblockingServer *s,
+            const sockaddr* addr, socklen_t addrLen);
 
   /**
    * This is called when the application transitions from one state into
@@ -738,7 +752,7 @@
    * @param v void* callback arg where we placed TConnection's "this".
    */
   static void eventHandler(int fd, short /* which */, void* v) {
-    assert(fd == ((TConnection*)v)->socket_);
+    assert(fd == ((TConnection*)v)->getTSocket()->getSocketFD());
     ((TConnection*)v)->workSocket();
   }
 
@@ -799,6 +813,22 @@
   TAppState getState() {
     return appState_;
   }
+
+  /// return the TSocket transport wrapping this network connection
+  boost::shared_ptr<TSocket> getTSocket() const {
+    return tSocket_;
+  }
+
+  /// return the server event handler if any
+  boost::shared_ptr<TServerEventHandler> getServerEventHandler() {
+    return serverEventHandler_;
+  }
+
+  /// return the Thrift connection context if any
+  void* getConnectionContext() {
+    return connectionContext_;
+  }
+
 };
 
 }}} // apache::thrift::server
diff --git a/lib/cpp/src/transport/TSocket.cpp b/lib/cpp/src/transport/TSocket.cpp
index ee76c3f..7a48505 100644
--- a/lib/cpp/src/transport/TSocket.cpp
+++ b/lib/cpp/src/transport/TSocket.cpp
@@ -357,6 +357,13 @@
   socket_ = -1;
 }
 
+void TSocket::setSocketFD(int socket) {
+  if (socket_ >= 0) {
+    close();
+  }
+  socket_ = socket;
+}
+
 uint32_t TSocket::read(uint8_t* buf, uint32_t len) {
   if (socket_ < 0) {
     throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket");
@@ -379,7 +386,13 @@
  try_again:
   // Read from the socket
   struct timeval begin;
-  gettimeofday(&begin, NULL);
+  if (recvTimeout_ > 0) {
+    gettimeofday(&begin, NULL);
+  } else {
+    // if there is no read timeout we don't need the TOD to determine whether
+    // an EAGAIN is due to a timeout or an out-of-resource condition.
+    begin.tv_sec = begin.tv_usec = 0;
+  }
   int got = recv(socket_, buf, len, 0);
   int errno_copy = errno; //gettimeofday can change errno
   ++g_socket_syscalls;
@@ -387,6 +400,11 @@
   // Check for error on read
   if (got < 0) {
     if (errno_copy == EAGAIN) {
+      // if no timeout we can assume that resource exhaustion has occurred.
+      if (recvTimeout_ == 0) {
+        throw TTransportException(TTransportException::TIMED_OUT,
+                                    "EAGAIN (unavailable resources)");
+      }
       // check if this is the lack of resources or timeout case
       struct timeval end;
       gettimeofday(&end, NULL);
@@ -417,8 +435,8 @@
     if (errno_copy == ECONNRESET) {
       /* shigin: freebsd doesn't follow POSIX semantic of recv and fails with
        * ECONNRESET if peer performed shutdown 
+       * edhall: eliminated close() since we do that in the destructor.
        */
-      close();
       return 0;
     }
     #endif
@@ -447,7 +465,8 @@
 
   // The remote host has closed the socket
   if (got == 0) {
-    close();
+    // edhall: we used to call close() here, but our caller may want to deal
+    // with the socket fd and we'll close() in our destructor in any case.
     return 0;
   }
 
@@ -456,43 +475,57 @@
 }
 
 void TSocket::write(const uint8_t* buf, uint32_t len) {
+  uint32_t sent = 0;
+
+  while (sent < len) {
+    uint32_t b = write_partial(buf + sent, len - sent);
+    if (b == 0) {
+      // We assume that we got 0 because send() errored with EAGAIN due to
+      // lack of system resources; release the CPU for a bit.
+      usleep(50);
+    }
+    sent += b;
+  }
+}
+
+uint32_t TSocket::write_partial(const uint8_t* buf, uint32_t len) {
   if (socket_ < 0) {
     throw TTransportException(TTransportException::NOT_OPEN, "Called write on non-open socket");
   }
 
   uint32_t sent = 0;
 
-  while (sent < len) {
+  int flags = 0;
+#ifdef MSG_NOSIGNAL
+  // Note the use of MSG_NOSIGNAL to suppress SIGPIPE errors, instead we
+  // check for the EPIPE return condition and close the socket in that case
+  flags |= MSG_NOSIGNAL;
+#endif // ifdef MSG_NOSIGNAL
 
-    int flags = 0;
-    #ifdef MSG_NOSIGNAL
-    // Note the use of MSG_NOSIGNAL to suppress SIGPIPE errors, instead we
-    // check for the EPIPE return condition and close the socket in that case
-    flags |= MSG_NOSIGNAL;
-    #endif // ifdef MSG_NOSIGNAL
+  int b = send(socket_, buf + sent, len - sent, flags);
+  ++g_socket_syscalls;
 
-    int b = send(socket_, buf + sent, len - sent, flags);
-    ++g_socket_syscalls;
-
+  if (b < 0) {
+    if (errno == EWOULDBLOCK || errno == EAGAIN) {
+      return 0;
+    }
     // Fail on a send error
-    if (b < 0) {
-      int errno_copy = errno;
-      GlobalOutput.perror("TSocket::write() send() " + getSocketInfo(), errno_copy);
+    int errno_copy = errno;
+    GlobalOutput.perror("TSocket::write_partial() send() " + getSocketInfo(), errno_copy);
 
-      if (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN) {
-        close();
-        throw TTransportException(TTransportException::NOT_OPEN, "write() send()", errno_copy);
-      }
-
-      throw TTransportException(TTransportException::UNKNOWN, "write() send()", errno_copy);
+    if (errno_copy == EPIPE || errno_copy == ECONNRESET || errno_copy == ENOTCONN) {
+      close();
+      throw TTransportException(TTransportException::NOT_OPEN, "write() send()", errno_copy);
     }
 
-    // Fail on blocked send
-    if (b == 0) {
-      throw TTransportException(TTransportException::NOT_OPEN, "Socket send returned 0.");
-    }
-    sent += b;
+    throw TTransportException(TTransportException::UNKNOWN, "write() send()", errno_copy);
   }
+  
+  // Fail on blocked send
+  if (b == 0) {
+    throw TTransportException(TTransportException::NOT_OPEN, "Socket send returned 0.");
+  }
+  return b;
 }
 
 std::string TSocket::getHost() {
@@ -598,7 +631,12 @@
 
 string TSocket::getSocketInfo() {
   std::ostringstream oss;
-  oss << "<Host: " << host_ << " Port: " << port_ << ">";
+  if (host_.empty() || port_ == 0) {
+    oss << "<Host: " << getPeerAddress();
+    oss << " Port: " << getPeerPort() << ">";
+  } else {
+    oss << "<Host: " << host_ << " Port: " << port_ << ">";
+  }
   return oss.str();
 }
 
diff --git a/lib/cpp/src/transport/TSocket.h b/lib/cpp/src/transport/TSocket.h
index 97562c2..e89059f 100644
--- a/lib/cpp/src/transport/TSocket.h
+++ b/lib/cpp/src/transport/TSocket.h
@@ -95,11 +95,16 @@
   uint32_t read(uint8_t* buf, uint32_t len);
 
   /**
-   * Writes to the underlying socket.
+   * Writes to the underlying socket.  Loops until done or fail.
    */
   void write(const uint8_t* buf, uint32_t len);
 
   /**
+   * Writes to the underlying socket.  Does single send() and returns result.
+   */
+  uint32_t write_partial(const uint8_t* buf, uint32_t len);
+
+  /**
    * Get the host that the socket is connected to
    *
    * @return string host identifier
@@ -191,6 +196,15 @@
     return socket_;
   }
 
+  /**
+   * (Re-)initialize a TSocket for the supplied descriptor.  This is only
+   * intended for use by TNonblockingServer -- other use may result in
+   * unfortunate surprises.
+   *
+   * @param fd the descriptor for an already-connected socket
+   */
+  void setSocketFD(int fd);
+
   /*
    * Returns a cached copy of the peer address.
    */
@@ -211,16 +225,16 @@
    */
   TSocket(int socket);
 
- protected:
-  /** connect, called by open */
-  void openConnection(struct addrinfo *res);
-
   /**
    * Set a cache of the peer address (used when trivially available: e.g.
    * accept() or connect()). Only caches IPV4 and IPV6; unset for others.
    */
   void setCachedAddress(const sockaddr* addr, socklen_t len);
 
+ protected:
+  /** connect, called by open */
+  void openConnection(struct addrinfo *res);
+
   /** Host to connect to */
   std::string host_;