blob: 7e69524d72edf7a1598feccecf479df442667df3 [file] [log] [blame]
Mark Sleee02385b2007-06-09 01:21:16 +00001#include <concurrency/ThreadManager.h>
2#include <concurrency/PosixThreadFactory.h>
3#include <concurrency/Monitor.h>
4#include <concurrency/Util.h>
5#include <concurrency/Mutex.h>
6#include <protocol/TBinaryProtocol.h>
7#include <server/TSimpleServer.h>
8#include <server/TThreadPoolServer.h>
9#include <server/TThreadedServer.h>
10#include <server/TNonblockingServer.h>
11#include <transport/TServerSocket.h>
12#include <transport/TSocket.h>
13#include <transport/TTransportUtils.h>
14#include <transport/TFileTransport.h>
15#include <TLogging.h>
16
17#include "Service.h"
18
19#include <boost/shared_ptr.hpp>
20
21#include <iostream>
22#include <set>
23#include <stdexcept>
24#include <sstream>
25
26#include <map>
27#include <ext/hash_map>
28using __gnu_cxx::hash_map;
29using __gnu_cxx::hash;
30
31using namespace std;
32using namespace boost;
33
34using namespace facebook::thrift;
35using namespace facebook::thrift::protocol;
36using namespace facebook::thrift::transport;
37using namespace facebook::thrift::server;
38using namespace facebook::thrift::concurrency;
39
40using namespace test::stress;
41
42struct eqstr {
43 bool operator()(const char* s1, const char* s2) const {
44 return strcmp(s1, s2) == 0;
45 }
46};
47
48struct ltstr {
49 bool operator()(const char* s1, const char* s2) const {
50 return strcmp(s1, s2) < 0;
51 }
52};
53
54
55// typedef hash_map<const char*, int, hash<const char*>, eqstr> count_map;
56typedef map<const char*, int, ltstr> count_map;
57
58class Server : public ServiceIf {
59 public:
60 Server() {}
61
62 void count(const char* method) {
63 MutexMonitor m(lock_);
64 int ct = counts_[method];
65 counts_[method] = ++ct;
66 }
67
68 void echoVoid() {
69 count("echoVoid");
70
71 //Sleep to simulate work
72 struct timeval time_struct;
73 time_struct.tv_sec = 0;
74 time_struct.tv_usec = 5000;
75
76 select( (int) NULL, (fd_set *)NULL, (fd_set *)NULL,(fd_set *)NULL, &time_struct );
77
78
79
80 return;
81 }
82
83 count_map getCount() {
84 MutexMonitor m(lock_);
85 return counts_;
86 }
87
88 int8_t echoByte(const int8_t arg) {return arg;}
89 int32_t echoI32(const int32_t arg) {return arg;}
90 int64_t echoI64(const int64_t arg) {return arg;}
91 void echoString(string& out, const string &arg) {
92 if (arg != "hello") {
93 T_ERROR_ABORT("WRONG STRING!!!!");
94 }
95 out = arg;
96 }
97 void echoList(vector<int8_t> &out, const vector<int8_t> &arg) { out = arg; }
98 void echoSet(set<int8_t> &out, const set<int8_t> &arg) { out = arg; }
99 void echoMap(map<int8_t, int8_t> &out, const map<int8_t, int8_t> &arg) { out = arg; }
100
101private:
102 count_map counts_;
103 Mutex lock_;
104
105};
106
107class ClientThread: public Runnable {
108public:
109
110 ClientThread(shared_ptr<TTransport>transport, shared_ptr<ServiceClient> client, Monitor& monitor, size_t& workerCount, size_t loopCount, TType loopType) :
111 _transport(transport),
112 _client(client),
113 _monitor(monitor),
114 _workerCount(workerCount),
115 _loopCount(loopCount),
116 _loopType(loopType)
117 {}
118
119 void run() {
120
121 // Wait for all worker threads to start
122
123 {Synchronized s(_monitor);
124 while(_workerCount == 0) {
125 _monitor.wait();
126 }
127 }
128
129 _startTime = Util::currentTime();
130
131 _transport->open();
132
133 switch(_loopType) {
134 case T_VOID: loopEchoVoid(); break;
135 case T_BYTE: loopEchoByte(); break;
136 case T_I32: loopEchoI32(); break;
137 case T_I64: loopEchoI64(); break;
138 case T_STRING: loopEchoString(); break;
139 default: cerr << "Unexpected loop type" << _loopType << endl; break;
140 }
141
142 _endTime = Util::currentTime();
143
144 _transport->close();
145
146 _done = true;
147
148 {Synchronized s(_monitor);
149
150 _workerCount--;
151
152 if(_workerCount == 0) {
153
154 _monitor.notify();
155 }
156 }
157 }
158
159 void loopEchoVoid() {
160 for(size_t ix = 0; ix < _loopCount; ix++) {
161 _client->echoVoid();
162 }
163 }
164
165 void loopEchoByte() {
166 for(size_t ix = 0; ix < _loopCount; ix++) {
167 int8_t arg = 1;
168 int8_t result;
169 result =_client->echoByte(arg);
170 assert(result == arg);
171 }
172 }
173
174 void loopEchoI32() {
175 for(size_t ix = 0; ix < _loopCount; ix++) {
176 int32_t arg = 1;
177 int32_t result;
178 result =_client->echoI32(arg);
179 assert(result == arg);
180 }
181 }
182
183 void loopEchoI64() {
184 for(size_t ix = 0; ix < _loopCount; ix++) {
185 int64_t arg = 1;
186 int64_t result;
187 result =_client->echoI64(arg);
188 assert(result == arg);
189 }
190 }
191
192 void loopEchoString() {
193 for(size_t ix = 0; ix < _loopCount; ix++) {
194 string arg = "hello";
195 string result;
196 _client->echoString(result, arg);
197 assert(result == arg);
198 }
199 }
200
201 shared_ptr<TTransport> _transport;
202 shared_ptr<ServiceClient> _client;
203 Monitor& _monitor;
204 size_t& _workerCount;
205 size_t _loopCount;
206 TType _loopType;
207 long long _startTime;
208 long long _endTime;
209 bool _done;
210 Monitor _sleep;
211};
212
213
214int main(int argc, char **argv) {
215
216 int port = 9091;
217 string serverType = "simple";
218 string protocolType = "binary";
219 size_t workerCount = 4;
220 size_t clientCount = 20;
221 size_t loopCount = 50000;
222 TType loopType = T_VOID;
223 string callName = "echoVoid";
224 bool runServer = true;
225 bool logRequests = false;
226 string requestLogPath = "./requestlog.tlog";
227 bool replayRequests = false;
228
229 ostringstream usage;
230
231 usage <<
232 argv[0] << " [--port=<port number>] [--server] [--server-type=<server-type>] [--protocol-type=<protocol-type>] [--workers=<worker-count>] [--clients=<client-count>] [--loop=<loop-count>]" << endl <<
233 "\tclients Number of client threads to create - 0 implies no clients, i.e. server only. Default is " << clientCount << endl <<
234 "\thelp Prints this help text." << endl <<
235 "\tcall Service method to call. Default is " << callName << endl <<
236 "\tloop The number of remote thrift calls each client makes. Default is " << loopCount << endl <<
237 "\tport The port the server and clients should bind to for thrift network connections. Default is " << port << endl <<
238 "\tserver Run the Thrift server in this process. Default is " << runServer << endl <<
239 "\tserver-type Type of server, \"simple\" or \"thread-pool\". Default is " << serverType << endl <<
240 "\tprotocol-type Type of protocol, \"binary\", \"ascii\", or \"xml\". Default is " << protocolType << endl <<
241 "\tlog-request Log all request to ./requestlog.tlog. Default is " << logRequests << endl <<
242 "\treplay-request Replay requests from log file (./requestlog.tlog) Default is " << replayRequests << endl <<
243 "\tworkers Number of thread pools workers. Only valid for thread-pool server type. Default is " << workerCount << endl;
244
245
246 map<string, string> args;
247
248 for(int ix = 1; ix < argc; ix++) {
249
250 string arg(argv[ix]);
251
252 if(arg.compare(0,2, "--") == 0) {
253
254 size_t end = arg.find_first_of("=", 2);
255
256 string key = string(arg, 2, end - 2);
257
258 if(end != string::npos) {
259 args[key] = string(arg, end + 1);
260 } else {
261 args[key] = "true";
262 }
263 } else {
264 throw invalid_argument("Unexcepted command line token: "+arg);
265 }
266 }
267
268 try {
269
270 if(!args["clients"].empty()) {
271 clientCount = atoi(args["clients"].c_str());
272 }
273
274 if(!args["help"].empty()) {
275 cerr << usage.str();
276 return 0;
277 }
278
279 if(!args["loop"].empty()) {
280 loopCount = atoi(args["loop"].c_str());
281 }
282
283 if(!args["call"].empty()) {
284 callName = args["call"];
285 }
286
287 if(!args["port"].empty()) {
288 port = atoi(args["port"].c_str());
289 }
290
291 if(!args["server"].empty()) {
292 runServer = args["server"] == "true";
293 }
294
295 if(!args["log-request"].empty()) {
296 logRequests = args["log-request"] == "true";
297 }
298
299 if(!args["replay-request"].empty()) {
300 replayRequests = args["replay-request"] == "true";
301 }
302
303 if(!args["server-type"].empty()) {
304 serverType = args["server-type"];
305 }
306
307 if(!args["workers"].empty()) {
308 workerCount = atoi(args["workers"].c_str());
309 }
310
311 } catch(exception& e) {
312 cerr << e.what() << endl;
313 cerr << usage;
314 }
315
316 shared_ptr<PosixThreadFactory> threadFactory = shared_ptr<PosixThreadFactory>(new PosixThreadFactory());
317
318 // Dispatcher
319 shared_ptr<Server> serviceHandler(new Server());
320
321 if (replayRequests) {
322 shared_ptr<Server> serviceHandler(new Server());
323 shared_ptr<ServiceProcessor> serviceProcessor(new ServiceProcessor(serviceHandler));
324
325 // Transports
326 shared_ptr<TFileTransport> fileTransport(new TFileTransport(requestLogPath));
327 fileTransport->setChunkSize(2 * 1024 * 1024);
328 fileTransport->setMaxEventSize(1024 * 16);
329 fileTransport->seekToEnd();
330
331 // Protocol Factory
332 shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
333
334 TFileProcessor fileProcessor(serviceProcessor,
335 protocolFactory,
336 fileTransport);
337
338 fileProcessor.process(0, true);
339 exit(0);
340 }
341
342
343 if(runServer) {
344
345 shared_ptr<ServiceProcessor> serviceProcessor(new ServiceProcessor(serviceHandler));
346
347 // Protocol Factory
348 shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
349
350 // Transport Factory
351 shared_ptr<TTransportFactory> transportFactory;
352
353 if (logRequests) {
354 // initialize the log file
355 shared_ptr<TFileTransport> fileTransport(new TFileTransport(requestLogPath));
356 fileTransport->setChunkSize(2 * 1024 * 1024);
357 fileTransport->setMaxEventSize(1024 * 16);
358
359 transportFactory =
360 shared_ptr<TTransportFactory>(new TPipedTransportFactory(fileTransport));
361 }
362
363 shared_ptr<Thread> serverThread;
364
365 if(serverType == "simple") {
366
367 serverThread = threadFactory->newThread(shared_ptr<TServer>(new TNonblockingServer(serviceProcessor, protocolFactory,port)));
368
369 } else if(serverType == "thread-pool") {
370
371 shared_ptr<ThreadManager> threadManager = ThreadManager::newSimpleThreadManager(workerCount);
372
373 threadManager->threadFactory(threadFactory);
374 threadManager->start();
375 serverThread = threadFactory->newThread(shared_ptr<TServer>(new TNonblockingServer(serviceProcessor, protocolFactory, port, threadManager)));
376 }
377
378 cerr << "Starting the server on port " << port << endl;
379 serverThread->start();
380
381 // If we aren't running clients, just wait forever for external clients
382
383 if (clientCount == 0) {
384 serverThread->join();
385 }
386 }
387
388 if (clientCount > 0) {
389
390 Monitor monitor;
391
392 size_t threadCount = 0;
393
394 set<shared_ptr<Thread> > clientThreads;
395
396 if(callName == "echoVoid") { loopType = T_VOID;}
397 else if(callName == "echoByte") { loopType = T_BYTE;}
398 else if(callName == "echoI32") { loopType = T_I32;}
399 else if(callName == "echoI64") { loopType = T_I64;}
400 else if(callName == "echoString") { loopType = T_STRING;}
401 else {throw invalid_argument("Unknown service call "+callName);}
402
403 for(size_t ix = 0; ix < clientCount; ix++) {
404
405 shared_ptr<TSocket> socket(new TSocket("127.0.0.1", port));
406 shared_ptr<TFramedTransport> framedSocket(new TFramedTransport(socket));
407 shared_ptr<TProtocol> protocol(new TBinaryProtocol(framedSocket));
408 shared_ptr<ServiceClient> serviceClient(new ServiceClient(protocol));
409
410 clientThreads.insert(threadFactory->newThread(shared_ptr<ClientThread>(new ClientThread(socket, serviceClient, monitor, threadCount, loopCount, loopType))));
411 }
412
413 for(std::set<shared_ptr<Thread> >::const_iterator thread = clientThreads.begin(); thread != clientThreads.end(); thread++) {
414 (*thread)->start();
415 }
416
417 long long time00;
418 long long time01;
419
420 {Synchronized s(monitor);
421 threadCount = clientCount;
422
423 cerr << "Launch "<< clientCount << " client threads" << endl;
424
425 time00 = Util::currentTime();
426
427 monitor.notifyAll();
428
429 while(threadCount > 0) {
430 monitor.wait();
431 }
432
433 time01 = Util::currentTime();
434 }
435
436 long long firstTime = 9223372036854775807LL;
437 long long lastTime = 0;
438
439 double averageTime = 0;
440 long long minTime = 9223372036854775807LL;
441 long long maxTime = 0;
442
443 for(set<shared_ptr<Thread> >::iterator ix = clientThreads.begin(); ix != clientThreads.end(); ix++) {
444
445 shared_ptr<ClientThread> client = dynamic_pointer_cast<ClientThread>((*ix)->runnable());
446
447 long long delta = client->_endTime - client->_startTime;
448
449 assert(delta > 0);
450
451 if(client->_startTime < firstTime) {
452 firstTime = client->_startTime;
453 }
454
455 if(client->_endTime > lastTime) {
456 lastTime = client->_endTime;
457 }
458
459 if(delta < minTime) {
460 minTime = delta;
461 }
462
463 if(delta > maxTime) {
464 maxTime = delta;
465 }
466
467 averageTime+= delta;
468 }
469
470 averageTime /= clientCount;
471
472
473 cout << "workers :" << workerCount << ", client : " << clientCount << ", loops : " << loopCount << ", rate : " << (clientCount * loopCount * 1000) / ((double)(time01 - time00)) << endl;
474
475 count_map count = serviceHandler->getCount();
476 count_map::iterator iter;
477 for (iter = count.begin(); iter != count.end(); ++iter) {
478 printf("%s => %d\n", iter->first, iter->second);
479 }
480 cerr << "done." << endl;
481 }
482
483 return 0;
484}