THRIFT-1198 C++ TestClient and Server Improvements (add Unix Domain Socket, HTTP, JSON)


git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1133116 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/test/cpp/src/TestServer.cpp b/test/cpp/src/TestServer.cpp
index 8401647..bb3cd43 100644
--- a/test/cpp/src/TestServer.cpp
+++ b/test/cpp/src/TestServer.cpp
@@ -20,6 +20,7 @@
 #include <concurrency/ThreadManager.h>
 #include <concurrency/PosixThreadFactory.h>
 #include <protocol/TBinaryProtocol.h>
+#include <protocol/TJSONProtocol.h>
 #include <server/TSimpleServer.h>
 #include <server/TThreadedServer.h>
 #include <server/TThreadPoolServer.h>
@@ -27,6 +28,8 @@
 #include <transport/TServerSocket.h>
 #include <transport/TSSLServerSocket.h>
 #include <transport/TSSLSocket.h>
+#include <transport/THttpServer.h>
+#include <transport/THttpTransport.h>
 #include <transport/TTransportUtils.h>
 #include "ThriftTest.h"
 
@@ -34,6 +37,8 @@
 #include <stdexcept>
 #include <sstream>
 
+#include <boost/program_options.hpp>
+
 #define __STDC_FORMAT_MACROS
 #include <inttypes.h>
 #include <signal.h>
