THRIFT-5863: Make TServerTransport able to customize the max message size
diff --git a/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java b/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java
index 310bb78..0d940a6 100644
--- a/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java
+++ b/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java
@@ -290,10 +290,12 @@
       selectThread_ = selectThread;
       buffer_ = ByteBuffer.allocate(4);
 
-      frameTrans_ = new TMemoryInputTransport();
+      frameTrans_ = new TMemoryInputTransport(trans_.getConfiguration());
       response_ = new TByteArrayOutputStream();
       inTrans_ = inputTransportFactory_.getTransport(frameTrans_);
-      outTrans_ = outputTransportFactory_.getTransport(new TIOStreamTransport(response_));
+      outTrans_ =
+          outputTransportFactory_.getTransport(
+              new TIOStreamTransport(trans_.getConfiguration(), response_));
       inProt_ = inputProtocolFactory_.getProtocol(inTrans_);
       outProt_ = outputProtocolFactory_.getProtocol(outTrans_);
 
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java
index 99f3192..a839488 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java
@@ -35,6 +35,10 @@
     getConfiguration().setMaxFrameSize(maxFrameSize);
   }
 
+  public void setMaxMessageSize(int maxMessageSize) {
+    getConfiguration().setMaxMessageSize(maxMessageSize);
+  }
+
   protected long knownMessageSize;
   protected long remainingMessageSize;
 
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java
index 9983797..0bcf601 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java
@@ -49,6 +49,9 @@
   /** Limit for client sockets request size */
   private int maxFrameSize_ = 0;
 
+  /** Max message size */
+  private int maxMessageSize_ = 0;
+
   public static class NonblockingAbstractServerSocketArgs
       extends AbstractServerTransportArgs<NonblockingAbstractServerSocketArgs> {}
 
@@ -93,6 +96,7 @@
       throws TTransportException {
     clientTimeout_ = args.clientTimeout;
     maxFrameSize_ = args.maxFrameSize;
+    maxMessageSize_ = args.maxMessageSize;
     try {
       serverSocketChannel = ServerSocketChannel.open();
       serverSocketChannel.configureBlocking(false);
@@ -135,6 +139,7 @@
       TNonblockingSocket tsocket = new TNonblockingSocket(socketChannel);
       tsocket.setTimeout(clientTimeout_);
       tsocket.setMaxFrameSize(maxFrameSize_);
+      tsocket.setMaxMessageSize(maxMessageSize_);
       return tsocket;
     } catch (IOException iox) {
       throw new TTransportException(iox);
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java b/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java
index e105662..59cef20 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java
@@ -38,6 +38,9 @@
   /** Timeout for client sockets from accept */
   private int clientTimeout_ = 0;
 
+  /** Max message size */
+  private int maxMessageSize_ = 0;
+
   public static class ServerSocketTransportArgs
       extends AbstractServerTransportArgs<ServerSocketTransportArgs> {
     ServerSocket serverSocket;
@@ -78,6 +81,7 @@
 
   public TServerSocket(ServerSocketTransportArgs args) throws TTransportException {
     clientTimeout_ = args.clientTimeout;
+    maxMessageSize_ = args.maxMessageSize;
     if (args.serverSocket != null) {
       this.serverSocket_ = args.serverSocket;
       return;
@@ -123,6 +127,7 @@
     }
     TSocket socket = new TSocket(result);
     socket.setTimeout(clientTimeout_);
+    socket.setMaxMessageSize(maxMessageSize_);
     return socket;
   }
 
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java
index 47fa251..05a3f09 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java
@@ -32,6 +32,7 @@
     int clientTimeout = 0;
     InetSocketAddress bindAddr;
     int maxFrameSize = TConfiguration.DEFAULT_MAX_FRAME_SIZE;
+    int maxMessageSize = TConfiguration.DEFAULT_MAX_MESSAGE_SIZE;
 
     public T backlog(int backlog) {
       this.backlog = backlog;
@@ -57,6 +58,11 @@
       this.maxFrameSize = maxFrameSize;
       return (T) this;
     }
+
+    public T maxMessageSize(int maxMessageSize) {
+      this.maxMessageSize = maxMessageSize;
+      return (T) this;
+    }
   }
 
   public abstract void listen() throws TTransportException;
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
index 930f8e8..33291aa 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
@@ -320,7 +320,8 @@
       byte[] inputPayload = requestReader.getPayload();
       requestReader.clear();
       byte[] rawInput = dataProtected ? saslPeer.unwrap(inputPayload) : inputPayload;
-      TMemoryTransport memoryTransport = new TMemoryTransport(rawInput);
+      TMemoryTransport memoryTransport =
+          new TMemoryTransport(underlyingTransport.getConfiguration(), rawInput);
       TProtocol requestProtocol = inputProtocolFactory.getProtocol(memoryTransport);
       TProtocol responseProtocol = outputProtocolFactory.getProtocol(memoryTransport);
 
diff --git a/lib/java/src/test/java/org/apache/thrift/server/TestThreadPoolServer.java b/lib/java/src/test/java/org/apache/thrift/server/TestThreadPoolServer.java
index 74205c7..c16f59c 100644
--- a/lib/java/src/test/java/org/apache/thrift/server/TestThreadPoolServer.java
+++ b/lib/java/src/test/java/org/apache/thrift/server/TestThreadPoolServer.java
@@ -23,10 +23,12 @@
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.atomic.AtomicReference;
 import org.apache.thrift.protocol.TBinaryProtocol;
 import org.apache.thrift.transport.TServerSocket;
 import org.apache.thrift.transport.TServerTransport;
 import org.apache.thrift.transport.TSocket;
+import org.apache.thrift.transport.TTransportException;
 import org.junit.jupiter.api.Test;
 import thrift.test.ThriftTest;
 
@@ -35,7 +37,20 @@
   /** Test server is shut down properly even with some open clients. */
   @Test
   public void testStopServerWithOpenClient() throws Exception {
-    TServerSocket serverSocket = new TServerSocket(0, 3000);
+    AtomicReference<TSocket> ref = new AtomicReference<>();
+    TServerSocket serverSocket =
+        new TServerSocket(
+            new TServerSocket.ServerSocketTransportArgs()
+                .port(0)
+                .clientTimeout(3000)
+                .maxMessageSize(51200)) {
+          @Override
+          public TSocket accept() throws TTransportException {
+            TSocket socket = super.accept();
+            ref.set(socket);
+            return socket;
+          }
+        };
     TThreadPoolServer server = buildServer(serverSocket);
     Thread serverThread = new Thread(server::serve);
     serverThread.start();
@@ -44,6 +59,7 @@
       Thread.sleep(1000);
       // There is a thread listening to the client
       assertEquals(1, ((ThreadPoolExecutor) server.getExecutorService()).getActiveCount());
+      assertEquals(51200, ref.get().getConfiguration().getMaxMessageSize());
 
       // Trigger the server to stop, but it does not wait
       server.stop();