THRIFT-847 Test Framework harmonization across all languages
THRIFT-2946 Enhance usability of cross test framework
Patch: Nobuaki Sukegawa
This closes: #358
diff --git a/test/crossrunner/run.py b/test/crossrunner/run.py
new file mode 100644
index 0000000..e3300ba
--- /dev/null
+++ b/test/crossrunner/run.py
@@ -0,0 +1,317 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import contextlib
+import multiprocessing
+import multiprocessing.managers
+import os
+import platform
+import random
+import socket
+import signal
+import subprocess
+import threading
+import time
+import traceback
+
+from crossrunner.test import TestEntry, domain_socket_path
+from crossrunner.report import ExecReporter, SummaryReporter
+
+RESULT_TIMEOUT = 128
+RESULT_ERROR = 64
+
+
+class ExecutionContext(object):
+ def __init__(self, cmd, cwd, env, report):
+ self._log = multiprocessing.get_logger()
+ self.report = report
+ self.cmd = cmd
+ self.cwd = cwd
+ self.env = env
+ self.timer = None
+ self.expired = False
+
+ def _expire(self):
+ self._log.info('Timeout')
+ self.expired = True
+ self.kill()
+
+ def kill(self):
+ self._log.debug('Killing process : %d' % self.proc.pid)
+ if platform.system() != 'Windows':
+ try:
+ os.killpg(self.proc.pid, signal.SIGKILL)
+ except Exception as err:
+ self._log.info('Failed to kill process group : %s' % str(err))
+ try:
+ self.proc.kill()
+ except Exception as err:
+ self._log.info('Failed to kill process : %s' % str(err))
+ self.report.killed()
+
+ def _popen_args(self):
+ args = {
+ 'cwd': self.cwd,
+ 'env': self.env,
+ 'stdout': self.report.out,
+ 'stderr': subprocess.STDOUT,
+ }
+ # make sure child processes doesn't remain after killing
+ if platform.system() == 'Windows':
+ DETACHED_PROCESS = 0x00000008
+ args.update(creationflags=DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP)
+ else:
+ args.update(preexec_fn=os.setsid)
+ return args
+
+ def start(self, timeout=0):
+ self._log.debug('COMMAND: %s', ' '.join(self.cmd))
+ self._log.debug('WORKDIR: %s', self.cwd)
+ self._log.debug('LOGFILE: %s', self.report.logpath)
+ self.report.begin()
+ self.proc = subprocess.Popen(self.cmd, **self._popen_args())
+ if timeout > 0:
+ self.timer = threading.Timer(timeout, self._expire)
+ self.timer.start()
+ return self._scoped()
+
+ @contextlib.contextmanager
+ def _scoped(self):
+ yield self
+ self._log.debug('Killing scoped process')
+ self.kill()
+
+ def wait(self):
+ self.proc.communicate()
+ if self.timer:
+ self.timer.cancel()
+ self.report.end(self.returncode)
+
+ @property
+ def returncode(self):
+ return self.proc.returncode if self.proc else None
+
+
+def exec_context(port, testdir, test, prog):
+ report = ExecReporter(testdir, test, prog)
+ prog.build_command(port)
+ return ExecutionContext(prog.command, prog.workdir, prog.env, report)
+
+
+def run_test(testdir, test_dict, async=True, max_retry=3):
+ try:
+ logger = multiprocessing.get_logger()
+ retry_count = 0
+ test = TestEntry(testdir, **test_dict)
+ while True:
+ if stop.is_set():
+ logger.debug('Skipping because shutting down')
+ return None
+ logger.debug('Start')
+ with PortAllocator.alloc_port_scoped(ports, test.socket) as port:
+ logger.debug('Start with port %d' % port)
+ sv = exec_context(port, testdir, test, test.server)
+ cl = exec_context(port, testdir, test, test.client)
+
+ logger.debug('Starting server')
+ with sv.start():
+ if test.delay > 0:
+ logger.debug('Delaying client for %.2f seconds' % test.delay)
+ time.sleep(test.delay)
+ cl_retry_count = 0
+ cl_max_retry = 10
+ cl_retry_wait = 0.5
+ while True:
+ logger.debug('Starting client')
+ cl.start(test.timeout)
+ logger.debug('Waiting client')
+ cl.wait()
+ if not cl.report.maybe_false_positive() or cl_retry_count >= cl_max_retry:
+ if cl_retry_count > 0 and cl_retry_count < cl_max_retry:
+ logger.warn('[%s]: Connected after %d retry (%.2f sec each)' % (test.server.name, cl_retry_count, cl_retry_wait))
+ break
+ logger.debug('Server may not be ready, waiting %.2f second...' % cl_retry_wait)
+ time.sleep(cl_retry_wait)
+ cl_retry_count += 1
+
+ if not sv.report.maybe_false_positive() or retry_count >= max_retry:
+ logger.debug('Finish')
+ return RESULT_TIMEOUT if cl.expired else cl.proc.returncode
+ logger.warn('[%s]: Detected socket bind failure, retrying...' % test.server.name)
+ retry_count += 1
+ except (KeyboardInterrupt, SystemExit):
+ logger.info('Interrupted execution')
+ if not async:
+ raise
+ stop.set()
+ return None
+ except Exception as ex:
+ logger.warn('Error while executing test : %s' % str(ex))
+ if not async:
+ raise
+ logger.info(traceback.print_exc())
+ return RESULT_ERROR
+
+
+class PortAllocator(object):
+ def __init__(self):
+ self._log = multiprocessing.get_logger()
+ self._lock = multiprocessing.Lock()
+ self._ports = set()
+ self._dom_ports = set()
+ self._last_alloc = 0
+
+ def _get_tcp_port(self):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.bind(('127.0.0.1', 0))
+ port = sock.getsockname()[1]
+ self._lock.acquire()
+ try:
+ ok = port not in self._ports
+ if ok:
+ self._ports.add(port)
+ self._last_alloc = time.time()
+ finally:
+ self._lock.release()
+ sock.close()
+ return port if ok else self._get_tcp_port()
+
+ def _get_domain_port(self):
+ port = random.randint(1024, 65536)
+ self._lock.acquire()
+ try:
+ ok = port not in self._dom_ports
+ if ok:
+ self._dom_ports.add(port)
+ finally:
+ self._lock.release()
+ return port if ok else self._get_domain_port()
+
+ def alloc_port(self, socket_type):
+ if socket_type == 'domain':
+ return self._get_domain_port()
+ else:
+ return self._get_tcp_port()
+
+ # static method for inter-process invokation
+ @staticmethod
+ @contextlib.contextmanager
+ def alloc_port_scoped(allocator, socket_type):
+ port = allocator.alloc_port(socket_type)
+ yield port
+ allocator.free_port(socket_type, port)
+
+ def free_port(self, socket_type, port):
+ self._log.debug('free_port')
+ self._lock.acquire()
+ try:
+ if socket_type == 'domain':
+ self._dom_ports.remove(port)
+ path = domain_socket_path(port)
+ if os.path.exists(path):
+ os.remove(path)
+ else:
+ self._ports.remove(port)
+ except IOError as err:
+ self._log.info('Error while freeing port : %s' % str(err))
+ finally:
+ self._lock.release()
+
+
+class NonAsyncResult(object):
+ def __init__(self, value):
+ self._value = value
+
+ def get(self, timeout=None):
+ return self._value
+
+ def wait(self, timeout=None):
+ pass
+
+ def ready(self):
+ return True
+
+ def successful(self):
+ return self._value == 0
+
+
+class TestDispatcher(object):
+ def __init__(self, testdir, concurrency):
+ self._log = multiprocessing.get_logger()
+ self.testdir = testdir
+ # seems needed for python 2.x to handle keyboard interrupt
+ self._stop = multiprocessing.Event()
+ self._async = concurrency > 1
+ if not self._async:
+ self._pool = None
+ global stop
+ global ports
+ stop = self._stop
+ ports = PortAllocator()
+ else:
+ self._m = multiprocessing.managers.BaseManager()
+ self._m.register('ports', PortAllocator)
+ self._m.start()
+ self._pool = multiprocessing.Pool(concurrency, self._pool_init, (self._m.address,))
+ self._report = SummaryReporter(testdir, concurrency > 1)
+ self._log.debug(
+ 'TestDispatcher started with %d concurrent jobs' % concurrency)
+
+ def _pool_init(self, address):
+ global stop
+ global m
+ global ports
+ stop = self._stop
+ m = multiprocessing.managers.BaseManager(address)
+ m.connect()
+ ports = m.ports()
+
+ def _dispatch_sync(self, test, cont):
+ r = run_test(self.testdir, test, False)
+ cont(r)
+ return NonAsyncResult(r)
+
+ def _dispatch_async(self, test, cont):
+ return self._pool.apply_async(func=run_test, args=(self.testdir, test,), callback=cont)
+
+ def dispatch(self, test):
+ index = self._report.add_test(test)
+
+ def cont(r):
+ if not self._stop.is_set():
+ self._log.debug('freeing port')
+ self._log.debug('adding result')
+ self._report.add_result(index, r, r == RESULT_TIMEOUT)
+ self._log.debug('finish continuation')
+ fn = self._dispatch_async if self._async else self._dispatch_sync
+ return fn(test, cont)
+
+ def wait(self):
+ if self._async:
+ self._pool.close()
+ self._pool.join()
+ self._m.shutdown()
+ return self._report.end()
+
+ def terminate(self):
+ self._stop.set()
+ if self._async:
+ self._pool.terminate()
+ self._pool.join()
+ self._m.shutdown()