THRIFT-5297: Improve TThreadPoolServer Handling of Incoming Connections
Client: Java
Patch: David Mollitor

This closes #2266
diff --git a/lib/java/src/org/apache/thrift/server/TThreadPoolServer.java b/lib/java/src/org/apache/thrift/server/TThreadPoolServer.java
index 467b10d..01749b9 100644
--- a/lib/java/src/org/apache/thrift/server/TThreadPoolServer.java
+++ b/lib/java/src/org/apache/thrift/server/TThreadPoolServer.java
@@ -19,11 +19,11 @@
 
 package org.apache.thrift.server;
 
-import java.util.Random;
-import java.util.WeakHashMap;
+import java.util.Optional;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.RejectedExecutionException;
 import java.util.concurrent.SynchronousQueue;
+import java.util.concurrent.ThreadFactory;
 import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
 
@@ -41,7 +41,7 @@
  * a worker pool that deals with client connections in blocking way.
  */
 public class TThreadPoolServer extends TServer {
-  private static final Logger LOGGER = LoggerFactory.getLogger(TThreadPoolServer.class.getName());
+  private static final Logger LOGGER = LoggerFactory.getLogger(TThreadPoolServer.class);
 
   public static class Args extends AbstractServerArgs<Args> {
     public int minWorkerThreads = 5;
@@ -49,10 +49,6 @@
     public ExecutorService executorService;
     public int stopTimeoutVal = 60;
     public TimeUnit stopTimeoutUnit = TimeUnit.SECONDS;
-    public int requestTimeout = 20;
-    public TimeUnit requestTimeoutUnit = TimeUnit.SECONDS;
-    public int beBackoffSlotLength = 100;
-    public TimeUnit beBackoffSlotLengthUnit = TimeUnit.MILLISECONDS;
 
     public Args(TServerTransport transport) {
       super(transport);
@@ -78,27 +74,6 @@
       return this;
     }
 
-    public Args requestTimeout(int n) {
-      requestTimeout = n;
-      return this;
-    }
-
-    public Args requestTimeoutUnit(TimeUnit tu) {
-      requestTimeoutUnit = tu;
-      return this;
-    }
-    //Binary exponential backoff slot length
-    public Args beBackoffSlotLength(int n) {
-      beBackoffSlotLength = n;
-      return this;
-    }
-
-    //Binary exponential backoff slot time unit
-    public Args beBackoffSlotLengthUnit(TimeUnit tu) {
-      beBackoffSlotLengthUnit = tu;
-      return this;
-    }
-
     public Args executorService(ExecutorService executorService) {
       this.executorService = executorService;
       return this;
@@ -107,41 +82,32 @@
 
   // Executor service for handling client connections
   private ExecutorService executorService_;
-  private WeakHashMap<WorkerProcess, Boolean> activeWorkers = new WeakHashMap<>();
 
   private final TimeUnit stopTimeoutUnit;
 
   private final long stopTimeoutVal;
 
-  private final TimeUnit requestTimeoutUnit;
-
-  private final long requestTimeout;
-
-  private final long beBackoffSlotInMillis;
-
-  private Random random = new Random(System.currentTimeMillis());
-
   public TThreadPoolServer(Args args) {
     super(args);
 
     stopTimeoutUnit = args.stopTimeoutUnit;
     stopTimeoutVal = args.stopTimeoutVal;
-    requestTimeoutUnit = args.requestTimeoutUnit;
-    requestTimeout = args.requestTimeout;
-    beBackoffSlotInMillis = args.beBackoffSlotLengthUnit.toMillis(args.beBackoffSlotLength);
 
     executorService_ = args.executorService != null ?
         args.executorService : createDefaultExecutorService(args);
   }
 
   private static ExecutorService createDefaultExecutorService(Args args) {
-    SynchronousQueue<Runnable> executorQueue =
-      new SynchronousQueue<Runnable>();
-    return new ThreadPoolExecutor(args.minWorkerThreads,
-                                  args.maxWorkerThreads,
-                                  args.stopTimeoutVal,
-                                  args.stopTimeoutUnit,
-                                  executorQueue);
+    return new ThreadPoolExecutor(args.minWorkerThreads, args.maxWorkerThreads, 60L, TimeUnit.SECONDS,
+        new SynchronousQueue<>(), new ThreadFactory() {
+          @Override
+          public Thread newThread(Runnable r) {
+            Thread thread = new Thread(r);
+            thread.setDaemon(true);
+            thread.setName("TThreadPoolServer WorkerProcess-%d");
+            return thread;
+          }
+        });
   }
 
   protected ExecutorService getExecutorService() {
@@ -149,7 +115,7 @@
   }
 
   protected boolean preServe() {
-  	try {
+    try {
       serverTransport_.listen();
     } catch (TTransportException ttx) {
       LOGGER.error("Error occurred during listening.", ttx);
@@ -166,13 +132,16 @@
   }
 
   public void serve() {
-  	if (!preServe()) {
-  		return;
-  	}
+    if (!preServe()) {
+      return;
+    }
 
-  	execute();
+    execute();
+
+    executorService_.shutdownNow();
+
     if (!waitForShutdown()) {
-  	  LOGGER.error("Shutdown is not done after " + stopTimeoutVal + stopTimeoutUnit);
+      LOGGER.error("Shutdown is not done after " + stopTimeoutVal + stopTimeoutUnit);
     }
 
     setServing(false);
@@ -182,51 +151,17 @@
     while (!stopped_) {
       try {
         TTransport client = serverTransport_.accept();
-        WorkerProcess wp = new WorkerProcess(client);
-
-        int retryCount = 0;
-        long remainTimeInMillis = requestTimeoutUnit.toMillis(requestTimeout);
-        while(true) {
-          try {
-            executorService_.execute(wp);
-            activeWorkers.put(wp, Boolean.TRUE);
-            break;
-          } catch(Throwable t) {
-            if (t instanceof RejectedExecutionException) {
-              retryCount++;
-              try {
-                if (remainTimeInMillis > 0) {
-                  //do a truncated 20 binary exponential backoff sleep
-                  long sleepTimeInMillis = ((long) (random.nextDouble() *
-                      (1L << Math.min(retryCount, 20)))) * beBackoffSlotInMillis;
-                  sleepTimeInMillis = Math.min(sleepTimeInMillis, remainTimeInMillis);
-                  TimeUnit.MILLISECONDS.sleep(sleepTimeInMillis);
-                  remainTimeInMillis = remainTimeInMillis - sleepTimeInMillis;
-                } else {
-                  client.close();
-                  wp = null;
-                  LOGGER.warn("Task has been rejected by ExecutorService " + retryCount
-                      + " times till timedout, reason: " + t);
-                  break;
-                }
-              } catch (InterruptedException e) {
-                LOGGER.warn("Interrupted while waiting to place client on executor queue.");
-                Thread.currentThread().interrupt();
-                break;
-              }
-            } else if (t instanceof Error) {
-              LOGGER.error("ExecutorService threw error: " + t, t);
-              throw (Error)t;
-            } else {
-              //for other possible runtime errors from ExecutorService, should also not kill serve
-              LOGGER.warn("ExecutorService threw error: " + t, t);
-              break;
-            }
+        try {
+          executorService_.execute(new WorkerProcess(client));
+        } catch (RejectedExecutionException ree) {
+          if (!stopped_) {
+            LOGGER.warn("ThreadPool is saturated with incoming requests. Closing latest connection.");
           }
+          client.close();
         }
       } catch (TTransportException ttx) {
         if (!stopped_) {
-          LOGGER.warn("Transport error occurred during acceptance of message.", ttx);
+          LOGGER.warn("Transport error occurred during acceptance of message", ttx);
         }
       }
     }
@@ -241,8 +176,7 @@
     long now = System.currentTimeMillis();
     while (timeoutMS >= 0) {
       try {
-        executorService_.awaitTermination(timeoutMS, TimeUnit.MILLISECONDS);
-        return true;
+        return executorService_.awaitTermination(timeoutMS, TimeUnit.MILLISECONDS);
       } catch (InterruptedException ix) {
         long newnow = System.currentTimeMillis();
         timeoutMS -= (newnow - now);
@@ -255,10 +189,6 @@
   public void stop() {
     stopped_ = true;
     serverTransport_.interrupt();
-    executorService_.shutdown();
-    for (WorkerProcess wp : activeWorkers.keySet()) {
-      wp.stop();
-    }
   }
 
   private class WorkerProcess implements Runnable {
@@ -287,7 +217,7 @@
       TProtocol inputProtocol = null;
       TProtocol outputProtocol = null;
 
-      TServerEventHandler eventHandler = null;
+      Optional<TServerEventHandler> eventHandler = Optional.empty();
       ServerContext connectionContext = null;
 
       try {
@@ -297,22 +227,25 @@
         inputProtocol = inputProtocolFactory_.getProtocol(inputTransport);
         outputProtocol = outputProtocolFactory_.getProtocol(outputTransport);
 
-        eventHandler = getEventHandler();
-        if (eventHandler != null) {
-          connectionContext = eventHandler.createContext(inputProtocol, outputProtocol);
+        eventHandler = Optional.ofNullable(getEventHandler());
+
+        if (eventHandler.isPresent()) {
+          connectionContext = eventHandler.get().createContext(inputProtocol, outputProtocol);
         }
-        // we check stopped_ first to make sure we're not supposed to be shutting
-        // down. this is necessary for graceful shutdown.
+
         while (true) {
-
-            if (eventHandler != null) {
-              eventHandler.processContext(connectionContext, inputTransport, outputTransport);
-            }
-
-            if (stopped_) {
-              break;
-            }
-            processor.process(inputProtocol, outputProtocol);
+          if (Thread.currentThread().isInterrupted()) {
+            LOGGER.debug("WorkerProcess requested to shutdown");
+            break;
+          }
+          if (eventHandler.isPresent()) {
+            eventHandler.get().processContext(connectionContext, inputTransport, outputTransport);
+          }
+          // This process cannot be interrupted by Interrupting the Thread. This
+          // will return once a message has been processed or the socket timeout
+          // has elapsed, at which point it will return and check the interrupt
+          // state of the thread.
+          processor.process(inputProtocol, outputProtocol);
         }
       } catch (Exception x) {
         LOGGER.debug("Error processing request", x);
@@ -322,11 +255,11 @@
         // Ignore err-logging all transport-level/type exceptions
         if (!isIgnorableException(x)) {
           // Log the exception at error level and continue
-          LOGGER.error((x instanceof TException? "Thrift " : "") + "Error occurred during processing of message.", x);
+          LOGGER.error((x instanceof TException ? "Thrift " : "") + "Error occurred during processing of message.", x);
         }
       } finally {
-        if (eventHandler != null) {
-          eventHandler.deleteContext(connectionContext, inputProtocol, outputProtocol);
+        if (eventHandler.isPresent()) {
+          eventHandler.get().deleteContext(connectionContext, inputProtocol, outputProtocol);
         }
         if (inputTransport != null) {
           inputTransport.close();
@@ -344,10 +277,9 @@
       TTransportException tTransportException = null;
 
       if (x instanceof TTransportException) {
-        tTransportException = (TTransportException)x;
-      }
-      else if (x.getCause() instanceof TTransportException) {
-        tTransportException = (TTransportException)x.getCause();
+        tTransportException = (TTransportException) x;
+      } else if (x.getCause() instanceof TTransportException) {
+        tTransportException = (TTransportException) x.getCause();
       }
 
       if (tTransportException != null) {
@@ -359,9 +291,5 @@
       }
       return false;
     }
-
-    private void stop() {
-      client_.close();
-    }
   }
 }
diff --git a/lib/java/test/org/apache/thrift/server/TestThreadPoolServer.java b/lib/java/test/org/apache/thrift/server/TestThreadPoolServer.java
index e81d801..4c84dc1 100644
--- a/lib/java/test/org/apache/thrift/server/TestThreadPoolServer.java
+++ b/lib/java/test/org/apache/thrift/server/TestThreadPoolServer.java
@@ -36,7 +36,7 @@
    */
   @Test
   public void testStopServerWithOpenClient() throws Exception {
-    TServerSocket serverSocket = new TServerSocket(0);
+    TServerSocket serverSocket = new TServerSocket(0, 3000);
     TThreadPoolServer server = buildServer(serverSocket);
     Thread serverThread = new Thread(() -> server.serve());
     serverThread.start();
@@ -45,11 +45,17 @@
       Thread.sleep(1000);
       // There is a thread listening to the client
       Assert.assertEquals(1, ((ThreadPoolExecutor) server.getExecutorService()).getActiveCount());
+
+      // Trigger the server to stop, but it does not wait
       server.stop();
-      server.waitForShutdown();
+      Assert.assertTrue(server.waitForShutdown());
+
       // After server is stopped, the executor thread pool should be shut down
-      Assert.assertTrue("Server thread pool should be terminated.", server.getExecutorService().isTerminated());
-      Assert.assertTrue("Client is still open.", client.isOpen());
+      Assert.assertTrue("Server thread pool should be terminated", server.getExecutorService().isTerminated());
+
+      // TODO: The socket is actually closed (timeout) but the client code
+      // ignores the timeout Exception and maintains the socket open state
+      Assert.assertTrue("Client should be closed after server shutdown", client.isOpen());
     }
   }