THRIFT-1429:The nonblocking servers is supposed to use TransportFactory to read the data
Client: Java
Patch: Bryan Duxbury 

Enforce the transport factory on the server-read side as well as on the server-write side



git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1296060 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/java/src/org/apache/thrift/server/AbstractNonblockingServer.java b/lib/java/src/org/apache/thrift/server/AbstractNonblockingServer.java
index bdd21c5..7fd75bf 100644
--- a/lib/java/src/org/apache/thrift/server/AbstractNonblockingServer.java
+++ b/lib/java/src/org/apache/thrift/server/AbstractNonblockingServer.java
@@ -332,10 +332,11 @@
           }
 
           // increment the amount of memory allocated to read buffers
-          readBufferBytesAllocated.addAndGet(frameSize);
+          readBufferBytesAllocated.addAndGet(frameSize + 4);
 
           // reallocate the readbuffer as a frame-sized buffer
-          buffer_ = ByteBuffer.allocate(frameSize);
+          buffer_ = ByteBuffer.allocate(frameSize + 4);
+          buffer_.putInt(frameSize);
 
           state_ = FrameBufferState.READING_FRAME;
         } else {
@@ -492,7 +493,7 @@
      * the data it needs to handle an invocation.
      */
     private TTransport getInputTransport() {
-      return new TMemoryInputTransport(buffer_.array());
+      return inputTransportFactory_.getTransport(new TMemoryInputTransport(buffer_.array()));
     }
 
     /**
diff --git a/lib/java/test/org/apache/thrift/server/ServerTestBase.java b/lib/java/test/org/apache/thrift/server/ServerTestBase.java
index b9974c4..08b57e2 100644
--- a/lib/java/test/org/apache/thrift/server/ServerTestBase.java
+++ b/lib/java/test/org/apache/thrift/server/ServerTestBase.java
@@ -29,16 +29,24 @@
 import junit.framework.TestCase;
 
 import org.apache.thrift.TException;
-import org.apache.thrift.TApplicationException;
 import org.apache.thrift.TProcessor;
 import org.apache.thrift.protocol.TBinaryProtocol;
 import org.apache.thrift.protocol.TCompactProtocol;
 import org.apache.thrift.protocol.TProtocol;
 import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.transport.TFramedTransport;
 import org.apache.thrift.transport.TSocket;
 import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportFactory;
+import org.apache.thrift.transport.TFramedTransport.Factory;
 
-import thrift.test.*;
+import thrift.test.Insanity;
+import thrift.test.Numberz;
+import thrift.test.ThriftTest;
+import thrift.test.Xception;
+import thrift.test.Xception2;
+import thrift.test.Xtruct;
+import thrift.test.Xtruct2;
 
 public abstract class ServerTestBase extends TestCase {
 
@@ -278,7 +286,11 @@
   private static final Xtruct XSTRUCT = new Xtruct("Zero", (byte) 1, -3, -5);
   private static final Xtruct2 XSTRUCT2 = new Xtruct2((byte)1, XSTRUCT, 5);
 
-  public abstract void startServer(TProcessor processor, TProtocolFactory protoFactory) throws Exception;
+  public void startServer(TProcessor processor, TProtocolFactory protoFactory) throws Exception{
+    startServer(processor, protoFactory, null);
+  }
+
+  public abstract void startServer(TProcessor processor, TProtocolFactory protoFactory, TTransportFactory factory) throws Exception;
 
   public abstract void stopServer() throws Exception;
 
@@ -491,6 +503,45 @@
     testClient.testVoid();
   }
 
+  private static class CallCountingTransportFactory extends TTransportFactory {
+    public int count = 0;
+    private final Factory factory;
+
+    public CallCountingTransportFactory(Factory factory) {
+      this.factory = factory;
+    }
+
+    @Override
+    public TTransport getTransport(TTransport trans) {
+      count++;
+      return factory.getTransport(trans);
+    }
+  }
+
+  public void testTransportFactory() throws Exception {
+    
+    for (TProtocolFactory protoFactory : getProtocols()) {
+      TestHandler handler = new TestHandler();
+      ThriftTest.Processor processor = new ThriftTest.Processor(handler);
+  
+      final CallCountingTransportFactory factory = new CallCountingTransportFactory(new TFramedTransport.Factory());
+  
+      startServer(processor, protoFactory, factory);
+      assertEquals(0, factory.count);
+  
+      TSocket socket = new TSocket(HOST, PORT);
+      socket.setTimeout(SOCKET_TIMEOUT);
+      TTransport transport = getClientTransport(socket);
+      open(transport);
+  
+      TProtocol protocol = protoFactory.getProtocol(transport);
+      ThriftTest.Client testClient = new ThriftTest.Client(protocol);
+      assertEquals(0, testClient.testByte((byte) 0));
+      assertEquals(2, factory.count);
+      stopServer();
+    }
+  }
+
   private void testException(ThriftTest.Client testClient) throws TException, Xception {
     //@TODO testException
     //testClient.testException("no Exception");
diff --git a/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java b/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java
index 597074e..b23cd5c 100644
--- a/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java
+++ b/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java
@@ -28,6 +28,7 @@
 import org.apache.thrift.transport.TSocket;
 import org.apache.thrift.transport.TTransport;
 import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.transport.TTransportFactory;
 
 import thrift.test.ThriftTest;
 
@@ -37,12 +38,16 @@
   private TServer server;
   private static final int NUM_QUERIES = 10000;
 
-  protected TServer getServer(TProcessor processor, TNonblockingServerSocket socket, TProtocolFactory protoFactory) {
-    return new TNonblockingServer(new Args(socket).processor(processor).protocolFactory(protoFactory));
+  protected TServer getServer(TProcessor processor, TNonblockingServerSocket socket, TProtocolFactory protoFactory, TTransportFactory factory) {
+    final Args args = new Args(socket).processor(processor).protocolFactory(protoFactory);
+    if (factory != null) {
+      args.transportFactory(factory);
+    }
+    return new TNonblockingServer(args);
   }
 
   @Override
-  public void startServer(final TProcessor processor, final TProtocolFactory protoFactory) throws Exception {
+  public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory) throws Exception {
     serverThread = new Thread() {
       public void run() {
         try {
@@ -50,7 +55,7 @@
           TNonblockingServerSocket tServerSocket =
             new TNonblockingServerSocket(PORT);
 
-          server = getServer(processor, tServerSocket, protoFactory);
+          server = getServer(processor, tServerSocket, protoFactory, factory);
 
           // Run it
           System.out.println("Starting the server on port " + PORT + "...");
diff --git a/lib/java/test/org/apache/thrift/transport/TestTSSLTransportFactory.java b/lib/java/test/org/apache/thrift/transport/TestTSSLTransportFactory.java
index 4bba451..478407a 100644
--- a/lib/java/test/org/apache/thrift/transport/TestTSSLTransportFactory.java
+++ b/lib/java/test/org/apache/thrift/transport/TestTSSLTransportFactory.java
@@ -47,13 +47,14 @@
   }
 
   @Override
-  public void startServer(final TProcessor processor, final TProtocolFactory protoFactory)
+  public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory)
   throws Exception {
     serverThread = new Thread() {
       public void run() {
         try {
           TServerTransport serverTransport = TSSLTransportFactory.getServerSocket(PORT);
-          server = new TSimpleServer(new Args(serverTransport).processor(processor));
+          final Args args = new Args(serverTransport).processor(processor);
+          server = new TSimpleServer(args);
           server.serve();
         } catch (TTransportException e) {
           e.printStackTrace();
@@ -79,4 +80,9 @@
   public List<TProtocolFactory> getProtocols() {
     return protocols;
   }
+
+  @Override
+  public void testTransportFactory() throws Exception {
+    // this test doesn't really apply to this suite, so let's skip it.
+  }
 }
diff --git a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
index e12d6fe..41d08f6 100644
--- a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
+++ b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
@@ -275,7 +275,7 @@
     }
 
     @Override
-    public void startServer(final TProcessor processor, final TProtocolFactory protoFactory) throws Exception {
+    public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory) throws Exception {
       serverThread = new Thread() {
         public void run() {
           try {