THRIFT-3247 Generate a C++ thread-safe client
Client: cpp
Patch: Ben Craig <bencraig@apache.org>
diff --git a/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.cpp b/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.cpp
new file mode 100644
index 0000000..c7e27c0
--- /dev/null
+++ b/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.cpp
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <thrift/async/TConcurrentClientSyncInfo.h>
+#include <thrift/TApplicationException.h>
+#include <thrift/transport/TTransportException.h>
+#include <limits>
+
+namespace apache { namespace thrift { namespace async {
+
+using namespace ::apache::thrift::concurrency;
+
+TConcurrentClientSyncInfo::TConcurrentClientSyncInfo() :
+  stop_(false),
+  seqidMutex_(),
+  // test rollover all the time
+  nextseqid_((std::numeric_limits<int32_t>::max)()-10),
+  seqidToMonitorMap_(),
+  freeMonitors_(),
+  writeMutex_(),
+  readMutex_(),
+  recvPending_(false),
+  wakeupSomeone_(false),
+  seqidPending_(0),
+  fnamePending_(),
+  mtypePending_(::apache::thrift::protocol::T_CALL)
+{
+  freeMonitors_.reserve(MONITOR_CACHE_SIZE);
+}
+
+bool TConcurrentClientSyncInfo::getPending(
+  std::string &fname,
+  ::apache::thrift::protocol::TMessageType &mtype,
+  int32_t &rseqid)
+{
+  if(stop_)
+    throwDeadConnection_();
+  wakeupSomeone_ = false;
+  if(recvPending_)
+  {
+    recvPending_ = false;
+    rseqid = seqidPending_;
+    fname  = fnamePending_;
+    mtype  = mtypePending_;
+    return true;
+  }
+  return false;
+}
+
+void TConcurrentClientSyncInfo::updatePending(
+  const std::string &fname,
+  ::apache::thrift::protocol::TMessageType mtype,
+  int32_t rseqid)
+{
+  recvPending_ = true;
+  seqidPending_ = rseqid;
+  fnamePending_ = fname;
+  mtypePending_ = mtype;
+  MonitorPtr monitor;
+  {
+    Guard seqidGuard(seqidMutex_);
+    MonitorMap::iterator i = seqidToMonitorMap_.find(rseqid);
+    if(i == seqidToMonitorMap_.end())
+      throwBadSeqId_();
+    monitor = i->second;
+  }
+  monitor->notify();
+}
+
+void TConcurrentClientSyncInfo::waitForWork(int32_t seqid)
+{
+  MonitorPtr m;
+  {
+    Guard seqidGuard(seqidMutex_);
+    m = seqidToMonitorMap_[seqid];
+  }
+  while(true)
+  {
+    // be very careful about setting state in this loop that affects waking up.  You may exit
+    // this function, attempt to grab some work, and someone else could have beaten you (or not
+    // left) the read mutex, and that will put you right back in this loop, with the mangled
+    // state you left behind.
+    if(stop_)
+      throwDeadConnection_();
+    if(wakeupSomeone_)
+      return;
+    if(recvPending_ && seqidPending_ == seqid)
+      return;
+    m->waitForever();
+  }
+}
+
+void TConcurrentClientSyncInfo::throwBadSeqId_()
+{
+  throw apache::thrift::TApplicationException(
+    TApplicationException::BAD_SEQUENCE_ID,
+    "server sent a bad seqid");
+}
+
+void TConcurrentClientSyncInfo::throwDeadConnection_()
+{
+  throw apache::thrift::transport::TTransportException(
+    apache::thrift::transport::TTransportException::NOT_OPEN,
+    "this client died on another thread, and is now in an unusable state");
+}
+
+void TConcurrentClientSyncInfo::wakeupAnyone_(const Guard &)
+{
+  wakeupSomeone_ = true;
+  if(!seqidToMonitorMap_.empty())
+  {
+    // The monitor map maps integers to monitors.  Larger integers are more recent
+    // messages.  Since this is ordered, it means that the last element is the most recent.
+    // We are trying to guess which thread will have its message complete next, so we are picking
+    // the most recent. The oldest message is likely to be some polling, long lived message.
+    // If we guess right, the thread we wake up will handle the message that comes in.
+    // If we guess wrong, the thread we wake up will hand off the work to the correct thread,
+    // costing us an extra context switch.
+    seqidToMonitorMap_.rbegin()->second->notify();
+  }
+}
+
+void TConcurrentClientSyncInfo::markBad_(const Guard &)
+{
+  wakeupSomeone_ = true;
+  stop_ = true;
+  for(MonitorMap::iterator i = seqidToMonitorMap_.begin(); i != seqidToMonitorMap_.end(); ++i)
+    i->second->notify();
+}
+
+TConcurrentClientSyncInfo::MonitorPtr
+TConcurrentClientSyncInfo::newMonitor_(const Guard &)
+{
+  if(freeMonitors_.empty())
+    return MonitorPtr(new Monitor(&readMutex_));
+  MonitorPtr retval;
+  //swapping to avoid an atomic operation
+  retval.swap(freeMonitors_.back());
+  freeMonitors_.pop_back();
+  return retval;
+}
+
+void TConcurrentClientSyncInfo::deleteMonitor_(
+  const Guard &,
+  TConcurrentClientSyncInfo::MonitorPtr &m) /*noexcept*/
+{
+  if(freeMonitors_.size() > MONITOR_CACHE_SIZE)
+  {
+    m.reset();
+    return;
+  }
+  //freeMonitors_ was reserved up to MONITOR_CACHE_SIZE in the ctor,
+  //so this shouldn't throw
+  freeMonitors_.push_back(TConcurrentClientSyncInfo::MonitorPtr());
+  //swapping to avoid an atomic operation
+  m.swap(freeMonitors_.back());
+}
+
+int32_t TConcurrentClientSyncInfo::generateSeqId()
+{
+  Guard seqidGuard(seqidMutex_);
+  if(stop_)
+    throwDeadConnection_();
+
+  if(!seqidToMonitorMap_.empty())
+    if(nextseqid_ == seqidToMonitorMap_.begin()->first)
+      throw apache::thrift::TApplicationException(
+        TApplicationException::BAD_SEQUENCE_ID,
+        "about to repeat a seqid");
+  int32_t newSeqId = nextseqid_++;
+  seqidToMonitorMap_[newSeqId] = newMonitor_(seqidGuard);
+  return newSeqId;
+}
+
+TConcurrentRecvSentry::TConcurrentRecvSentry(TConcurrentClientSyncInfo *sync, int32_t seqid) :
+  sync_(*sync),
+  seqid_(seqid),
+  committed_(false)
+{
+  sync_.getReadMutex().lock();
+}
+
+TConcurrentRecvSentry::~TConcurrentRecvSentry()
+{
+  {
+    Guard seqidGuard(sync_.seqidMutex_);
+    sync_.deleteMonitor_(seqidGuard, sync_.seqidToMonitorMap_[seqid_]);
+
+    sync_.seqidToMonitorMap_.erase(seqid_);
+    if(committed_)
+      sync_.wakeupAnyone_(seqidGuard);
+    else
+      sync_.markBad_(seqidGuard);
+  }
+  sync_.getReadMutex().unlock();
+}
+
+void TConcurrentRecvSentry::commit()
+{
+  committed_ = true;
+}
+
+TConcurrentSendSentry::TConcurrentSendSentry(TConcurrentClientSyncInfo *sync) :
+  sync_(*sync),
+  committed_(false)
+{
+  sync_.getWriteMutex().lock();
+}
+
+TConcurrentSendSentry::~TConcurrentSendSentry()
+{
+  if(!committed_)
+  {
+    Guard seqidGuard(sync_.seqidMutex_);
+    sync_.markBad_(seqidGuard);
+  }
+  sync_.getWriteMutex().unlock();
+}
+
+void TConcurrentSendSentry::commit()
+{
+  committed_ = true;
+}
+
+
+}}} // apache::thrift::async
diff --git a/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.h b/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.h
new file mode 100644
index 0000000..8997a23
--- /dev/null
+++ b/lib/cpp/src/thrift/async/TConcurrentClientSyncInfo.h
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef _THRIFT_TCONCURRENTCLIENTSYNCINFO_H_
+#define _THRIFT_TCONCURRENTCLIENTSYNCINFO_H_ 1
+
+#include <thrift/protocol/TProtocol.h>
+#include <thrift/concurrency/Mutex.h>
+#include <thrift/concurrency/Monitor.h>
+#include <boost/shared_ptr.hpp>
+#include <vector>
+#include <string>
+#include <map>
+
+namespace apache { namespace thrift { namespace async {
+
+class TConcurrentClientSyncInfo;
+
+class TConcurrentSendSentry
+{
+public:
+  explicit TConcurrentSendSentry(TConcurrentClientSyncInfo *sync);
+  ~TConcurrentSendSentry();
+
+  void commit();
+private:
+  TConcurrentClientSyncInfo &sync_;
+  bool committed_;
+};
+
+class TConcurrentRecvSentry
+{
+public:
+  TConcurrentRecvSentry(TConcurrentClientSyncInfo *sync, int32_t seqid);
+  ~TConcurrentRecvSentry();
+
+  void commit();
+private:
+  TConcurrentClientSyncInfo &sync_;
+  int32_t seqid_;
+  bool committed_;
+};
+
+class TConcurrentClientSyncInfo
+{
+private: //typedefs
+  typedef boost::shared_ptr< ::apache::thrift::concurrency::Monitor> MonitorPtr;
+  typedef std::map<int32_t, MonitorPtr> MonitorMap;
+public:
+  TConcurrentClientSyncInfo();
+
+  int32_t generateSeqId();
+
+  bool getPending(
+    std::string &fname,
+    ::apache::thrift::protocol::TMessageType &mtype,
+    int32_t &rseqid); /* requires readMutex_ */
+
+  void updatePending(
+    const std::string &fname,
+    ::apache::thrift::protocol::TMessageType mtype,
+    int32_t rseqid); /* requires readMutex_ */
+
+  void waitForWork(int32_t seqid); /* requires readMutex_ */
+
+  ::apache::thrift::concurrency::Mutex &getReadMutex() {return readMutex_;}
+  ::apache::thrift::concurrency::Mutex &getWriteMutex() {return writeMutex_;}
+
+private: //constants
+  enum {MONITOR_CACHE_SIZE = 10};
+private: //functions
+  MonitorPtr newMonitor_(
+    const ::apache::thrift::concurrency::Guard &seqidGuard); /* requires seqidMutex_ */
+  void deleteMonitor_(
+    const ::apache::thrift::concurrency::Guard &seqidGuard,
+    MonitorPtr &m); /*noexcept*/ /* requires seqidMutex_ */
+  void wakeupAnyone_(
+    const ::apache::thrift::concurrency::Guard &seqidGuard); /* requires seqidMutex_ */
+  void markBad_(
+    const ::apache::thrift::concurrency::Guard &seqidGuard); /* requires seqidMutex_ */
+  void throwBadSeqId_();
+  void throwDeadConnection_();
+private: //data members
+
+  volatile bool stop_;
+
+  ::apache::thrift::concurrency::Mutex seqidMutex_;
+  // begin seqidMutex_ protected members
+  int32_t nextseqid_;
+  MonitorMap seqidToMonitorMap_;
+  std::vector<MonitorPtr> freeMonitors_;
+  // end seqidMutex_ protected members
+
+  ::apache::thrift::concurrency::Mutex writeMutex_;
+
+  ::apache::thrift::concurrency::Mutex readMutex_;
+  // begin readMutex_ protected members
+  bool recvPending_;
+  bool wakeupSomeone_;
+  int32_t seqidPending_;
+  std::string fnamePending_;
+  ::apache::thrift::protocol::TMessageType mtypePending_;
+  // end readMutex_ protected members
+
+
+  friend class TConcurrentSendSentry;
+  friend class TConcurrentRecvSentry;
+};
+
+}}} // apache::thrift::async
+
+#endif // _THRIFT_TCONCURRENTCLIENTSYNCINFO_H_
diff --git a/lib/cpp/src/thrift/protocol/TProtocol.h b/lib/cpp/src/thrift/protocol/TProtocol.h
index 1aa2122..b44e91a 100644
--- a/lib/cpp/src/thrift/protocol/TProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TProtocol.h
@@ -552,26 +552,36 @@
   inline boost::shared_ptr<TTransport> getInputTransport() { return ptrans_; }
   inline boost::shared_ptr<TTransport> getOutputTransport() { return ptrans_; }
 
