Create a TServerEventHandler interface in TServer

Summary: Such that users can supply an event handler to a server that will be used to signal various events that take place inside the server core.

Reviewed By: dreiss

Test Plan: Rebuilt all servers, work by default

Other Notes: Partially submitted and also reviewed by Dave Simpson at Powerset


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665371 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/src/server/TNonblockingServer.cpp b/lib/cpp/src/server/TNonblockingServer.cpp
index 7735ec2..de32db5 100644
--- a/lib/cpp/src/server/TNonblockingServer.cpp
+++ b/lib/cpp/src/server/TNonblockingServer.cpp
@@ -675,9 +675,9 @@
   // Initialize libevent core
   registerEvents(static_cast<event_base*>(event_init()));
 
-  // Run pre-serve callback function if we have one
-  if (preServeCallback_) {
-    preServeCallback_(preServeCallbackArg_);
+  // Run the preServe event
+  if (eventHandler_ != NULL) {
+    eventHandler_->preServe();
   }
 
   // Run libevent engine, never returns, invokes calls to eventHandler
diff --git a/lib/cpp/src/server/TNonblockingServer.h b/lib/cpp/src/server/TNonblockingServer.h
index 2cdd897..8483306 100644
--- a/lib/cpp/src/server/TNonblockingServer.h
+++ b/lib/cpp/src/server/TNonblockingServer.h
@@ -69,11 +69,6 @@
    */
   std::stack<TConnection*> connectionStack_;
 
-  // Pointer to optional function called after opening the listen socket and
-  // before running the event loop, along with its argument data.
-  void (*preServeCallback_)(void*);
-  void* preServeCallbackArg_;
-
   void handleEvent(int fd, short which);
 
  public:
@@ -84,9 +79,7 @@
     port_(port),
     frameResponses_(true),
     threadPoolProcessing_(false),
-    eventBase_(NULL),
-    preServeCallback_(NULL),
-    preServeCallbackArg_(NULL) {}
+    eventBase_(NULL) {}
 
   TNonblockingServer(boost::shared_ptr<TProcessor> processor,
                      boost::shared_ptr<TProtocolFactory> protocolFactory,
@@ -97,9 +90,7 @@
     port_(port),
     frameResponses_(true),
     threadManager_(threadManager),
-    eventBase_(NULL),
-    preServeCallback_(NULL),
-    preServeCallbackArg_(NULL) {
+    eventBase_(NULL) {
     setInputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory()));
     setOutputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory()));
     setInputProtocolFactory(protocolFactory);
@@ -118,9 +109,7 @@
     serverSocket_(0),
     port_(port),
     frameResponses_(true),
