THRIFT-3247 Generate a C++ thread-safe client
Client: cpp
Patch: Ben Craig <bencraig@apache.org>
diff --git a/compiler/cpp/src/generate/t_cpp_generator.cc b/compiler/cpp/src/generate/t_cpp_generator.cc
index 847f0ba..f591107 100644
--- a/compiler/cpp/src/generate/t_cpp_generator.cc
+++ b/compiler/cpp/src/generate/t_cpp_generator.cc
@@ -1368,7 +1368,7 @@
// Declare stack tmp variables
out << endl
- << indent() << "apache::thrift::protocol::TRecursionTracker tracker(*iprot);" << endl
+ << indent() << "apache::thrift::protocol::TInputRecursionTracker tracker(*iprot);" << endl
<< indent() << "uint32_t xfer = 0;" << endl
<< indent() << "std::string fname;" << endl
<< indent() << "::apache::thrift::protocol::TType ftype;" << endl
@@ -1492,7 +1492,7 @@
out << indent() << "uint32_t xfer = 0;" << endl;
- indent(out) << "apache::thrift::protocol::TRecursionTracker tracker(*oprot);" << endl;
+ indent(out) << "apache::thrift::protocol::TOutputRecursionTracker tracker(*oprot);" << endl;
indent(out) << "xfer += oprot->writeStructBegin(\"" << name << "\");" << endl;
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
@@ -1790,6 +1790,7 @@
if (gen_cob_style_) {
f_header_ << "#include <thrift/async/TAsyncDispatchProcessor.h>" << endl;
}
+ f_header_ << "#include <thrift/async/TConcurrentClientSyncInfo.h>" << endl;
f_header_ << "#include \"" << get_include_prefix(*get_program()) << program_name_ << "_types.h\""
<< endl;
@@ -1845,6 +1846,7 @@
generate_service_processor(tservice, "");
generate_service_multiface(tservice);
generate_service_skeleton(tservice);
+ generate_service_client(tservice, "Concurrent");
// Generate all the cob components
if (gen_cob_style_) {
@@ -2326,6 +2328,13 @@
}
// Generate the header portion
+ if(style == "Concurrent")
+ {
+ f_header_ <<
+ "// The \'concurrent\' client is a thread safe client that correctly handles\n"
+ "// out of order responses. It is slower than the regular client, so should\n"
+ "// only be used when you need to share a connection among multiple threads\n";
+ }
f_header_ << template_header << "class " << service_name_ << style << "Client" << short_suffix
<< " : "
<< "virtual public " << service_name_ << ifstyle << if_suffix << extends_client << " {"
@@ -2438,16 +2447,38 @@
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
indent(f_header_) << function_signature(*f_iter, ifstyle) << ";" << endl;
// TODO(dreiss): Use private inheritance to avoid generating thise in cob-style.
- t_function send_function(g_type_void,
- string("send_") + (*f_iter)->get_name(),
- (*f_iter)->get_arglist());
- indent(f_header_) << function_signature(&send_function, "") << ";" << endl;
+ if(style == "Concurrent" && !(*f_iter)->is_oneway()) {
+ // concurrent clients need to move the seqid from the send function to the
+ // recv function. Oneway methods don't have a recv function, so we don't need to
+ // move the seqid for them. Attempting to do so would result in a seqid leak.
+ t_function send_function(g_type_i32, /*returning seqid*/
+ string("send_") + (*f_iter)->get_name(),
+ (*f_iter)->get_arglist());
+ indent(f_header_) << function_signature(&send_function, "") << ";" << endl;
+ }
+ else {
+ t_function send_function(g_type_void,
+ string("send_") + (*f_iter)->get_name(),
+ (*f_iter)->get_arglist());
+ indent(f_header_) << function_signature(&send_function, "") << ";" << endl;
+ }
if (!(*f_iter)->is_oneway()) {
- t_struct noargs(program_);
- t_function recv_function((*f_iter)->get_returntype(),
- string("recv_") + (*f_iter)->get_name(),
- &noargs);
- indent(f_header_) << function_signature(&recv_function, "") << ";" << endl;
+ if(style == "Concurrent") {
+ t_field seqIdArg(g_type_i32, "seqid");
+ t_struct seqIdArgStruct(program_);
+ seqIdArgStruct.append(&seqIdArg);
+ t_function recv_function((*f_iter)->get_returntype(),
+ string("recv_") + (*f_iter)->get_name(),
+ &seqIdArgStruct);
+ indent(f_header_) << function_signature(&recv_function, "") << ";" << endl;
+ }
+ else {
+ t_struct noargs(program_);
+ t_function recv_function((*f_iter)->get_returntype(),
+ string("recv_") + (*f_iter)->get_name(),
+ &noargs);
+ indent(f_header_) << function_signature(&recv_function, "") << ";" << endl;
+ }
}
}
indent_down();
@@ -2465,10 +2496,16 @@
<< "boost::shared_ptr< ::apache::thrift::transport::TMemoryBuffer> otrans_;"
<< endl;
}
- f_header_ << indent() << prot_ptr << " piprot_;" << endl << indent() << prot_ptr << " poprot_;"
- << endl << indent() << protocol_type << "* iprot_;" << endl << indent()
- << protocol_type << "* oprot_;" << endl;
+ f_header_ <<
+ indent() << prot_ptr << " piprot_;" << endl <<
+ indent() << prot_ptr << " poprot_;" << endl <<
+ indent() << protocol_type << "* iprot_;" << endl <<
+ indent() << protocol_type << "* oprot_;" << endl;
+ if (style == "Concurrent") {
+ f_header_ <<
+ indent() << "::apache::thrift::async::TConcurrentClientSyncInfo sync_;"<<endl;
+ }
indent_down();
}
@@ -2486,6 +2523,15 @@
// Generate client method implementations
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+ string seqIdCapture;
+ string seqIdUse;
+ string seqIdCommaUse;
+ if (style == "Concurrent" && !(*f_iter)->is_oneway()) {
+ seqIdCapture = "int32_t seqid = ";
+ seqIdUse = "seqid";
+ seqIdCommaUse = ", seqid";
+ }
+
string funname = (*f_iter)->get_name();
// Open function
@@ -2494,7 +2540,7 @@
}
indent(out) << function_signature(*f_iter, ifstyle, scope) << endl;
scope_up(out);
- indent(out) << "send_" << funname << "(";
+ indent(out) << seqIdCapture << "send_" << funname << "(";
// Get the struct of function call params
t_struct* arg_struct = (*f_iter)->get_arglist();
@@ -2518,12 +2564,12 @@
out << indent();
if (!(*f_iter)->get_returntype()->is_void()) {
if (is_complex_type((*f_iter)->get_returntype())) {
- out << "recv_" << funname << "(_return);" << endl;
+ out << "recv_" << funname << "(_return" << seqIdCommaUse << ");" << endl;
} else {
- out << "return recv_" << funname << "();" << endl;
+ out << "return recv_" << funname << "(" << seqIdUse << ");" << endl;
}
} else {
- out << "recv_" << funname << "();" << endl;
+ out << "recv_" << funname << "(" << seqIdUse << ");" << endl;
}
}
} else {
@@ -2541,8 +2587,12 @@
// if (style != "Cob") // TODO(dreiss): Libify the client and don't generate this for cob-style
if (true) {
+ t_type *send_func_return_type = g_type_void;
+ if (style == "Concurrent" && !(*f_iter)->is_oneway()) {
+ send_func_return_type = g_type_i32;
+ }
// Function for sending
- t_function send_function(g_type_void,
+ t_function send_function(send_func_return_type,
string("send_") + (*f_iter)->get_name(),
(*f_iter)->get_arglist());
@@ -2557,11 +2607,25 @@
string argsname = tservice->get_name() + "_" + (*f_iter)->get_name() + "_pargs";
string resultname = tservice->get_name() + "_" + (*f_iter)->get_name() + "_presult";
+ string cseqidVal = "0";
+ if(style == "Concurrent") {
+ if (!(*f_iter)->is_oneway()) {
+ cseqidVal = "this->sync_.generateSeqId()";
+ }
+ }
// Serialize the request
- out << indent() << "int32_t cseqid = 0;" << endl << indent() << _this
- << "oprot_->writeMessageBegin(\"" << (*f_iter)->get_name()
- << "\", ::apache::thrift::protocol::" << ((*f_iter)->is_oneway() ? "T_ONEWAY" : "T_CALL")
- << ", cseqid);" << endl << endl << indent() << argsname << " args;" << endl;
+ out <<
+ indent() << "int32_t cseqid = " << cseqidVal << ";" << endl;
+ if(style == "Concurrent") {
+ out <<
+ indent() << "::apache::thrift::async::TConcurrentSendSentry sentry(&this->sync_);" << endl;
+ }
+ out <<
+ indent() << _this << "oprot_->writeMessageBegin(\"" <<
+ (*f_iter)->get_name() <<
+ "\", ::apache::thrift::protocol::" << ((*f_iter)->is_oneway() ? "T_ONEWAY" : "T_CALL") <<
+ ", cseqid);" << endl << endl <<
+ indent() << argsname << " args;" << endl;
for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
out << indent() << "args." << (*fld_iter)->get_name() << " = &" << (*fld_iter)->get_name()
@@ -2573,15 +2637,35 @@
<< "oprot_->getTransport()->writeEnd();" << endl << indent() << _this
<< "oprot_->getTransport()->flush();" << endl;
+ if (style == "Concurrent") {
+ out <<
+ endl <<
+ indent() << "sentry.commit();" << endl;
+
+ if(!(*f_iter)->is_oneway()) {
+ out <<
+ indent() << "return cseqid;" << endl;
+ }
+ }
scope_down(out);
out << endl;
// Generate recv function only if not an oneway function
if (!(*f_iter)->is_oneway()) {
t_struct noargs(program_);
+
+ t_field seqIdArg(g_type_i32, "seqid");
+ t_struct seqIdArgStruct(program_);
+ seqIdArgStruct.append(&seqIdArg);
+
+ t_struct *recv_function_args = &noargs;
+ if(style == "Concurrent") {
+ recv_function_args = &seqIdArgStruct;
+ }
+
t_function recv_function((*f_iter)->get_returntype(),
string("recv_") + (*f_iter)->get_name(),
- &noargs);
+ recv_function_args);
// Open the recv function
if (gen_templates_) {
indent(out) << template_header;
@@ -2589,42 +2673,76 @@
indent(out) << function_signature(&recv_function, "", scope) << endl;
scope_up(out);
- out << endl << indent() << "int32_t rseqid = 0;" << endl << indent() << "std::string fname;"
- << endl << indent() << "::apache::thrift::protocol::TMessageType mtype;" << endl;
+ out << endl <<
+ indent() << "int32_t rseqid = 0;" << endl <<
+ indent() << "std::string fname;" << endl <<
+ indent() << "::apache::thrift::protocol::TMessageType mtype;" << endl;
+ if(style == "Concurrent") {
+ out <<
+ endl <<
+ indent() << "// the read mutex gets dropped and reacquired as part of waitForWork()" << endl <<
+ indent() << "// The destructor of this sentry wakes up other clients" << endl <<
+ indent() << "::apache::thrift::async::TConcurrentRecvSentry sentry(&this->sync_, seqid);" << endl;
+ }
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << "bool completed = false;" << endl << endl << indent() << "try {";
indent_up();
}
- out << endl << indent() << _this << "iprot_->readMessageBegin(fname, mtype, rseqid);"
- << endl << indent() << "if (mtype == ::apache::thrift::protocol::T_EXCEPTION) {" << endl
- << indent() << " ::apache::thrift::TApplicationException x;" << endl << indent()
- << " x.read(" << _this << "iprot_);" << endl << indent() << " " << _this
- << "iprot_->readMessageEnd();" << endl << indent() << " " << _this
- << "iprot_->getTransport()->readEnd();" << endl;
+ out << endl;
+ if (style == "Concurrent") {
+ out <<
+ indent() << "while(true) {" << endl <<
+ indent() << " if(!this->sync_.getPending(fname, mtype, rseqid)) {" << endl;
+ indent_up();
+ indent_up();
+ }
+ out <<
+ indent() << _this << "iprot_->readMessageBegin(fname, mtype, rseqid);" << endl;
+ if (style == "Concurrent") {
+ scope_down(out);
+ out << indent() << "if(seqid == rseqid) {" << endl;
+ indent_up();
+ }
+ out <<
+ indent() << "if (mtype == ::apache::thrift::protocol::T_EXCEPTION) {" << endl <<
+ indent() << " ::apache::thrift::TApplicationException x;" << endl <<
+ indent() << " x.read(" << _this << "iprot_);" << endl <<
+ indent() << " " << _this << "iprot_->readMessageEnd();" << endl <<
+ indent() << " " << _this << "iprot_->getTransport()->readEnd();" << endl;
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << " completed = true;" << endl << indent() << " completed__(true);"
<< endl;
}
- out << indent() << " throw x;" << endl << indent() << "}" << endl << indent()
- << "if (mtype != ::apache::thrift::protocol::T_REPLY) {" << endl << indent() << " "
- << _this << "iprot_->skip("
- << "::apache::thrift::protocol::T_STRUCT);" << endl << indent() << " " << _this
- << "iprot_->readMessageEnd();" << endl << indent() << " " << _this
- << "iprot_->getTransport()->readEnd();" << endl;
+ if (style == "Concurrent") {
+ out << indent() << " sentry.commit();" << endl;
+ }
+ out <<
+ indent() << " throw x;" << endl <<
+ indent() << "}" << endl <<
+ indent() << "if (mtype != ::apache::thrift::protocol::T_REPLY) {" << endl <<
+ indent() << " " << _this << "iprot_->skip(" << "::apache::thrift::protocol::T_STRUCT);" << endl <<
+ indent() << " " << _this << "iprot_->readMessageEnd();" << endl <<
+ indent() << " " << _this << "iprot_->getTransport()->readEnd();" << endl;
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << " completed = true;" << endl << indent() << " completed__(false);"
<< endl;
}
- out << indent() << "}" << endl << indent() << "if (fname.compare(\""
- << (*f_iter)->get_name() << "\") != 0) {" << endl << indent() << " " << _this
- << "iprot_->skip("
- << "::apache::thrift::protocol::T_STRUCT);" << endl << indent() << " " << _this
- << "iprot_->readMessageEnd();" << endl << indent() << " " << _this
- << "iprot_->getTransport()->readEnd();" << endl;
+ out <<
+ indent() << "}" << endl <<
+ indent() << "if (fname.compare(\"" << (*f_iter)->get_name() << "\") != 0) {" << endl <<
+ indent() << " " << _this << "iprot_->skip(" << "::apache::thrift::protocol::T_STRUCT);" << endl <<
+ indent() << " " << _this << "iprot_->readMessageEnd();" << endl <<
+ indent() << " " << _this << "iprot_->getTransport()->readEnd();" << endl;
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << " completed = true;" << endl << indent() << " completed__(false);"
<< endl;
}
+ if (style == "Concurrent") {
+ out << endl <<
+ indent() << " // in a bad state, don't commit" << endl <<
+ indent() << " using ::apache::thrift::protocol::TProtocolException;" << endl <<
+ indent() << " throw TProtocolException(TProtocolException::INVALID_DATA);" << endl;
+ }
out << indent() << "}" << endl;
if (!(*f_iter)->get_returntype()->is_void()
@@ -2646,19 +2764,29 @@
// Careful, only look for _result if not a void function
if (!(*f_iter)->get_returntype()->is_void()) {
if (is_complex_type((*f_iter)->get_returntype())) {
- out << indent() << "if (result.__isset.success) {" << endl << indent()
- << " // _return pointer has now been filled" << endl;
+ out <<
+ indent() << "if (result.__isset.success) {" << endl;
+ out <<
+ indent() << " // _return pointer has now been filled" << endl;
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << " completed = true;" << endl << indent() << " completed__(true);"
<< endl;
}
- out << indent() << " return;" << endl << indent() << "}" << endl;
+ if (style == "Concurrent") {
+ out << indent() << " sentry.commit();" << endl;
+ }
+ out <<
+ indent() << " return;" << endl <<
+ indent() << "}" << endl;
} else {
out << indent() << "if (result.__isset.success) {" << endl;
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << " completed = true;" << endl << indent() << " completed__(true);"
<< endl;
}
+ if (style == "Concurrent") {
+ out << indent() << " sentry.commit();" << endl;
+ }
out << indent() << " return _return;" << endl << indent() << "}" << endl;
}
}
@@ -2672,6 +2800,9 @@
out << indent() << " completed = true;" << endl << indent() << " completed__(true);"
<< endl;
}
+ if (style == "Concurrent") {
+ out << indent() << " sentry.commit();" << endl;
+ }
out << indent() << " throw result." << (*x_iter)->get_name() << ";" << endl << indent()
<< "}" << endl;
}
@@ -2682,17 +2813,35 @@
out << indent() << "completed = true;" << endl << indent() << "completed__(true);"
<< endl;
}
+ if (style == "Concurrent") {
+ out << indent() << "sentry.commit();" << endl;
+ }
indent(out) << "return;" << endl;
} else {
if (style == "Cob" && !gen_no_client_completion_) {
out << indent() << "completed = true;" << endl << indent() << "completed__(true);"
<< endl;
}
+ if (style == "Concurrent") {
+ out << indent() << "// in a bad state, don't commit" << endl;
+ }
out << indent() << "throw "
"::apache::thrift::TApplicationException(::apache::thrift::"
"TApplicationException::MISSING_RESULT, \"" << (*f_iter)->get_name()
<< " failed: unknown result\");" << endl;
}
+ if(style == "Concurrent") {
+ indent_down();
+ indent_down();
+ out <<
+ indent() << " }" << endl <<
+ indent() << " // seqid != rseqid" << endl <<
+ indent() << " this->sync_.updatePending(fname, mtype, rseqid);" << endl <<
+ endl <<
+ indent() << " // this will temporarily unlock the readMutex, and let other clients get work done" << endl <<
+ indent() << " this->sync_.waitForWork(seqid);" << endl <<
+ indent() << "} // end while(true)" << endl;
+ }
if (style == "Cob" && !gen_no_client_completion_) {
indent_down();
out << indent() << "} catch (...) {" << endl << indent() << " if (!completed) {" << endl
@@ -3096,7 +3245,7 @@
<< "return processor;" << endl;
indent_down();
- f_out_ << indent() << "}" << endl;
+ f_out_ << indent() << "}" << endl << endl;
}
/**
diff --git a/lib/cpp/CMakeLists.txt b/lib/cpp/CMakeLists.txt
index bab2e84..d6107bc 100755
--- a/lib/cpp/CMakeLists.txt
+++ b/lib/cpp/CMakeLists.txt
@@ -35,6 +35,8 @@
src/thrift/TApplicationException.cpp
src/thrift/TOutput.cpp
src/thrift/async/TAsyncChannel.cpp
+ src/thrift/async/TConcurrentClientSyncInfo.h
+ src/thrift/async/TConcurrentClientSyncInfo.cpp
src/thrift/concurrency/ThreadManager.cpp
src/thrift/concurrency/TimerManager.cpp
src/thrift/concurrency/Util.cpp
diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am
index 0ecbeee..82f2e3a 100755
--- a/lib/cpp/Makefile.am
+++ b/lib/cpp/Makefile.am
@@ -66,6 +66,7 @@
src/thrift/TOutput.cpp \
src/thrift/VirtualProfiling.cpp \
src/thrift/async/TAsyncChannel.cpp \
+ src/thrift/async/TConcurrentClientSyncInfo.cpp \
src/thrift/concurrency/ThreadManager.cpp \
src/thrift/concurrency/TimerManager.cpp \
src/thrift/concurrency/Util.cpp \
diff --git a/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.cpp b/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.cpp
new file mode 100644
index 0000000..c7e27c0
--- /dev/null
+++ b/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.cpp
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <thrift/async/TConcurrentClientSyncInfo.h>
+#include <thrift/TApplicationException.h>
+#include <thrift/transport/TTransportException.h>
+#include <limits>
+
+namespace apache { namespace thrift { namespace async {
+
+using namespace ::apache::thrift::concurrency;
+
+TConcurrentClientSyncInfo::TConcurrentClientSyncInfo() :
+ stop_(false),
+ seqidMutex_(),
+ // test rollover all the time
+ nextseqid_((std::numeric_limits<int32_t>::max)()-10),
+ seqidToMonitorMap_(),
+ freeMonitors_(),
+ writeMutex_(),
+ readMutex_(),
+ recvPending_(false),
+ wakeupSomeone_(false),
+ seqidPending_(0),
+ fnamePending_(),
+ mtypePending_(::apache::thrift::protocol::T_CALL)
+{
+ freeMonitors_.reserve(MONITOR_CACHE_SIZE);
+}
+
+bool TConcurrentClientSyncInfo::getPending(
+ std::string &fname,
+ ::apache::thrift::protocol::TMessageType &mtype,
+ int32_t &rseqid)
+{
+ if(stop_)
+ throwDeadConnection_();
+ wakeupSomeone_ = false;
+ if(recvPending_)
+ {
+ recvPending_ = false;
+ rseqid = seqidPending_;
+ fname = fnamePending_;
+ mtype = mtypePending_;
+ return true;
+ }
+ return false;
+}
+
+void TConcurrentClientSyncInfo::updatePending(
+ const std::string &fname,
+ ::apache::thrift::protocol::TMessageType mtype,
+ int32_t rseqid)
+{
+ recvPending_ = true;
+ seqidPending_ = rseqid;
+ fnamePending_ = fname;
+ mtypePending_ = mtype;
+ MonitorPtr monitor;
+ {
+ Guard seqidGuard(seqidMutex_);
+ MonitorMap::iterator i = seqidToMonitorMap_.find(rseqid);
+ if(i == seqidToMonitorMap_.end())
+ throwBadSeqId_();
+ monitor = i->second;
+ }
+ monitor->notify();
+}
+
+void TConcurrentClientSyncInfo::waitForWork(int32_t seqid)
+{
+ MonitorPtr m;
+ {
+ Guard seqidGuard(seqidMutex_);
+ m = seqidToMonitorMap_[seqid];
+ }
+ while(true)
+ {
+ // be very careful about setting state in this loop that affects waking up. You may exit
+ // this function, attempt to grab some work, and someone else could have beaten you (or not
+ // left) the read mutex, and that will put you right back in this loop, with the mangled
+ // state you left behind.
+ if(stop_)
+ throwDeadConnection_();
+ if(wakeupSomeone_)
+ return;
+ if(recvPending_ && seqidPending_ == seqid)
+ return;
+ m->waitForever();
+ }
+}
+
+void TConcurrentClientSyncInfo::throwBadSeqId_()
+{
+ throw apache::thrift::TApplicationException(
+ TApplicationException::BAD_SEQUENCE_ID,
+ "server sent a bad seqid");
+}
+
+void TConcurrentClientSyncInfo::throwDeadConnection_()
+{
+ throw apache::thrift::transport::TTransportException(
+ apache::thrift::transport::TTransportException::NOT_OPEN,
+ "this client died on another thread, and is now in an unusable state");
+}
+
+void TConcurrentClientSyncInfo::wakeupAnyone_(const Guard &)
+{
+ wakeupSomeone_ = true;
+ if(!seqidToMonitorMap_.empty())
+ {
+ // The monitor map maps integers to monitors. Larger integers are more recent
+ // messages. Since this is ordered, it means that the last element is the most recent.
+ // We are trying to guess which thread will have its message complete next, so we are picking
+ // the most recent. The oldest message is likely to be some polling, long lived message.
+ // If we guess right, the thread we wake up will handle the message that comes in.
+ // If we guess wrong, the thread we wake up will hand off the work to the correct thread,
+ // costing us an extra context switch.
+ seqidToMonitorMap_.rbegin()->second->notify();
+ }
+}
+
+void TConcurrentClientSyncInfo::markBad_(const Guard &)
+{
+ wakeupSomeone_ = true;
+ stop_ = true;
+ for(MonitorMap::iterator i = seqidToMonitorMap_.begin(); i != seqidToMonitorMap_.end(); ++i)
+ i->second->notify();
+}
+
+TConcurrentClientSyncInfo::MonitorPtr
+TConcurrentClientSyncInfo::newMonitor_(const Guard &)
+{
+ if(freeMonitors_.empty())
+ return MonitorPtr(new Monitor(&readMutex_));
+ MonitorPtr retval;
+ //swapping to avoid an atomic operation
+ retval.swap(freeMonitors_.back());
+ freeMonitors_.pop_back();
+ return retval;
+}
+
+void TConcurrentClientSyncInfo::deleteMonitor_(
+ const Guard &,
+ TConcurrentClientSyncInfo::MonitorPtr &m) /*noexcept*/
+{
+ if(freeMonitors_.size() > MONITOR_CACHE_SIZE)
+ {
+ m.reset();
+ return;
+ }
+ //freeMonitors_ was reserved up to MONITOR_CACHE_SIZE in the ctor,
+ //so this shouldn't throw
+ freeMonitors_.push_back(TConcurrentClientSyncInfo::MonitorPtr());
+ //swapping to avoid an atomic operation
+ m.swap(freeMonitors_.back());
+}
+
+int32_t TConcurrentClientSyncInfo::generateSeqId()
+{
+ Guard seqidGuard(seqidMutex_);
+ if(stop_)
+ throwDeadConnection_();
+
+ if(!seqidToMonitorMap_.empty())
+ if(nextseqid_ == seqidToMonitorMap_.begin()->first)
+ throw apache::thrift::TApplicationException(
+ TApplicationException::BAD_SEQUENCE_ID,
+ "about to repeat a seqid");
+ int32_t newSeqId = nextseqid_++;
+ seqidToMonitorMap_[newSeqId] = newMonitor_(seqidGuard);
+ return newSeqId;
+}
+
+TConcurrentRecvSentry::TConcurrentRecvSentry(TConcurrentClientSyncInfo *sync, int32_t seqid) :
+ sync_(*sync),
+ seqid_(seqid),
+ committed_(false)
+{
+ sync_.getReadMutex().lock();
+}
+
+TConcurrentRecvSentry::~TConcurrentRecvSentry()
+{
+ {
+ Guard seqidGuard(sync_.seqidMutex_);
+ sync_.deleteMonitor_(seqidGuard, sync_.seqidToMonitorMap_[seqid_]);
+
+ sync_.seqidToMonitorMap_.erase(seqid_);
+ if(committed_)
+ sync_.wakeupAnyone_(seqidGuard);
+ else
+ sync_.markBad_(seqidGuard);
+ }
+ sync_.getReadMutex().unlock();
+}
+
+void TConcurrentRecvSentry::commit()
+{
+ committed_ = true;
+}
+
+TConcurrentSendSentry::TConcurrentSendSentry(TConcurrentClientSyncInfo *sync) :
+ sync_(*sync),
+ committed_(false)
+{
+ sync_.getWriteMutex().lock();
+}
+
+TConcurrentSendSentry::~TConcurrentSendSentry()
+{
+ if(!committed_)
+ {
+ Guard seqidGuard(sync_.seqidMutex_);
+ sync_.markBad_(seqidGuard);
+ }
+ sync_.getWriteMutex().unlock();
+}
+
+void TConcurrentSendSentry::commit()
+{
+ committed_ = true;
+}
+
+
+}}} // apache::thrift::async
diff --git a/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.h b/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.h
new file mode 100644
index 0000000..8997a23
--- /dev/null
+++ b/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.h
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef _THRIFT_TCONCURRENTCLIENTSYNCINFO_H_
+#define _THRIFT_TCONCURRENTCLIENTSYNCINFO_H_ 1
+
+#include <thrift/protocol/TProtocol.h>
+#include <thrift/concurrency/Mutex.h>
+#include <thrift/concurrency/Monitor.h>
+#include <boost/shared_ptr.hpp>
+#include <vector>
+#include <string>
+#include <map>
+
+namespace apache { namespace thrift { namespace async {
+
+class TConcurrentClientSyncInfo;
+
+class TConcurrentSendSentry
+{
+public:
+ explicit TConcurrentSendSentry(TConcurrentClientSyncInfo *sync);
+ ~TConcurrentSendSentry();
+
+ void commit();
+private:
+ TConcurrentClientSyncInfo &sync_;
+ bool committed_;
+};
+
+class TConcurrentRecvSentry
+{
+public:
+ TConcurrentRecvSentry(TConcurrentClientSyncInfo *sync, int32_t seqid);
+ ~TConcurrentRecvSentry();
+
+ void commit();
+private:
+ TConcurrentClientSyncInfo &sync_;
+ int32_t seqid_;
+ bool committed_;
+};
+
+class TConcurrentClientSyncInfo
+{
+private: //typedefs
+ typedef boost::shared_ptr< ::apache::thrift::concurrency::Monitor> MonitorPtr;
+ typedef std::map<int32_t, MonitorPtr> MonitorMap;
+public:
+ TConcurrentClientSyncInfo();
+
+ int32_t generateSeqId();
+
+ bool getPending(
+ std::string &fname,
+ ::apache::thrift::protocol::TMessageType &mtype,
+ int32_t &rseqid); /* requires readMutex_ */
+
+ void updatePending(
+ const std::string &fname,
+ ::apache::thrift::protocol::TMessageType mtype,
+ int32_t rseqid); /* requires readMutex_ */
+
+ void waitForWork(int32_t seqid); /* requires readMutex_ */
+
+ ::apache::thrift::concurrency::Mutex &getReadMutex() {return readMutex_;}
+ ::apache::thrift::concurrency::Mutex &getWriteMutex() {return writeMutex_;}
+
+private: //constants
+ enum {MONITOR_CACHE_SIZE = 10};
+private: //functions
+ MonitorPtr newMonitor_(
+ const ::apache::thrift::concurrency::Guard &seqidGuard); /* requires seqidMutex_ */
+ void deleteMonitor_(
+ const ::apache::thrift::concurrency::Guard &seqidGuard,
+ MonitorPtr &m); /*noexcept*/ /* requires seqidMutex_ */
+ void wakeupAnyone_(
+ const ::apache::thrift::concurrency::Guard &seqidGuard); /* requires seqidMutex_ */
+ void markBad_(
+ const ::apache::thrift::concurrency::Guard &seqidGuard); /* requires seqidMutex_ */
+ void throwBadSeqId_();
+ void throwDeadConnection_();
+private: //data members
+
+ volatile bool stop_;
+
+ ::apache::thrift::concurrency::Mutex seqidMutex_;
+ // begin seqidMutex_ protected members
+ int32_t nextseqid_;
+ MonitorMap seqidToMonitorMap_;
+ std::vector<MonitorPtr> freeMonitors_;
+ // end seqidMutex_ protected members
+
+ ::apache::thrift::concurrency::Mutex writeMutex_;
+
+ ::apache::thrift::concurrency::Mutex readMutex_;
+ // begin readMutex_ protected members
+ bool recvPending_;
+ bool wakeupSomeone_;
+ int32_t seqidPending_;
+ std::string fnamePending_;
+ ::apache::thrift::protocol::TMessageType mtypePending_;
+ // end readMutex_ protected members
+
+
+ friend class TConcurrentSendSentry;
+ friend class TConcurrentRecvSentry;
+};
+
+}}} // apache::thrift::async
+
+#endif // _THRIFT_TCONCURRENTCLIENTSYNCINFO_H_
diff --git a/lib/cpp/src/thrift/protocol/TProtocol.h b/lib/cpp/src/thrift/protocol/TProtocol.h
index 1aa2122..b44e91a 100644
--- a/lib/cpp/src/thrift/protocol/TProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TProtocol.h
@@ -552,26 +552,36 @@
inline boost::shared_ptr<TTransport> getInputTransport() { return ptrans_; }
inline boost::shared_ptr<TTransport> getOutputTransport() { return ptrans_; }
- void incrementRecursionDepth() {
- if (recursion_limit_ < ++recursion_depth_) {
+ // input and output recursion depth are kept separate so that one protocol
+ // can be used concurrently for both input and output.
+ void incrementInputRecursionDepth() {
+ if (recursion_limit_ < ++input_recursion_depth_) {
throw TProtocolException(TProtocolException::DEPTH_LIMIT);
}
}
+ void decrementInputRecursionDepth() { --input_recursion_depth_; }
- void decrementRecursionDepth() { --recursion_depth_; }
+ void incrementOutputRecursionDepth() {
+ if (recursion_limit_ < ++output_recursion_depth_) {
+ throw TProtocolException(TProtocolException::DEPTH_LIMIT);
+ }
+ }
+ void decrementOutputRecursionDepth() { --output_recursion_depth_; }
+
uint32_t getRecursionLimit() const {return recursion_limit_;}
void setRecurisionLimit(uint32_t depth) {recursion_limit_ = depth;}
protected:
TProtocol(boost::shared_ptr<TTransport> ptrans)
- : ptrans_(ptrans), recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT)
+ : ptrans_(ptrans), input_recursion_depth_(0), output_recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT)
{}
boost::shared_ptr<TTransport> ptrans_;
private:
TProtocol() {}
- uint32_t recursion_depth_;
+ uint32_t input_recursion_depth_;
+ uint32_t output_recursion_depth_;
uint32_t recursion_limit_;
};
@@ -617,13 +627,23 @@
static uint64_t fromWire64(uint64_t x) {return letohll(x);}
};
-struct TRecursionTracker {
+struct TOutputRecursionTracker {
TProtocol &prot_;
- TRecursionTracker(TProtocol &prot) : prot_(prot) {
- prot_.incrementRecursionDepth();
+ TOutputRecursionTracker(TProtocol &prot) : prot_(prot) {
+ prot_.incrementOutputRecursionDepth();
}
- ~TRecursionTracker() {
- prot_.decrementRecursionDepth();
+ ~TOutputRecursionTracker() {
+ prot_.decrementOutputRecursionDepth();
+ }
+};
+
+struct TInputRecursionTracker {
+ TProtocol &prot_;
+ TInputRecursionTracker(TProtocol &prot) : prot_(prot) {
+ prot_.incrementInputRecursionDepth();
+ }
+ ~TInputRecursionTracker() {
+ prot_.decrementInputRecursionDepth();
}
};
@@ -634,7 +654,7 @@
*/
template <class Protocol_>
uint32_t skip(Protocol_& prot, TType type) {
- TRecursionTracker tracker(prot);
+ TInputRecursionTracker tracker(prot);
switch (type) {
case T_BOOL: {
diff --git a/test/cpp/src/StressTest.cpp b/test/cpp/src/StressTest.cpp
index fa468a4..9371bce 100644
--- a/test/cpp/src/StressTest.cpp
+++ b/test/cpp/src/StressTest.cpp
@@ -33,7 +33,6 @@
#include <thrift/TLogging.h>
#include "Service.h"
-
#include <iostream>
#include <set>
#include <stdexcept>
@@ -102,20 +101,26 @@
Mutex lock_;
};
+enum TransportOpenCloseBehavior {
+ OpenAndCloseTransportInThread,
+ DontOpenAndCloseTransportInThread
+};
class ClientThread : public Runnable {
public:
ClientThread(boost::shared_ptr<TTransport> transport,
- boost::shared_ptr<ServiceClient> client,
+ boost::shared_ptr<ServiceIf> client,
Monitor& monitor,
size_t& workerCount,
size_t loopCount,
- TType loopType)
+ TType loopType,
+ TransportOpenCloseBehavior behavior)
: _transport(transport),
_client(client),
_monitor(monitor),
_workerCount(workerCount),
_loopCount(loopCount),
- _loopType(loopType) {}
+ _loopType(loopType),
+ _behavior(behavior) {}
void run() {
@@ -129,8 +134,9 @@
}
_startTime = Util::currentTime();
-
- _transport->open();
+ if(_behavior == OpenAndCloseTransportInThread) {
+ _transport->open();
+ }
switch (_loopType) {
case T_VOID:
@@ -155,7 +161,9 @@
_endTime = Util::currentTime();
- _transport->close();
+ if(_behavior == OpenAndCloseTransportInThread) {
+ _transport->close();
+ }
_done = true;
@@ -217,7 +225,7 @@
}
boost::shared_ptr<TTransport> _transport;
- boost::shared_ptr<ServiceClient> _client;
+ boost::shared_ptr<ServiceIf> _client;
Monitor& _monitor;
size_t& _workerCount;
size_t _loopCount;
@@ -226,6 +234,7 @@
int64_t _endTime;
bool _done;
Monitor _sleep;
+ TransportOpenCloseBehavior _behavior;
};
class TStartObserver : public apache::thrift::server::TServerEventHandler {
@@ -253,6 +262,7 @@
#endif
int port = 9091;
+ string clientType = "regular";
string serverType = "thread-pool";
string protocolType = "binary";
size_t workerCount = 4;
@@ -269,23 +279,23 @@
usage << argv[0] << " [--port=<port number>] [--server] [--server-type=<server-type>] "
"[--protocol-type=<protocol-type>] [--workers=<worker-count>] "
- "[--clients=<client-count>] [--loop=<loop-count>]" << endl
+ "[--clients=<client-count>] [--loop=<loop-count>] "
+ "[--client-type=<client-type>]" << endl
<< "\tclients Number of client threads to create - 0 implies no clients, i.e. "
- "server only. Default is " << clientCount << endl
+ "server only. Default is " << clientCount << endl
<< "\thelp Prints this help text." << endl
<< "\tcall Service method to call. Default is " << callName << endl
- << "\tloop The number of remote thrift calls each client makes. Default is "
- << loopCount << endl << "\tport The port the server and clients should bind to "
- "for thrift network connections. Default is " << port << endl
- << "\tserver Run the Thrift server in this process. Default is " << runServer
- << endl << "\tserver-type Type of server, \"simple\" or \"thread-pool\". Default is "
- << serverType << endl
- << "\tprotocol-type Type of protocol, \"binary\", \"ascii\", or \"xml\". Default is "
- << protocolType << endl
- << "\tlog-request Log all request to ./requestlog.tlog. Default is " << logRequests
- << endl << "\treplay-request Replay requests from log file (./requestlog.tlog) Default is "
- << replayRequests << endl << "\tworkers Number of thread pools workers. Only valid "
- "for thread-pool server type. Default is " << workerCount
+ << "\tloop The number of remote thrift calls each client makes. Default is " << loopCount << endl
+ << "\tport The port the server and clients should bind to "
+ "for thrift network connections. Default is " << port << endl
+ << "\tserver Run the Thrift server in this process. Default is " << runServer << endl
+ << "\tserver-type Type of server, \"simple\" or \"thread-pool\". Default is " << serverType << endl
+ << "\tprotocol-type Type of protocol, \"binary\", \"ascii\", or \"xml\". Default is " << protocolType << endl
+ << "\tlog-request Log all request to ./requestlog.tlog. Default is " << logRequests << endl
+ << "\treplay-request Replay requests from log file (./requestlog.tlog) Default is " << replayRequests << endl
+ << "\tworkers Number of thread pools workers. Only valid "
+ "for thread-pool server type. Default is " << workerCount << endl
+ << "\tclient-type Type of client, \"regular\" or \"concurrent\". Default is " << clientType << endl
<< endl;
map<string, string> args;
@@ -359,7 +369,18 @@
throw invalid_argument("Unknown server type " + serverType);
}
}
+ if (!args["client-type"].empty()) {
+ clientType = args["client-type"];
+ if (clientType == "regular") {
+
+ } else if (clientType == "concurrent") {
+
+ } else {
+
+ throw invalid_argument("Unknown client type " + clientType);
+ }
+ }
if (!args["workers"].empty()) {
workerCount = atoi(args["workers"].c_str());
}
@@ -458,7 +479,7 @@
}
}
- if (clientCount > 0) {
+ if (clientCount > 0) { //FIXME: start here for client type?
Monitor monitor;
@@ -480,15 +501,28 @@
throw invalid_argument("Unknown service call " + callName);
}
- for (size_t ix = 0; ix < clientCount; ix++) {
+ if(clientType == "regular") {
+ for (size_t ix = 0; ix < clientCount; ix++) {
+ boost::shared_ptr<TSocket> socket(new TSocket("127.0.0.1", port));
+ boost::shared_ptr<TBufferedTransport> bufferedSocket(new TBufferedTransport(socket, 2048));
+ boost::shared_ptr<TProtocol> protocol(new TBinaryProtocol(bufferedSocket));
+ boost::shared_ptr<ServiceClient> serviceClient(new ServiceClient(protocol));
+
+ clientThreads.insert(threadFactory->newThread(boost::shared_ptr<ClientThread>(
+ new ClientThread(socket, serviceClient, monitor, threadCount, loopCount, loopType, OpenAndCloseTransportInThread))));
+ }
+ } else if(clientType == "concurrent") {
boost::shared_ptr<TSocket> socket(new TSocket("127.0.0.1", port));
boost::shared_ptr<TBufferedTransport> bufferedSocket(new TBufferedTransport(socket, 2048));
boost::shared_ptr<TProtocol> protocol(new TBinaryProtocol(bufferedSocket));
- boost::shared_ptr<ServiceClient> serviceClient(new ServiceClient(protocol));
-
- clientThreads.insert(threadFactory->newThread(boost::shared_ptr<ClientThread>(
- new ClientThread(socket, serviceClient, monitor, threadCount, loopCount, loopType))));
+ //boost::shared_ptr<ServiceClient> serviceClient(new ServiceClient(protocol));
+ boost::shared_ptr<ServiceConcurrentClient> serviceClient(new ServiceConcurrentClient(protocol));
+ socket->open();
+ for (size_t ix = 0; ix < clientCount; ix++) {
+ clientThreads.insert(threadFactory->newThread(boost::shared_ptr<ClientThread>(
+ new ClientThread(socket, serviceClient, monitor, threadCount, loopCount, loopType, DontOpenAndCloseTransportInThread))));
+ }
}
for (std::set<boost::shared_ptr<Thread> >::const_iterator thread = clientThreads.begin();
@@ -504,7 +538,7 @@
Synchronized s(monitor);
threadCount = clientCount;
- cerr << "Launch " << clientCount << " client threads" << endl;
+ cerr << "Launch " << clientCount << " " << clientType << " client threads" << endl;
time00 = Util::currentTime();