Add stage base class, refactor discovery, etc
diff --git a/wally/run_test.py b/wally/run_test.py
index 9ae2c9e..1a645b6 100755
--- a/wally/run_test.py
+++ b/wally/run_test.py
@@ -1,21 +1,18 @@
-import os
import logging
-import contextlib
-from typing import List, Dict, Iterable, Iterator, Tuple, Optional, Union, cast
-from concurrent.futures import ThreadPoolExecutor, Future
+from concurrent.futures import Future
+from typing import List, Dict, Tuple, Optional, Union, cast
-from .node_interfaces import NodeInfo, IRPCNode
-from .test_run_class import TestRun
-from .discover import discover
-from . import pretty_yaml, utils, report, ssh_utils, start_vms, hw_info
+from . import utils, ssh_utils, hw_info
+from .config import ConfigBlock
from .node import setup_rpc, connect
-from .config import ConfigBlock, Config
-
-from .suits.mysql import MysqlTest
-from .suits.itest import TestInputConfig
+from .node_interfaces import NodeInfo, IRPCNode
+from .stage import Stage, StepOrder
from .suits.io.fio import IOPerfTest
-from .suits.postgres import PgBenchTest
+from .suits.itest import TestInputConfig
+from .suits.mysql import MysqlTest
from .suits.omgbench import OmgTest
+from .suits.postgres import PgBenchTest
+from .test_run_class import TestRun
TOOL_TYPE_MAPPER = {
@@ -29,431 +26,149 @@
logger = logging.getLogger("wally")
-def connect_all(nodes_info: List[NodeInfo], pool: ThreadPoolExecutor, conn_timeout: int = 30) -> List[IRPCNode]:
- """Connect to all nodes, log errors"""
+class ConnectStage(Stage):
+ """Connect to nodes stage"""
- logger.info("Connecting to %s nodes", len(nodes_info))
+ priority = StepOrder.CONNECT
- def connect_ext(node_info: NodeInfo) -> Tuple[bool, Union[IRPCNode, NodeInfo]]:
- try:
- ssh_node = connect(node_info, conn_timeout=conn_timeout)
- # TODO(koder): need to pass all required rpc bytes to this call
- return True, setup_rpc(ssh_node, b"")
- except Exception as exc:
- logger.error("During connect to {}: {!s}".format(node, exc))
- return False, node_info
-
- failed_testnodes = [] # type: List[NodeInfo]
- failed_nodes = [] # type: List[NodeInfo]
- ready = [] # type: List[IRPCNode]
-
- for ok, node in pool.map(connect_ext, nodes_info):
- if not ok:
- node = cast(NodeInfo, node)
- if 'testnode' in node.roles:
- failed_testnodes.append(node)
- else:
- failed_nodes.append(node)
- else:
- ready.append(cast(IRPCNode, node))
-
- if failed_nodes:
- msg = "Node(s) {} would be excluded - can't connect"
- logger.warning(msg.format(",".join(map(str, failed_nodes))))
-
- if failed_testnodes:
- msg = "Can't connect to testnode(s) " + \
- ",".join(map(str, failed_testnodes))
- logger.error(msg)
- raise utils.StopTestError(msg)
-
- if not failed_nodes:
- logger.info("All nodes connected successfully")
-
- return ready
-
-
-def collect_info_stage(ctx: TestRun) -> None:
- futures = {} # type: Dict[str, Future]
-
- with ctx.get_pool() as pool:
- for node in ctx.nodes:
- hw_info_path = "hw_info/{}".format(node.info.node_id())
- if hw_info_path not in ctx.storage:
- futures[hw_info_path] = pool.submit(hw_info.get_hw_info, node), node
-
- sw_info_path = "sw_info/{}".format(node.info.node_id())
- if sw_info_path not in ctx.storage:
- futures[sw_info_path] = pool.submit(hw_info.get_sw_info, node)
-
- for path, future in futures.items():
- ctx.storage[path] = future.result()
-
-
-@contextlib.contextmanager
-def suspend_vm_nodes_ctx(ctx: TestRun, unused_nodes: List[IRPCNode]) -> Iterator[List[int]]:
-
- pausable_nodes_ids = [cast(int, node.info.os_vm_id)
- for node in unused_nodes
- if node.info.os_vm_id is not None]
-
- non_pausable = len(unused_nodes) - len(pausable_nodes_ids)
-
- if non_pausable:
- logger.warning("Can't pause {} nodes".format(non_pausable))
-
- if pausable_nodes_ids:
- logger.debug("Try to pause {} unused nodes".format(len(pausable_nodes_ids)))
+ def run(self, ctx: TestRun) -> None:
with ctx.get_pool() as pool:
- start_vms.pause(ctx.os_connection, pausable_nodes_ids, pool)
+ logger.info("Connecting to %s nodes", len(ctx.nodes_info))
- try:
- yield pausable_nodes_ids
- finally:
- if pausable_nodes_ids:
- logger.debug("Unpausing {} nodes".format(len(pausable_nodes_ids)))
- with ctx.get_pool() as pool:
- start_vms.unpause(ctx.os_connection, pausable_nodes_ids, pool)
+ def connect_ext(node_info: NodeInfo) -> Tuple[bool, Union[IRPCNode, NodeInfo]]:
+ try:
+ ssh_node = connect(node_info, conn_timeout=ctx.config.connect_timeout)
+ # TODO(koder): need to pass all required rpc bytes to this call
+ return True, setup_rpc(ssh_node, b"")
+ except Exception as exc:
+ logger.error("During connect to {}: {!s}".format(node, exc))
+ return False, node_info
+ failed_testnodes = [] # type: List[NodeInfo]
+ failed_nodes = [] # type: List[NodeInfo]
+ ctx.nodes = []
-def run_tests(ctx: TestRun, test_block: ConfigBlock, nodes: List[IRPCNode]) -> None:
- """Run test from test block"""
+ for ok, node in pool.map(connect_ext, ctx.nodes_info):
+ if not ok:
+ node = cast(NodeInfo, node)
+ if 'testnode' in node.roles:
+ failed_testnodes.append(node)
+ else:
+ failed_nodes.append(node)
+ else:
+ ctx.nodes.append(cast(IRPCNode, node))
- test_nodes = [node for node in nodes if 'testnode' in node.info.roles]
+ if failed_nodes:
+ msg = "Node(s) {} would be excluded - can't connect"
+ logger.warning(msg.format(",".join(map(str, failed_nodes))))
- if not test_nodes:
- logger.error("No test nodes found")
- return
-
- for name, params in test_block.items():
- vm_count = params.get('node_limit', None) # type: Optional[int]
-
- # select test nodes
- if vm_count is None:
- curr_test_nodes = test_nodes
- unused_nodes = [] # type: List[IRPCNode]
- else:
- curr_test_nodes = test_nodes[:vm_count]
- unused_nodes = test_nodes[vm_count:]
-
- if not curr_test_nodes:
- logger.error("No nodes found for test, skipping it.")
- continue
-
- # results_path = generate_result_dir_name(cfg.results_storage, name, params)
- # utils.mkdirs_if_unxists(results_path)
-
- # suspend all unused virtual nodes
- if ctx.config.get('suspend_unused_vms', True):
- suspend_ctx = suspend_vm_nodes_ctx(ctx, unused_nodes)
- else:
- suspend_ctx = utils.empty_ctx()
-
- resumable_nodes_ids = [cast(int, node.info.os_vm_id)
- for node in curr_test_nodes
- if node.info.os_vm_id is not None]
-
- if resumable_nodes_ids:
- logger.debug("Check and unpause {} nodes".format(len(resumable_nodes_ids)))
-
- with ctx.get_pool() as pool:
- start_vms.unpause(ctx.os_connection, resumable_nodes_ids, pool)
-
- with suspend_ctx:
- test_cls = TOOL_TYPE_MAPPER[name]
- remote_dir = ctx.config.default_test_local_folder.format(name=name, uuid=ctx.config.run_uuid)
- test_cfg = TestInputConfig(test_cls.__name__,
- params=params,
- run_uuid=ctx.config.run_uuid,
- nodes=test_nodes,
- storage=ctx.storage,
- remote_dir=remote_dir)
-
- test_cls(test_cfg).run()
-
-
-def connect_stage(ctx: TestRun) -> None:
- ctx.clear_calls_stack.append(disconnect_stage)
-
- with ctx.get_pool() as pool:
- ctx.nodes = connect_all(ctx.nodes_info, pool)
-
-
-def discover_stage(ctx: TestRun) -> None:
- """discover clusters and nodes stage"""
-
- # TODO(koder): Properly store discovery info and check if it available to skip phase
-
- discover_info = ctx.config.get('discover')
- if discover_info:
- if "discovered_nodes" in ctx.storage:
- nodes = ctx.storage.load_list("discovered_nodes", NodeInfo)
- ctx.fuel_openstack_creds = ctx.storage.load("fuel_openstack_creds", start_vms.OSCreds)
- else:
- discover_objs = [i.strip() for i in discover_info.strip().split(",")]
-
- ctx.fuel_openstack_creds, nodes = discover.discover(ctx,
- discover_objs,
- ctx.config.clouds,
- not ctx.config.dont_discover_nodes)
-
- ctx.storage["fuel_openstack_creds"] = ctx.fuel_openstack_creds # type: ignore
- ctx.storage["discovered_nodes"] = nodes # type: ignore
- ctx.nodes_info.extend(nodes)
-
- for url, roles in ctx.config.get('explicit_nodes', {}).items():
- creds = ssh_utils.parse_ssh_uri(url)
- roles = set(roles.split(","))
- ctx.nodes_info.append(NodeInfo(creds, roles))
-
-
-def save_nodes_stage(ctx: TestRun) -> None:
- """Save nodes list to file"""
- ctx.storage['nodes'] = ctx.nodes_info # type: ignore
-
-
-def ensure_connected_to_openstack(ctx: TestRun) -> None:
- if not ctx.os_connection is None:
- if ctx.os_creds is None:
- ctx.os_creds = get_OS_credentials(ctx)
- ctx.os_connection = start_vms.os_connect(ctx.os_creds)
-
-
-def reuse_vms_stage(ctx: TestRun) -> None:
- if "reused_nodes" in ctx.storage:
- ctx.nodes_info.extend(ctx.storage.load_list("reused_nodes", NodeInfo))
- else:
- reused_nodes = []
- vms_patterns = ctx.config.get('clouds/openstack/vms', [])
- private_key_path = get_vm_keypair_path(ctx.config)[0]
-
- for creds in vms_patterns:
- user_name, vm_name_pattern = creds.split("@", 1)
- msg = "Vm like {} lookup failed".format(vm_name_pattern)
-
- with utils.LogError(msg):
- msg = "Looking for vm with name like {0}".format(vm_name_pattern)
- logger.debug(msg)
-
- ensure_connected_to_openstack(ctx)
-
- for ip, vm_id in start_vms.find_vms(ctx.os_connection, vm_name_pattern):
- creds = ssh_utils.ConnCreds(host=ip, user=user_name, key_file=private_key_path)
- node_info = NodeInfo(creds, {'testnode'})
- node_info.os_vm_id = vm_id
- reused_nodes.append(node_info)
- ctx.nodes_info.append(node_info)
-
- ctx.storage["reused_nodes"] = reused_nodes # type: ignore
-
-
-def get_OS_credentials(ctx: TestRun) -> start_vms.OSCreds:
-
- if "openstack_openrc" in ctx.storage:
- return ctx.storage.load("openstack_openrc", start_vms.OSCreds)
-
- creds = None
- os_creds = None
- force_insecure = False
- cfg = ctx.config
-
- if 'openstack' in cfg.clouds:
- os_cfg = cfg.clouds['openstack']
- if 'OPENRC' in os_cfg:
- logger.info("Using OS credentials from " + os_cfg['OPENRC'])
- creds_tuple = utils.get_creds_openrc(os_cfg['OPENRC'])
- os_creds = start_vms.OSCreds(*creds_tuple)
- elif 'ENV' in os_cfg:
- logger.info("Using OS credentials from shell environment")
- os_creds = start_vms.get_openstack_credentials()
- elif 'OS_TENANT_NAME' in os_cfg:
- logger.info("Using predefined credentials")
- os_creds = start_vms.OSCreds(os_cfg['OS_USERNAME'].strip(),
- os_cfg['OS_PASSWORD'].strip(),
- os_cfg['OS_TENANT_NAME'].strip(),
- os_cfg['OS_AUTH_URL'].strip(),
- os_cfg.get('OS_INSECURE', False))
-
- elif 'OS_INSECURE' in os_cfg:
- force_insecure = os_cfg.get('OS_INSECURE', False)
-
- if os_creds is None and 'fuel' in cfg.clouds and 'openstack_env' in cfg.clouds['fuel'] and \
- ctx.fuel_openstack_creds is not None:
- logger.info("Using fuel creds")
- creds = ctx.fuel_openstack_creds
- elif os_creds is None:
- logger.error("Can't found OS credentials")
- raise utils.StopTestError("Can't found OS credentials", None)
-
- if creds is None:
- creds = os_creds
-
- if force_insecure and not creds.insecure:
- creds = start_vms.OSCreds(creds.name,
- creds.passwd,
- creds.tenant,
- creds.auth_url,
- True)
-
- logger.debug(("OS_CREDS: user={0.name} tenant={0.tenant} " +
- "auth_url={0.auth_url} insecure={0.insecure}").format(creds))
-
- ctx.storage["openstack_openrc"] = creds # type: ignore
- return creds
-
-
-def get_vm_keypair_path(cfg: Config) -> Tuple[str, str]:
- key_name = cfg.vm_configs['keypair_name']
- private_path = os.path.join(cfg.settings_dir, key_name + "_private.pem")
- public_path = os.path.join(cfg.settings_dir, key_name + "_public.pub")
- return (private_path, public_path)
-
-
-@contextlib.contextmanager
-def create_vms_ctx(ctx: TestRun, vm_config: ConfigBlock, already_has_count: int = 0) -> Iterator[List[NodeInfo]]:
-
- if 'spawned_vm_ids' in ctx.storage:
- os_nodes_ids = ctx.storage.get('spawned_vm_ids', []) # type: List[int]
- new_nodes = [] # type: List[NodeInfo]
-
- # TODO(koder): reconnect to old VM's
- raise NotImplementedError("Reconnect to old vms is not implemented")
- else:
- os_nodes_ids = []
- new_nodes = []
- no_spawn = False
- if vm_config['count'].startswith('='):
- count = int(vm_config['count'][1:])
- if count <= already_has_count:
- logger.debug("Not need new vms")
- no_spawn = True
-
- if not no_spawn:
- ensure_connected_to_openstack(ctx)
- params = ctx.config.vm_configs[vm_config['cfg_name']].copy()
- params.update(vm_config)
- params.update(get_vm_keypair_path(ctx.config))
- params['group_name'] = ctx.config.run_uuid
- params['keypair_name'] = ctx.config.vm_configs['keypair_name']
-
- if not vm_config.get('skip_preparation', False):
- logger.info("Preparing openstack")
- start_vms.prepare_os(ctx.os_connection, params)
-
- with ctx.get_pool() as pool:
- for node_info in start_vms.launch_vms(ctx.os_connection, params, pool, already_has_count):
- node_info.roles.add('testnode')
- os_nodes_ids.append(node_info.os_vm_id)
- new_nodes.append(node_info)
-
- ctx.storage['spawned_vm_ids'] = os_nodes_ids # type: ignore
- yield new_nodes
-
- # keep nodes in case of error for future test restart
- if not ctx.config.keep_vm:
- shut_down_vms_stage(ctx, os_nodes_ids)
-
- del ctx.storage['spawned_vm_ids']
-
-
-@contextlib.contextmanager
-def sensor_monitoring(ctx: TestRun, cfg: ConfigBlock, nodes: List[IRPCNode]) -> Iterator[None]:
- yield
-
-
-def run_tests_stage(ctx: TestRun) -> None:
- for group in ctx.config.get('tests', []):
- gitems = list(group.items())
- if len(gitems) != 1:
- msg = "Items in tests section should have len == 1"
- logger.error(msg)
- raise utils.StopTestError(msg)
-
- key, config = gitems[0]
-
- if 'start_test_nodes' == key:
- if 'openstack' not in config:
- msg = "No openstack block in config - can't spawn vm's"
+ if failed_testnodes:
+ msg = "Can't connect to testnode(s) " + \
+ ",".join(map(str, failed_testnodes))
logger.error(msg)
raise utils.StopTestError(msg)
- num_test_nodes = len([node for node in ctx.nodes if 'testnode' in node.info.roles])
- vm_ctx = create_vms_ctx(ctx, config['openstack'], num_test_nodes)
- tests = config.get('tests', [])
- else:
- vm_ctx = utils.empty_ctx([])
- tests = [group]
+ if not failed_nodes:
+ logger.info("All nodes connected successfully")
- # make mypy happy
- new_nodes = [] # type: List[NodeInfo]
+ def cleanup(self, ctx: TestRun) -> None:
+ # TODO(koder): what next line was for?
+ # ssh_utils.close_all_sessions()
- with vm_ctx as new_nodes:
- if new_nodes:
- with ctx.get_pool() as pool:
- new_rpc_nodes = connect_all(new_nodes, pool)
+ for node in ctx.nodes:
+ node.disconnect()
- test_nodes = ctx.nodes + new_rpc_nodes
- if ctx.config.get('sensors'):
- sensor_ctx = sensor_monitoring(ctx, ctx.config.get('sensors'), test_nodes)
- else:
- sensor_ctx = utils.empty_ctx([])
+class CollectInfoStage(Stage):
+ """Collect node info"""
+ priority = StepOrder.START_SENSORS - 1
+ config_block = 'collect_info'
+
+ def run(self, ctx: TestRun) -> None:
+ if not ctx.config.collect_info:
+ return
+
+ futures = {} # type: Dict[str, Future]
+
+ with ctx.get_pool() as pool:
+ for node in ctx.nodes:
+ hw_info_path = "hw_info/{}".format(node.info.node_id())
+ if hw_info_path not in ctx.storage:
+ futures[hw_info_path] = pool.submit(hw_info.get_hw_info, node), node
+
+ sw_info_path = "sw_info/{}".format(node.info.node_id())
+ if sw_info_path not in ctx.storage:
+ futures[sw_info_path] = pool.submit(hw_info.get_sw_info, node)
+
+ for path, future in futures.items():
+ ctx.storage[path] = future.result()
+
+
+class ExplicitNodesStage(Stage):
+ """add explicit nodes"""
+
+ priority = StepOrder.DISCOVER
+ config_block = 'nodes'
+
+ def run(self, ctx: TestRun) -> None:
+ explicit_nodes = []
+ for url, roles in ctx.config.get('explicit_nodes', {}).items():
+ creds = ssh_utils.parse_ssh_uri(url)
+ roles = set(roles.split(","))
+ explicit_nodes.append(NodeInfo(creds, roles))
+
+ ctx.nodes_info.extend(explicit_nodes)
+ ctx.storage['explicit_nodes'] = explicit_nodes # type: ignore
+
+
+class SaveNodesStage(Stage):
+ """Save nodes list to file"""
+
+ priority = StepOrder.CONNECT
+
+ def run(self, ctx: TestRun) -> None:
+ ctx.storage['all_nodes'] = ctx.nodes_info # type: ignore
+
+
+class RunTestsStage(Stage):
+
+ priority = StepOrder.TEST
+ config_block = 'tests'
+
+ def run(self, ctx: TestRun) -> None:
+ for test_group in ctx.config.get('tests', []):
if not ctx.config.no_tests:
- for test_group in tests:
- with sensor_ctx:
- run_tests(ctx, test_group, test_nodes)
+ test_nodes = [node for node in ctx.nodes if 'testnode' in node.info.roles]
- for node in new_rpc_nodes:
- node.disconnect()
+ if not test_nodes:
+ logger.error("No test nodes found")
+ return
+ for name, params in test_group.items():
+ vm_count = params.get('node_limit', None) # type: Optional[int]
-def clouds_connect_stage(ctx: TestRun) -> None:
- # TODO(koder): need to use this to connect to openstack in upper code
- # conn = ctx.config['clouds/openstack']
- # user, passwd, tenant = parse_creds(conn['creds'])
- # auth_data = dict(auth_url=conn['auth_url'],
- # username=user,
- # api_key=passwd,
- # project_id=tenant) # type: Dict[str, str]
- # logger.debug("Discovering openstack nodes with connection details: %r", conn)
- # connect to openstack, fuel
+ # select test nodes
+ if vm_count is None:
+ curr_test_nodes = test_nodes
+ else:
+ curr_test_nodes = test_nodes[:vm_count]
- # # parse FUEL REST credentials
- # username, tenant_name, password = parse_creds(fuel_data['creds'])
- # creds = {"username": username,
- # "tenant_name": tenant_name,
- # "password": password}
- #
- # # connect to FUEL
- # conn = fuel_rest_api.KeystoneAuth(fuel_data['url'], creds, headers=None)
- pass
+ if not curr_test_nodes:
+ logger.error("No nodes found for test, skipping it.")
+ continue
+ test_cls = TOOL_TYPE_MAPPER[name]
+ remote_dir = ctx.config.default_test_local_folder.format(name=name, uuid=ctx.config.run_uuid)
+ test_cfg = TestInputConfig(test_cls.__name__,
+ params=params,
+ run_uuid=ctx.config.run_uuid,
+ nodes=test_nodes,
+ storage=ctx.storage,
+ remote_dir=remote_dir)
-def shut_down_vms_stage(ctx: TestRun, nodes_ids: List[int]) -> None:
- if nodes_ids:
- logger.info("Removing nodes")
- start_vms.clear_nodes(ctx.os_connection, nodes_ids)
- logger.info("Nodes has been removed")
+ test_cls(test_cfg).run()
-
-def clear_enviroment(ctx: TestRun) -> None:
- shut_down_vms_stage(ctx, ctx.storage.get('spawned_vm_ids', []))
- ctx.storage['spawned_vm_ids'] = [] # type: ignore
-
-
-def disconnect_stage(ctx: TestRun) -> None:
- # TODO(koder): what next line was for?
- # ssh_utils.close_all_sessions()
-
- for node in ctx.nodes:
- node.disconnect()
-
-
-def console_report_stage(ctx: TestRun) -> None:
- # TODO(koder): load data from storage
- raise NotImplementedError("...")
-
-def html_report_stage(ctx: TestRun) -> None:
- # TODO(koder): load data from storage
- raise NotImplementedError("...")
+ @classmethod
+ def validate_config(cls, cfg: ConfigBlock) -> None:
+ pass