Implement proper TThreadedServer shutdown
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665049 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/src/server/TThreadedServer.cpp b/lib/cpp/src/server/TThreadedServer.cpp
index 079e046..34584bd 100644
--- a/lib/cpp/src/server/TThreadedServer.cpp
+++ b/lib/cpp/src/server/TThreadedServer.cpp
@@ -26,9 +26,11 @@
public:
- Task(shared_ptr<TProcessor> processor,
+ Task(TThreadedServer* server,
+ shared_ptr<TProcessor> processor,
shared_ptr<TProtocol> input,
shared_ptr<TProtocol> output) :
+ server_(server),
processor_(processor),
input_(input),
output_(output) {
@@ -52,13 +54,25 @@
}
input_->getTransport()->close();
output_->getTransport()->close();
+
+ // Remove this task from parent bookkeeping
+ {
+ Synchronized s(server_->tasksMonitor_);
+ server_->tasks_.erase(this);
+ if (server_->tasks_.empty()) {
+ server_->tasksMonitor_.notify();
+ }
+ }
+
}
private:
+ TThreadedServer* server_;
+ friend class TThreadedServer;
+
shared_ptr<TProcessor> processor_;
shared_ptr<TProtocol> input_;
shared_ptr<TProtocol> output_;
-
};
@@ -66,7 +80,8 @@
shared_ptr<TServerTransport> serverTransport,
shared_ptr<TTransportFactory> transportFactory,
shared_ptr<TProtocolFactory> protocolFactory):
- TServer(processor, serverTransport, transportFactory, protocolFactory) {
+ TServer(processor, serverTransport, transportFactory, protocolFactory),
+ stop_(false) {
threadFactory_ = shared_ptr<PosixThreadFactory>(new PosixThreadFactory());
}
@@ -88,47 +103,86 @@
return;
}
- while (true) {
+ while (!stop_) {
try {
+ client.reset();
+ inputTransport.reset();
+ outputTransport.reset();
+ inputProtocol.reset();
+ outputProtocol.reset();
+
// Fetch client from server
client = serverTransport_->accept();
+
// Make IO transports
inputTransport = inputTransportFactory_->getTransport(client);
outputTransport = outputTransportFactory_->getTransport(client);
inputProtocol = inputProtocolFactory_->getProtocol(inputTransport);
outputProtocol = outputProtocolFactory_->getProtocol(outputTransport);
- TThreadedServer::Task* t = new TThreadedServer::Task(processor_,
- inputProtocol,
- outputProtocol);
+ TThreadedServer::Task* task = new TThreadedServer::Task(this,
+ processor_,
+ inputProtocol,
+ outputProtocol);
+
+ // Create a task
+ shared_ptr<Runnable> runnable =
+ shared_ptr<Runnable>(task);
// Create a thread for this task
shared_ptr<Thread> thread =
- shared_ptr<Thread>(threadFactory_->newThread(shared_ptr<Runnable>(t)));
+ shared_ptr<Thread>(threadFactory_->newThread(runnable));
+ // Insert thread into the set of threads
+ {
+ Synchronized s(tasksMonitor_);
+ tasks_.insert(task);
+ }
+
// Start the thread!
thread->start();
} catch (TTransportException& ttx) {
- inputTransport->close();
- outputTransport->close();
- client->close();
- cerr << "TThreadedServer: TServerTransport died on accept: " << ttx.what() << endl;
+ if (inputTransport != NULL) { inputTransport->close(); }
+ if (outputTransport != NULL) { outputTransport->close(); }
+ if (client != NULL) { client->close(); }
+ if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) {
+ cerr << "TThreadedServer: TServerTransport died on accept: " << ttx.what() << endl;
+ }
continue;
} catch (TException& tx) {
- inputTransport->close();
- outputTransport->close();
- client->close();
+ if (inputTransport != NULL) { inputTransport->close(); }
+ if (outputTransport != NULL) { outputTransport->close(); }
+ if (client != NULL) { client->close(); }
cerr << "TThreadedServer: Caught TException: " << tx.what() << endl;
continue;
} catch (string s) {
- inputTransport->close();
- outputTransport->close();
- client->close();
+ if (inputTransport != NULL) { inputTransport->close(); }
+ if (outputTransport != NULL) { outputTransport->close(); }
+ if (client != NULL) { client->close(); }
cerr << "TThreadedServer: Unknown exception: " << s << endl;
break;
}
}
+
+ // If stopped manually, make sure to close server transport
+ if (stop_) {
+ try {
+ serverTransport_->close();
+ } catch (TException &tx) {
+ cerr << "TThreadedServer: Exception shutting down: " << tx.what() << endl;
+ }
+ try {
+ Synchronized s(tasksMonitor_);
+ while (!tasks_.empty()) {
+ tasksMonitor_.wait();
+ }
+ } catch (TException &tx) {
+ cerr << "TThreadedServer: Exception joining workers: " << tx.what() << endl;
+ }
+ stop_ = false;
+ }
+
}
}}} // facebook::thrift::server
diff --git a/lib/cpp/src/server/TThreadedServer.h b/lib/cpp/src/server/TThreadedServer.h
index c6306b2..4fe775a 100644
--- a/lib/cpp/src/server/TThreadedServer.h
+++ b/lib/cpp/src/server/TThreadedServer.h
@@ -9,6 +9,7 @@
#include <server/TServer.h>
#include <transport/TServerTransport.h>
+#include <concurrency/Monitor.h>
#include <concurrency/Thread.h>
#include <boost/shared_ptr.hpp>
@@ -18,6 +19,7 @@
using facebook::thrift::TProcessor;
using facebook::thrift::transport::TServerTransport;
using facebook::thrift::transport::TTransportFactory;
+using facebook::thrift::concurrency::Monitor;
using facebook::thrift::concurrency::ThreadFactory;
class TThreadedServer : public TServer {
@@ -34,8 +36,17 @@
virtual void serve();
+ void stop() {
+ stop_ = true;
+ serverTransport_->interrupt();
+ }
+
protected:
boost::shared_ptr<ThreadFactory> threadFactory_;
+ volatile bool stop_;
+
+ Monitor tasksMonitor_;
+ std::set<Task*> tasks_;
};