-    threadManager_(threadManager),
-    preServeCallback_(NULL),
-    preServeCallbackArg_(NULL) {
+    threadManager_(threadManager) {
     setInputTransportFactory(inputTransportFactory);
     setOutputTransportFactory(outputTransportFactory);
     setInputProtocolFactory(inputProtocolFactory);
@@ -171,11 +160,6 @@
 
   void serve();
 
-  void setPreServeCallback(void(*fn_ptr)(void*), void* arg = NULL) {
-    preServeCallback_ = fn_ptr;
-    preServeCallbackArg_ = arg;
-  }
-
 };
 
 /**
diff --git a/lib/cpp/src/server/TServer.h b/lib/cpp/src/server/TServer.h
index 5f13255..1b6bd88 100644
--- a/lib/cpp/src/server/TServer.h
+++ b/lib/cpp/src/server/TServer.h
@@ -14,22 +14,62 @@
 
 #include <boost/shared_ptr.hpp>
 
-namespace facebook { namespace thrift { namespace server { 
+namespace facebook { namespace thrift { namespace server {
 
 using facebook::thrift::TProcessor;
 using facebook::thrift::protocol::TBinaryProtocolFactory;
+using facebook::thrift::protocol::TProtocol;
 using facebook::thrift::protocol::TProtocolFactory;
 using facebook::thrift::transport::TServerTransport;
 using facebook::thrift::transport::TTransport;
 using facebook::thrift::transport::TTransportFactory;
 
 /**
+ * Virtual interface class that can handle events from the server core. To
+ * use this you should subclass it and implement the methods that you care
+ * about. Your subclass can also store local data that you may care about,
+ * such as additional "arguments" to these methods (stored in the object
+ * instance's state).
+ */
+class TServerEventHandler {
+ public:
+
+  virtual ~TServerEventHandler() {}
+
+  /**
+   * Called before the server begins.
+   */
+  virtual void preServe() {}
+
+  /**
+   * Called when a new client has connected and is about to being processing.
+   */
+  virtual void clientBegin(boost::shared_ptr<TProtocol> input,
+                           boost::shared_ptr<TProtocol> output) {}
+
+  /**
+   * Called when a client has finished making requests.
+   */
+  virtual void clientEnd(boost::shared_ptr<TProtocol> input,
+                         boost::shared_ptr<TProtocol> output) {}
+
+ protected:
+
+  /**
+   * Prevent direct instantiation.
+   */
+  TServerEventHandler() {}
+
+};
+
+/**
  * Thrift server.
  *
  * @author Mark Slee <mcslee@facebook.com>
  */
 class TServer : public concurrency::Runnable {
-public:
+ public:
+
   virtual ~TServer() {}
 
   virtual void serve() = 0;
@@ -40,7 +80,7 @@
   virtual void run() {
     serve();
   }
-  
+
   boost::shared_ptr<TProcessor> getProcessor() {
     return processor_;
   }
@@ -56,7 +96,7 @@
   boost::shared_ptr<TTransportFactory> getOutputTransportFactory() {
     return outputTransportFactory_;
   }
-  
+
   boost::shared_ptr<TProtocolFactory> getInputProtocolFactory() {
     return inputProtocolFactory_;
   }
@@ -65,6 +105,10 @@
     return outputProtocolFactory_;
   }
 
+  boost::shared_ptr<TServerEventHandler> getEventHandler() {
+    return eventHandler_;
+  }
+
 protected:
   TServer(boost::shared_ptr<TProcessor> processor):
     processor_(processor) {
@@ -108,7 +152,7 @@
     inputProtocolFactory_(inputProtocolFactory),
     outputProtocolFactory_(outputProtocolFactory) {}
 
- 
+
   // Class variables
   boost::shared_ptr<TProcessor> processor_;
   boost::shared_ptr<TServerTransport> serverTransport_;
@@ -119,6 +163,8 @@
   boost::shared_ptr<TProtocolFactory> inputProtocolFactory_;
   boost::shared_ptr<TProtocolFactory> outputProtocolFactory_;
 
+  boost::shared_ptr<TServerEventHandler> eventHandler_;
+
   void setInputTransportFactory(boost::shared_ptr<TTransportFactory> inputTransportFactory) {
     inputTransportFactory_ = inputTransportFactory;
   }
@@ -135,8 +181,12 @@
     outputProtocolFactory_ = outputProtocolFactory;
   }
 
+  void setServerEventHandler(boost::shared_ptr<TServerEventHandler> eventHandler) {
+    eventHandler_ = eventHandler;
+  }
+
 };
-  
+
 }}} // facebook::thrift::server
 
 #endif // #ifndef _THRIFT_SERVER_TSERVER_H_
diff --git a/lib/cpp/src/server/TSimpleServer.cpp b/lib/cpp/src/server/TSimpleServer.cpp
index d5d3797..f3011a2 100644
--- a/lib/cpp/src/server/TSimpleServer.cpp
+++ b/lib/cpp/src/server/TSimpleServer.cpp
@@ -9,7 +9,7 @@
 #include <string>
 #include <iostream>
 