-  void incrementRecursionDepth() {
-    if (recursion_limit_ < ++recursion_depth_) {
+  // input and output recursion depth are kept separate so that one protocol
+  // can be used concurrently for both input and output.
+  void incrementInputRecursionDepth() {
+    if (recursion_limit_ < ++input_recursion_depth_) {
       throw TProtocolException(TProtocolException::DEPTH_LIMIT);
     }
   }
+  void decrementInputRecursionDepth() { --input_recursion_depth_; }
 
-  void decrementRecursionDepth() { --recursion_depth_; }
+  void incrementOutputRecursionDepth() {
+    if (recursion_limit_ < ++output_recursion_depth_) {
+      throw TProtocolException(TProtocolException::DEPTH_LIMIT);
+    }
+  }
+  void decrementOutputRecursionDepth() { --output_recursion_depth_; }
+
   uint32_t getRecursionLimit() const {return recursion_limit_;}
   void setRecurisionLimit(uint32_t depth) {recursion_limit_ = depth;}
 
 protected:
   TProtocol(boost::shared_ptr<TTransport> ptrans)
-    : ptrans_(ptrans), recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT)
+    : ptrans_(ptrans), input_recursion_depth_(0), output_recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT)
   {}
 
   boost::shared_ptr<TTransport> ptrans_;
 
 private:
   TProtocol() {}
