THRIFT-1198 C++ TestClient and Server Improvements (add Unix Domain Socket, HTTP, JSON)


git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1133116 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/test/cpp/Thrift-test.mk b/test/cpp/Thrift-test.mk
index 6987e33..e0a624c 100644
--- a/test/cpp/Thrift-test.mk
+++ b/test/cpp/Thrift-test.mk
@@ -46,8 +46,8 @@
 LD     = g++
 
 # Compiler flags
-DCFL  = -Wall -O3 -g -I. -I./gen-cpp $(include_flags) -L$(thrift_home)/lib/cpp/.libs -lthrift -lthriftnb -levent
-LFL   =  -L$(thrift_home)/lib/cpp/.libs -lthrift -lthriftnb -levent
+DCFL  = -Wall -O3 -g -I. -I./gen-cpp $(include_flags) -L$(thrift_home)/lib/cpp/.libs -lthrift -lthriftnb -levent -lboost_program_options
+LFL   =  -L$(thrift_home)/lib/cpp/.libs -lthrift -lthriftnb -levent -lboost_program_options
 CCFL  = -Wall -O3 -I. -I./gen-cpp $(include_flags)
 CFL   = $(CCFL) $(LFL)
 
diff --git a/test/cpp/src/TestClient.cpp b/test/cpp/src/TestClient.cpp
index 417a7a1..23d7dcd 100644
--- a/test/cpp/src/TestClient.cpp
+++ b/test/cpp/src/TestClient.cpp
@@ -17,15 +17,19 @@
  * under the License.
  */
 
-#include <stdio.h>
+#include <iostream>
 #include <unistd.h>
 #include <sys/time.h>
 #include <protocol/TBinaryProtocol.h>
+#include <protocol/TJSONProtocol.h>
+#include <transport/THttpClient.h>
 #include <transport/TTransportUtils.h>
 #include <transport/TSocket.h>
 #include <transport/TSSLSocket.h>
 
 #include <boost/shared_ptr.hpp>
+#include <boost/program_options.hpp>
+
 #include "ThriftTest.h"
 
 #define __STDC_FORMAT_MACROS
@@ -56,30 +60,66 @@
   string host = "localhost";
   int port = 9090;
   int numTests = 1;
-  bool framed = false;
   bool ssl = false;
+  string transport_type = "buffered";
+  string protocol_type = "binary";
+  string domain_socket = "";
 
