THRIFT-4327: add API to efficiently remove a single timer
Client: C++

This closes #1353
diff --git a/lib/cpp/src/thrift/concurrency/TimerManager.cpp b/lib/cpp/src/thrift/concurrency/TimerManager.cpp
index 9ae1f94..2017146 100644
--- a/lib/cpp/src/thrift/concurrency/TimerManager.cpp
+++ b/lib/cpp/src/thrift/concurrency/TimerManager.cpp
@@ -30,6 +30,7 @@
 namespace concurrency {
 
 using stdcxx::shared_ptr;
+using stdcxx::weak_ptr;
 
 /**
  * TimerManager class
@@ -54,6 +55,8 @@
 
   bool operator==(const shared_ptr<Runnable> & runnable) const { return runnable_ == runnable; }
 
+  task_iterator it_;
+
 private:
   shared_ptr<Runnable> runnable_;
   friend class TimerManager::Dispatcher;
@@ -108,6 +111,7 @@
           for (task_iterator ix = manager_->taskMap_.begin(); ix != expiredTaskEnd; ix++) {
             shared_ptr<TimerManager::Task> task = ix->second;
             expiredTasks.insert(task);
+            task->it_ = manager_->taskMap_.end();
             if (task->state_ == TimerManager::Task::WAITING) {
               task->state_ = TimerManager::Task::EXECUTING;
             }
@@ -235,7 +239,7 @@
   return taskCount_;
 }
 
-void TimerManager::add(shared_ptr<Runnable> task, int64_t timeout) {
+TimerManager::Timer TimerManager::add(shared_ptr<Runnable> task, int64_t timeout) {
   int64_t now = Util::currentTime();
   timeout += now;
 
@@ -250,9 +254,9 @@
     // because the new task might insert at the front.
     bool notifyRequired = (taskCount_ == 0) ? true : timeout < taskMap_.begin()->first;
 
+    shared_ptr<Task> timer(new Task(task));
     taskCount_++;
-    taskMap_.insert(
-        std::pair<int64_t, shared_ptr<Task> >(timeout, shared_ptr<Task>(new Task(task))));
+    timer->it_ = taskMap_.insert(std::pair<int64_t, shared_ptr<Task> >(timeout, timer));
 
     // If the task map was empty, or if we have an expiration that is earlier
     // than any previously seen, kick the dispatcher so it can update its
@@ -260,10 +264,13 @@
     if (notifyRequired) {
       monitor_.notify();
     }
+
+    return timer;
   }
 }
 
-void TimerManager::add(shared_ptr<Runnable> task, const struct THRIFT_TIMESPEC& value) {
+TimerManager::Timer TimerManager::add(shared_ptr<Runnable> task,
+    const struct THRIFT_TIMESPEC& value) {
 
   int64_t expiration;
   Util::toMilliseconds(expiration, value);
@@ -274,10 +281,11 @@
     throw InvalidArgumentException();
   }
 
-  add(task, expiration - now);
+  return add(task, expiration - now);
 }
 
-void TimerManager::add(shared_ptr<Runnable> task, const struct timeval& value) {
+TimerManager::Timer TimerManager::add(shared_ptr<Runnable> task,
+    const struct timeval& value) {
 
   int64_t expiration;
   Util::toMilliseconds(expiration, value);
@@ -288,7 +296,7 @@
     throw InvalidArgumentException();
   }
 
-  add(task, expiration - now);
+  return add(task, expiration - now);
 }
 
 void TimerManager::remove(shared_ptr<Runnable> task) {
@@ -311,6 +319,26 @@
   }
 }
 
+void TimerManager::remove(Timer handle) {
+  Synchronized s(monitor_);
+  if (state_ != TimerManager::STARTED) {
+    throw IllegalStateException();
+  }
+
+  shared_ptr<Task> task = handle.lock();
+  if (!task) {
+    throw NoSuchTaskException();
+  }
+
+  if (task->it_ == taskMap_.end()) {
+    // Task is being executed
+    throw UncancellableTaskException();
+  }
+
+  taskMap_.erase(task->it_);
+  taskCount_--;
+}
+
 TimerManager::STATE TimerManager::state() const {
   return state_;
 }
diff --git a/lib/cpp/src/thrift/concurrency/TimerManager.h b/lib/cpp/src/thrift/concurrency/TimerManager.h
index f664348..2bfc6a7 100644
--- a/lib/cpp/src/thrift/concurrency/TimerManager.h
+++ b/lib/cpp/src/thrift/concurrency/TimerManager.h
@@ -42,6 +42,9 @@
 class TimerManager {
 
 public:
+  class Task;
+  typedef stdcxx::weak_ptr<Task> Timer;
+
   TimerManager();
 
   virtual ~TimerManager();
@@ -69,28 +72,33 @@
    *
    * @param task The task to execute
    * @param timeout Time in milliseconds to delay before executing task
+   * @return Handle of the timer, which can be used to remove the timer.
    */
-  virtual void add(stdcxx::shared_ptr<Runnable> task, int64_t timeout);
+  virtual Timer add(stdcxx::shared_ptr<Runnable> task, int64_t timeout);
 
   /**
    * Adds a task to be executed at some time in the future by a worker thread.
    *
    * @param task The task to execute
    * @param timeout Absolute time in the future to execute task.
+   * @return Handle of the timer, which can be used to remove the timer.
    */
-  virtual void add(stdcxx::shared_ptr<Runnable> task, const struct THRIFT_TIMESPEC& timeout);
+  virtual Timer add(stdcxx::shared_ptr<Runnable> task, const struct THRIFT_TIMESPEC& timeout);
 
   /**
    * Adds a task to be executed at some time in the future by a worker thread.
    *
    * @param task The task to execute
    * @param timeout Absolute time in the future to execute task.
+   * @return Handle of the timer, which can be used to remove the timer.
    */
-  virtual void add(stdcxx::shared_ptr<Runnable> task, const struct timeval& timeout);
+  virtual Timer add(stdcxx::shared_ptr<Runnable> task, const struct timeval& timeout);
 
   /**
    * Removes a pending task
    *
+   * @param task The task to remove. All timers which execute this task will
+   * be removed.
    * @throws NoSuchTaskException Specified task doesn't exist. It was either
    *                             processed already or this call was made for a
    *                             task that was never added to this timer
@@ -100,13 +108,26 @@
    */
   virtual void remove(stdcxx::shared_ptr<Runnable> task);
 
+  /**
+   * Removes a single pending task
+   *
+   * @param timer The timer to remove. The timer is returned when calling the
+   * add() method.
+   * @throws NoSuchTaskException Specified task doesn't exist. It was either
+   *                             processed already or this call was made for a
+   *                             task that was never added to this timer
+   *
+   * @throws UncancellableTaskException Specified task is already being
+   *                                    executed or has completed execution.
+   */
+  virtual void remove(Timer timer);
+
   enum STATE { UNINITIALIZED, STARTING, STARTED, STOPPING, STOPPED };
 
   virtual STATE state() const;
 
 private:
   stdcxx::shared_ptr<const ThreadFactory> threadFactory_;
-  class Task;
   friend class Task;
   std::multimap<int64_t, stdcxx::shared_ptr<Task> > taskMap_;
   size_t taskCount_;
diff --git a/lib/cpp/test/concurrency/Tests.cpp b/lib/cpp/test/concurrency/Tests.cpp
index d09d438..df5099d 100644
--- a/lib/cpp/test/concurrency/Tests.cpp
+++ b/lib/cpp/test/concurrency/Tests.cpp
@@ -45,7 +45,7 @@
 	  // lower the scale of every test
 	  WEIGHT = 1;
   }