-namespace facebook { namespace thrift { namespace server { 
+namespace facebook { namespace thrift { namespace server {
 
 using namespace std;
 using namespace facebook::thrift;
@@ -38,6 +38,11 @@
     return;
   }
 
+  // Run the preServe event
+  if (eventHandler_ != NULL) {
+    eventHandler_->preServe();
+  }
+
   // Fetch client from server
   while (!stop_) {
     try {
@@ -46,6 +51,9 @@
       outputTransport = outputTransportFactory_->getTransport(client);
       inputProtocol = inputProtocolFactory_->getProtocol(inputTransport);
       outputProtocol = outputProtocolFactory_->getProtocol(outputTransport);
+      if (eventHandler_ != NULL) {
+        eventHandler_->clientBegin(inputProtocol, outputProtocol);
+      }
       try {
         while (processor_->process(inputProtocol, outputProtocol)) {
           // Peek ahead, is the remote side closed?
@@ -58,9 +66,12 @@
       } catch (TException& tx) {
         cerr << "TSimpleServer exception: " << tx.what() << endl;
       }
+      if (eventHandler_ != NULL) {
+        eventHandler_->clientEnd(inputProtocol, outputProtocol);
+      }
       inputTransport->close();
       outputTransport->close();
-      client->close();    
+      client->close();
     } catch (TTransportException& ttx) {
       if (inputTransport != NULL) { inputTransport->close(); }
       if (outputTransport != NULL) { outputTransport->close(); }
diff --git a/lib/cpp/src/server/TSimpleServer.h b/lib/cpp/src/server/TSimpleServer.h
index b02f106..1ab6f07 100644
--- a/lib/cpp/src/server/TSimpleServer.h
+++ b/lib/cpp/src/server/TSimpleServer.h
@@ -10,7 +10,7 @@
 #include "server/TServer.h"
 #include "transport/TServerTransport.h"
 
-namespace facebook { namespace thrift { namespace server { 
+namespace facebook { namespace thrift { namespace server {
 
 /**
  * This is the most basic simple server. It is single-threaded and runs a
@@ -35,11 +35,11 @@
                 boost::shared_ptr<TTransportFactory> outputTransportFactory,
                 boost::shared_ptr<TProtocolFactory> inputProtocolFactory,
                 boost::shared_ptr<TProtocolFactory> outputProtocolFactory):
-    TServer(processor, serverTransport, 
+    TServer(processor, serverTransport,
             inputTransportFactory, outputTransportFactory,
             inputProtocolFactory, outputProtocolFactory),
     stop_(false) {}
-    
+
   ~TSimpleServer() {}
 
   void serve();
diff --git a/lib/cpp/src/server/TThreadPoolServer.cpp b/lib/cpp/src/server/TThreadPoolServer.cpp
index 26eb373..355bf8e 100644
--- a/lib/cpp/src/server/TThreadPoolServer.cpp
+++ b/lib/cpp/src/server/TThreadPoolServer.cpp
@@ -11,7 +11,7 @@
 #include <string>
 #include <iostream>
 
-namespace facebook { namespace thrift { namespace server { 
+namespace facebook { namespace thrift { namespace server {
 
 using boost::shared_ptr;
 using namespace std;
@@ -20,21 +20,28 @@
 using namespace facebook::thrift::protocol;;
 using namespace facebook::thrift::transport;
 
-class TThreadPoolServer::Task: public Runnable {
-       
+class TThreadPoolServer::Task : public Runnable {
+
 public:
-    
-  Task(shared_ptr<TProcessor> processor,
+
+  Task(TThreadPoolServer &server,
+       shared_ptr<TProcessor> processor,
        shared_ptr<TProtocol> input,
        shared_ptr<TProtocol> output) :
+    server_(server),
     processor_(processor),
     input_(input),
     output_(output) {
   }
 
   ~Task() {}
-    
+
   void run() {
+    boost::shared_ptr<TServerEventHandler> eventHandler =
+      server_.getEventHandler();
+    if (eventHandler != NULL) {
+      eventHandler->clientBegin(input_, output_);
+    }
     try {
       while (processor_->process(input_, output_)) {
         if (!input_->getTransport()->peek()) {
@@ -50,23 +57,27 @@
     } catch (...) {
       cerr << "TThreadPoolServer uncaught exception." << endl;
     }
+    if (eventHandler != NULL) {
+      eventHandler->clientEnd(input_, output_);
+    }
     input_->getTransport()->close();
     output_->getTransport()->close();
   }
 
  private:
+  TServer& server_;
   shared_ptr<TProcessor> processor_;
   shared_ptr<TProtocol> input_;
   shared_ptr<TProtocol> output_;
 
 };
-  
+
 TThreadPoolServer::TThreadPoolServer(shared_ptr<TProcessor> processor,
                                      shared_ptr<TServerTransport> serverTransport,
                                      shared_ptr<TTransportFactory> transportFactory,
                                      shared_ptr<TProtocolFactory> protocolFactory,
                                      shared_ptr<ThreadManager> threadManager) :
-  TServer(processor, serverTransport, transportFactory, protocolFactory), 
+  TServer(processor, serverTransport, transportFactory, protocolFactory),
   threadManager_(threadManager),
   stop_(false), timeout_(0) {}
 
@@ -75,7 +86,7 @@
                                      shared_ptr<TTransportFactory> inputTransportFactory,
                                      shared_ptr<TTransportFactory> outputTransportFactory,
                                      shared_ptr<TProtocolFactory> inputProtocolFactory,
-                                     shared_ptr<TProtocolFactory> outputProtocolFactory, 
+                                     shared_ptr<TProtocolFactory> outputProtocolFactory,
                                      shared_ptr<ThreadManager> threadManager) :
   TServer(processor, serverTransport, inputTransportFactory, outputTransportFactory,
           inputProtocolFactory, outputProtocolFactory),
@@ -99,7 +110,12 @@
     cerr << "TThreadPoolServer::run() listen(): " << ttx.what() << endl;
     return;
   }
-  
+
+  // Run the preServe event
+  if (eventHandler_ != NULL) {
+    eventHandler_->preServe();
+  }
+
   while (!stop_) {
     try {
       client.reset();
@@ -118,7 +134,7 @@
       outputProtocol = outputProtocolFactory_->getProtocol(outputTransport);
 
       // Add to threadmanager pool
-      threadManager_->add(shared_ptr<TThreadPoolServer::Task>(new TThreadPoolServer::Task(processor_, inputProtocol, outputProtocol)), timeout_);
+      threadManager_->add(shared_ptr<TThreadPoolServer::Task>(new TThreadPoolServer::Task(*this, processor_, inputProtocol, outputProtocol)), timeout_);
 
     } catch (TTransportException& ttx) {
       if (inputTransport != NULL) { inputTransport->close(); }
diff --git a/lib/cpp/src/server/TThreadPoolServer.h b/lib/cpp/src/server/TThreadPoolServer.h
index 769db47..f27e8f7 100644
--- a/lib/cpp/src/server/TThreadPoolServer.h
+++ b/lib/cpp/src/server/TThreadPoolServer.h
@@ -13,7 +13,7 @@
 
 #include <boost/shared_ptr.hpp>
 
-namespace facebook { namespace thrift { namespace server { 
+namespace facebook { namespace thrift { namespace server {
 
 using facebook::thrift::concurrency::ThreadManager;
 using facebook::thrift::protocol::TProtocolFactory;
@@ -23,7 +23,7 @@
 class TThreadPoolServer : public TServer {
  public:
   class Task;
-  
+
   TThreadPoolServer(boost::shared_ptr<TProcessor> processor,
                     boost::shared_ptr<TServerTransport> serverTransport,
                     boost::shared_ptr<TTransportFactory> transportFactory,
@@ -35,7 +35,7 @@
                     boost::shared_ptr<TTransportFactory> inputTransportFactory,
                     boost::shared_ptr<TTransportFactory> outputTransportFactory,
                     boost::shared_ptr<TProtocolFactory> inputProtocolFactory,
-                    boost::shared_ptr<TProtocolFactory> outputProtocolFactory, 
+                    boost::shared_ptr<TProtocolFactory> outputProtocolFactory,
                     boost::shared_ptr<ThreadManager> threadManager);
 
   virtual ~TThreadPoolServer();
@@ -45,7 +45,7 @@
   virtual int64_t getTimeout() const;
 
   virtual void setTimeout(int64_t value);
-  
+
   virtual void stop() {
     stop_ = true;
     serverTransport_->interrupt();
@@ -58,7 +58,7 @@
   volatile bool stop_;
 
   volatile int64_t timeout_;
-  
+
 };
 
 }}} // facebook::thrift::server
diff --git a/lib/cpp/src/server/TThreadedServer.cpp b/lib/cpp/src/server/TThreadedServer.cpp
index 34584bd..d07e9da 100644
--- a/lib/cpp/src/server/TThreadedServer.cpp
+++ b/lib/cpp/src/server/TThreadedServer.cpp
@@ -13,7 +13,7 @@
 #include <pthread.h>
 #include <unistd.h>
 
-namespace facebook { namespace thrift { namespace server { 
+namespace facebook { namespace thrift { namespace server {
 
 using boost::shared_ptr;
 using namespace std;
@@ -23,10 +23,10 @@
 using namespace facebook::thrift::concurrency;
 
 class TThreadedServer::Task: public Runnable {
-       
+
 public:
-    
-  Task(TThreadedServer* server,
+
+  Task(TThreadedServer& server,
        shared_ptr<TProcessor> processor,
        shared_ptr<TProtocol> input,
        shared_ptr<TProtocol> output) :
@@ -37,8 +37,13 @@
   }
 
   ~Task() {}
-    
+
   void run() {
+    boost::shared_ptr<TServerEventHandler> eventHandler =
+      server_.getEventHandler();
+    if (eventHandler != NULL) {
+      eventHandler->clientBegin(input_, output_);
+    }
     try {
       while (processor_->process(input_, output_)) {
         if (!input_->getTransport()->peek()) {
@@ -52,22 +57,25 @@
     } catch (...) {
       cerr << "TThreadedServer uncaught exception." << endl;
     }
+    if (eventHandler != NULL) {
+      eventHandler->clientEnd(input_, output_);
+    }
     input_->getTransport()->close();
     output_->getTransport()->close();
-    
+
     // Remove this task from parent bookkeeping
     {
-      Synchronized s(server_->tasksMonitor_);
-      server_->tasks_.erase(this);
-      if (server_->tasks_.empty()) {
-        server_->tasksMonitor_.notify();
+      Synchronized s(server_.tasksMonitor_);
+      server_.tasks_.erase(this);
+      if (server_.tasks_.empty()) {
+        server_.tasksMonitor_.notify();
       }
     }
 
   }
 
  private:
-  TThreadedServer* server_;
+  TThreadedServer& server_;
   friend class TThreadedServer;
 
   shared_ptr<TProcessor> processor_;
@@ -103,7 +111,12 @@
     return;
   }
 
-  while (!stop_) {   
+  // Run the preServe event
+  if (eventHandler_ != NULL) {
+    eventHandler_->preServe();
+  }
+
+  while (!stop_) {
     try {
       client.reset();
       inputTransport.reset();
@@ -120,11 +133,11 @@
       inputProtocol = inputProtocolFactory_->getProtocol(inputTransport);
       outputProtocol = outputProtocolFactory_->getProtocol(outputTransport);
 
-      TThreadedServer::Task* task = new TThreadedServer::Task(this,
-                                                              processor_, 
+      TThreadedServer::Task* task = new TThreadedServer::Task(*this,
+                                                              processor_,
                                                               inputProtocol,
                                                               outputProtocol);
-        
+
       // Create a task
       shared_ptr<Runnable> runnable =
         shared_ptr<Runnable>(task);
@@ -132,7 +145,7 @@
       // Create a thread for this task
       shared_ptr<Thread> thread =
         shared_ptr<Thread>(threadFactory_->newThread(runnable));
-      
+
       // Insert thread into the set of threads
       {
         Synchronized s(tasksMonitor_);
diff --git a/lib/cpp/src/server/TThreadedServer.h b/lib/cpp/src/server/TThreadedServer.h
index 4fe775a..28d3549 100644
--- a/lib/cpp/src/server/TThreadedServer.h
+++ b/lib/cpp/src/server/TThreadedServer.h
@@ -14,7 +14,7 @@
 
 #include <boost/shared_ptr.hpp>
 
-namespace facebook { namespace thrift { namespace server { 
+namespace facebook { namespace thrift { namespace server {
 
 using facebook::thrift::TProcessor;
 using facebook::thrift::transport::TServerTransport;
@@ -26,7 +26,7 @@
 
  public:
   class Task;
-  
+
   TThreadedServer(boost::shared_ptr<TProcessor> processor,
                   boost::shared_ptr<TServerTransport> serverTransport,
                   boost::shared_ptr<TTransportFactory> transportFactory,