typing and refactoring on the way
diff --git a/wally/run_test.py b/wally/run_test.py
index 4390f5e..b96ecf8 100755
--- a/wally/run_test.py
+++ b/wally/run_test.py
@@ -1,28 +1,16 @@
import os
-import re
import time
import logging
import functools
import contextlib
import collections
-from typing import List, Dict, Optional, Iterable, Any, Generator, Mapping, Callable
-from yaml import load as _yaml_load
+from typing import List, Dict, Iterable, Any, Iterator, Mapping, Callable, Tuple, Optional
+from concurrent.futures import ThreadPoolExecutor, Future
-try:
- from yaml import CLoader
- yaml_load = functools.partial(_yaml_load, Loader=CLoader)
-except ImportError:
- yaml_load = _yaml_load
-
-from concurrent.futures import ThreadPoolExecutor, wait
-
-from .config import Config
-from .config import get_test_files
-from .discover import discover, Node
from .inode import INode
+from .discover import discover
from .test_run_class import TestRun
-
-from . import pretty_yaml, utils, report, ssh_utils, start_vms
+from . import pretty_yaml, utils, report, ssh_utils, start_vms, hw_info
from .suits.mysql import MysqlTest
from .suits.itest import TestConfig
@@ -42,27 +30,27 @@
logger = logging.getLogger("wally")
-def connect_all(nodes: Iterable[INode], spawned_node: Optional[bool]=False) -> None:
+
+def connect_all(nodes: Iterable[INode],
+ pool: ThreadPoolExecutor,
+ conn_timeout: int = 30,
+ rpc_conn_callback: ssh_utils.RPCBeforeConnCallback = None) -> None:
"""Connect to all nodes, log errors
- nodes:[Node] - list of nodes
- spawned_node:bool - whenever nodes is newly spawned VM
+ nodes - list of nodes
"""
- logger.info("Connecting to nodes")
-
- conn_timeout = 240 if spawned_node else 30
+ logger.info("Connecting to %s nodes", len(nodes))
def connect_ext(node: INode) -> bool:
try:
node.connect_ssh(conn_timeout)
- node.connect_rpc(conn_timeout)
+ node.rpc, node.rpc_params = ssh_utils.setup_rpc(node, rpc_conn_callback=rpc_conn_callback)
return True
except Exception as exc:
logger.error("During connect to {}: {!s}".format(node, exc))
return False
- with ThreadPoolExecutor(32) as pool:
- list(pool.map(connect_ext, nodes))
+ list(pool.map(connect_ext, nodes))
failed_testnodes = []
failed_nodes = []
@@ -88,39 +76,25 @@
logger.info("All nodes connected successfully")
-def collect_hw_info_stage(cfg: Config, nodes: Iterable[INode]) -> None:
- # TODO(koder): rewrite this function, to use other storage type
- if os.path.exists(cfg.hwreport_fname):
- msg = "{0} already exists. Skip hw info"
- logger.info(msg.format(cfg['hwreport_fname']))
- return
+def collect_info_stage(ctx: TestRun, nodes: Iterable[INode]) -> None:
+ futures = {} # type: Dict[str, Future]
- with ThreadPoolExecutor(32) as pool:
- fitures = pool.submit(node.discover_hardware_info
- for node in nodes)
- wait(fitures)
-
- with open(cfg.hwreport_fname, 'w') as hwfd:
+ with ctx.get_pool() as pool:
for node in nodes:
- hwfd.write("-" * 60 + "\n")
- hwfd.write("Roles : " + ", ".join(node.roles) + "\n")
- hwfd.write(str(node.hwinfo) + "\n")
- hwfd.write("-" * 60 + "\n\n")
+ hw_info_path = "hw_info/{}".format(node.node_id())
+ if hw_info_path not in ctx.storage:
+ futures[hw_info_path] = pool.submit(hw_info.get_hw_info, node)
- if node.hwinfo.hostname is not None:
- fname = os.path.join(
- cfg.hwinfo_directory,
- node.hwinfo.hostname + "_lshw.xml")
+ sw_info_path = "sw_info/{}".format(node.node_id())
+ if sw_info_path not in ctx.storage:
+ futures[sw_info_path] = pool.submit(hw_info.get_sw_info, node)
- with open(fname, "w") as fd:
- fd.write(node.hwinfo.raw)
-
- logger.info("Hardware report stored in " + cfg.hwreport_fname)
- logger.debug("Raw hardware info in " + cfg.hwinfo_directory + " folder")
+ for path, future in futures.items():
+ ctx.storage[path] = future.result()
@contextlib.contextmanager
-def suspend_vm_nodes_ctx(unused_nodes: Iterable[INode]) -> Generator[List[int]]:
+def suspend_vm_nodes_ctx(unused_nodes: List[INode]) -> Iterator[List[int]]:
pausable_nodes_ids = [node.os_vm_id for node in unused_nodes
if node.os_vm_id is not None]
@@ -163,17 +137,17 @@
@contextlib.contextmanager
-def sensor_monitoring(sensor_cfg: Any, nodes: Iterable[INode]) -> Generator[None]:
+def sensor_monitoring(sensor_cfg: Any, nodes: Iterable[INode]) -> Iterator[None]:
# TODO(koder): write this function
pass
-def run_tests(cfg: Config, test_block, nodes: Iterable[INode]) -> None:
- """
- Run test from test block
- """
+def run_tests(cfg: Config,
+ test_block: Dict[str, Dict[str, Any]],
+ nodes: Iterable[INode]) -> Iterator[Tuple[str, List[Any]]]:
+ """Run test from test block"""
+
test_nodes = [node for node in nodes if 'testnode' in node.roles]
- not_test_nodes = [node for node in nodes if 'testnode' not in node.roles]
if len(test_nodes) == 0:
logger.error("No test nodes found")
@@ -185,7 +159,7 @@
# iterate over all node counts
limit = params.get('node_limit', len(test_nodes))
if isinstance(limit, int):
- vm_limits = [limit]
+ vm_limits = [limit] # type: List[int]
else:
list_or_tpl = isinstance(limit, (tuple, list))
all_ints = list_or_tpl and all(isinstance(climit, int)
@@ -194,7 +168,7 @@
msg = "'node_limit' parameter ion config should" + \
"be either int or list if integers, not {0!r}".format(limit)
raise ValueError(msg)
- vm_limits = limit
+ vm_limits = limit # type: List[int]
for vm_count in vm_limits:
# select test nodes
@@ -249,7 +223,7 @@
def connect_stage(cfg: Config, ctx: TestRun) -> None:
ctx.clear_calls_stack.append(disconnect_stage)
connect_all(ctx.nodes)
- ctx.nodes = [node for node in ctx.nodes if node.connection is not None]
+ ctx.nodes = [node for node in ctx.nodes if node.is_connected()]
def discover_stage(cfg: Config, ctx: TestRun) -> None:
@@ -279,7 +253,7 @@
roles.remove('testnode')
if len(roles) != 0:
- cluster[node.conn_url] = roles
+ cluster[node.ssh_conn_url] = roles
with open(cfg.nodes_report_file, "w") as fd:
fd.write(pretty_yaml.dumps(cluster))
@@ -363,7 +337,7 @@
def get_vm_keypair(cfg: Config) -> Dict[str, str]:
- res = {}
+ res = {} # type: Dict[str, str]
for field, ext in (('keypair_file_private', 'pem'),
('keypair_file_public', 'pub')):
fpath = cfg.vm_configs.get(field)
@@ -379,7 +353,7 @@
@contextlib.contextmanager
-def create_vms_ctx(ctx: TestRun, cfg: Config, config, already_has_count: int=0) -> Generator[List[INode]]:
+def create_vms_ctx(ctx: TestRun, cfg: Config, config, already_has_count: int=0) -> Iterator[List[INode]]:
if config['count'].startswith('='):
count = int(config['count'][1:])
if count <= already_has_count:
@@ -405,7 +379,7 @@
if not config.get('skip_preparation', False):
logger.info("Preparing openstack")
- start_vms.prepare_os_subpr(nova, params, os_creds)
+ start_vms.prepare_os(nova, params, os_creds)
new_nodes = []
old_nodes = ctx.nodes[:]
@@ -431,13 +405,13 @@
ctx.results = collections.defaultdict(lambda: [])
for group in cfg.get('tests', []):
-
- if len(group.items()) != 1:
+ gitems = list(group.items())
+ if len(gitems) != 1:
msg = "Items in tests section should have len == 1"
logger.error(msg)
raise utils.StopTestError(msg)
- key, config = group.items()[0]
+ key, config = gitems[0]
if 'start_test_nodes' == key:
if 'openstack' not in config:
@@ -469,7 +443,8 @@
if not cfg.no_tests:
for test_group in tests:
with sensor_ctx:
- for tp, res in run_tests(cfg, test_group, ctx.nodes):
+ it = run_tests(cfg, test_group, ctx.nodes)
+ for tp, res in it:
ctx.results[tp].extend(res)
@@ -489,7 +464,7 @@
os.remove(vm_ids_fname)
-def store_nodes_in_log(cfg: Config, nodes_ids: Iterable[str]):
+def store_nodes_in_log(cfg: Config, nodes_ids: Iterable[str]) -> None:
with open(cfg.vm_ids_fname, 'w') as fd:
fd.write("\n".join(nodes_ids))
@@ -503,8 +478,7 @@
ssh_utils.close_all_sessions()
for node in ctx.nodes:
- if node.connection is not None:
- node.connection.close()
+ node.disconnect()
def store_raw_results_stage(cfg: Config, ctx: TestRun) -> None:
@@ -577,7 +551,7 @@
report.make_io_report(list(data[0]),
cfg.get('comment', ''),
html_rep_fname,
- lab_info=ctx.hw_info)
+ lab_info=ctx.nodes)
def load_data_from_path(test_res_dir: str) -> Mapping[str, List[Any]]:
@@ -599,5 +573,5 @@
ctx.results.setdefault(tp, []).extend(vals)
-def load_data_from(var_dir: str) -> Callable:
+def load_data_from(var_dir: str) -> Callable[[TestRun], None]:
return functools.partial(load_data_from_path_stage, var_dir)