THRIFT-1429:The nonblocking servers is supposed to use TransportFactory to read the data
Client: Java
Patch: Bryan Duxbury
Enforce the transport factory on the server-read side as well as on the server-write side
git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1296060 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/java/src/org/apache/thrift/server/AbstractNonblockingServer.java b/lib/java/src/org/apache/thrift/server/AbstractNonblockingServer.java
index bdd21c5..7fd75bf 100644
--- a/lib/java/src/org/apache/thrift/server/AbstractNonblockingServer.java
+++ b/lib/java/src/org/apache/thrift/server/AbstractNonblockingServer.java
@@ -332,10 +332,11 @@
}
// increment the amount of memory allocated to read buffers
- readBufferBytesAllocated.addAndGet(frameSize);
+ readBufferBytesAllocated.addAndGet(frameSize + 4);
// reallocate the readbuffer as a frame-sized buffer
- buffer_ = ByteBuffer.allocate(frameSize);
+ buffer_ = ByteBuffer.allocate(frameSize + 4);
+ buffer_.putInt(frameSize);
state_ = FrameBufferState.READING_FRAME;
} else {
@@ -492,7 +493,7 @@
* the data it needs to handle an invocation.
*/
private TTransport getInputTransport() {
- return new TMemoryInputTransport(buffer_.array());
+ return inputTransportFactory_.getTransport(new TMemoryInputTransport(buffer_.array()));
}
/**
diff --git a/lib/java/test/org/apache/thrift/server/ServerTestBase.java b/lib/java/test/org/apache/thrift/server/ServerTestBase.java
index b9974c4..08b57e2 100644
--- a/lib/java/test/org/apache/thrift/server/ServerTestBase.java
+++ b/lib/java/test/org/apache/thrift/server/ServerTestBase.java
@@ -29,16 +29,24 @@
import junit.framework.TestCase;
import org.apache.thrift.TException;
-import org.apache.thrift.TApplicationException;
import org.apache.thrift.TProcessor;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TCompactProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.transport.TFramedTransport;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportFactory;
+import org.apache.thrift.transport.TFramedTransport.Factory;
-import thrift.test.*;
+import thrift.test.Insanity;
+import thrift.test.Numberz;
+import thrift.test.ThriftTest;
+import thrift.test.Xception;
+import thrift.test.Xception2;
+import thrift.test.Xtruct;
+import thrift.test.Xtruct2;
public abstract class ServerTestBase extends TestCase {
@@ -278,7 +286,11 @@
private static final Xtruct XSTRUCT = new Xtruct("Zero", (byte) 1, -3, -5);
private static final Xtruct2 XSTRUCT2 = new Xtruct2((byte)1, XSTRUCT, 5);
- public abstract void startServer(TProcessor processor, TProtocolFactory protoFactory) throws Exception;
+ public void startServer(TProcessor processor, TProtocolFactory protoFactory) throws Exception{
+ startServer(processor, protoFactory, null);
+ }
+
+ public abstract void startServer(TProcessor processor, TProtocolFactory protoFactory, TTransportFactory factory) throws Exception;
public abstract void stopServer() throws Exception;
@@ -491,6 +503,45 @@
testClient.testVoid();
}
+ private static class CallCountingTransportFactory extends TTransportFactory {
+ public int count = 0;
+ private final Factory factory;
+
+ public CallCountingTransportFactory(Factory factory) {
+ this.factory = factory;
+ }
+
+ @Override
+ public TTransport getTransport(TTransport trans) {
+ count++;
+ return factory.getTransport(trans);
+ }
+ }
+
+ public void testTransportFactory() throws Exception {
+
+ for (TProtocolFactory protoFactory : getProtocols()) {
+ TestHandler handler = new TestHandler();
+ ThriftTest.Processor processor = new ThriftTest.Processor(handler);
+
+ final CallCountingTransportFactory factory = new CallCountingTransportFactory(new TFramedTransport.Factory());
+
+ startServer(processor, protoFactory, factory);
+ assertEquals(0, factory.count);
+
+ TSocket socket = new TSocket(HOST, PORT);
+ socket.setTimeout(SOCKET_TIMEOUT);
+ TTransport transport = getClientTransport(socket);
+ open(transport);
+
+ TProtocol protocol = protoFactory.getProtocol(transport);
+ ThriftTest.Client testClient = new ThriftTest.Client(protocol);
+ assertEquals(0, testClient.testByte((byte) 0));
+ assertEquals(2, factory.count);
+ stopServer();
+ }
+ }
+
private void testException(ThriftTest.Client testClient) throws TException, Xception {
//@TODO testException
//testClient.testException("no Exception");
diff --git a/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java b/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java
index 597074e..b23cd5c 100644
--- a/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java
+++ b/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java
@@ -28,6 +28,7 @@
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.transport.TTransportFactory;
import thrift.test.ThriftTest;
@@ -37,12 +38,16 @@
private TServer server;
private static final int NUM_QUERIES = 10000;
- protected TServer getServer(TProcessor processor, TNonblockingServerSocket socket, TProtocolFactory protoFactory) {
- return new TNonblockingServer(new Args(socket).processor(processor).protocolFactory(protoFactory));
+ protected TServer getServer(TProcessor processor, TNonblockingServerSocket socket, TProtocolFactory protoFactory, TTransportFactory factory) {
+ final Args args = new Args(socket).processor(processor).protocolFactory(protoFactory);
+ if (factory != null) {
+ args.transportFactory(factory);
+ }
+ return new TNonblockingServer(args);
}
@Override
- public void startServer(final TProcessor processor, final TProtocolFactory protoFactory) throws Exception {
+ public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory) throws Exception {
serverThread = new Thread() {
public void run() {
try {
@@ -50,7 +55,7 @@
TNonblockingServerSocket tServerSocket =
new TNonblockingServerSocket(PORT);
- server = getServer(processor, tServerSocket, protoFactory);
+ server = getServer(processor, tServerSocket, protoFactory, factory);
// Run it
System.out.println("Starting the server on port " + PORT + "...");
diff --git a/lib/java/test/org/apache/thrift/transport/TestTSSLTransportFactory.java b/lib/java/test/org/apache/thrift/transport/TestTSSLTransportFactory.java
index 4bba451..478407a 100644
--- a/lib/java/test/org/apache/thrift/transport/TestTSSLTransportFactory.java
+++ b/lib/java/test/org/apache/thrift/transport/TestTSSLTransportFactory.java
@@ -47,13 +47,14 @@
}
@Override
- public void startServer(final TProcessor processor, final TProtocolFactory protoFactory)
+ public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory)
throws Exception {
serverThread = new Thread() {
public void run() {
try {
TServerTransport serverTransport = TSSLTransportFactory.getServerSocket(PORT);
- server = new TSimpleServer(new Args(serverTransport).processor(processor));
+ final Args args = new Args(serverTransport).processor(processor);
+ server = new TSimpleServer(args);
server.serve();
} catch (TTransportException e) {
e.printStackTrace();
@@ -79,4 +80,9 @@
public List<TProtocolFactory> getProtocols() {
return protocols;
}
+
+ @Override
+ public void testTransportFactory() throws Exception {
+ // this test doesn't really apply to this suite, so let's skip it.
+ }
}
diff --git a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
index e12d6fe..41d08f6 100644
--- a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
+++ b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
@@ -275,7 +275,7 @@
}
@Override
- public void startServer(final TProcessor processor, final TProtocolFactory protoFactory) throws Exception {
+ public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory) throws Exception {
serverThread = new Thread() {
public void run() {
try {