-  for (int i = 0; i < argc; ++i) {
-    if (strcmp(argv[i], "-h") == 0) {
-      char* pch = strtok(argv[++i], ":");
-      if (pch != NULL) {
-        host = string(pch);
-      }
-      pch = strtok(NULL, ":");
-      if (pch != NULL) {
-        port = atoi(pch);
-      }
-    } else if (strcmp(argv[i], "-n") == 0) {
-      numTests = atoi(argv[++i]);
-    } else if (strcmp(argv[i], "-f") == 0) {
-      framed = true;
-    } else if (strcmp(argv[i], "--ssl") == 0) {
-      ssl = true;
-    }
+  program_options::options_description desc("Allowed options");
+  desc.add_options()
+      ("help,h", "produce help message")
+      ("host", program_options::value<string>(&host)->default_value(host), "Host to connect")
+      ("port", program_options::value<int>(&port)->default_value(port), "Port number to connect")
+	  ("domain-socket", program_options::value<string>(&domain_socket)->default_value(domain_socket), "Domain Socket (e.g. /tmp/ThriftTest.thrift), instead of host and port")
+      ("transport", program_options::value<string>(&transport_type)->default_value(transport_type), "Transport: buffered, framed, http")
+      ("protocol", program_options::value<string>(&protocol_type)->default_value(protocol_type), "Protocol: binary, json")
+	  ("ssl", "Encrypted Transport using SSL")
+      ("testloops,n", program_options::value<int>(&numTests)->default_value(numTests), "Number of Tests")
+  ;
+
+  program_options::variables_map vm;
+  program_options::store(program_options::parse_command_line(argc, argv, desc), vm);
+  program_options::notify(vm);    
+
+  if (vm.count("help")) {
+    cout << desc << "\n";
+    return 1;
   }
 
+  try {   
+    if (!protocol_type.empty()) {
+      if (protocol_type == "binary") {
+      } else if (protocol_type == "json") {
+      } else {
+          throw invalid_argument("Unknown protocol type "+protocol_type);
+      }
+    }
+
+	if (!transport_type.empty()) {
+      if (transport_type == "buffered") {
+      } else if (transport_type == "framed") {
+      } else if (transport_type == "http") {
+      } else {
+          throw invalid_argument("Unknown transport type "+transport_type);
+      }
+    }
+
+  } catch (std::exception& e) {
+    cerr << e.what() << endl;
+    cout << desc << "\n";
+    return 1;
+  }
+
+  if (vm.count("ssl")) {
+    ssl = true;
+  }
+
+  shared_ptr<TTransport> transport;
+  shared_ptr<TProtocol> protocol;
+
   shared_ptr<TSocket> socket;
   shared_ptr<TSSLSocketFactory> factory;
+
   if (ssl) {
     factory = shared_ptr<TSSLSocketFactory>(new TSSLSocketFactory());
     factory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
@@ -87,22 +127,42 @@
     factory->authenticate(true);
     socket = factory->createSocket(host, port);
   } else {
-    socket = shared_ptr<TSocket>(new TSocket(host, port));
+    if (domain_socket != "") {
+      socket = shared_ptr<TSocket>(new TSocket(domain_socket));
+      port = 0;
+    }
+    else {
+      socket = shared_ptr<TSocket>(new TSocket(host, port));
+    }
   }
 
-  shared_ptr<TBufferBase> transport;
-
-  if (framed) {
+  if (transport_type.compare("http") == 0) {
+    shared_ptr<TTransport> httpSocket(new THttpClient(socket, host, "/service"));
+    transport = httpSocket;
+  } else if (transport_type.compare("framed") == 0){
     shared_ptr<TFramedTransport> framedSocket(new TFramedTransport(socket));
     transport = framedSocket;
-  } else {
+  } else{
     shared_ptr<TBufferedTransport> bufferedSocket(new TBufferedTransport(socket));
     transport = bufferedSocket;
   }
 
-  shared_ptr< TBinaryProtocolT<TBufferBase> > protocol(
-      new TBinaryProtocolT<TBufferBase>(transport));
-  ThriftTestClientT< TBinaryProtocolT<TBufferBase> > testClient(protocol);
+  if (protocol_type.compare("json") == 0) {
+    shared_ptr<TProtocol> jsonProtocol(new TJSONProtocol(transport));
+    protocol = jsonProtocol;
+  } else{
+    shared_ptr<TBinaryProtocol> binaryProtocol(new TBinaryProtocol(transport));
+    protocol = binaryProtocol;
+  }
+
+  // Connection info
+  cout << "Connecting (" << transport_type << "/" << protocol_type << ") to: " << domain_socket;
+  if (port != 0) {
+    cout << host << ":" << port;
+  }
+  cout << endl;
+
+  ThriftTestClient testClient(protocol);
 
   uint64_t time_min = 0;
   uint64_t time_max = 0;
diff --git a/test/cpp/src/TestServer.cpp b/test/cpp/src/TestServer.cpp
index 8401647..bb3cd43 100644
--- a/test/cpp/src/TestServer.cpp
+++ b/test/cpp/src/TestServer.cpp
@@ -20,6 +20,7 @@
 #include <concurrency/ThreadManager.h>
 #include <concurrency/PosixThreadFactory.h>
 #include <protocol/TBinaryProtocol.h>
+#include <protocol/TJSONProtocol.h>
 #include <server/TSimpleServer.h>
 #include <server/TThreadedServer.h>
 #include <server/TThreadPoolServer.h>
@@ -27,6 +28,8 @@
 #include <transport/TServerSocket.h>
 #include <transport/TSSLServerSocket.h>
 #include <transport/TSSLSocket.h>
+#include <transport/THttpServer.h>
+#include <transport/THttpTransport.h>
 #include <transport/TTransportUtils.h>
 #include "ThriftTest.h"
 
@@ -34,6 +37,8 @@
 #include <stdexcept>
 #include <sstream>
 
+#include <boost/program_options.hpp>
+
 #define __STDC_FORMAT_MACROS
 #include <inttypes.h>
 #include <signal.h>
@@ -324,97 +329,100 @@
 
 
 int main(int argc, char **argv) {
-
   int port = 9090;
-  string serverType = "simple";
-  string protocolType = "binary";
-  size_t workerCount = 4;
   bool ssl = false;
+  string transport_type = "buffered";
+  string protocol_type = "binary";
+  string server_type = "simple";
+  string domain_socket = "";
+  size_t workers = 4;
 
-  ostringstream usage;
+ 
+  program_options::options_description desc("Allowed options");
+  desc.add_options()
+      ("help,h", "produce help message")
+      ("port", program_options::value<int>(&port)->default_value(port), "Port number to listen")
+	  ("domain-socket", program_options::value<string>(&domain_socket)->default_value(domain_socket),
+	    "Unix Domain Socket (e.g. /tmp/ThriftTest.thrift)")
+      ("server-type", program_options::value<string>(&server_type)->default_value(server_type),
+        "type of server, \"simple\", \"thread-pool\", \"threaded\", or \"nonblocking\"")
+      ("transport", program_options::value<string>(&transport_type)->default_value(transport_type),
+        "transport: buffered, framed, http")
+      ("protocol", program_options::value<string>(&protocol_type)->default_value(protocol_type),
+        "protocol: binary, json")
+	  ("ssl", "Encrypted Transport using SSL")
+	  ("processor-events", "processor-events")
+      ("workers,n", program_options::value<size_t>(&workers)->default_value(workers),
+        "Number of thread pools workers. Only valid for thread-pool server type")
+  ;
 
-  usage <<
-    argv[0] << " [--port=<port number>] [--server-type=<server-type>] [--protocol-type=<protocol-type>] [--workers=<worker-count>] [--processor-events]" << endl <<
+  program_options::variables_map vm;
+  program_options::store(program_options::parse_command_line(argc, argv, desc), vm);
+  program_options::notify(vm);    
 
-    "\t\tserver-type\t\ttype of server, \"simple\", \"thread-pool\", \"threaded\", or \"nonblocking\".  Default is " << serverType << endl <<
-
-    "\t\tprotocol-type\t\ttype of protocol, \"binary\", \"ascii\", or \"xml\".  Default is " << protocolType << endl <<
-
-    "\t\tworkers\t\tNumber of thread pools workers.  Only valid for thread-pool server type.  Default is " << workerCount << endl;
-
-  map<string, string>  args;
-
-  for (int ix = 1; ix < argc; ix++) {
-    string arg(argv[ix]);
-    if (arg.compare(0,2, "--") == 0) {
-      size_t end = arg.find_first_of("=", 2);
-      if (end != string::npos) {
-	args[string(arg, 2, end - 2)] = string(arg, end + 1);
-      } else {
-	args[string(arg, 2)] = "true";
-      }
-    } else {
-      throw invalid_argument("Unexcepted command line token: "+arg);
-    }
+  if (vm.count("help")) {
+      cout << desc << "\n";
+      return 1;
   }
-
+  
   try {
-
-    if (!args["port"].empty()) {
-      port = atoi(args["port"].c_str());
-    }
-
-    if (!args["server-type"].empty()) {
-      serverType = args["server-type"];
-      if (serverType == "simple") {
-      } else if (serverType == "thread-pool") {
-      } else if (serverType == "threaded") {
-      } else if (serverType == "nonblocking") {
+    if (!server_type.empty()) {
+      if (server_type == "simple") {
+      } else if (server_type == "thread-pool") {
+      } else if (server_type == "threaded") {
+      } else if (server_type == "nonblocking") {
       } else {
-	throw invalid_argument("Unknown server type "+serverType);
+          throw invalid_argument("Unknown server type "+server_type);
+      }
+    }
+    
+    if (!protocol_type.empty()) {
+      if (protocol_type == "binary") {
+      } else if (protocol_type == "json") {
+      } else {
+          throw invalid_argument("Unknown protocol type "+protocol_type);
       }
     }
 
-    if (!args["protocol-type"].empty()) {
-      protocolType = args["protocol-type"];
-      if (protocolType == "binary") {
-      } else if (protocolType == "ascii") {
-	throw invalid_argument("ASCII protocol not supported");
-      } else if (protocolType == "xml") {
-	throw invalid_argument("XML protocol not supported");
+	if (!transport_type.empty()) {
+      if (transport_type == "buffered") {
+      } else if (transport_type == "framed") {
+      } else if (transport_type == "http") {
       } else {
-	throw invalid_argument("Unknown protocol type "+protocolType);
+          throw invalid_argument("Unknown transport type "+transport_type);
       }
     }
 
-    if (!args["workers"].empty()) {
-      workerCount = atoi(args["workers"].c_str());
-    }
   } catch (std::exception& e) {
     cerr << e.what() << endl;
-    cerr << usage;
+    cout << desc << "\n";
+    return 1;
   }
 
-  if (args["ssl"] == "true") {
+  if (vm.count("ssl")) {
     ssl = true;
     signal(SIGPIPE, SIG_IGN);
   }
 
   // Dispatcher
-  shared_ptr<TProtocolFactory> protocolFactory(
-      new TBinaryProtocolFactoryT<TBufferBase>());
+  shared_ptr<TProtocolFactory> protocolFactory;
+  if (protocol_type == "json") {
+    shared_ptr<TProtocolFactory> jsonProtocolFactory(new TJSONProtocolFactory());
+    protocolFactory = jsonProtocolFactory;
+  } else {
+    shared_ptr<TProtocolFactory> binaryProtocolFactory(new TBinaryProtocolFactoryT<TBufferBase>());
+    protocolFactory = binaryProtocolFactory;
+  }
 
+  // Processor
   shared_ptr<TestHandler> testHandler(new TestHandler());
-
-  shared_ptr<TProcessor> testProcessor(
-      new ThriftTestProcessorT< TBinaryProtocolT<TBufferBase> >(testHandler));
-
-
-  if (!args["processor-events"].empty()) {
+  shared_ptr<ThriftTestProcessor> testProcessor(new ThriftTestProcessor(testHandler));
+  
+  if (vm.count("processor-events")) {
     testProcessor->setEventHandler(shared_ptr<TProcessorEventHandler>(
           new TestProcessorEventHandler()));
   }
-
+  
   // Transport
   shared_ptr<TSSLSocketFactory> sslSocketFactory;
   shared_ptr<TServerSocket> serverSocket;
@@ -426,26 +434,51 @@
     sslSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
     serverSocket = shared_ptr<TServerSocket>(new TSSLServerSocket(port, sslSocketFactory));
   } else {
-    serverSocket = shared_ptr<TServerSocket>(new TServerSocket(port));
+	if (domain_socket != "") {
+	  unlink(domain_socket.c_str());
+	  serverSocket = shared_ptr<TServerSocket>(new TServerSocket(domain_socket));
+	  port = 0;
+	}
+	else {
+      serverSocket = shared_ptr<TServerSocket>(new TServerSocket(port));
+	}
   }
+
   // Factory
-  shared_ptr<TTransportFactory> transportFactory(new TBufferedTransportFactory());
+  shared_ptr<TTransportFactory> transportFactory;
+  
+  if (transport_type == "http") {
+    shared_ptr<TTransportFactory> httpTransportFactory(new THttpServerTransportFactory()); 
+    transportFactory = httpTransportFactory;
+  } else if (transport_type == "framed") {
+    shared_ptr<TTransportFactory> framedTransportFactory(new TFramedTransportFactory()); 
+    transportFactory = framedTransportFactory;
+  } else {
+    shared_ptr<TTransportFactory> bufferedTransportFactory(new TBufferedTransportFactory()); 
+    transportFactory = bufferedTransportFactory;
+  }
 
-  if (serverType == "simple") {
+  // Server Info
+  cout << "Starting \"" << server_type << "\" server ("
+    << transport_type << "/" << protocol_type << ") listen on: " << domain_socket;
+  if (port != 0) {
+    cout << port;
+  }
+  cout << endl;
 
-    // Server
+  // Server
+  if (server_type == "simple") {
     TSimpleServer simpleServer(testProcessor,
-			       serverSocket,
+                               serverSocket,
                                transportFactory,
                                protocolFactory);
 
-    printf("Starting the server on port %d...\n", port);
     simpleServer.serve();
 
-  } else if (serverType == "thread-pool") {
+  } else if (server_type == "thread-pool") {
 
     shared_ptr<ThreadManager> threadManager =
-      ThreadManager::newSimpleThreadManager(workerCount);
+      ThreadManager::newSimpleThreadManager(workers);
 
     shared_ptr<PosixThreadFactory> threadFactory =
       shared_ptr<PosixThreadFactory>(new PosixThreadFactory());
@@ -455,30 +488,27 @@
     threadManager->start();
 
     TThreadPoolServer threadPoolServer(testProcessor,
-				       serverSocket,
+                                       serverSocket,
                                        transportFactory,
                                        protocolFactory,
-				       threadManager);
+                                       threadManager);
 
-    printf("Starting the server on port %d...\n", port);
     threadPoolServer.serve();
 
-  } else if (serverType == "threaded") {
+  } else if (server_type == "threaded") {
 
     TThreadedServer threadedServer(testProcessor,
                                    serverSocket,
                                    transportFactory,
                                    protocolFactory);
 
-    printf("Starting the server on port %d...\n", port);
     threadedServer.serve();
 
-  } else if (serverType == "nonblocking") {
+  } else if (server_type == "nonblocking") {
     TNonblockingServer nonblockingServer(testProcessor, port);
-    printf("Starting the nonblocking server on port %d...\n", port);
     nonblockingServer.serve();
   }
 
-  printf("done.\n");
+  cout << "done." << endl;
   return 0;
 }