-  
+
   bool runAll = args[0].compare("all") == 0;
 
   if (runAll || args[0].compare("thread-factory") == 0) {
@@ -137,6 +137,20 @@
       std::cerr << "\t\tTimerManager tests FAILED" << std::endl;
       return 1;
     }
+
+    std::cout << "\t\tTimerManager test03" << std::endl;
+
+    if (!timerManagerTests.test03()) {
+      std::cerr << "\t\tTimerManager tests FAILED" << std::endl;
+      return 1;
+    }
+
+    std::cout << "\t\tTimerManager test04" << std::endl;
+
+    if (!timerManagerTests.test04()) {
+      std::cerr << "\t\tTimerManager tests FAILED" << std::endl;
+      return 1;
+    }
   }
 
   if (runAll || args[0].compare("thread-manager") == 0) {
diff --git a/lib/cpp/test/concurrency/TimerManagerTests.h b/lib/cpp/test/concurrency/TimerManagerTests.h
index 80d373b..3779b0d 100644
--- a/lib/cpp/test/concurrency/TimerManagerTests.h
+++ b/lib/cpp/test/concurrency/TimerManagerTests.h
@@ -192,6 +192,74 @@
     return true;
   }
 
+  /**
+   * This test creates two tasks, removes the first one then waits for the second one. It then
+   * verifies that the timer manager properly clean up itself and the remaining orphaned timeout
+   * task when the manager goes out of scope and its destructor is called.
+   */
+  bool test03(int64_t timeout = 1000LL) {
+    TimerManager timerManager;
+    timerManager.threadFactory(shared_ptr<PlatformThreadFactory>(new PlatformThreadFactory()));
+    timerManager.start();
+    assert(timerManager.state() == TimerManager::STARTED);
+
+    Synchronized s(_monitor);
+
+    // Setup the two tasks
+    shared_ptr<TimerManagerTests::Task> taskToRemove
+        = shared_ptr<TimerManagerTests::Task>(new TimerManagerTests::Task(_monitor, timeout / 2));
+    TimerManager::Timer timer = timerManager.add(taskToRemove, taskToRemove->_timeout);
+
+    shared_ptr<TimerManagerTests::Task> task
+      = shared_ptr<TimerManagerTests::Task>(new TimerManagerTests::Task(_monitor, timeout));
+    timerManager.add(task, task->_timeout);
+
+    // Remove one task and wait until the other has completed
+    timerManager.remove(timer);
+    _monitor.wait(timeout * 2);
+
+    assert(!taskToRemove->_done);
+    assert(task->_done);
+
+    // Verify behavior when removing the removed task
+    try {
+      timerManager.remove(timer);
+      assert(0 == "ERROR: This remove should send a NoSuchTaskException exception.");
+    } catch (NoSuchTaskException&) {
+    }
+
+    return true;
+  }
+
+  /**
+   * This test creates one tasks, and tries to remove it after it has expired.
+   */
+  bool test04(int64_t timeout = 1000LL) {
+    TimerManager timerManager;
+    timerManager.threadFactory(shared_ptr<PlatformThreadFactory>(new PlatformThreadFactory()));
+    timerManager.start();
+    assert(timerManager.state() == TimerManager::STARTED);
+
+    Synchronized s(_monitor);
+
+    // Setup the task
+    shared_ptr<TimerManagerTests::Task> task
+      = shared_ptr<TimerManagerTests::Task>(new TimerManagerTests::Task(_monitor, timeout / 10));
+    TimerManager::Timer timer = timerManager.add(task, task->_timeout);
+
+    // Wait until the task has completed
+    _monitor.wait(timeout);
+
+    // Verify behavior when removing the expired task
+    try {
+      timerManager.remove(timer);
+      assert(0 == "ERROR: This remove should send a NoSuchTaskException exception.");
+    } catch (NoSuchTaskException&) {
+    }
+
+    return true;
+  }
+
   friend class TestTask;
 
   Monitor _monitor;