blob: abbd70b5e3842983c53c3dbadba284da8a4a8603 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import contextlib
import multiprocessing
import multiprocessing.managers
import os
import platform
import random
import signal
import socket
import subprocess
import sys
import threading
import time
from .compat import str_join
from .test import TestEntry, domain_socket_path
from .report import ExecReporter, SummaryReporter
RESULT_TIMEOUT = 128
RESULT_ERROR = 64
class ExecutionContext(object):
def __init__(self, cmd, cwd, env, report):
self._log = multiprocessing.get_logger()
self.report = report
self.cmd = cmd
self.cwd = cwd
self.env = env
self.timer = None
self.expired = False
def _expire(self):
self._log.info('Timeout')
self.expired = True
self.kill()
def kill(self):
self._log.debug('Killing process : %d' % self.proc.pid)
if platform.system() != 'Windows':
try:
os.killpg(self.proc.pid, signal.SIGKILL)
except Exception as err:
self._log.info('Failed to kill process group : %s' % str(err))
try:
self.proc.kill()
except Exception as err:
self._log.info('Failed to kill process : %s' % str(err))
self.report.killed()
def _popen_args(self):
args = {
'cwd': self.cwd,
'env': self.env,
'stdout': self.report.out,
'stderr': subprocess.STDOUT,
}
# make sure child processes doesn't remain after killing
if platform.system() == 'Windows':
DETACHED_PROCESS = 0x00000008
args.update(creationflags=DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP)
else:
args.update(preexec_fn=os.setsid)
return args
def start(self, timeout=0):
joined = str_join(' ', self.cmd)
self._log.debug('COMMAND: %s', joined)
self._log.debug('WORKDIR: %s', self.cwd)
self._log.debug('LOGFILE: %s', self.report.logpath)
self.report.begin()
self.proc = subprocess.Popen(self.cmd, **self._popen_args())
if timeout > 0:
self.timer = threading.Timer(timeout, self._expire)
self.timer.start()
return self._scoped()
@contextlib.contextmanager
def _scoped(self):
yield self
self._log.debug('Killing scoped process')
self.kill()
def wait(self):
self.proc.communicate()
if self.timer:
self.timer.cancel()
self.report.end(self.returncode)
@property
def returncode(self):
return self.proc.returncode if self.proc else None
def exec_context(port, logdir, test, prog):
report = ExecReporter(logdir, test, prog)
prog.build_command(port)
return ExecutionContext(prog.command, prog.workdir, prog.env, report)
def run_test(testdir, logdir, test_dict, async=True, max_retry=3):
try:
logger = multiprocessing.get_logger()
retry_count = 0
test = TestEntry(testdir, **test_dict)
while True:
if stop.is_set():
logger.debug('Skipping because shutting down')
return None
logger.debug('Start')
with PortAllocator.alloc_port_scoped(ports, test.socket) as port:
logger.debug('Start with port %d' % port)
sv = exec_context(port, logdir, test, test.server)
cl = exec_context(port, logdir, test, test.client)
logger.debug('Starting server')
with sv.start():
if test.delay > 0:
logger.debug('Delaying client for %.2f seconds' % test.delay)
time.sleep(test.delay)
cl_retry_count = 0
cl_max_retry = 10
cl_retry_wait = 0.5
while True:
logger.debug('Starting client')
cl.start(test.timeout)
logger.debug('Waiting client')
cl.wait()
if not cl.report.maybe_false_positive() or cl_retry_count >= cl_max_retry:
if cl_retry_count > 0 and cl_retry_count < cl_max_retry:
logger.warn('[%s]: Connected after %d retry (%.2f sec each)' % (test.server.name, cl_retry_count, cl_retry_wait))
break
logger.debug('Server may not be ready, waiting %.2f second...' % cl_retry_wait)
time.sleep(cl_retry_wait)
cl_retry_count += 1
if not sv.report.maybe_false_positive() or retry_count >= max_retry:
logger.debug('Finish')
return RESULT_TIMEOUT if cl.expired else cl.proc.returncode
logger.warn('[%s]: Detected socket bind failure, retrying...' % test.server.name)
retry_count += 1
except (KeyboardInterrupt, SystemExit):
logger.info('Interrupted execution')
if not async:
raise
stop.set()
return None
except Exception as ex:
logger.warn('%s', ex)
if not async:
raise
logger.debug('Error executing [%s]', test.name, exc_info=sys.exc_info())
return RESULT_ERROR
class PortAllocator(object):
def __init__(self):
self._log = multiprocessing.get_logger()
self._lock = multiprocessing.Lock()
self._ports = set()
self._dom_ports = set()
self._last_alloc = 0
def _get_tcp_port(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
self._lock.acquire()
try:
ok = port not in self._ports
if ok:
self._ports.add(port)
self._last_alloc = time.time()
finally:
self._lock.release()
sock.close()
return port if ok else self._get_tcp_port()
def _get_domain_port(self):
port = random.randint(1024, 65536)
self._lock.acquire()
try:
ok = port not in self._dom_ports
if ok:
self._dom_ports.add(port)
finally:
self._lock.release()
return port if ok else self._get_domain_port()
def alloc_port(self, socket_type):
if socket_type in ('domain', 'abstract'):
return self._get_domain_port()
else:
return self._get_tcp_port()
# static method for inter-process invokation
@staticmethod
@contextlib.contextmanager
def alloc_port_scoped(allocator, socket_type):
port = allocator.alloc_port(socket_type)
yield port
allocator.free_port(socket_type, port)
def free_port(self, socket_type, port):
self._log.debug('free_port')
self._lock.acquire()
try:
if socket_type == 'domain':
self._dom_ports.remove(port)
path = domain_socket_path(port)
if os.path.exists(path):
os.remove(path)
elif socket_type == 'abstract':
self._dom_ports.remove(port)
else:
self._ports.remove(port)
except IOError as err:
self._log.info('Error while freeing port : %s' % str(err))
finally:
self._lock.release()
class NonAsyncResult(object):
def __init__(self, value):
self._value = value
def get(self, timeout=None):
return self._value
def wait(self, timeout=None):
pass
def ready(self):
return True
def successful(self):
return self._value == 0
class TestDispatcher(object):
def __init__(self, testdir, logdir, concurrency):
self._log = multiprocessing.get_logger()
self.testdir = testdir
self.logdir = logdir
# seems needed for python 2.x to handle keyboard interrupt
self._stop = multiprocessing.Event()
self._async = concurrency > 1
if not self._async:
self._pool = None
global stop
global ports
stop = self._stop
ports = PortAllocator()
else:
self._m = multiprocessing.managers.BaseManager()
self._m.register('ports', PortAllocator)
self._m.start()
self._pool = multiprocessing.Pool(concurrency, self._pool_init, (self._m.address,))
self._report = SummaryReporter(logdir, concurrency > 1)
self._log.debug(
'TestDispatcher started with %d concurrent jobs' % concurrency)
def _pool_init(self, address):
global stop
global m
global ports
stop = self._stop
m = multiprocessing.managers.BaseManager(address)
m.connect()
ports = m.ports()
def _dispatch_sync(self, test, cont):
r = run_test(self.testdir, self.logdir, test, False)
cont(r)
return NonAsyncResult(r)
def _dispatch_async(self, test, cont):
self._log.debug('_dispatch_async')
return self._pool.apply_async(func=run_test, args=(self.testdir, self.logdir, test,), callback=cont)
def dispatch(self, test):
index = self._report.add_test(test)
def cont(r):
if not self._stop.is_set():
self._log.debug('freeing port')
self._log.debug('adding result')
self._report.add_result(index, r, r == RESULT_TIMEOUT)
self._log.debug('finish continuation')
fn = self._dispatch_async if self._async else self._dispatch_sync
return fn(test, cont)
def wait(self):
if self._async:
self._pool.close()
self._pool.join()
self._m.shutdown()
return self._report.end()
def terminate(self):
self._stop.set()
if self._async:
self._pool.terminate()
self._pool.join()
self._m.shutdown()