@@ -324,97 +329,100 @@
 
 
 int main(int argc, char **argv) {
-
   int port = 9090;
-  string serverType = "simple";
-  string protocolType = "binary";
-  size_t workerCount = 4;
   bool ssl = false;
+  string transport_type = "buffered";
+  string protocol_type = "binary";
+  string server_type = "simple";
+  string domain_socket = "";
+  size_t workers = 4;
 
-  ostringstream usage;
+ 
+  program_options::options_description desc("Allowed options");
+  desc.add_options()
+      ("help,h", "produce help message")
+      ("port", program_options::value<int>(&port)->default_value(port), "Port number to listen")
+	  ("domain-socket", program_options::value<string>(&domain_socket)->default_value(domain_socket),
+	    "Unix Domain Socket (e.g. /tmp/ThriftTest.thrift)")
+      ("server-type", program_options::value<string>(&server_type)->default_value(server_type),
+        "type of server, \"simple\", \"thread-pool\", \"threaded\", or \"nonblocking\"")
+      ("transport", program_options::value<string>(&transport_type)->default_value(transport_type),
+        "transport: buffered, framed, http")
+      ("protocol", program_options::value<string>(&protocol_type)->default_value(protocol_type),
+        "protocol: binary, json")
+	  ("ssl", "Encrypted Transport using SSL")
+	  ("processor-events", "processor-events")
+      ("workers,n", program_options::value<size_t>(&workers)->default_value(workers),
+        "Number of thread pools workers. Only valid for thread-pool server type")
+  ;
 
-  usage <<
-    argv[0] << " [--port=<port number>] [--server-type=<server-type>] [--protocol-type=<protocol-type>] [--workers=<worker-count>] [--processor-events]" << endl <<
+  program_options::variables_map vm;
+  program_options::store(program_options::parse_command_line(argc, argv, desc), vm);
+  program_options::notify(vm);    
 
-    "\t\tserver-type\t\ttype of server, \"simple\", \"thread-pool\", \"threaded\", or \"nonblocking\".  Default is " << serverType << endl <<
-
-    "\t\tprotocol-type\t\ttype of protocol, \"binary\", \"ascii\", or \"xml\".  Default is " << protocolType << endl <<
-
-    "\t\tworkers\t\tNumber of thread pools workers.  Only valid for thread-pool server type.  Default is " << workerCount << endl;
-
-  map<string, string>  args;
-
-  for (int ix = 1; ix < argc; ix++) {
-    string arg(argv[ix]);
-    if (arg.compare(0,2, "--") == 0) {
-      size_t end = arg.find_first_of("=", 2);
-      if (end != string::npos) {
-	args[string(arg, 2, end - 2)] = string(arg, end + 1);
-      } else {
-	args[string(arg, 2)] = "true";
-      }
-    } else {
-      throw invalid_argument("Unexcepted command line token: "+arg);
-    }
+  if (vm.count("help")) {
+      cout << desc << "\n";
+      return 1;
   }
-
+  
   try {
-
-    if (!args["port"].empty()) {
-      port = atoi(args["port"].c_str());
-    }
-
-    if (!args["server-type"].empty()) {
-      serverType = args["server-type"];
-      if (serverType == "simple") {
-      } else if (serverType == "thread-pool") {
-      } else if (serverType == "threaded") {
-      } else if (serverType == "nonblocking") {
+    if (!server_type.empty()) {
+      if (server_type == "simple") {
+      } else if (server_type == "thread-pool") {
+      } else if (server_type == "threaded") {
+      } else if (server_type == "nonblocking") {
       } else {
-	throw invalid_argument("Unknown server type "+serverType);
+          throw invalid_argument("Unknown server type "+server_type);
+      }
+    }
+    
+    if (!protocol_type.empty()) {
+      if (protocol_type == "binary") {
+      } else if (protocol_type == "json") {
+      } else {
+          throw invalid_argument("Unknown protocol type "+protocol_type);
       }
     }
 
-    if (!args["protocol-type"].empty()) {
-      protocolType = args["protocol-type"];
-      if (protocolType == "binary") {
-      } else if (protocolType == "ascii") {
-	throw invalid_argument("ASCII protocol not supported");
-      } else if (protocolType == "xml") {
-	throw invalid_argument("XML protocol not supported");
+	if (!transport_type.empty()) {
+      if (transport_type == "buffered") {
+      } else if (transport_type == "framed") {
+      } else if (transport_type == "http") {
       } else {
-	throw invalid_argument("Unknown protocol type "+protocolType);
+          throw invalid_argument("Unknown transport type "+transport_type);
       }
     }
 
-    if (!args["workers"].empty()) {
-      workerCount = atoi(args["workers"].c_str());
-    }
   } catch (std::exception& e) {
     cerr << e.what() << endl;
-    cerr << usage;
+    cout << desc << "\n";
+    return 1;
   }
 
-  if (args["ssl"] == "true") {
+  if (vm.count("ssl")) {
     ssl = true;
     signal(SIGPIPE, SIG_IGN);
   }
 
   // Dispatcher
-  shared_ptr<TProtocolFactory> protocolFactory(
-      new TBinaryProtocolFactoryT<TBufferBase>());
+  shared_ptr<TProtocolFactory> protocolFactory;
+  if (protocol_type == "json") {
+    shared_ptr<TProtocolFactory> jsonProtocolFactory(new TJSONProtocolFactory());
+    protocolFactory = jsonProtocolFactory;
+  } else {
+    shared_ptr<TProtocolFactory> binaryProtocolFactory(new TBinaryProtocolFactoryT<TBufferBase>());
+    protocolFactory = binaryProtocolFactory;
+  }
 
+  // Processor
   shared_ptr<TestHandler> testHandler(new TestHandler());
-
-  shared_ptr<TProcessor> testProcessor(
-      new ThriftTestProcessorT< TBinaryProtocolT<TBufferBase> >(testHandler));
-
-
-  if (!args["processor-events"].empty()) {
+  shared_ptr<ThriftTestProcessor> testProcessor(new ThriftTestProcessor(testHandler));
+  
+  if (vm.count("processor-events")) {
     testProcessor->setEventHandler(shared_ptr<TProcessorEventHandler>(
           new TestProcessorEventHandler()));
   }
-
+  
   // Transport
   shared_ptr<TSSLSocketFactory> sslSocketFactory;
   shared_ptr<TServerSocket> serverSocket;
@@ -426,26 +434,51 @@
     sslSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
     serverSocket = shared_ptr<TServerSocket>(new TSSLServerSocket(port, sslSocketFactory));
   } else {
-    serverSocket = shared_ptr<TServerSocket>(new TServerSocket(port));
+	if (domain_socket != "") {
+	  unlink(domain_socket.c_str());
+	  serverSocket = shared_ptr<TServerSocket>(new TServerSocket(domain_socket));
+	  port = 0;
+	}
+	else {
+      serverSocket = shared_ptr<TServerSocket>(new TServerSocket(port));
+	}
   }
+
   // Factory
-  shared_ptr<TTransportFactory> transportFactory(new TBufferedTransportFactory());
+  shared_ptr<TTransportFactory> transportFactory;
+  
+  if (transport_type == "http") {
+    shared_ptr<TTransportFactory> httpTransportFactory(new THttpServerTransportFactory()); 
+    transportFactory = httpTransportFactory;
+  } else if (transport_type == "framed") {
+    shared_ptr<TTransportFactory> framedTransportFactory(new TFramedTransportFactory()); 
+    transportFactory = framedTransportFactory;
+  } else {
+    shared_ptr<TTransportFactory> bufferedTransportFactory(new TBufferedTransportFactory()); 
+    transportFactory = bufferedTransportFactory;
+  }
 
-  if (serverType == "simple") {
+  // Server Info
+  cout << "Starting \"" << server_type << "\" server ("
+    << transport_type << "/" << protocol_type << ") listen on: " << domain_socket;
+  if (port != 0) {
+    cout << port;
+  }
+  cout << endl;
 
-    // Server
+  // Server
+  if (server_type == "simple") {
     TSimpleServer simpleServer(testProcessor,
-			       serverSocket,
+                               serverSocket,
                                transportFactory,
                                protocolFactory);
 
-    printf("Starting the server on port %d...\n", port);
     simpleServer.serve();
 
-  } else if (serverType == "thread-pool") {
+  } else if (server_type == "thread-pool") {
 
     shared_ptr<ThreadManager> threadManager =
-      ThreadManager::newSimpleThreadManager(workerCount);
+      ThreadManager::newSimpleThreadManager(workers);
 
     shared_ptr<PosixThreadFactory> threadFactory =
       shared_ptr<PosixThreadFactory>(new PosixThreadFactory());
@@ -455,30 +488,27 @@
     threadManager->start();
 
     TThreadPoolServer threadPoolServer(testProcessor,
-				       serverSocket,
+                                       serverSocket,
                                        transportFactory,
                                        protocolFactory,
-				       threadManager);
+                                       threadManager);
 
-    printf("Starting the server on port %d...\n", port);
     threadPoolServer.serve();
 
-  } else if (serverType == "threaded") {
+  } else if (server_type == "threaded") {
 
     TThreadedServer threadedServer(testProcessor,
                                    serverSocket,
                                    transportFactory,
                                    protocolFactory);
 
-    printf("Starting the server on port %d...\n", port);
     threadedServer.serve();
 
-  } else if (serverType == "nonblocking") {
+  } else if (server_type == "nonblocking") {
     TNonblockingServer nonblockingServer(testProcessor, port);
-    printf("Starting the nonblocking server on port %d...\n", port);
     nonblockingServer.serve();
   }
 
-  printf("done.\n");
+  cout << "done." << endl;
   return 0;
 }