Implement proper TThreadedServer shutdown


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665049 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/src/server/TThreadedServer.cpp b/lib/cpp/src/server/TThreadedServer.cpp
index 079e046..34584bd 100644
--- a/lib/cpp/src/server/TThreadedServer.cpp
+++ b/lib/cpp/src/server/TThreadedServer.cpp
@@ -26,9 +26,11 @@
        
 public:
     
-  Task(shared_ptr<TProcessor> processor,
+  Task(TThreadedServer* server,
+       shared_ptr<TProcessor> processor,
        shared_ptr<TProtocol> input,
        shared_ptr<TProtocol> output) :
+    server_(server),
     processor_(processor),
     input_(input),
     output_(output) {
@@ -52,13 +54,25 @@
     }
     input_->getTransport()->close();
     output_->getTransport()->close();
+    
+    // Remove this task from parent bookkeeping
+    {
+      Synchronized s(server_->tasksMonitor_);
+      server_->tasks_.erase(this);
+      if (server_->tasks_.empty()) {
+        server_->tasksMonitor_.notify();
+      }
+    }
+
   }
 
  private:
+  TThreadedServer* server_;
+  friend class TThreadedServer;
+
   shared_ptr<TProcessor> processor_;
   shared_ptr<TProtocol> input_;
   shared_ptr<TProtocol> output_;
-
 };
 
 
@@ -66,7 +80,8 @@
                                  shared_ptr<TServerTransport> serverTransport,
                                  shared_ptr<TTransportFactory> transportFactory,
                                  shared_ptr<TProtocolFactory> protocolFactory):
-  TServer(processor, serverTransport, transportFactory, protocolFactory) {
+  TServer(processor, serverTransport, transportFactory, protocolFactory),
+  stop_(false) {
   threadFactory_ = shared_ptr<PosixThreadFactory>(new PosixThreadFactory());
 }
 
@@ -88,47 +103,86 @@
     return;
   }
 
-  while (true) {   
+  while (!stop_) {   
     try {
+      client.reset();
+      inputTransport.reset();
+      outputTransport.reset();
+      inputProtocol.reset();
+      outputProtocol.reset();
+
       // Fetch client from server
       client = serverTransport_->accept();
+
       // Make IO transports
       inputTransport = inputTransportFactory_->getTransport(client);
       outputTransport = outputTransportFactory_->getTransport(client);
       inputProtocol = inputProtocolFactory_->getProtocol(inputTransport);
       outputProtocol = outputProtocolFactory_->getProtocol(outputTransport);
 
-      TThreadedServer::Task* t = new TThreadedServer::Task(processor_, 
-                                                           inputProtocol,
-                                                           outputProtocol);
+      TThreadedServer::Task* task = new TThreadedServer::Task(this,
+                                                              processor_, 
+                                                              inputProtocol,
+                                                              outputProtocol);
+        
+      // Create a task
+      shared_ptr<Runnable> runnable =
+        shared_ptr<Runnable>(task);
 
       // Create a thread for this task
       shared_ptr<Thread> thread =
-        shared_ptr<Thread>(threadFactory_->newThread(shared_ptr<Runnable>(t)));
+        shared_ptr<Thread>(threadFactory_->newThread(runnable));
       
+      // Insert thread into the set of threads
+      {
+        Synchronized s(tasksMonitor_);
+        tasks_.insert(task);
+      }
+
       // Start the thread!
       thread->start();
 
     } catch (TTransportException& ttx) {
-      inputTransport->close();
-      outputTransport->close();
-      client->close();
-      cerr << "TThreadedServer: TServerTransport died on accept: " << ttx.what() << endl;
+      if (inputTransport != NULL) { inputTransport->close(); }
+      if (outputTransport != NULL) { outputTransport->close(); }
+      if (client != NULL) { client->close(); }
+      if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) {
+        cerr << "TThreadedServer: TServerTransport died on accept: " << ttx.what() << endl;
+      }
       continue;
     } catch (TException& tx) {
-      inputTransport->close();
-      outputTransport->close();
-      client->close();
+      if (inputTransport != NULL) { inputTransport->close(); }
+      if (outputTransport != NULL) { outputTransport->close(); }
+      if (client != NULL) { client->close(); }
       cerr << "TThreadedServer: Caught TException: " << tx.what() << endl;
       continue;
     } catch (string s) {
-      inputTransport->close();
-      outputTransport->close();
-      client->close();
+      if (inputTransport != NULL) { inputTransport->close(); }
+      if (outputTransport != NULL) { outputTransport->close(); }
+      if (client != NULL) { client->close(); }
       cerr << "TThreadedServer: Unknown exception: " << s << endl;
       break;
     }
   }
+
+  // If stopped manually, make sure to close server transport
+  if (stop_) {
+    try {
+      serverTransport_->close();
+    } catch (TException &tx) {
+      cerr << "TThreadedServer: Exception shutting down: " << tx.what() << endl;
+    }
+    try {
+      Synchronized s(tasksMonitor_);
+      while (!tasks_.empty()) {
+        tasksMonitor_.wait();
+      }
+    } catch (TException &tx) {
+      cerr << "TThreadedServer: Exception joining workers: " << tx.what() << endl;
+    }
+    stop_ = false;
+  }
+
 }
 
 }}} // facebook::thrift::server
diff --git a/lib/cpp/src/server/TThreadedServer.h b/lib/cpp/src/server/TThreadedServer.h
index c6306b2..4fe775a 100644
--- a/lib/cpp/src/server/TThreadedServer.h
+++ b/lib/cpp/src/server/TThreadedServer.h
@@ -9,6 +9,7 @@
 
 #include <server/TServer.h>
 #include <transport/TServerTransport.h>
+#include <concurrency/Monitor.h>
 #include <concurrency/Thread.h>
 
 #include <boost/shared_ptr.hpp>
@@ -18,6 +19,7 @@
 using facebook::thrift::TProcessor;
 using facebook::thrift::transport::TServerTransport;
 using facebook::thrift::transport::TTransportFactory;
+using facebook::thrift::concurrency::Monitor;
 using facebook::thrift::concurrency::ThreadFactory;
 
 class TThreadedServer : public TServer {
@@ -34,8 +36,17 @@
 
   virtual void serve();
 
+  void stop() {
+    stop_ = true;
+    serverTransport_->interrupt();
+  }
+
  protected:
   boost::shared_ptr<ThreadFactory> threadFactory_;
+  volatile bool stop_;
+
+  Monitor tasksMonitor_;
+  std::set<Task*> tasks_;
 
 };