-  uint32_t recursion_depth_;
+  uint32_t input_recursion_depth_;
+  uint32_t output_recursion_depth_;
   uint32_t recursion_limit_;
 };
 
@@ -617,13 +627,23 @@
   static uint64_t fromWire64(uint64_t x) {return letohll(x);}
 };
 
-struct TRecursionTracker {
+struct TOutputRecursionTracker {
   TProtocol &prot_;
-  TRecursionTracker(TProtocol &prot) : prot_(prot) {
-    prot_.incrementRecursionDepth();
+  TOutputRecursionTracker(TProtocol &prot) : prot_(prot) {
+    prot_.incrementOutputRecursionDepth();
   }
-  ~TRecursionTracker() {
-    prot_.decrementRecursionDepth();
+  ~TOutputRecursionTracker() {
+    prot_.decrementOutputRecursionDepth();
+  }
+};
+
+struct TInputRecursionTracker {
+  TProtocol &prot_;
+  TInputRecursionTracker(TProtocol &prot) : prot_(prot) {
+    prot_.incrementInputRecursionDepth();
+  }
+  ~TInputRecursionTracker() {
+    prot_.decrementInputRecursionDepth();
   }
 };
 
@@ -634,7 +654,7 @@
  */
 template <class Protocol_>
 uint32_t skip(Protocol_& prot, TType type) {
-  TRecursionTracker tracker(prot);
+  TInputRecursionTracker tracker(prot);
 
   switch (type) {
   case T_BOOL: {