THRIFT-3247 Generate a C++ thread-safe client
Client: cpp
Patch: Ben Craig <bencraig@apache.org>
diff --git a/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.cpp b/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.cpp
new file mode 100644
index 0000000..c7e27c0
--- /dev/null
+++ b/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.cpp
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <thrift/async/TConcurrentClientSyncInfo.h>
+#include <thrift/TApplicationException.h>
+#include <thrift/transport/TTransportException.h>
+#include <limits>
+
+namespace apache { namespace thrift { namespace async {
+
+using namespace ::apache::thrift::concurrency;
+
+TConcurrentClientSyncInfo::TConcurrentClientSyncInfo() :
+ stop_(false),
+ seqidMutex_(),
+ // test rollover all the time
+ nextseqid_((std::numeric_limits<int32_t>::max)()-10),
+ seqidToMonitorMap_(),
+ freeMonitors_(),
+ writeMutex_(),
+ readMutex_(),
+ recvPending_(false),
+ wakeupSomeone_(false),
+ seqidPending_(0),
+ fnamePending_(),
+ mtypePending_(::apache::thrift::protocol::T_CALL)
+{
+ freeMonitors_.reserve(MONITOR_CACHE_SIZE);
+}
+
+bool TConcurrentClientSyncInfo::getPending(
+ std::string &fname,
+ ::apache::thrift::protocol::TMessageType &mtype,
+ int32_t &rseqid)
+{
+ if(stop_)
+ throwDeadConnection_();
+ wakeupSomeone_ = false;
+ if(recvPending_)
+ {
+ recvPending_ = false;
+ rseqid = seqidPending_;
+ fname = fnamePending_;
+ mtype = mtypePending_;
+ return true;
+ }
+ return false;
+}
+
+void TConcurrentClientSyncInfo::updatePending(
+ const std::string &fname,
+ ::apache::thrift::protocol::TMessageType mtype,
+ int32_t rseqid)
+{
+ recvPending_ = true;
+ seqidPending_ = rseqid;
+ fnamePending_ = fname;
+ mtypePending_ = mtype;
+ MonitorPtr monitor;
+ {
+ Guard seqidGuard(seqidMutex_);
+ MonitorMap::iterator i = seqidToMonitorMap_.find(rseqid);
+ if(i == seqidToMonitorMap_.end())
+ throwBadSeqId_();
+ monitor = i->second;
+ }
+ monitor->notify();
+}
+
+void TConcurrentClientSyncInfo::waitForWork(int32_t seqid)
+{
+ MonitorPtr m;
+ {
+ Guard seqidGuard(seqidMutex_);
+ m = seqidToMonitorMap_[seqid];
+ }
+ while(true)
+ {
+ // be very careful about setting state in this loop that affects waking up. You may exit
+ // this function, attempt to grab some work, and someone else could have beaten you (or not
+ // left) the read mutex, and that will put you right back in this loop, with the mangled
+ // state you left behind.
+ if(stop_)
+ throwDeadConnection_();
+ if(wakeupSomeone_)
+ return;
+ if(recvPending_ && seqidPending_ == seqid)
+ return;
+ m->waitForever();
+ }
+}
+
+void TConcurrentClientSyncInfo::throwBadSeqId_()
+{
+ throw apache::thrift::TApplicationException(
+ TApplicationException::BAD_SEQUENCE_ID,
+ "server sent a bad seqid");
+}
+
+void TConcurrentClientSyncInfo::throwDeadConnection_()
+{
+ throw apache::thrift::transport::TTransportException(
+ apache::thrift::transport::TTransportException::NOT_OPEN,
+ "this client died on another thread, and is now in an unusable state");
+}
+
+void TConcurrentClientSyncInfo::wakeupAnyone_(const Guard &)
+{
+ wakeupSomeone_ = true;
+ if(!seqidToMonitorMap_.empty())
+ {
+ // The monitor map maps integers to monitors. Larger integers are more recent
+ // messages. Since this is ordered, it means that the last element is the most recent.
+ // We are trying to guess which thread will have its message complete next, so we are picking
+ // the most recent. The oldest message is likely to be some polling, long lived message.
+ // If we guess right, the thread we wake up will handle the message that comes in.
+ // If we guess wrong, the thread we wake up will hand off the work to the correct thread,
+ // costing us an extra context switch.
+ seqidToMonitorMap_.rbegin()->second->notify();
+ }
+}
+
+void TConcurrentClientSyncInfo::markBad_(const Guard &)
+{
+ wakeupSomeone_ = true;
+ stop_ = true;
+ for(MonitorMap::iterator i = seqidToMonitorMap_.begin(); i != seqidToMonitorMap_.end(); ++i)
+ i->second->notify();
+}
+
+TConcurrentClientSyncInfo::MonitorPtr
+TConcurrentClientSyncInfo::newMonitor_(const Guard &)
+{
+ if(freeMonitors_.empty())
+ return MonitorPtr(new Monitor(&readMutex_));
+ MonitorPtr retval;
+ //swapping to avoid an atomic operation
+ retval.swap(freeMonitors_.back());
+ freeMonitors_.pop_back();
+ return retval;
+}
+
+void TConcurrentClientSyncInfo::deleteMonitor_(
+ const Guard &,
+ TConcurrentClientSyncInfo::MonitorPtr &m) /*noexcept*/
+{
+ if(freeMonitors_.size() > MONITOR_CACHE_SIZE)
+ {
+ m.reset();
+ return;
+ }
+ //freeMonitors_ was reserved up to MONITOR_CACHE_SIZE in the ctor,
+ //so this shouldn't throw
+ freeMonitors_.push_back(TConcurrentClientSyncInfo::MonitorPtr());
+ //swapping to avoid an atomic operation
+ m.swap(freeMonitors_.back());
+}
+
+int32_t TConcurrentClientSyncInfo::generateSeqId()
+{
+ Guard seqidGuard(seqidMutex_);
+ if(stop_)
+ throwDeadConnection_();
+
+ if(!seqidToMonitorMap_.empty())
+ if(nextseqid_ == seqidToMonitorMap_.begin()->first)
+ throw apache::thrift::TApplicationException(
+ TApplicationException::BAD_SEQUENCE_ID,
+ "about to repeat a seqid");
+ int32_t newSeqId = nextseqid_++;
+ seqidToMonitorMap_[newSeqId] = newMonitor_(seqidGuard);
+ return newSeqId;
+}
+
+TConcurrentRecvSentry::TConcurrentRecvSentry(TConcurrentClientSyncInfo *sync, int32_t seqid) :
+ sync_(*sync),
+ seqid_(seqid),
+ committed_(false)
+{
+ sync_.getReadMutex().lock();
+}
+
+TConcurrentRecvSentry::~TConcurrentRecvSentry()
+{
+ {
+ Guard seqidGuard(sync_.seqidMutex_);
+ sync_.deleteMonitor_(seqidGuard, sync_.seqidToMonitorMap_[seqid_]);
+
+ sync_.seqidToMonitorMap_.erase(seqid_);
+ if(committed_)
+ sync_.wakeupAnyone_(seqidGuard);
+ else
+ sync_.markBad_(seqidGuard);
+ }
+ sync_.getReadMutex().unlock();
+}
+
+void TConcurrentRecvSentry::commit()
+{
+ committed_ = true;
+}
+
+TConcurrentSendSentry::TConcurrentSendSentry(TConcurrentClientSyncInfo *sync) :
+ sync_(*sync),
+ committed_(false)
+{
+ sync_.getWriteMutex().lock();
+}
+
+TConcurrentSendSentry::~TConcurrentSendSentry()
+{
+ if(!committed_)
+ {
+ Guard seqidGuard(sync_.seqidMutex_);
+ sync_.markBad_(seqidGuard);
+ }
+ sync_.getWriteMutex().unlock();
+}
+
+void TConcurrentSendSentry::commit()
+{
+ committed_ = true;
+}
+
+
+}}} // apache::thrift::async
diff --git a/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.h b/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.h
new file mode 100644
index 0000000..8997a23
--- /dev/null
+++ b/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.h
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef _THRIFT_TCONCURRENTCLIENTSYNCINFO_H_
+#define _THRIFT_TCONCURRENTCLIENTSYNCINFO_H_ 1
+
+#include <thrift/protocol/TProtocol.h>
+#include <thrift/concurrency/Mutex.h>
+#include <thrift/concurrency/Monitor.h>
+#include <boost/shared_ptr.hpp>
+#include <vector>
+#include <string>
+#include <map>
+
+namespace apache { namespace thrift { namespace async {
+
+class TConcurrentClientSyncInfo;
+
+class TConcurrentSendSentry
+{
+public:
+ explicit TConcurrentSendSentry(TConcurrentClientSyncInfo *sync);
+ ~TConcurrentSendSentry();
+
+ void commit();
+private:
+ TConcurrentClientSyncInfo &sync_;
+ bool committed_;
+};
+
+class TConcurrentRecvSentry
+{
+public:
+ TConcurrentRecvSentry(TConcurrentClientSyncInfo *sync, int32_t seqid);
+ ~TConcurrentRecvSentry();
+
+ void commit();
+private:
+ TConcurrentClientSyncInfo &sync_;
+ int32_t seqid_;
+ bool committed_;
+};
+
+class TConcurrentClientSyncInfo
+{
+private: //typedefs
+ typedef boost::shared_ptr< ::apache::thrift::concurrency::Monitor> MonitorPtr;
+ typedef std::map<int32_t, MonitorPtr> MonitorMap;
+public:
+ TConcurrentClientSyncInfo();
+
+ int32_t generateSeqId();
+
+ bool getPending(
+ std::string &fname,
+ ::apache::thrift::protocol::TMessageType &mtype,
+ int32_t &rseqid); /* requires readMutex_ */
+
+ void updatePending(
+ const std::string &fname,
+ ::apache::thrift::protocol::TMessageType mtype,
+ int32_t rseqid); /* requires readMutex_ */
+
+ void waitForWork(int32_t seqid); /* requires readMutex_ */
+
+ ::apache::thrift::concurrency::Mutex &getReadMutex() {return readMutex_;}
+ ::apache::thrift::concurrency::Mutex &getWriteMutex() {return writeMutex_;}
+
+private: //constants
+ enum {MONITOR_CACHE_SIZE = 10};
+private: //functions
+ MonitorPtr newMonitor_(
+ const ::apache::thrift::concurrency::Guard &seqidGuard); /* requires seqidMutex_ */
+ void deleteMonitor_(
+ const ::apache::thrift::concurrency::Guard &seqidGuard,
+ MonitorPtr &m); /*noexcept*/ /* requires seqidMutex_ */
+ void wakeupAnyone_(
+ const ::apache::thrift::concurrency::Guard &seqidGuard); /* requires seqidMutex_ */
+ void markBad_(
+ const ::apache::thrift::concurrency::Guard &seqidGuard); /* requires seqidMutex_ */
+ void throwBadSeqId_();
+ void throwDeadConnection_();
+private: //data members
+
+ volatile bool stop_;
+
+ ::apache::thrift::concurrency::Mutex seqidMutex_;
+ // begin seqidMutex_ protected members
+ int32_t nextseqid_;
+ MonitorMap seqidToMonitorMap_;
+ std::vector<MonitorPtr> freeMonitors_;
+ // end seqidMutex_ protected members
+
+ ::apache::thrift::concurrency::Mutex writeMutex_;
+
+ ::apache::thrift::concurrency::Mutex readMutex_;
+ // begin readMutex_ protected members
+ bool recvPending_;
+ bool wakeupSomeone_;
+ int32_t seqidPending_;
+ std::string fnamePending_;
+ ::apache::thrift::protocol::TMessageType mtypePending_;
+ // end readMutex_ protected members
+
+
+ friend class TConcurrentSendSentry;
+ friend class TConcurrentRecvSentry;
+};
+
+}}} // apache::thrift::async
+
+#endif // _THRIFT_TCONCURRENTCLIENTSYNCINFO_H_
diff --git a/lib/cpp/src/thrift/protocol/TProtocol.h b/lib/cpp/src/thrift/protocol/TProtocol.h
index 1aa2122..b44e91a 100644
--- a/lib/cpp/src/thrift/protocol/TProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TProtocol.h
@@ -552,26 +552,36 @@
inline boost::shared_ptr<TTransport> getInputTransport() { return ptrans_; }
inline boost::shared_ptr<TTransport> getOutputTransport() { return ptrans_; }
- void incrementRecursionDepth() {
- if (recursion_limit_ < ++recursion_depth_) {
+ // input and output recursion depth are kept separate so that one protocol
+ // can be used concurrently for both input and output.
+ void incrementInputRecursionDepth() {
+ if (recursion_limit_ < ++input_recursion_depth_) {
throw TProtocolException(TProtocolException::DEPTH_LIMIT);
}
}
+ void decrementInputRecursionDepth() { --input_recursion_depth_; }
- void decrementRecursionDepth() { --recursion_depth_; }
+ void incrementOutputRecursionDepth() {
+ if (recursion_limit_ < ++output_recursion_depth_) {
+ throw TProtocolException(TProtocolException::DEPTH_LIMIT);
+ }
+ }
+ void decrementOutputRecursionDepth() { --output_recursion_depth_; }
+
uint32_t getRecursionLimit() const {return recursion_limit_;}
void setRecurisionLimit(uint32_t depth) {recursion_limit_ = depth;}
protected:
TProtocol(boost::shared_ptr<TTransport> ptrans)
- : ptrans_(ptrans), recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT)
+ : ptrans_(ptrans), input_recursion_depth_(0), output_recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT)
{}
boost::shared_ptr<TTransport> ptrans_;
private:
TProtocol() {}
- uint32_t recursion_depth_;
+ uint32_t input_recursion_depth_;
+ uint32_t output_recursion_depth_;
uint32_t recursion_limit_;
};
@@ -617,13 +627,23 @@
static uint64_t fromWire64(uint64_t x) {return letohll(x);}
};
-struct TRecursionTracker {
+struct TOutputRecursionTracker {
TProtocol &prot_;
- TRecursionTracker(TProtocol &prot) : prot_(prot) {
- prot_.incrementRecursionDepth();
+ TOutputRecursionTracker(TProtocol &prot) : prot_(prot) {
+ prot_.incrementOutputRecursionDepth();
}
- ~TRecursionTracker() {
- prot_.decrementRecursionDepth();
+ ~TOutputRecursionTracker() {
+ prot_.decrementOutputRecursionDepth();
+ }
+};
+
+struct TInputRecursionTracker {
+ TProtocol &prot_;
+ TInputRecursionTracker(TProtocol &prot) : prot_(prot) {
+ prot_.incrementInputRecursionDepth();
+ }
+ ~TInputRecursionTracker() {
+ prot_.decrementInputRecursionDepth();
}
};
@@ -634,7 +654,7 @@
*/
template <class Protocol_>
uint32_t skip(Protocol_& prot, TType type) {
- TRecursionTracker tracker(prot);
+ TInputRecursionTracker tracker(prot);
switch (type) {
case T_BOOL: {