THRIFT-5863: Make TServerTransport able to customize the max message size
diff --git a/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java b/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java
index 310bb78..0d940a6 100644
--- a/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java
+++ b/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java
@@ -290,10 +290,12 @@
selectThread_ = selectThread;
buffer_ = ByteBuffer.allocate(4);
- frameTrans_ = new TMemoryInputTransport();
+ frameTrans_ = new TMemoryInputTransport(trans_.getConfiguration());
response_ = new TByteArrayOutputStream();
inTrans_ = inputTransportFactory_.getTransport(frameTrans_);
- outTrans_ = outputTransportFactory_.getTransport(new TIOStreamTransport(response_));
+ outTrans_ =
+ outputTransportFactory_.getTransport(
+ new TIOStreamTransport(trans_.getConfiguration(), response_));
inProt_ = inputProtocolFactory_.getProtocol(inTrans_);
outProt_ = outputProtocolFactory_.getProtocol(outTrans_);
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java
index 99f3192..a839488 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java
@@ -35,6 +35,10 @@
getConfiguration().setMaxFrameSize(maxFrameSize);
}
+ public void setMaxMessageSize(int maxMessageSize) {
+ getConfiguration().setMaxMessageSize(maxMessageSize);
+ }
+
protected long knownMessageSize;
protected long remainingMessageSize;
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java
index 9983797..0bcf601 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java
@@ -49,6 +49,9 @@
/** Limit for client sockets request size */
private int maxFrameSize_ = 0;
+ /** Max message size */
+ private int maxMessageSize_ = 0;
+
public static class NonblockingAbstractServerSocketArgs
extends AbstractServerTransportArgs<NonblockingAbstractServerSocketArgs> {}
@@ -93,6 +96,7 @@
throws TTransportException {
clientTimeout_ = args.clientTimeout;
maxFrameSize_ = args.maxFrameSize;
+ maxMessageSize_ = args.maxMessageSize;
try {
serverSocketChannel = ServerSocketChannel.open();
serverSocketChannel.configureBlocking(false);
@@ -135,6 +139,7 @@
TNonblockingSocket tsocket = new TNonblockingSocket(socketChannel);
tsocket.setTimeout(clientTimeout_);
tsocket.setMaxFrameSize(maxFrameSize_);
+ tsocket.setMaxMessageSize(maxMessageSize_);
return tsocket;
} catch (IOException iox) {
throw new TTransportException(iox);
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java b/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java
index e105662..59cef20 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java
@@ -38,6 +38,9 @@
/** Timeout for client sockets from accept */
private int clientTimeout_ = 0;
+ /** Max message size */
+ private int maxMessageSize_ = 0;
+
public static class ServerSocketTransportArgs
extends AbstractServerTransportArgs<ServerSocketTransportArgs> {
ServerSocket serverSocket;
@@ -78,6 +81,7 @@
public TServerSocket(ServerSocketTransportArgs args) throws TTransportException {
clientTimeout_ = args.clientTimeout;
+ maxMessageSize_ = args.maxMessageSize;
if (args.serverSocket != null) {
this.serverSocket_ = args.serverSocket;
return;
@@ -123,6 +127,7 @@
}
TSocket socket = new TSocket(result);
socket.setTimeout(clientTimeout_);
+ socket.setMaxMessageSize(maxMessageSize_);
return socket;
}
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java
index 47fa251..05a3f09 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java
@@ -32,6 +32,7 @@
int clientTimeout = 0;
InetSocketAddress bindAddr;
int maxFrameSize = TConfiguration.DEFAULT_MAX_FRAME_SIZE;
+ int maxMessageSize = TConfiguration.DEFAULT_MAX_MESSAGE_SIZE;
public T backlog(int backlog) {
this.backlog = backlog;
@@ -57,6 +58,11 @@
this.maxFrameSize = maxFrameSize;
return (T) this;
}
+
+ public T maxMessageSize(int maxMessageSize) {
+ this.maxMessageSize = maxMessageSize;
+ return (T) this;
+ }
}
public abstract void listen() throws TTransportException;
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
index 930f8e8..33291aa 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
@@ -320,7 +320,8 @@
byte[] inputPayload = requestReader.getPayload();
requestReader.clear();
byte[] rawInput = dataProtected ? saslPeer.unwrap(inputPayload) : inputPayload;
- TMemoryTransport memoryTransport = new TMemoryTransport(rawInput);
+ TMemoryTransport memoryTransport =
+ new TMemoryTransport(underlyingTransport.getConfiguration(), rawInput);
TProtocol requestProtocol = inputProtocolFactory.getProtocol(memoryTransport);
TProtocol responseProtocol = outputProtocolFactory.getProtocol(memoryTransport);
diff --git a/lib/java/src/test/java/org/apache/thrift/server/TestThreadPoolServer.java b/lib/java/src/test/java/org/apache/thrift/server/TestThreadPoolServer.java
index 74205c7..c16f59c 100644
--- a/lib/java/src/test/java/org/apache/thrift/server/TestThreadPoolServer.java
+++ b/lib/java/src/test/java/org/apache/thrift/server/TestThreadPoolServer.java
@@ -23,10 +23,12 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.atomic.AtomicReference;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.transport.TServerSocket;
import org.apache.thrift.transport.TServerTransport;
import org.apache.thrift.transport.TSocket;
+import org.apache.thrift.transport.TTransportException;
import org.junit.jupiter.api.Test;
import thrift.test.ThriftTest;
@@ -35,7 +37,20 @@
/** Test server is shut down properly even with some open clients. */
@Test
public void testStopServerWithOpenClient() throws Exception {
- TServerSocket serverSocket = new TServerSocket(0, 3000);
+ AtomicReference<TSocket> ref = new AtomicReference<>();
+ TServerSocket serverSocket =
+ new TServerSocket(
+ new TServerSocket.ServerSocketTransportArgs()
+ .port(0)
+ .clientTimeout(3000)
+ .maxMessageSize(51200)) {
+ @Override
+ public TSocket accept() throws TTransportException {
+ TSocket socket = super.accept();
+ ref.set(socket);
+ return socket;
+ }
+ };
TThreadPoolServer server = buildServer(serverSocket);
Thread serverThread = new Thread(server::serve);
serverThread.start();
@@ -44,6 +59,7 @@
Thread.sleep(1000);
// There is a thread listening to the client
assertEquals(1, ((ThreadPoolExecutor) server.getExecutorService()).getActiveCount());
+ assertEquals(51200, ref.get().getConfiguration().getMaxMessageSize());
// Trigger the server to stop, but it does not wait
server.stop();