refactoring and typing in progress
diff --git a/configs-examples/full.yaml b/configs-examples/full.yaml
index 85382c8..f5f479b 100644
--- a/configs-examples/full.yaml
+++ b/configs-examples/full.yaml
@@ -2,6 +2,7 @@
suspend_unused_vms: false
results_storage: /var/wally_results
var_dir_root: /tmp/perf_tests
+settings_dir: ~/.wally
discover: fuel_openrc_only
collect_info: false
suspend_unused_vms: true
diff --git a/requirements.txt b/requirements.txt
index 62fa724..ea91948 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,3 +12,7 @@
pycrypto
ecdsa
psutil
+python-novaclient
+python-cinderclient
+python-keystoneclient
+python-glanceclient
diff --git a/requirements_extra.txt b/requirements_extra.txt
new file mode 100644
index 0000000..a2a3828
--- /dev/null
+++ b/requirements_extra.txt
@@ -0,0 +1,2 @@
+oktest
+iso8601==0.1.10
\ No newline at end of file
diff --git a/scripts/install.sh b/scripts/install.sh
index c767b3d..3dfeef7 100755
--- a/scripts/install.sh
+++ b/scripts/install.sh
@@ -7,24 +7,20 @@
popd > /dev/null
function install_apt() {
- apt-get install -y python-openssl python-novaclient python-cinderclient \
- python-keystoneclient python-glanceclient python-faulthandler \
- python-pip
-
+ MODULES="python-openssl python-faulthandler python-pip"
if [ "$FULL" == "--full" ] ; then
- apt-get install -y python-scipy python-numpy python-matplotlib python-psutil
+ MODULES="$MODULES python-scipy python-numpy python-matplotlib python-psutil"
fi
+ apt-get install -y $MODULES
}
function install_yum() {
- yum -y install pyOpenSSL python-novaclient python-cinderclient \
- python-keystoneclient python-glanceclient \
- python-pip python-ecdsa
-
+ MODULES="pyOpenSSL python-pip python-ecdsa"
if [ "$FULL" == "--full" ] ; then
- yum -y install scipy numpy python-matplotlib python-psutil
+ MODULES="$MODULES scipy numpy python-matplotlib python-psutil"
fi
+ yum -y install $MODULES
}
if which apt-get >/dev/null; then
@@ -33,7 +29,7 @@
if which yum >/dev/null; then
install_yum
else
- echo "Error: Neither apt-get, not yum installed. Can't install deps"
+ echo "Error: Neither apt-get, not yum installed. Can't install binary dependencies."
exit 1
fi
fi
@@ -41,5 +37,5 @@
pip install -r "$SCRIPTPATH/../requirements.txt"
if [ "$FULL" == "--full" ] ; then
- pip install oktest iso8601==0.1.10
+ pip install -r "$SCRIPTPATH/../requirements_extra.txt"
fi
diff --git a/wally/config.py b/wally/config.py
index 59db0e1..cac4acf 100644
--- a/wally/config.py
+++ b/wally/config.py
@@ -1,14 +1,11 @@
from typing import Any, Dict
-from .storage import IStorable, IStorage
+from .storage import IStorable
-class NoData:
- @classmethod
- def get(cls: type, name: str, x: Any) -> type:
- return cls
+ConfigBlock = Dict[str, Any]
class Config(IStorable):
- # for mypy only
+ # make mypy happy
run_uuid = None # type: str
storage_url = None # type: str
comment = None # type: str
@@ -18,11 +15,13 @@
build_id = None # type: str
build_description = None # type: str
build_type = None # type: str
+ default_test_local_folder = None # type: str
+ settings_dir = None # type: str
- def __init__(self, dct: Dict[str, Any]) -> None:
+ def __init__(self, dct: ConfigBlock) -> None:
self.__dict__['_dct'] = dct
- def get(self, path: str, default: Any = NoData) -> Any:
+ def get(self, path: str, default: Any = None) -> Any:
curr = self
while path:
@@ -53,8 +52,5 @@
def __setattr__(self, name: str, val: Any):
self.__dct[name] = val
-
-class Context:
- def __init__(self, config: Config, storage: IStorage):
- self.config = config
- self.storage = storage
\ No newline at end of file
+ def __contains__(self, name: str) -> bool:
+ return self.get(name) is not None
diff --git a/wally/discover/discover.py b/wally/discover/discover.py
index 233650e..d1eb9ac 100644
--- a/wally/discover/discover.py
+++ b/wally/discover/discover.py
@@ -8,7 +8,7 @@
from . import openstack
from ..utils import parse_creds, StopTestError
from ..test_run_class import TestRun
-from ..node import Node, NodeInfo
+from ..node import Node
logger = logging.getLogger("wally.discover")
diff --git a/wally/hw_info.py b/wally/hw_info.py
index 4f21314..812921e 100644
--- a/wally/hw_info.py
+++ b/wally/hw_info.py
@@ -4,7 +4,7 @@
from typing import List, Tuple
from . import utils
-from .interfaces import IRemoteNode
+from .node_interfaces import IRPCNode
def get_data(rr: str, data: str) -> str:
@@ -115,7 +115,7 @@
self.ceph_version = None # type: str
-def get_sw_info(node: IRemoteNode) -> SWInfo:
+def get_sw_info(node: IRPCNode) -> SWInfo:
res = SWInfo()
res.OS_version = utils.get_os(node)
@@ -128,7 +128,7 @@
return res
-def get_hw_info(node: IRemoteNode) -> HWInfo:
+def get_hw_info(node: IRPCNode) -> HWInfo:
res = HWInfo()
lshw_out = node.run('sudo lshw -xml 2>/dev/null', nolog=True)
diff --git a/wally/inode.py b/wally/inode.py
deleted file mode 100644
index 7851a4a..0000000
--- a/wally/inode.py
+++ /dev/null
@@ -1,61 +0,0 @@
-import abc
-from typing import Set, Dict, Optional
-
-from .ssh_utils import parse_ssh_uri
-from . import hw_info
-from .interfaces import IRemoteNode, IHost
-
-
-class FuelNodeInfo:
- """FUEL master node additional info"""
- def __init__(self,
- version: str,
- fuel_ext_iface: str,
- openrc: Dict[str, str]) -> None:
-
- self.version = version # type: str
- self.fuel_ext_iface = fuel_ext_iface # type: str
- self.openrc = openrc # type: Dict[str, str]
-
-
-class NodeInfo:
- """Node information object"""
- def __init__(self,
- ssh_conn_url: str,
- roles: Set[str],
- bind_ip: str = None,
- ssh_key: str = None) -> None:
- self.ssh_conn_url = ssh_conn_url # type: str
- self.roles = roles # type: Set[str]
-
- if bind_ip is None:
- bind_ip = parse_ssh_uri(self.ssh_conn_url).host
-
- self.bind_ip = bind_ip # type: str
- self.ssh_key = ssh_key # type: Optional[str]
-
-
-class INode(IRemoteNode, metaclass=abc.ABCMeta):
- """Node object"""
-
- def __init__(self, node_info: NodeInfo):
- IRemoteNode.__init__(self)
- self.node_info = node_info # type: NodeInfo
- self.hwinfo = None # type: hw_info.HWInfo
- self.swinfo = None # type: hw_info.SWInfo
- self.os_vm_id = None # type: str
- self.ssh_conn = None # type: IHost
- self.ssh_conn_url = None # type: str
- self.rpc_conn = None
- self.rpc_conn_url = None # type: str
-
- @abc.abstractmethod
- def __str__(self):
- pass
-
- def __repr__(self):
- return str(self)
-
- @abc.abstractmethod
- def node_id(self) -> str:
- pass
diff --git a/wally/interfaces.py b/wally/interfaces.py
deleted file mode 100644
index d87c753..0000000
--- a/wally/interfaces.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import abc
-from typing import Any, Set, Dict
-
-
-class IRemoteShell(metaclass=abc.ABCMeta):
- @abc.abstractmethod
- def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
- pass
-
-
-class IHost(IRemoteShell, metaclass=abc.ABCMeta):
- @abc.abstractmethod
- def get_ip(self) -> str:
- pass
-
- @abc.abstractmethod
- def __str__(self) -> str:
- pass
-
- @abc.abstractmethod
- def put_to_file(self, path: str, content: bytes) -> None:
- pass
-
-
-class IRemoteFS(metaclass=abc.ABCMeta):
- @abc.abstractmethod
- def copy_file(self, local_path: str, remote_path: str = None) -> str:
- pass
-
- @abc.abstractmethod
- def get_file_content(self, path: str) -> bytes:
- pass
-
- @abc.abstractmethod
- def put_to_file(self, path:str, content: bytes) -> None:
- pass
-
- @abc.abstractmethod
- def forward_port(self, ip: str, remote_port: int, local_port: int = None) -> int:
- pass
-
- @abc.abstractmethod
- def get_interface(self, ip: str) -> str:
- pass
-
- @abc.abstractmethod
- def stat_file(self, path:str) -> Any:
- pass
-
-
-class IRPC(metaclass=abc.ABCMeta):
- pass
-
-
-class IRemoteNode(IRemoteFS, IRemoteShell, metaclass=abc.ABCMeta):
-
- def __init__(self) -> None:
- self.roles = set() # type: Set[str]
- self.rpc = None # type: IRPC
- self.rpc_params = None # type: Dict[str, Any]
-
- @abc.abstractmethod
- def is_connected(self) -> bool:
- pass
-
- @abc.abstractmethod
- def disconnect(self) -> None:
- pass
-
- @abc.abstractmethod
- def connect_ssh(self, timeout: int = None) -> None:
- pass
-
- @abc.abstractmethod
- def get_ip(self) -> str:
- pass
-
- @abc.abstractmethod
- def get_user(self) -> str:
- pass
-
diff --git a/wally/main.py b/wally/main.py
index 11d6fc1..bfb9af7 100644
--- a/wally/main.py
+++ b/wally/main.py
@@ -5,7 +5,7 @@
import logging
import argparse
import functools
-from typing import List, Tuple, Any, Callable, IO, cast, TYPE_CHECKING
+from typing import List, Tuple, Any, Callable, IO, cast, Optional
from yaml import load as _yaml_load
@@ -29,7 +29,7 @@
from . import utils, run_test, pretty_yaml
-from .storage import make_storage, IStorage
+from .storage import make_storage, Storage
from .config import Config
from .logger import setup_loggers
from .stage import log_stage, StageType
@@ -72,6 +72,8 @@
descr = "Disk io performance test suite"
parser = argparse.ArgumentParser(prog='wally', description=descr)
parser.add_argument("-l", '--log-level', help="print some extra log info")
+ parser.add_argument("-s", '--settings-dir', default=None,
+ help="Folder to store key/settings/history files")
subparsers = parser.add_subparsers(dest='subparser_name')
@@ -112,6 +114,17 @@
return parser.parse_args(argv[1:])
+def get_config_path(config: Config, opts_value: Optional[str]) -> str:
+ if opts_value is None and 'settings_dir' not in config:
+ val = "~/.wally"
+ elif opts_value is not None:
+ val = opts_value
+ else:
+ val = config.settings_dir
+
+ return os.path.abspath(os.path.expanduser(val))
+
+
def main(argv: List[str]) -> int:
if faulthandler is not None:
faulthandler.register(signal.SIGUSR1, all_threads=True)
@@ -123,7 +136,7 @@
# stop mypy from telling that config & storage might be undeclared
config = None # type: Config
- storage = None # type: IStorage
+ storage = None # type: Storage
if opts.subparser_name == 'test':
if opts.resume:
@@ -143,6 +156,7 @@
config.build_id = opts.build_id
config.build_description = opts.build_description
config.build_type = opts.build_type
+ config.settings_dir = get_config_path(config, opts.settings_dir)
storage = make_storage(config.storage_url)
@@ -174,7 +188,7 @@
elif opts.subparser_name == 'report':
storage = make_storage(opts.data_dir, existing=True)
- config = storage.load('config', Config)
+ config.settings_dir = get_config_path(config, opts.settings_dir)
elif opts.subparser_name == 'compare':
x = run_test.load_data_from_path(opts.data_path1)
diff --git a/wally/node.py b/wally/node.py
index 8435cf3..3bc52fc 100644
--- a/wally/node.py
+++ b/wally/node.py
@@ -1,116 +1,214 @@
-import re
-import getpass
-from typing import Tuple
-from .inode import INode, NodeInfo
+import os
+import time
+import json
+import socket
+import logging
+import subprocess
+from typing import Callable
-from .ssh_utils import parse_ssh_uri, run_over_ssh, connect
+import agent
+
+from .node_interfaces import IRPCNode, NodeInfo, ISSHHost, RPCBeforeConnCallback
+from .ssh_utils import parse_ssh_uri, ssh_connect
-class Node(INode):
- """Node object"""
+logger = logging.getLogger("wally")
- def __init__(self, node_info: NodeInfo) -> None:
- INode.__init__(self)
- self.info = node_info
- self.roles = node_info.roles
- self.bind_ip = node_info.bind_ip
+class SSHHost(ISSHHost):
+ def __init__(self, ssh_conn, node_name: str, ip: str) -> None:
+ self.conn = ssh_conn
+ self.node_name = node_name
+ self.ip = ip
- assert self.ssh_conn_url.startswith("ssh://")
- self.ssh_conn_url = node_info.ssh_conn_url
-
- self.ssh_conn = None
- self.rpc_conn_url = None
- self.rpc_conn = None
- self.os_vm_id = None
- self.hw_info = None
-
- if self.ssh_conn_url is not None:
- self.ssh_cred = parse_ssh_uri(self.ssh_conn_url)
- self.node_id = "{0.host}:{0.port}".format(self.ssh_cred)
- else:
- self.ssh_cred = None
- self.node_id = None
+ def get_ip(self) -> str:
+ return self.ip
def __str__(self) -> str:
- template = "<Node: url={conn_url!r} roles={roles}" + \
- " connected={is_connected}>"
- return template.format(conn_url=self.ssh_conn_url,
- roles=", ".join(self.roles),
- is_connected=self.ssh_conn is not None)
+ return self.node_name
+
+ def put_to_file(self, path: str, content: bytes) -> None:
+ with self.conn.open_sftp() as sftp:
+ with sftp.open(path, "wb") as fd:
+ fd.write(content)
+
+ def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
+ transport = self.conn.get_transport()
+ session = transport.open_session()
+
+ try:
+ session.set_combine_stderr(True)
+
+ stime = time.time()
+
+ if not nolog:
+ logger.debug("SSH:{0} Exec {1!r}".format(self, cmd))
+
+ session.exec_command(cmd)
+ session.settimeout(1)
+ session.shutdown_write()
+ output = ""
+
+ while True:
+ try:
+ ndata = session.recv(1024)
+ output += ndata
+ if "" == ndata:
+ break
+ except socket.timeout:
+ pass
+
+ if time.time() - stime > timeout:
+ raise OSError(output + "\nExecution timeout")
+
+ code = session.recv_exit_status()
+ finally:
+ found = False
+
+ if found:
+ session.close()
+
+ if code != 0:
+ templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
+ raise OSError(templ.format(self, cmd, code, output))
+
+ return output
+
+
+class LocalHost(ISSHHost):
+ def __str__(self):
+ return "<Local>"
+
+ def get_ip(self) -> str:
+ return 'localhost'
+
+ def put_to_file(self, path: str, content: bytes) -> None:
+ dir_name = os.path.dirname(path)
+ os.makedirs(dir_name, exist_ok=True)
+
+ with open(path, "wb") as fd:
+ fd.write(content)
+
+ def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
+ proc = subprocess.Popen(cmd, shell=True,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT)
+
+ stdout_data, _ = proc.communicate()
+ if proc.returncode != 0:
+ templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
+ raise OSError(templ.format(self, cmd, proc.returncode, stdout_data))
+
+ return stdout_data
+
+
+def connect(conn_url: str, conn_timeout: int = 60) -> ISSHHost:
+ if conn_url == 'local':
+ return LocalHost()
+ else:
+ return SSHHost(*ssh_connect(parse_ssh_uri(conn_url), conn_timeout))
+
+
+class RPCNode(IRPCNode):
+ """Node object"""
+
+ def __init__(self, conn: agent.Client, info: NodeInfo) -> None:
+ self.info = info
+ self.conn = conn
+
+ # if self.ssh_conn_url is not None:
+ # self.ssh_cred = parse_ssh_uri(self.ssh_conn_url)
+ # self.node_id = "{0.host}:{0.port}".format(self.ssh_cred)
+ # else:
+ # self.ssh_cred = None
+ # self.node_id = None
+
+ def __str__(self) -> str:
+ return "<Node: url={!r} roles={!r} hops=/>".format(self.info.ssh_conn_url, ",".join(self.info.roles))
def __repr__(self) -> str:
return str(self)
- def connect_ssh(self, timeout: int=None) -> None:
- self.ssh_conn = connect(self.ssh_conn_url)
-
- def connect_rpc(self) -> None:
- raise NotImplementedError()
-
- def prepare_rpc(self) -> None:
- raise NotImplementedError()
-
- def get_ip(self) -> str:
- """get node connection ip address"""
-
- if self.ssh_conn_url == 'local':
- return '127.0.0.1'
- return self.ssh_cred.host
-
- def get_user(self) -> str:
- """"get ssh connection username"""
- if self.ssh_conn_url == 'local':
- return getpass.getuser()
- return self.ssh_cred.user
-
- def run(self, cmd: str, stdin_data: str=None, timeout: int=60, nolog: bool=False) -> Tuple[int, str]:
- """Run command on node. Will use rpc connection, if available"""
-
- if self.rpc_conn is None:
- return run_over_ssh(self.ssh_conn, cmd,
- stdin_data=stdin_data, timeout=timeout,
- nolog=nolog, node=self)
- assert not stdin_data
- proc_id = self.rpc_conn.cli.spawn(cmd)
- exit_code = None
- output = ""
-
- while exit_code is None:
- exit_code, stdout_data, stderr_data = self.rpc_conn.cli.get_updates(proc_id)
- output += stdout_data + stderr_data
-
- return exit_code, output
-
- def discover_hardware_info(self) -> None:
- raise NotImplementedError()
-
def get_file_content(self, path: str) -> str:
raise NotImplementedError()
def forward_port(self, ip: str, remote_port: int, local_port: int = None) -> int:
raise NotImplementedError()
- def get_interface(self, ip: str) -> str:
- """Get node external interface for given IP"""
- data = self.run("ip a", nolog=True)
- curr_iface = None
- for line in data.split("\n"):
- match1 = re.match(r"\d+:\s+(?P<name>.*?):\s\<", line)
- if match1 is not None:
- curr_iface = match1.group('name')
+def setup_rpc(node: ISSHHost, rpc_server_code: bytes, port: int = 0,
+ rpc_conn_callback: RPCBeforeConnCallback = None) -> IRPCNode:
+ code_file = node.run("mktemp").strip()
+ log_file = node.run("mktemp").strip()
+ node.put_to_file(code_file, rpc_server_code)
+ cmd = "python {code_file} server --listen-addr={listen_ip}:{port} --daemon " + \
+ "--show-settings --stdout-file={out_file}"
+ params_js = node.run(cmd.format(code_file=code_file,
+ listen_addr=node.get_ip(),
+ out_file=log_file,
+ port=port)).strip()
+ params = json.loads(params_js)
+ params['log_file'] = log_file
- match2 = re.match(r"\s+inet\s+(?P<ip>[0-9.]+)/", line)
- if match2 is not None:
- if match2.group('ip') == ip:
- assert curr_iface is not None
- return curr_iface
+ if rpc_conn_callback:
+ ip, port = rpc_conn_callback(node, port)
+ else:
+ ip = node.get_ip()
+ port = int(params['addr'].split(":")[1])
- raise KeyError("Can't found interface for ip {0}".format(ip))
+ rpc_conn = agent.connect((ip, port))
+ node.info.params.update(params)
+ return RPCNode(rpc_conn, node.info)
- def sync_hw_info(self) -> None:
- pass
- def sync_sw_info(self) -> None:
- pass
\ No newline at end of file
+
+ # class RemoteNode(node_interfaces.IRPCNode):
+# def __init__(self, node_info: node_interfaces.NodeInfo, rpc_conn: agent.RPCClient):
+# self.info = node_info
+# self.rpc = rpc_conn
+#
+ # def get_interface(self, ip: str) -> str:
+ # """Get node external interface for given IP"""
+ # data = self.run("ip a", nolog=True)
+ # curr_iface = None
+ #
+ # for line in data.split("\n"):
+ # match1 = re.match(r"\d+:\s+(?P<name>.*?):\s\<", line)
+ # if match1 is not None:
+ # curr_iface = match1.group('name')
+ #
+ # match2 = re.match(r"\s+inet\s+(?P<ip>[0-9.]+)/", line)
+ # if match2 is not None:
+ # if match2.group('ip') == ip:
+ # assert curr_iface is not None
+ # return curr_iface
+ #
+ # raise KeyError("Can't found interface for ip {0}".format(ip))
+ #
+ # def get_user(self) -> str:
+ # """"get ssh connection username"""
+ # if self.ssh_conn_url == 'local':
+ # return getpass.getuser()
+ # return self.ssh_cred.user
+ #
+ #
+ # def run(self, cmd: str, stdin_data: str = None, timeout: int = 60, nolog: bool = False) -> Tuple[int, str]:
+ # """Run command on node. Will use rpc connection, if available"""
+ #
+ # if self.rpc_conn is None:
+ # return run_over_ssh(self.ssh_conn, cmd,
+ # stdin_data=stdin_data, timeout=timeout,
+ # nolog=nolog, node=self)
+ # assert not stdin_data
+ # proc_id = self.rpc_conn.cli.spawn(cmd)
+ # exit_code = None
+ # output = ""
+ #
+ # while exit_code is None:
+ # exit_code, stdout_data, stderr_data = self.rpc_conn.cli.get_updates(proc_id)
+ # output += stdout_data + stderr_data
+ #
+ # return exit_code, output
+
+
diff --git a/wally/node_interfaces.py b/wally/node_interfaces.py
new file mode 100644
index 0000000..e0b56aa
--- /dev/null
+++ b/wally/node_interfaces.py
@@ -0,0 +1,90 @@
+import abc
+from typing import Any, Set, Optional, List, Dict, Callable
+
+
+class NodeInfo:
+ """Node information object, result of dicovery process or config parsing"""
+
+ def __init__(self,
+ ssh_conn_url: str,
+ roles: Set[str],
+ hops: List['NodeInfo'] = None,
+ ssh_key: bytes = None) -> None:
+
+ self.hops = [] # type: List[NodeInfo]
+ if hops is not None:
+ self.hops = hops
+
+ self.ssh_conn_url = ssh_conn_url # type: str
+ self.rpc_conn_url = None # type: str
+ self.roles = roles # type: Set[str]
+ self.os_vm_id = None # type: Optional[int]
+ self.ssh_key = ssh_key # type: Optional[bytes]
+ self.params = {} # type: Dict[str, Any]
+
+
+class ISSHHost(metaclass=abc.ABCMeta):
+ """Minimal interface, required to setup RPC connection"""
+ info = None # type: NodeInfo
+
+ @abc.abstractmethod
+ def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
+ pass
+
+ @abc.abstractmethod
+ def get_ip(self) -> str:
+ pass
+
+ @abc.abstractmethod
+ def __str__(self) -> str:
+ pass
+
+ @abc.abstractmethod
+ def put_to_file(self, path: str, content: bytes) -> None:
+ pass
+
+
+class IRPCNode(metaclass=abc.ABCMeta):
+ """Remote filesystem interface"""
+ info = None # type: NodeInfo
+
+ @abc.abstractmethod
+ def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
+ pass
+
+ @abc.abstractmethod
+ def copy_file(self, local_path: str, remote_path: str = None) -> str:
+ pass
+
+ @abc.abstractmethod
+ def get_file_content(self, path: str) -> bytes:
+ pass
+
+ @abc.abstractmethod
+ def put_to_file(self, path:str, content: bytes) -> None:
+ pass
+
+ @abc.abstractmethod
+ def forward_port(self, ip: str, remote_port: int, local_port: int = None) -> int:
+ pass
+
+ @abc.abstractmethod
+ def get_interface(self, ip: str) -> str:
+ pass
+
+ @abc.abstractmethod
+ def stat_file(self, path:str) -> Any:
+ pass
+
+ @abc.abstractmethod
+ def node_id(self) -> str:
+ pass
+
+
+ @abc.abstractmethod
+ def disconnect(self) -> str:
+ pass
+
+
+
+RPCBeforeConnCallback = Callable[[NodeInfo, int], None]
\ No newline at end of file
diff --git a/wally/run_test.py b/wally/run_test.py
index b96ecf8..e62cfb1 100755
--- a/wally/run_test.py
+++ b/wally/run_test.py
@@ -4,13 +4,14 @@
import functools
import contextlib
import collections
-from typing import List, Dict, Iterable, Any, Iterator, Mapping, Callable, Tuple, Optional
+from typing import List, Dict, Iterable, Any, Iterator, Mapping, Callable, Tuple, Optional, Union, cast
from concurrent.futures import ThreadPoolExecutor, Future
-from .inode import INode
-from .discover import discover
+from .node_interfaces import NodeInfo, IRPCNode
from .test_run_class import TestRun
+from .discover import discover
from . import pretty_yaml, utils, report, ssh_utils, start_vms, hw_info
+from .config import ConfigBlock, Config
from .suits.mysql import MysqlTest
from .suits.itest import TestConfig
@@ -30,37 +31,35 @@
logger = logging.getLogger("wally")
-
-def connect_all(nodes: Iterable[INode],
+def connect_all(nodes_info: List[NodeInfo],
pool: ThreadPoolExecutor,
conn_timeout: int = 30,
- rpc_conn_callback: ssh_utils.RPCBeforeConnCallback = None) -> None:
- """Connect to all nodes, log errors
- nodes - list of nodes
- """
+ rpc_conn_callback: ssh_utils.RPCBeforeConnCallback = None) -> List[IRPCNode]:
+ """Connect to all nodes, log errors"""
- logger.info("Connecting to %s nodes", len(nodes))
+ logger.info("Connecting to %s nodes", len(nodes_info))
- def connect_ext(node: INode) -> bool:
+ def connect_ext(node_info: NodeInfo) -> Tuple[bool, Union[IRPCNode, NodeInfo]]:
try:
- node.connect_ssh(conn_timeout)
- node.rpc, node.rpc_params = ssh_utils.setup_rpc(node, rpc_conn_callback=rpc_conn_callback)
- return True
+ ssh_node = ssh_utils.connect(node_info.ssh_conn_url, conn_timeout=conn_timeout)
+ return True, ssh_utils.setup_rpc(ssh_node, rpc_conn_callback=rpc_conn_callback)
except Exception as exc:
logger.error("During connect to {}: {!s}".format(node, exc))
- return False
+ return False, node_info
- list(pool.map(connect_ext, nodes))
+ failed_testnodes = [] # type: List[NodeInfo]
+ failed_nodes = [] # type: List[NodeInfo]
+ ready = [] # type: List[IRPCNode]
- failed_testnodes = []
- failed_nodes = []
-
- for node in nodes:
- if not node.is_connected():
+ for ok, node in pool.map(connect_ext, nodes_info):
+ if not ok:
+ node = cast(NodeInfo, node)
if 'testnode' in node.roles:
failed_testnodes.append(node)
else:
failed_nodes.append(node)
+ else:
+ ready.append(cast(IRPCNode, node))
if failed_nodes:
msg = "Node(s) {} would be excluded - can't connect"
@@ -75,15 +74,17 @@
if not failed_nodes:
logger.info("All nodes connected successfully")
+ return ready
-def collect_info_stage(ctx: TestRun, nodes: Iterable[INode]) -> None:
- futures = {} # type: Dict[str, Future]
+
+def collect_info_stage(ctx: TestRun, nodes: Iterable[IRPCNode]) -> None:
+ futures = {} # type: Dict[str, Future]]
with ctx.get_pool() as pool:
for node in nodes:
hw_info_path = "hw_info/{}".format(node.node_id())
if hw_info_path not in ctx.storage:
- futures[hw_info_path] = pool.submit(hw_info.get_hw_info, node)
+ futures[hw_info_path] = pool.submit(hw_info.get_hw_info, node), node
sw_info_path = "sw_info/{}".format(node.node_id())
if sw_info_path not in ctx.storage:
@@ -94,174 +95,118 @@
@contextlib.contextmanager
-def suspend_vm_nodes_ctx(unused_nodes: List[INode]) -> Iterator[List[int]]:
+def suspend_vm_nodes_ctx(unused_nodes: List[IRPCNode]) -> Iterator[List[int]]:
- pausable_nodes_ids = [node.os_vm_id for node in unused_nodes
- if node.os_vm_id is not None]
+ pausable_nodes_ids = [cast(int, node.info.os_vm_id)
+ for node in unused_nodes
+ if node.info.os_vm_id is not None]
non_pausable = len(unused_nodes) - len(pausable_nodes_ids)
- if 0 != non_pausable:
- logger.warning("Can't pause {} nodes".format(
- non_pausable))
+ if non_pausable:
+ logger.warning("Can't pause {} nodes".format(non_pausable))
- if len(pausable_nodes_ids) != 0:
- logger.debug("Try to pause {} unused nodes".format(
- len(pausable_nodes_ids)))
+ if pausable_nodes_ids:
+ logger.debug("Try to pause {} unused nodes".format(len(pausable_nodes_ids)))
start_vms.pause(pausable_nodes_ids)
try:
yield pausable_nodes_ids
finally:
- if len(pausable_nodes_ids) != 0:
- logger.debug("Unpausing {} nodes".format(
- len(pausable_nodes_ids)))
+ if pausable_nodes_ids:
+ logger.debug("Unpausing {} nodes".format(len(pausable_nodes_ids)))
start_vms.unpause(pausable_nodes_ids)
-def generate_result_dir_name(results: str, name: str, params: Dict[str, Any]) -> str:
- # make a directory for results
- all_tests_dirs = os.listdir(results)
-
- if 'name' in params:
- dir_name = "{}_{}".format(name, params['name'])
- else:
- for idx in range(len(all_tests_dirs) + 1):
- dir_name = "{}_{}".format(name, idx)
- if dir_name not in all_tests_dirs:
- break
- else:
- raise utils.StopTestError("Can't select directory for test results")
-
- return os.path.join(results, dir_name)
-
-
-@contextlib.contextmanager
-def sensor_monitoring(sensor_cfg: Any, nodes: Iterable[INode]) -> Iterator[None]:
- # TODO(koder): write this function
- pass
-
-
-def run_tests(cfg: Config,
- test_block: Dict[str, Dict[str, Any]],
- nodes: Iterable[INode]) -> Iterator[Tuple[str, List[Any]]]:
+def run_tests(ctx: TestRun, test_block: ConfigBlock, nodes: List[IRPCNode]) -> None:
"""Run test from test block"""
- test_nodes = [node for node in nodes if 'testnode' in node.roles]
+ test_nodes = [node for node in nodes if 'testnode' in node.info.roles]
- if len(test_nodes) == 0:
+ if not test_nodes:
logger.error("No test nodes found")
return
for name, params in test_block.items():
- results = []
+ vm_count = params.get('node_limit', None) # type: Optional[int]
- # iterate over all node counts
- limit = params.get('node_limit', len(test_nodes))
- if isinstance(limit, int):
- vm_limits = [limit] # type: List[int]
+ # select test nodes
+ if vm_count is None:
+ curr_test_nodes = test_nodes
+ unused_nodes = []
else:
- list_or_tpl = isinstance(limit, (tuple, list))
- all_ints = list_or_tpl and all(isinstance(climit, int)
- for climit in limit)
- if not all_ints:
- msg = "'node_limit' parameter ion config should" + \
- "be either int or list if integers, not {0!r}".format(limit)
- raise ValueError(msg)
- vm_limits = limit # type: List[int]
+ curr_test_nodes = test_nodes[:vm_count]
+ unused_nodes = test_nodes[vm_count:]
- for vm_count in vm_limits:
- # select test nodes
- if vm_count == 'all':
- curr_test_nodes = test_nodes
- unused_nodes = []
- else:
- curr_test_nodes = test_nodes[:vm_count]
- unused_nodes = test_nodes[vm_count:]
+ if not curr_test_nodes:
+ logger.error("No nodes found for test, skipping it.")
+ continue
- if 0 == len(curr_test_nodes):
- continue
+ # results_path = generate_result_dir_name(cfg.results_storage, name, params)
+ # utils.mkdirs_if_unxists(results_path)
- results_path = generate_result_dir_name(cfg.results_storage, name, params)
- utils.mkdirs_if_unxists(results_path)
+ # suspend all unused virtual nodes
+ if ctx.config.get('suspend_unused_vms', True):
+ suspend_ctx = suspend_vm_nodes_ctx(unused_nodes)
+ else:
+ suspend_ctx = utils.empty_ctx()
- # suspend all unused virtual nodes
- if cfg.settings.get('suspend_unused_vms', True):
- suspend_ctx = suspend_vm_nodes_ctx(unused_nodes)
- else:
- suspend_ctx = utils.empty_ctx()
+ with suspend_ctx:
+ resumable_nodes_ids = [cast(int, node.info.os_vm_id)
+ for node in curr_test_nodes
+ if node.info.os_vm_id is not None]
- with suspend_ctx:
- resumable_nodes_ids = [node.os_vm_id for node in curr_test_nodes
- if node.os_vm_id is not None]
+ if resumable_nodes_ids:
+ logger.debug("Check and unpause {} nodes".format(len(resumable_nodes_ids)))
+ start_vms.unpause(resumable_nodes_ids)
- if len(resumable_nodes_ids) != 0:
- logger.debug("Check and unpause {} nodes".format(
- len(resumable_nodes_ids)))
- start_vms.unpause(resumable_nodes_ids)
+ test_cls = TOOL_TYPE_MAPPER[name]
- test_cls = TOOL_TYPE_MAPPER[name]
+ remote_dir = ctx.config.default_test_local_folder.format(name=name, uuid=ctx.config.run_uuid)
- remote_dir = cfg.default_test_local_folder.format(name=name)
+ test_cfg = TestConfig(test_cls.__name__,
+ params=params,
+ run_uuid=ctx.config.run_uuid,
+ nodes=test_nodes,
+ storage=ctx.storage,
+ remote_dir=remote_dir)
- test_cfg = TestConfig(test_cls.__name__,
- params=params,
- test_uuid=cfg.run_uuid,
- nodes=test_nodes,
- log_directory=results_path,
- remote_dir=remote_dir)
-
- t_start = time.time()
- res = test_cls(test_cfg).run()
- t_end = time.time()
-
- results.append(res)
-
- yield name, results
+ test_cls(test_cfg).run()
-def connect_stage(cfg: Config, ctx: TestRun) -> None:
+def connect_stage(ctx: TestRun) -> None:
ctx.clear_calls_stack.append(disconnect_stage)
- connect_all(ctx.nodes)
- ctx.nodes = [node for node in ctx.nodes if node.is_connected()]
+
+ with ctx.get_pool() as pool:
+ ctx.nodes = connect_all(ctx.nodes_info, pool, rpc_conn_callback=ctx.before_conn_callback)
-def discover_stage(cfg: Config, ctx: TestRun) -> None:
+def discover_stage(ctx: TestRun) -> None:
"""discover clusters and nodes stage"""
- if cfg.get('discover') is not None:
- discover_objs = [i.strip() for i in cfg.discover.strip().split(",")]
+ discover_info = ctx.config.get('discover')
+ if discover_info:
+ discover_objs = [i.strip() for i in discover_info.strip().split(",")]
- nodes = discover(ctx,
- discover_objs,
- cfg.clouds,
- cfg.results_storage,
- not cfg.dont_discover_nodes)
+ nodes_info = discover.discover(ctx, discover_objs,
+ ctx.config.clouds,
+ ctx.storage,
+ not ctx.config.dont_discover_nodes)
- ctx.nodes.extend(nodes)
+ ctx.nodes_info.extend(nodes_info)
- for url, roles in cfg.get('explicit_nodes', {}).items():
- ctx.nodes.append(Node(url, roles.split(",")))
+ for url, roles in ctx.config.get('explicit_nodes', {}).items():
+ ctx.nodes_info.append(NodeInfo(url, set(roles.split(","))))
-def save_nodes_stage(cfg: Config, ctx: TestRun) -> None:
+def save_nodes_stage(ctx: TestRun) -> None:
"""Save nodes list to file"""
- cluster = {}
- for node in ctx.nodes:
- roles = node.roles[:]
- if 'testnode' in roles:
- roles.remove('testnode')
-
- if len(roles) != 0:
- cluster[node.ssh_conn_url] = roles
-
- with open(cfg.nodes_report_file, "w") as fd:
- fd.write(pretty_yaml.dumps(cluster))
+ ctx.storage['nodes'] = ctx.nodes_info
-def reuse_vms_stage(cfg: Config, ctx: TestRun) -> None:
- vms_patterns = cfg.get('clouds', {}).get('openstack', {}).get('vms', [])
- private_key_path = get_vm_keypair(cfg)['keypair_file_private']
+def reuse_vms_stage(ctx: TestRun) -> None:
+ vms_patterns = ctx.config.get('clouds/openstack/vms', [])
+ private_key_path = get_vm_keypair(ctx.config)['keypair_file_private']
for creds in vms_patterns:
user_name, vm_name_pattern = creds.split("@", 1)
@@ -272,7 +217,7 @@
logger.debug(msg)
if not start_vms.is_connected():
- os_creds = get_OS_credentials(cfg, ctx)
+ os_creds = get_OS_credentials(ctx)
else:
os_creds = None
@@ -281,15 +226,16 @@
conn_url = "ssh://{user}@{ip}::{key}".format(user=user_name,
ip=ip,
key=private_key_path)
- node = Node(conn_url, ['testnode'])
- node.os_vm_id = vm_id
- ctx.nodes.append(node)
+ node_info = NodeInfo(conn_url, ['testnode'])
+ node_info.os_vm_id = vm_id
+ ctx.nodes_info.append(node_info)
-def get_OS_credentials(cfg: Config, ctx: TestRun) -> None:
+def get_OS_credentials(ctx: TestRun) -> None:
creds = None
os_creds = None
force_insecure = False
+ cfg = ctx.config
if 'openstack' in cfg.clouds:
os_cfg = cfg.clouds['openstack']
@@ -336,75 +282,65 @@
return creds
-def get_vm_keypair(cfg: Config) -> Dict[str, str]:
- res = {} # type: Dict[str, str]
- for field, ext in (('keypair_file_private', 'pem'),
- ('keypair_file_public', 'pub')):
- fpath = cfg.vm_configs.get(field)
-
- if fpath is None:
- fpath = cfg.vm_configs['keypair_name'] + "." + ext
-
- if os.path.isabs(fpath):
- res[field] = fpath
- else:
- res[field] = os.path.join(cfg.config_folder, fpath)
- return res
+def get_vm_keypair(cfg: Config) -> Tuple[str, str]:
+ key_name = cfg.vm_configs['keypair_name']
+ private_path = os.path.join(cfg.settings_dir, key_name + "_private.pem")
+ public_path = os.path.join(cfg.settings_dir, key_name + "_public.pub")
+ return (private_path, public_path)
@contextlib.contextmanager
-def create_vms_ctx(ctx: TestRun, cfg: Config, config, already_has_count: int=0) -> Iterator[List[INode]]:
- if config['count'].startswith('='):
- count = int(config['count'][1:])
+def create_vms_ctx(ctx: TestRun, vm_config: ConfigBlock, already_has_count: int = 0) -> Iterator[List[NodeInfo]]:
+ if vm_config['count'].startswith('='):
+ count = int(vm_config['count'][1:])
if count <= already_has_count:
logger.debug("Not need new vms")
yield []
return
- params = cfg.vm_configs[config['cfg_name']].copy()
- os_nodes_ids = []
-
if not start_vms.is_connected():
- os_creds = get_OS_credentials(cfg, ctx)
+ os_creds = get_OS_credentials(ctx)
else:
os_creds = None
nova = start_vms.nova_connect(os_creds)
- params.update(config)
- params.update(get_vm_keypair(cfg))
+ os_nodes_ids = ctx.storage.get('spawned_vm_ids', []) # # type: List[int]
+ new_nodes = [] # type: List[IRPCNode]
- params['group_name'] = cfg.run_uuid
- params['keypair_name'] = cfg.vm_configs['keypair_name']
+ if not os_nodes_ids:
+ params = ctx.config.vm_configs[vm_config['cfg_name']].copy()
+ params.update(vm_config)
+ params.update(get_vm_keypair(ctx.config))
+ params['group_name'] = ctx.config.run_uuid
+ params['keypair_name'] = ctx.config.vm_configs['keypair_name']
- if not config.get('skip_preparation', False):
- logger.info("Preparing openstack")
- start_vms.prepare_os(nova, params, os_creds)
+ if not vm_config.get('skip_preparation', False):
+ logger.info("Preparing openstack")
+ start_vms.prepare_os(nova, params, os_creds)
+ else:
+ # TODO(koder): reconnect to old VM's
+ raise NotImplementedError("Reconnect to old vms is not implemented")
- new_nodes = []
+ already_has_count += len(os_nodes_ids)
old_nodes = ctx.nodes[:]
- try:
- for new_node, node_id in start_vms.launch_vms(nova, params, already_has_count):
- new_node.roles.append('testnode')
- ctx.nodes.append(new_node)
- os_nodes_ids.append(node_id)
- new_nodes.append(new_node)
- store_nodes_in_log(cfg, os_nodes_ids)
- ctx.openstack_nodes_ids = os_nodes_ids
+ for node_info, node_id in start_vms.launch_vms(nova, params, already_has_count):
+ node_info.roles.append('testnode')
+ os_nodes_ids.append(node_id)
+ new_nodes.append(node_info)
+ ctx.storage['spawned_vm_ids'] = os_nodes_ids
- yield new_nodes
+ yield new_nodes
- finally:
- if not cfg.keep_vm:
- shut_down_vms_stage(cfg, ctx)
- ctx.nodes = old_nodes
+ # keep nodes in case of error for future test restart
+ if not ctx.config.keep_vm:
+ shut_down_vms_stage(ctx, os_nodes_ids)
+ ctx.storage['spawned_vm_ids'] = []
-def run_tests_stage(cfg: Config, ctx: TestRun) -> None:
- ctx.results = collections.defaultdict(lambda: [])
-
- for group in cfg.get('tests', []):
+def run_tests_stage(ctx: TestRun) -> None:
+ for group in ctx.config.get('tests', []):
gitems = list(group.items())
if len(gitems) != 1:
msg = "Items in tests section should have len == 1"
@@ -419,159 +355,138 @@
logger.error(msg)
raise utils.StopTestError(msg)
- num_test_nodes = 0
- for node in ctx.nodes:
- if 'testnode' in node.roles:
- num_test_nodes += 1
-
- vm_ctx = create_vms_ctx(ctx, cfg, config['openstack'],
- num_test_nodes)
+ num_test_nodes = len([node for node in ctx.nodes if 'testnode' in node.info.roles])
+ vm_ctx = create_vms_ctx(ctx, config['openstack'], num_test_nodes)
tests = config.get('tests', [])
else:
vm_ctx = utils.empty_ctx([])
tests = [group]
- if cfg.get('sensors') is None:
- sensor_ctx = utils.empty_ctx()
- else:
- sensor_ctx = sensor_monitoring(cfg.get('sensors'), ctx.nodes)
+ with vm_ctx as new_nodes: # type: List[NodeInfo]
+ if new_nodes:
+ with ctx.get_pool() as pool:
+ new_rpc_nodes = connect_all(new_nodes, pool, rpc_conn_callback=ctx.before_conn_callback)
- with vm_ctx as new_nodes:
- if len(new_nodes) != 0:
- connect_all(new_nodes, True)
+ test_nodes = ctx.nodes + new_rpc_nodes
- if not cfg.no_tests:
+ if ctx.config.get('sensors'):
+ sensor_ctx = sensor_monitoring(ctx.config.get('sensors'), test_nodes)
+ else:
+ sensor_ctx = utils.empty_ctx([])
+
+ if not ctx.config.no_tests:
for test_group in tests:
with sensor_ctx:
- it = run_tests(cfg, test_group, ctx.nodes)
- for tp, res in it:
- ctx.results[tp].extend(res)
+ run_tests(ctx, test_group, test_nodes)
+
+ for node in new_rpc_nodes:
+ node.disconnect()
-def shut_down_vms_stage(cfg: Config, ctx: TestRun) -> None:
- vm_ids_fname = cfg.vm_ids_fname
- if ctx.openstack_nodes_ids is None:
- nodes_ids = open(vm_ids_fname).read().split()
- else:
- nodes_ids = ctx.openstack_nodes_ids
-
- if len(nodes_ids) != 0:
+def shut_down_vms_stage(ctx: TestRun, nodes_ids: List[int]) -> None:
+ if nodes_ids:
logger.info("Removing nodes")
start_vms.clear_nodes(nodes_ids)
logger.info("Nodes has been removed")
- if os.path.exists(vm_ids_fname):
- os.remove(vm_ids_fname)
+
+def clear_enviroment(ctx: TestRun) -> None:
+ shut_down_vms_stage(ctx, ctx.storage.get('spawned_vm_ids', []))
+ ctx.storage['spawned_vm_ids'] = []
-def store_nodes_in_log(cfg: Config, nodes_ids: Iterable[str]) -> None:
- with open(cfg.vm_ids_fname, 'w') as fd:
- fd.write("\n".join(nodes_ids))
-
-
-def clear_enviroment(cfg: Config, ctx: TestRun) -> None:
- if os.path.exists(cfg.vm_ids_fname):
- shut_down_vms_stage(cfg, ctx)
-
-
-def disconnect_stage(cfg: Config, ctx: TestRun) -> None:
- ssh_utils.close_all_sessions()
+def disconnect_stage(ctx: TestRun) -> None:
+ # TODO(koder): what next line was for?
+ # ssh_utils.close_all_sessions()
for node in ctx.nodes:
node.disconnect()
-def store_raw_results_stage(cfg: Config, ctx: TestRun) -> None:
- if os.path.exists(cfg.raw_results):
- cont = yaml_load(open(cfg.raw_results).read())
- else:
- cont = []
-
- cont.extend(utils.yamable(ctx.results).items())
- raw_data = pretty_yaml.dumps(cont)
-
- with open(cfg.raw_results, "w") as fd:
- fd.write(raw_data)
+def console_report_stage(ctx: TestRun) -> None:
+ # TODO(koder): load data from storage
+ raise NotImplementedError("...")
+ # first_report = True
+ # text_rep_fname = ctx.config.text_report_file
+ #
+ # with open(text_rep_fname, "w") as fd:
+ # for tp, data in ctx.results.items():
+ # if 'io' == tp and data is not None:
+ # rep_lst = []
+ # for result in data:
+ # rep_lst.append(
+ # IOPerfTest.format_for_console(list(result)))
+ # rep = "\n\n".join(rep_lst)
+ # elif tp in ['mysql', 'pgbench'] and data is not None:
+ # rep = MysqlTest.format_for_console(data)
+ # elif tp == 'omg':
+ # rep = OmgTest.format_for_console(data)
+ # else:
+ # logger.warning("Can't generate text report for " + tp)
+ # continue
+ #
+ # fd.write(rep)
+ # fd.write("\n")
+ #
+ # if first_report:
+ # logger.info("Text report were stored in " + text_rep_fname)
+ # first_report = False
+ #
+ # print("\n" + rep + "\n")
-def console_report_stage(cfg: Config, ctx: TestRun) -> None:
- first_report = True
- text_rep_fname = cfg.text_report_file
- with open(text_rep_fname, "w") as fd:
- for tp, data in ctx.results.items():
- if 'io' == tp and data is not None:
- rep_lst = []
- for result in data:
- rep_lst.append(
- IOPerfTest.format_for_console(list(result)))
- rep = "\n\n".join(rep_lst)
- elif tp in ['mysql', 'pgbench'] and data is not None:
- rep = MysqlTest.format_for_console(data)
- elif tp == 'omg':
- rep = OmgTest.format_for_console(data)
- else:
- logger.warning("Can't generate text report for " + tp)
- continue
+# def test_load_report_stage(cfg: Config, ctx: TestRun) -> None:
+# load_rep_fname = cfg.load_report_file
+# found = False
+# for idx, (tp, data) in enumerate(ctx.results.items()):
+# if 'io' == tp and data is not None:
+# if found:
+# logger.error("Making reports for more than one " +
+# "io block isn't supported! All " +
+# "report, except first are skipped")
+# continue
+# found = True
+# report.make_load_report(idx, cfg['results'], load_rep_fname)
+#
+#
- fd.write(rep)
- fd.write("\n")
+def html_report_stage(ctx: TestRun) -> None:
+ # TODO(koder): load data from storage
+ raise NotImplementedError("...")
+ # html_rep_fname = cfg.html_report_file
+ # found = False
+ # for tp, data in ctx.results.items():
+ # if 'io' == tp and data is not None:
+ # if found or len(data) > 1:
+ # logger.error("Making reports for more than one " +
+ # "io block isn't supported! All " +
+ # "report, except first are skipped")
+ # continue
+ # found = True
+ # report.make_io_report(list(data[0]),
+ # cfg.get('comment', ''),
+ # html_rep_fname,
+ # lab_info=ctx.nodes)
- if first_report:
- logger.info("Text report were stored in " + text_rep_fname)
- first_report = False
-
- print("\n" + rep + "\n")
-
-
-def test_load_report_stage(cfg: Config, ctx: TestRun) -> None:
- load_rep_fname = cfg.load_report_file
- found = False
- for idx, (tp, data) in enumerate(ctx.results.items()):
- if 'io' == tp and data is not None:
- if found:
- logger.error("Making reports for more than one " +
- "io block isn't supported! All " +
- "report, except first are skipped")
- continue
- found = True
- report.make_load_report(idx, cfg['results'], load_rep_fname)
-
-
-def html_report_stage(cfg: Config, ctx: TestRun) -> None:
- html_rep_fname = cfg.html_report_file
- found = False
- for tp, data in ctx.results.items():
- if 'io' == tp and data is not None:
- if found or len(data) > 1:
- logger.error("Making reports for more than one " +
- "io block isn't supported! All " +
- "report, except first are skipped")
- continue
- found = True
- report.make_io_report(list(data[0]),
- cfg.get('comment', ''),
- html_rep_fname,
- lab_info=ctx.nodes)
-
-
-def load_data_from_path(test_res_dir: str) -> Mapping[str, List[Any]]:
- files = get_test_files(test_res_dir)
- raw_res = yaml_load(open(files['raw_results']).read())
- res = collections.defaultdict(list)
-
- for tp, test_lists in raw_res:
- for tests in test_lists:
- for suite_name, suite_data in tests.items():
- result_folder = suite_data[0]
- res[tp].append(TOOL_TYPE_MAPPER[tp].load(suite_name, result_folder))
-
- return res
-
-
-def load_data_from_path_stage(var_dir: str, _, ctx: TestRun) -> None:
- for tp, vals in load_data_from_path(var_dir).items():
- ctx.results.setdefault(tp, []).extend(vals)
-
-
-def load_data_from(var_dir: str) -> Callable[[TestRun], None]:
- return functools.partial(load_data_from_path_stage, var_dir)
+#
+# def load_data_from_path(test_res_dir: str) -> Mapping[str, List[Any]]:
+# files = get_test_files(test_res_dir)
+# raw_res = yaml_load(open(files['raw_results']).read())
+# res = collections.defaultdict(list)
+#
+# for tp, test_lists in raw_res:
+# for tests in test_lists:
+# for suite_name, suite_data in tests.items():
+# result_folder = suite_data[0]
+# res[tp].append(TOOL_TYPE_MAPPER[tp].load(suite_name, result_folder))
+#
+# return res
+#
+#
+# def load_data_from_path_stage(var_dir: str, _, ctx: TestRun) -> None:
+# for tp, vals in load_data_from_path(var_dir).items():
+# ctx.results.setdefault(tp, []).extend(vals)
+#
+#
+# def load_data_from(var_dir: str) -> Callable[[TestRun], None]:
+# return functools.partial(load_data_from_path_stage, var_dir)
diff --git a/wally/ssh_utils.py b/wally/ssh_utils.py
index 2941e7c..7728dfd 100644
--- a/wally/ssh_utils.py
+++ b/wally/ssh_utils.py
@@ -1,27 +1,25 @@
import re
-import json
import time
+import errno
import socket
import logging
import os.path
import getpass
+import selectors
from io import BytesIO
-import subprocess
-from typing import Union, Optional, cast, Dict, List, Tuple, Any, Callable
-from concurrent.futures import ThreadPoolExecutor
+from typing import Union, Optional, cast, Dict, List, Tuple
import paramiko
-import agent
-
-from . import interfaces, utils
+from . import utils
logger = logging.getLogger("wally")
+IPAddr = Tuple[str, int]
-class URIsNamespace(object):
- class ReParts(object):
+class URIsNamespace:
+ class ReParts:
user_rr = "[^:]*?"
host_rr = "[^:@]*?"
port_rr = "\\d+"
@@ -58,41 +56,27 @@
class ConnCreds:
conn_uri_attrs = ("user", "passwd", "host", "port", "key_file")
- def __init__(self) -> None:
- self.user = None # type: Optional[str]
- self.passwd = None # type: Optional[str]
- self.host = None # type: str
- self.port = 22 # type: int
- self.key_file = None # type: Optional[str]
+ def __init__(self, host: str, user: str, passwd: str = None, port: int = 22, key_file: str = None) -> None:
+ self.user = user
+ self.passwd = passwd
+ self.host = host
+ self.port = port
+ self.key_file = key_file
def __str__(self) -> str:
return str(self.__dict__)
-SSHCredsType = Union[str, ConnCreds]
-
-
def parse_ssh_uri(uri: str) -> ConnCreds:
- # [ssh://]+
- # user:passwd@ip_host:port
- # user:passwd@ip_host
- # user@ip_host:port
- # user@ip_host
- # ip_host:port
- # ip_host
- # user@ip_host:port:path_to_key_file
- # user@ip_host::path_to_key_file
- # ip_host:port:path_to_key_file
- # ip_host::path_to_key_file
+ """Parse ssh connection URL from one of following form
+ [ssh://]user:passwd@host[:port]
+ [ssh://][user@]host[:port][:key_file]
+ """
if uri.startswith("ssh://"):
uri = uri[len("ssh://"):]
- res = ConnCreds()
- res.port = 22
- res.key_file = None
- res.passwd = None
- res.user = getpass.getuser()
+ res = ConnCreds("", getpass.getuser())
for rr in URIsNamespace.uri_reg_exprs:
rrm = re.match(rr, uri)
@@ -103,109 +87,18 @@
raise ValueError("Can't parse {0!r} as ssh uri value".format(uri))
-class LocalHost(interfaces.IHost):
- def __str__(self):
- return "<Local>"
-
- def get_ip(self) -> str:
- return 'localhost'
-
- def put_to_file(self, path: str, content: bytes) -> None:
- dirname = os.path.dirname(path)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
- with open(path, "wb") as fd:
- fd.write(content)
-
- def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
- proc = subprocess.Popen(cmd, shell=True,
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT)
-
- stdout_data, _ = proc.communicate()
- if proc.returncode != 0:
- templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
- raise OSError(templ.format(self, cmd, proc.returncode, stdout_data))
-
- return stdout_data
+NODE_KEYS = {} # type: Dict[IPAddr, paramiko.RSAKey]
-class SSHHost(interfaces.IHost):
- def __init__(self, ssh_conn, node_name: str, ip: str) -> None:
- self.conn = ssh_conn
- self.node_name = node_name
- self.ip = ip
-
- def get_ip(self) -> str:
- return self.ip
-
- def __str__(self) -> str:
- return self.node_name
-
- def put_to_file(self, path: str, content: bytes) -> None:
- with self.conn.open_sftp() as sftp:
- with sftp.open(path, "wb") as fd:
- fd.write(content)
-
- def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
- transport = self.conn.get_transport()
- session = transport.open_session()
-
- try:
- session.set_combine_stderr(True)
-
- stime = time.time()
-
- if not nolog:
- logger.debug("SSH:{0} Exec {1!r}".format(self, cmd))
-
- session.exec_command(cmd)
- session.settimeout(1)
- session.shutdown_write()
- output = ""
-
- while True:
- try:
- ndata = session.recv(1024)
- output += ndata
- if "" == ndata:
- break
- except socket.timeout:
- pass
-
- if time.time() - stime > timeout:
- raise OSError(output + "\nExecution timeout")
-
- code = session.recv_exit_status()
- finally:
- found = False
-
- if found:
- session.close()
-
- if code != 0:
- templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
- raise OSError(templ.format(self, cmd, code, output))
-
- return output
+def set_key_for_node(host_port: IPAddr, key: bytes) -> None:
+ with BytesIO(key) as sio:
+ NODE_KEYS[host_port] = paramiko.RSAKey.from_private_key(sio)
-NODE_KEYS = {} # type: Dict[Tuple[str, int], paramiko.RSAKey]
-
-
-def set_key_for_node(host_port: Tuple[str, int], key: bytes) -> None:
- sio = BytesIO(key)
- NODE_KEYS[host_port] = paramiko.RSAKey.from_private_key(sio)
- sio.close()
-
-
-def ssh_connect(creds: SSHCredsType, conn_timeout: int = 60) -> interfaces.IHost:
- if creds == 'local':
- return LocalHost()
-
- tcp_timeout = 15
- default_banner_timeout = 30
+def ssh_connect(creds: ConnCreds,
+ conn_timeout: int = 60,
+ tcp_timeout: int = 15,
+ default_banner_timeout: int = 30) -> Tuple[paramiko.SSHClient, str, str]:
ssh = paramiko.SSHClient()
ssh.load_host_keys('/dev/null')
@@ -223,8 +116,6 @@
if paramiko.__version_info__ >= (1, 15, 2):
banner_timeout_arg['banner_timeout'] = int(min(default_banner_timeout, time_left))
- creds = cast(ConnCreds, creds)
-
if creds.passwd is not None:
ssh.connect(creds.host,
timeout=c_tcp_timeout,
@@ -259,7 +150,7 @@
look_for_keys=False,
port=creds.port,
**banner_timeout_arg)
- return SSHHost(ssh, "{0.host}:{0.port}".format(creds), creds.host)
+ return ssh, "{0.host}:{0.port}".format(creds), creds.host
except paramiko.PasswordRequiredException:
raise
except (socket.error, paramiko.SSHException):
@@ -268,68 +159,35 @@
time.sleep(1)
-def connect(uri: str, **params) -> interfaces.IHost:
- if uri == 'local':
- res = LocalHost()
- else:
- creds = parse_ssh_uri(uri)
- creds.port = int(creds.port)
- res = ssh_connect(creds, **params)
- return res
-
-
-SetupResult = Tuple[interfaces.IRPC, Dict[str, Any]]
-
-
-RPCBeforeConnCallback = Callable[[interfaces.IHost, int], None]
-
-
-def setup_rpc(node: interfaces.IHost,
- rpc_server_code: bytes,
- port: int = 0,
- rpc_conn_callback: RPCBeforeConnCallback = None) -> SetupResult:
- code_file = node.run("mktemp").strip()
- log_file = node.run("mktemp").strip()
- node.put_to_file(code_file, rpc_server_code)
- cmd = "python {code_file} server --listen-addr={listen_ip}:{port} --daemon " + \
- "--show-settings --stdout-file={out_file}"
- params_js = node.run(cmd.format(code_file=code_file,
- listen_addr=node.get_ip(),
- out_file=log_file,
- port=port)).strip()
- params = json.loads(params_js)
- params['log_file'] = log_file
-
- if rpc_conn_callback:
- ip, port = rpc_conn_callback(node, port)
- else:
- ip = node.get_ip()
- port = int(params['addr'].split(":")[1])
-
- return agent.connect((ip, port)), params
-
-
-def wait_ssh_awailable(addrs: List[Tuple[str, int]],
+def wait_ssh_available(addrs: List[IPAddr],
timeout: int = 300,
- tcp_timeout: float = 1.0,
- max_threads: int = 32) -> None:
- addrs = addrs[:]
- tout = utils.Timeout(timeout)
+ tcp_timeout: float = 1.0) -> None:
+ addrs = set(addrs)
+ for _ in utils.Timeout(timeout):
+ with selectors.DefaultSelector() as selector: # type: selectors.BaseSelector
+ for addr in addrs:
+ sock = socket.socket()
+ sock.setblocking(False)
+ try:
+ sock.connect(addr)
+ except BlockingIOError:
+ pass
+ selector.register(sock, selectors.EVENT_READ, data=addr)
- def check_sock(addr):
- s = socket.socket()
- s.settimeout(tcp_timeout)
- try:
- s.connect(addr)
- return True
- except (socket.timeout, ConnectionRefusedError):
- return False
+ etime = time.time() + tcp_timeout
+ ltime = etime - time.time()
+ while ltime > 0:
+ for key, _ in selector.select(timeout=ltime):
+ selector.unregister(key.fileobj)
+ try:
+ key.fileobj.getpeername()
+ addrs.remove(key.data)
+ except OSError as exc:
+ if exc.errno == errno.ENOTCONN:
+ pass
+ ltime = etime - time.time()
- with ThreadPoolExecutor(max_workers=max_threads) as pool:
- while addrs:
- check_result = pool.map(check_sock, addrs)
- addrs = [addr for ok, addr in zip(check_result, addrs) if not ok] # type: List[Tuple[str, int]]
- tout.tick()
-
+ if not addrs:
+ break
diff --git a/wally/start_vms.py b/wally/start_vms.py
index 075f348..af81463 100644
--- a/wally/start_vms.py
+++ b/wally/start_vms.py
@@ -2,22 +2,28 @@
import os
import stat
import time
-import urllib
import os.path
import logging
-
-from typing import Dict, Any, Iterable, Generator, NamedTuple
+import tempfile
+import subprocess
+import urllib.request
+from typing import Dict, Any, Iterable, Iterator, NamedTuple, Optional, List, Tuple
from concurrent.futures import ThreadPoolExecutor
+from keystoneauth1 import loading, session
from novaclient.exceptions import NotFound
-from novaclient.client import Client as n_client
-from cinderclient.v1.client import Client as c_client
+from novaclient.client import Client as NovaClient
+from cinderclient.client import Client as CinderClient
+from glanceclient import Client as GlanceClient
-from .inode import NodeInfo
+
+from .utils import Timeout
+from .node_interfaces import NodeInfo
+
__doc__ = """
Module used to reliably spawn set of VM's, evenly distributed across
-openstack cluster. Main functions:
+compute servers in openstack cluster. Main functions:
get_OS_credentials - extract openstack credentials from different sources
nova_connect - connect to nova api
@@ -32,15 +38,6 @@
logger = logging.getLogger("wally.vms")
-STORED_OPENSTACK_CREDS = None
-NOVA_CONNECTION = None
-CINDER_CONNECTION = None
-
-
-def is_connected() -> bool:
- return NOVA_CONNECTION is not None
-
-
OSCreds = NamedTuple("OSCreds",
[("name", str),
("passwd", str),
@@ -49,57 +46,41 @@
("insecure", bool)])
-def ostack_get_creds() -> OSCreds:
- if STORED_OPENSTACK_CREDS is None:
- is_insecure = \
- os.environ.get('OS_INSECURE', 'False').lower() in ('true', 'yes')
- return OSCreds(os.environ.get('OS_USERNAME'),
- os.environ.get('OS_PASSWORD'),
- os.environ.get('OS_TENANT_NAME'),
- os.environ.get('OS_AUTH_URL'),
- is_insecure)
- else:
- return STORED_OPENSTACK_CREDS
+# TODO(koder): should correctly process different sources, not only env????
+def get_openstack_credentials() -> OSCreds:
+ is_insecure = os.environ.get('OS_INSECURE', 'false').lower() in ('true', 'yes')
+
+ return OSCreds(os.environ.get('OS_USERNAME'),
+ os.environ.get('OS_PASSWORD'),
+ os.environ.get('OS_TENANT_NAME'),
+ os.environ.get('OS_AUTH_URL'),
+ is_insecure)
-def nova_connect(os_creds: OSCreds=None) -> n_client:
- global NOVA_CONNECTION
- global STORED_OPENSTACK_CREDS
-
- if NOVA_CONNECTION is None:
- if os_creds is None:
- os_creds = ostack_get_creds()
- else:
- STORED_OPENSTACK_CREDS = os_creds
-
- NOVA_CONNECTION = n_client('1.1',
- os_creds.name,
- os_creds.passwd,
- os_creds.tenant,
- os_creds.auth_url,
- insecure=os_creds.insecure)
- return NOVA_CONNECTION
+class OSConnection:
+ def __init__(self, nova: NovaClient, cinder: CinderClient, glance: GlanceClient) -> None:
+ self.nova = nova
+ self.cinder = cinder
+ self.glance = glance
-def cinder_connect(os_creds: OSCreds=None) -> c_client:
- global CINDER_CONNECTION
- global STORED_OPENSTACK_CREDS
+def os_connect(os_creds: OSCreds, version: str = "2") -> OSConnection:
+ loader = loading.get_plugin_loader('password')
+ auth = loader.load_from_options(auth_url=os_creds.auth_url,
+ username=os_creds.name,
+ password=os_creds.passwd,
+ project_id=os_creds.tenant)
+ auth_sess = session.Session(auth=auth)
- if CINDER_CONNECTION is None:
- if os_creds is None:
- os_creds = ostack_get_creds()
- else:
- STORED_OPENSTACK_CREDS = os_creds
- CINDER_CONNECTION = c_client(os_creds.name,
- os_creds.passwd,
- os_creds.tenant,
- os_creds.auth_url,
- insecure=os_creds.insecure)
- return CINDER_CONNECTION
+ glance = GlanceClient(version, session=auth_sess)
+ nova = NovaClient(version, session=auth_sess)
+ cinder = CinderClient(os_creds.name, os_creds.passwd, os_creds.tenant, os_creds.auth_url,
+ insecure=os_creds.insecure, api_version=version)
+ return OSConnection(nova, cinder, glance)
-def find_vms(nova: n_client, name_prefix: str) -> Iterable[str, int]:
- for srv in nova.servers.list():
+def find_vms(conn: OSConnection, name_prefix: str) -> Iterable[str, int]:
+ for srv in conn.nova.servers.list():
if srv.name.startswith(name_prefix):
for ips in srv.addresses.values():
for ip in ips:
@@ -108,43 +89,33 @@
break
-def pause(ids: Iterable[str]) -> None:
- def pause_vm(conn: n_client, vm_id: str) -> None:
- vm = conn.servers.get(vm_id)
+def pause(conn: OSConnection, ids: Iterable[int], executor: ThreadPoolExecutor) -> None:
+ def pause_vm(vm_id: str) -> None:
+ vm = conn.nova.servers.get(vm_id)
if vm.status == 'ACTIVE':
vm.pause()
- conn = nova_connect()
- with ThreadPoolExecutor(max_workers=16) as executor:
- futures = [executor.submit(pause_vm, conn, vm_id)
- for vm_id in ids]
- for future in futures:
- future.result()
+ for future in executor.map(pause_vm, ids):
+ future.result()
-def unpause(ids: Iterable[str], max_resume_time=10) -> None:
- def unpause(conn: n_client, vm_id: str) -> None:
- vm = conn.servers.get(vm_id)
+def unpause(conn: OSConnection, ids: Iterable[int], executor: ThreadPoolExecutor, max_resume_time=10) -> None:
+ def unpause(vm_id: str) -> None:
+ vm = conn.nova.servers.get(vm_id)
if vm.status == 'PAUSED':
vm.unpause()
- for i in range(max_resume_time * 10):
- vm = conn.servers.get(vm_id)
+ for _ in Timeout(max_resume_time):
+ vm = conn.nova.servers.get(vm_id)
if vm.status != 'PAUSED':
return
- time.sleep(0.1)
raise RuntimeError("Can't unpause vm {0}".format(vm_id))
- conn = nova_connect()
- with ThreadPoolExecutor(max_workers=16) as executor:
- futures = [executor.submit(unpause, conn, vm_id)
- for vm_id in ids]
-
- for future in futures:
- future.result()
+ for future in executor.map(unpause, ids):
+ future.result()
-def prepare_os(nova: n_client, params: Dict[str, Any], os_creds: OSCreds) -> None:
+def prepare_os(conn: OSConnection, params: Dict[str, Any], max_vm_per_node: int = 8) -> None:
"""prepare openstack for futher usage
Creates server groups, security rules, keypair, flavor
@@ -153,7 +124,7 @@
Don't check, that existing object has required attributes
params:
- nova: novaclient connection
+ nova: OSConnection
params: dict {
security_group:str - security group name with allowed ssh and ping
aa_group_name:str - template for anti-affinity group names. Should
@@ -174,30 +145,18 @@
max_vm_per_compute: int=8 maximum expected amount of VM, per
compute host. Used to create appropriate
count of server groups for even placement
-
- returns: None
"""
- allow_ssh(nova, params['security_group'])
+ allow_ssh_and_ping(conn, params['security_group'])
- MAX_VM_PER_NODE = 8
- serv_groups = map(params['aa_group_name'].format,
- range(MAX_VM_PER_NODE))
+ for idx in range(max_vm_per_node):
+ get_or_create_aa_group(conn, params['aa_group_name'].format(idx))
- for serv_groups in serv_groups:
- get_or_create_aa_group(nova, serv_groups)
-
- create_keypair(nova,
- params['keypair_name'],
- params['keypair_file_public'],
- params['keypair_file_private'])
-
- create_image(os_creds, nova, params['image']['name'],
- params['image']['url'])
-
- create_flavor(nova, **params['flavor'])
+ create_keypair(conn, params['keypair_name'], params['keypair_file_public'], params['keypair_file_private'])
+ create_image(conn, params['image']['name'], params['image']['url'])
+ create_flavor(conn, **params['flavor'])
-def create_keypair(nova: n_client, name: str, pub_key_path: str, priv_key_path: str):
+def create_keypair(conn: OSConnection, name: str, pub_key_path: str, priv_key_path: str):
"""create and upload keypair into nova, if doesn't exists yet
Create and upload keypair into nova, if keypair with given bane
@@ -205,19 +164,17 @@
create new keys, and store'em into files.
parameters:
- nova: nova connection
+ conn: OSConnection
name: str - ketpair name
pub_key_path: str - path for public key
priv_key_path: str - path for private key
-
- returns: None
"""
pub_key_exists = os.path.exists(pub_key_path)
priv_key_exists = os.path.exists(priv_key_path)
try:
- kpair = nova.keypairs.find(name=name)
+ kpair = conn.nova.keypairs.find(name=name)
# if file not found- delete and recreate
except NotFound:
kpair = None
@@ -231,9 +188,9 @@
if kpair is None:
if pub_key_exists:
with open(pub_key_path) as pub_key_fd:
- return nova.keypairs.create(name, pub_key_fd.read())
+ return conn.nova.keypairs.create(name, pub_key_fd.read())
else:
- key = nova.keypairs.create(name)
+ key = conn.nova.keypairs.create(name)
with open(priv_key_path, "w") as priv_key_fd:
priv_key_fd.write(key.private_key)
@@ -248,54 +205,50 @@
" or remove key from openstack")
-def get_or_create_aa_group(nova: n_client, name: str) -> int:
+def get_or_create_aa_group(conn: OSConnection, name: str) -> int:
"""create anti-affinity server group, if doesn't exists yet
parameters:
- nova: nova connection
+ conn: OSConnection
name: str - group name
returns: str - group id
"""
try:
- group = nova.server_groups.find(name=name)
+ return conn.nova.server_groups.find(name=name).id
except NotFound:
- group = nova.server_groups.create(name=name,
- policies=['anti-affinity'])
-
- return group.id
+ return conn.nova.server_groups.create(name=name, policies=['anti-affinity']).id
-def allow_ssh(nova: n_client, group_name: str) -> int:
+def allow_ssh_and_ping(conn: OSConnection, group_name: str) -> int:
"""create sequrity group for ping and ssh
parameters:
- nova: nova connection
+ conn:
group_name: str - group name
returns: str - group id
"""
try:
- secgroup = nova.security_groups.find(name=group_name)
+ secgroup = conn.nova.security_groups.find(name=group_name)
except NotFound:
- secgroup = nova.security_groups.create(group_name,
- "allow ssh/ping to node")
+ secgroup = conn.nova.security_groups.create(group_name, "allow ssh/ping to node")
- nova.security_group_rules.create(secgroup.id,
- ip_protocol="tcp",
- from_port="22",
- to_port="22",
- cidr="0.0.0.0/0")
+ conn.nova.security_group_rules.create(secgroup.id,
+ ip_protocol="tcp",
+ from_port="22",
+ to_port="22",
+ cidr="0.0.0.0/0")
- nova.security_group_rules.create(secgroup.id,
- ip_protocol="icmp",
- from_port=-1,
- cidr="0.0.0.0/0",
- to_port=-1)
+ conn.nova.security_group_rules.create(secgroup.id,
+ ip_protocol="icmp",
+ from_port=-1,
+ cidr="0.0.0.0/0",
+ to_port=-1)
return secgroup.id
-def create_image(nova: n_client, os_creds: OSCreds, name: str, url: str):
+def create_image(conn: OSConnection, name: str, url: str) -> None:
"""upload image into glance from given URL, if given image doesn't exisis yet
parameters:
@@ -308,33 +261,31 @@
returns: None
"""
try:
- nova.images.find(name=name)
+ conn.nova.images.find(name=name)
return
except NotFound:
pass
- tempnam = os.tempnam()
+ ok = False
+ with tempfile.NamedTemporaryFile() as temp_fd:
+ try:
+ cmd = "wget --dns-timeout=30 --connect-timeout=30 --read-timeout=30 -o {} {}"
+ subprocess.check_call(cmd.format(temp_fd.name, url))
+ ok = True
- try:
- urllib.urlretrieve(url, tempnam)
+ # TODO(koder): add proper error handling
+ except Exception:
+ pass
- cmd = "OS_USERNAME={0.name}"
- cmd += " OS_PASSWORD={0.passwd}"
- cmd += " OS_TENANT_NAME={0.tenant}"
- cmd += " OS_AUTH_URL={0.auth_url}"
- cmd += " glance {1} image-create --name {2} $opts --file {3}"
- cmd += " --disk-format qcow2 --container-format bare --is-public true"
+ if not ok:
+ urllib.request.urlretrieve(url, temp_fd.name)
- cmd = cmd.format(os_creds,
- '--insecure' if os_creds.insecure else "",
- name,
- tempnam)
- finally:
- if os.path.exists(tempnam):
- os.unlink(tempnam)
+ image = conn.glance.images.create(name=name)
+ with open(temp_fd.name, 'rb') as fd:
+ conn.glance.images.upload(image.id, fd)
-def create_flavor(nova: n_client, name: str, ram_size: int, hdd_size: int, cpu_count: int):
+def create_flavor(conn: OSConnection, name: str, ram_size: int, hdd_size: int, cpu_count: int) -> None:
"""create flavor, if doesn't exisis yet
parameters:
@@ -347,17 +298,16 @@
returns: None
"""
try:
- nova.flavors.find(name)
+ conn.nova.flavors.find(name)
return
except NotFound:
pass
- nova.flavors.create(name, cpu_count, ram_size, hdd_size)
+ conn.nova.flavors.create(name, cpu_count, ram_size, hdd_size)
-def create_volume(size: int, name: str):
- cinder = cinder_connect()
- vol = cinder.volumes.create(size=size, display_name=name)
+def create_volume(conn: OSConnection, size: int, name: str) -> Any:
+ vol = conn.cinder.volumes.create(size=size, display_name=name)
err_count = 0
while vol.status != 'available':
@@ -367,16 +317,16 @@
raise RuntimeError("Fail to create volume")
else:
err_count += 1
- cinder.volumes.delete(vol)
+ conn.cinder.volumes.delete(vol)
time.sleep(1)
- vol = cinder.volumes.create(size=size, display_name=name)
+ vol = conn.cinder.volumes.create(size=size, display_name=name)
continue
time.sleep(1)
- vol = cinder.volumes.get(vol.id)
+ vol = conn.cinder.volumes.get(vol.id)
return vol
-def wait_for_server_active(nova: n_client, server, timeout: int=300)-> None:
+def wait_for_server_active(conn: OSConnection, server: Any, timeout: int = 300)-> None:
"""waiting till server became active
parameters:
@@ -387,29 +337,25 @@
returns: None
"""
- t = time.time()
- while True:
- time.sleep(1)
- sstate = getattr(server, 'OS-EXT-STS:vm_state').lower()
+ for _ in Timeout(timeout, no_exc=True):
+ server_state = getattr(server, 'OS-EXT-STS:vm_state').lower()
- if sstate == 'active':
+ if server_state == 'active':
return True
- if sstate == 'error':
+ if server_state == 'error':
return False
- if time.time() - t > timeout:
- return False
-
- server = nova.servers.get(server)
+ server = conn.nova.servers.get(server)
+ return False
class Allocate(object):
pass
-def get_floating_ips(nova, pool, amount):
- """allocate flationg ips
+def get_floating_ips(conn: OSConnection, pool: Optional[str], amount: int) -> List[str]:
+ """allocate floating ips
parameters:
nova: nova connection
@@ -418,7 +364,7 @@
returns: [ip object]
"""
- ip_list = nova.floating_ips.list()
+ ip_list = conn.nova.floating_ips.list()
if pool is not None:
ip_list = [ip for ip in ip_list if ip.pool == pool]
@@ -426,7 +372,10 @@
return [ip for ip in ip_list if ip.instance_id is None][:amount]
-def launch_vms(nova, params, already_has_count=0) -> Iterator[NodeInfo]:
+def launch_vms(conn: OSConnection,
+ params: Dict[str, Any],
+ executor: ThreadPoolExecutor,
+ already_has_count: int = 0) -> Iterator[NodeInfo]:
"""launch virtual servers
Parameters:
@@ -454,13 +403,12 @@
already_has_count: int=0 - how many servers already exists. Used to distribute
new servers evenly across all compute nodes, taking
old server in accout
- returns: generator of str - server credentials, in format USER@IP:KEY_PATH
+ returns: generator of NodeInfo - server credentials, in format USER@IP:KEY_PATH
"""
logger.debug("Calculating new vm count")
- count = params['count']
- nova = nova_connect()
- lst = nova.services.list(binary='nova-compute')
+ count = params['count'] # type: int
+ lst = conn.nova.services.list(binary='nova-compute')
srv_count = len([srv for srv in lst if srv.status == 'enabled'])
if isinstance(count, str):
@@ -500,12 +448,14 @@
private_key_path = params['keypair_file_private']
creds = params['image']['creds']
- for ip, os_node in create_vms_mt(NOVA_CONNECTION, count, **vm_params):
+ for ip, os_node in create_vms_mt(conn, count, executor, **vm_params):
conn_uri = creds.format(ip=ip, private_key_path=private_key_path)
- yield NodeInfo(conn_uri, []), os_node.id
+ info = NodeInfo(conn_uri, set())
+ info.os_vm_id = os_node.id
+ yield info
-def get_free_server_grpoups(nova, template):
+def get_free_server_groups(conn: OSConnection, template: str) -> Iterator[str]:
"""get fre server groups, that match given name template
parameters:
@@ -515,113 +465,117 @@
returns: generator or str - server group names
"""
- for g in nova.server_groups.list():
- if g.members == []:
- if re.match(template, g.name):
- yield str(g.id)
+ for server_group in conn.nova.server_groups.list():
+ if not server_group.members:
+ if re.match(template, server_group.name):
+ yield str(server_group.id)
-def create_vms_mt(nova, amount, group_name, keypair_name, img_name,
- flavor_name, vol_sz=None, network_zone_name=None,
- flt_ip_pool=None, name_templ='wally-{id}',
- scheduler_hints=None, security_group=None,
- sec_group_size=None):
+def create_vms_mt(conn: OSConnection,
+ amount: int,
+ executor: ThreadPoolExecutor,
+ group_name: str,
+ keypair_name: str,
+ img_name: str,
+ flavor_name: str,
+ vol_sz: int = None,
+ network_zone_name: str = None,
+ flt_ip_pool: str = None,
+ name_templ: str ='wally-{id}',
+ scheduler_hints: Dict = None,
+ security_group: str = None,
+ sec_group_size: int = None) -> List[Tuple[str, Any]]:
- with ThreadPoolExecutor(max_workers=16) as executor:
- if network_zone_name is not None:
- network_future = executor.submit(nova.networks.find,
- label=network_zone_name)
- else:
- network_future = None
+ if network_zone_name is not None:
+ network_future = executor.submit(conn.nova.networks.find,
+ label=network_zone_name)
+ else:
+ network_future = None
- fl_future = executor.submit(nova.flavors.find, name=flavor_name)
- img_future = executor.submit(nova.images.find, name=img_name)
+ fl_future = executor.submit(conn.nova.flavors.find, name=flavor_name)
+ img_future = executor.submit(conn.nova.images.find, name=img_name)
- if flt_ip_pool is not None:
- ips_future = executor.submit(get_floating_ips,
- nova, flt_ip_pool, amount)
- logger.debug("Wait for floating ip")
- ips = ips_future.result()
- ips += [Allocate] * (amount - len(ips))
- else:
- ips = [None] * amount
+ if flt_ip_pool is not None:
+ ips_future = executor.submit(get_floating_ips,
+ conn, flt_ip_pool, amount)
+ logger.debug("Wait for floating ip")
+ ips = ips_future.result()
+ ips += [Allocate] * (amount - len(ips))
+ else:
+ ips = [None] * amount
- logger.debug("Getting flavor object")
- fl = fl_future.result()
- logger.debug("Getting image object")
- img = img_future.result()
+ logger.debug("Getting flavor object")
+ fl = fl_future.result()
+ logger.debug("Getting image object")
+ img = img_future.result()
- if network_future is not None:
- logger.debug("Waiting for network results")
- nics = [{'net-id': network_future.result().id}]
- else:
- nics = None
+ if network_future is not None:
+ logger.debug("Waiting for network results")
+ nics = [{'net-id': network_future.result().id}]
+ else:
+ nics = None
- names = []
- for i in range(amount):
- names.append(name_templ.format(group=group_name, id=i))
+ names = [] # type: List[str]
+ for i in range(amount):
+ names.append(name_templ.format(group=group_name, id=i))
- futures = []
- logger.debug("Requesting new vm's")
+ futures = []
+ logger.debug("Requesting new vm's")
- orig_scheduler_hints = scheduler_hints.copy()
+ orig_scheduler_hints = scheduler_hints.copy()
+ group_name_template = scheduler_hints['group'].format("\\d+")
+ groups = list(get_free_server_groups(conn, group_name_template + "$"))
+ groups.sort()
- MAX_SHED_GROUPS = 32
- for start_idx in range(MAX_SHED_GROUPS):
- pass
+ for idx, (name, flt_ip) in enumerate(zip(names, ips), 2):
- group_name_template = scheduler_hints['group'].format("\\d+")
- groups = list(get_free_server_grpoups(nova, group_name_template + "$"))
- groups.sort()
-
- for idx, (name, flt_ip) in enumerate(zip(names, ips), 2):
-
- scheduler_hints = None
- if orig_scheduler_hints is not None and sec_group_size is not None:
- if "group" in orig_scheduler_hints:
- scheduler_hints = orig_scheduler_hints.copy()
- scheduler_hints['group'] = groups[idx // sec_group_size]
-
- if scheduler_hints is None:
+ scheduler_hints = None
+ if orig_scheduler_hints is not None and sec_group_size is not None:
+ if "group" in orig_scheduler_hints:
scheduler_hints = orig_scheduler_hints.copy()
+ scheduler_hints['group'] = groups[idx // sec_group_size]
- params = (nova, name, keypair_name, img, fl,
- nics, vol_sz, flt_ip, scheduler_hints,
- flt_ip_pool, [security_group])
+ if scheduler_hints is None:
+ scheduler_hints = orig_scheduler_hints.copy()
- futures.append(executor.submit(create_vm, *params))
- res = [future.result() for future in futures]
- logger.debug("Done spawning")
- return res
+ params = (conn, name, keypair_name, img, fl,
+ nics, vol_sz, flt_ip, scheduler_hints,
+ flt_ip_pool, [security_group])
+
+ futures.append(executor.submit(create_vm, *params))
+ res = [future.result() for future in futures]
+ logger.debug("Done spawning")
+ return res
-def create_vm(nova, name, keypair_name, img,
- fl, nics, vol_sz=None,
- flt_ip=False,
- scheduler_hints=None,
- pool=None,
- security_groups=None):
- for i in range(3):
- srv = nova.servers.create(name,
- flavor=fl,
- image=img,
- nics=nics,
- key_name=keypair_name,
- scheduler_hints=scheduler_hints,
- security_groups=security_groups)
+def create_vm(conn: OSConnection,
+ name: str,
+ keypair_name: str,
+ img: Any,
+ flavor: Any,
+ nics: List,
+ vol_sz: int = None,
+ flt_ip: Any = False,
+ scheduler_hints: Dict = None,
+ pool: str = None,
+ security_groups=None,
+ max_retry: int = 3,
+ delete_timeout: int = 120) -> Tuple[str, Any]:
- if not wait_for_server_active(nova, srv):
+ # make mypy/pylint happy
+ srv = None # type: Any
+ for i in range(max_retry):
+ srv = conn.nova.servers.create(name, flavor=flavor, image=img, nics=nics, key_name=keypair_name,
+ scheduler_hints=scheduler_hints, security_groups=security_groups)
+
+ if not wait_for_server_active(conn, srv):
msg = "Server {0} fails to start. Kill it and try again"
logger.debug(msg.format(srv))
- nova.servers.delete(srv)
+ conn.nova.servers.delete(srv)
try:
- for j in range(120):
- srv = nova.servers.get(srv.id)
- time.sleep(1)
- else:
- msg = "Server {0} delete timeout".format(srv.id)
- raise RuntimeError(msg)
+ for _ in Timeout(delete_timeout, "Server {0} delete timeout".format(srv.id)):
+ srv = conn.nova.servers.get(srv.id)
except NotFound:
pass
else:
@@ -630,27 +584,22 @@
raise RuntimeError("Failed to start server".format(srv.id))
if vol_sz is not None:
- vol = create_volume(vol_sz, name)
- nova.volumes.create_server_volume(srv.id, vol.id, None)
+ vol = create_volume(conn, vol_sz, name)
+ conn.nova.volumes.create_server_volume(srv.id, vol.id, None)
if flt_ip is Allocate:
- flt_ip = nova.floating_ips.create(pool)
+ flt_ip = conn.nova.floating_ips.create(pool)
if flt_ip is not None:
srv.add_floating_ip(flt_ip)
- return flt_ip.ip, nova.servers.get(srv.id)
+ return flt_ip.ip, conn.nova.servers.get(srv.id)
-def clear_nodes(nodes_ids):
- clear_all(NOVA_CONNECTION, nodes_ids, None)
-
-
-MAX_SERVER_DELETE_TIME = 120
-
-
-def clear_all(nova, ids=None, name_templ=None,
- max_server_delete_time=MAX_SERVER_DELETE_TIME):
+def clear_nodes(conn: OSConnection,
+ ids: List[int] = None,
+ name_templ: str = None,
+ max_server_delete_time: int = 120):
try:
def need_delete(srv):
if name_templ is not None:
@@ -659,43 +608,42 @@
return srv.id in ids
volumes_to_delete = []
- cinder = cinder_connect()
- for vol in cinder.volumes.list():
+ for vol in conn.cinder.volumes.list():
for attachment in vol.attachments:
if attachment['server_id'] in ids:
volumes_to_delete.append(vol)
break
- deleted_srvs = set()
- for srv in nova.servers.list():
+ still_alive = set()
+ for srv in conn.nova.servers.list():
if need_delete(srv):
logger.debug("Deleting server {0}".format(srv.name))
- nova.servers.delete(srv)
- deleted_srvs.add(srv.id)
+ conn.nova.servers.delete(srv)
+ still_alive.add(srv.id)
- count = 0
- while count < max_server_delete_time:
- if count % 60 == 0:
- logger.debug("Waiting till all servers are actually deleted")
- all_id = set(srv.id for srv in nova.servers.list())
- if len(all_id.intersection(deleted_srvs)) == 0:
- break
- count += 1
- time.sleep(1)
- else:
- logger.warning("Failed to remove servers. " +
- "You, probably, need to remove them manually")
- return
- logger.debug("Done, deleting volumes")
+ if still_alive:
+ logger.debug("Waiting till all servers are actually deleted")
+ tout = Timeout(max_server_delete_time, no_exc=True)
+ while tout.tick() and still_alive:
+ all_id = set(srv.id for srv in conn.nova.servers.list())
+ still_alive = still_alive.intersection(all_id)
- # wait till vm actually deleted
+ if still_alive:
+ logger.warning("Failed to remove servers {}. ".format(",".join(still_alive)) +
+ "You, probably, need to remove them manually (and volumes as well)")
+ return
- # logger.warning("Volume deletion commented out")
- for vol in volumes_to_delete:
- logger.debug("Deleting volume " + vol.display_name)
- cinder.volumes.delete(vol)
+ if volumes_to_delete:
+ logger.debug("Deleting volumes")
- logger.debug("Clearing done (yet some volumes may still deleting)")
- except:
+ # wait till vm actually deleted
+
+ # logger.warning("Volume deletion commented out")
+ for vol in volumes_to_delete:
+ logger.debug("Deleting volume " + vol.display_name)
+ conn.cinder.volumes.delete(vol)
+
+ logger.debug("Clearing complete (yet some volumes may still be deleting)")
+ except Exception:
logger.exception("During removing servers. " +
"You, probably, need to remove them manually")
diff --git a/wally/storage.py b/wally/storage.py
index 5212f4a..02de173 100644
--- a/wally/storage.py
+++ b/wally/storage.py
@@ -2,8 +2,9 @@
This module contains interfaces for storage classes
"""
+import os
import abc
-from typing import Any, Iterable, TypeVar, Type, IO
+from typing import Any, Iterable, TypeVar, Type, IO, Tuple, Union, Dict, List
class IStorable(metaclass=abc.ABCMeta):
@@ -32,46 +33,11 @@
ObjClass = TypeVar('ObjClass')
-class IStorage(metaclass=abc.ABCMeta):
- """interface for storage"""
- @abc.abstractmethod
- def __init__(self, path: str, existing_storage: bool = False) -> None:
- pass
-
- @abc.abstractmethod
- def __setitem__(self, path: str, value: IStorable) -> None:
- pass
-
- @abc.abstractmethod
- def __getitem__(self, path: str) -> IStorable:
- pass
-
- @abc.abstractmethod
- def __contains__(self, path: str) -> bool:
- pass
-
- @abc.abstractmethod
- def list(self, path: str) -> Iterable[str]:
- pass
-
- @abc.abstractmethod
- def load(self, path: str, obj_class: Type[ObjClass]) -> ObjClass:
- pass
-
- @abc.abstractmethod
- def get_stream(self, path: str) -> IO:
- pass
-
-
class ISimpleStorage(metaclass=abc.ABCMeta):
"""interface for low-level storage, which doesn't support serialization
and can operate only on bytes"""
@abc.abstractmethod
- def __init__(self, path: str) -> None:
- pass
-
- @abc.abstractmethod
def __setitem__(self, path: str, value: bytes) -> None:
pass
@@ -103,13 +69,47 @@
pass
-# TODO(koder): this is concrete storage and serializer classes to be implemented
-class FSStorage(IStorage):
+class FSStorage(ISimpleStorage):
"""Store all data in files on FS"""
+ def __init__(self, root_path: str, existing: bool) -> None:
+ self.root_path = root_path
+ if existing:
+ if not os.path.isdir(self.root_path):
+ raise ValueError("No storage found at {!r}".format(root_path))
+
+ def ensure_dir(self, path):
+ os.makedirs(path, exist_ok=True)
+
@abc.abstractmethod
- def __init__(self, root_path: str, serializer: ISerializer, existing: bool = False) -> None:
- pass
+ def __setitem__(self, path: str, value: bytes) -> None:
+ path = os.path.join(self.root_path, path)
+ self.ensure_dir(os.path.dirname(path))
+ with open(path, "wb") as fd:
+ fd.write(value)
+
+ @abc.abstractmethod
+ def __getitem__(self, path: str) -> bytes:
+ path = os.path.join(self.root_path, path)
+ with open(path, "rb") as fd:
+ return fd.read()
+
+ @abc.abstractmethod
+ def __contains__(self, path: str) -> bool:
+ path = os.path.join(self.root_path, path)
+ return os.path.exists(path)
+
+ @abc.abstractmethod
+ def list(self, path: str) -> Iterable[Tuple[bool, str]]:
+ path = os.path.join(self.root_path, path)
+ for entry in os.scandir(path):
+ if not entry.name in ('..', '.'):
+ yield entry.is_file(), entry.name
+
+ @abc.abstractmethod
+ def get_stream(self, path: str, mode: str = "rb") -> IO:
+ path = os.path.join(self.root_path, path)
+ return open(path, mode)
class YAMLSerializer(ISerializer):
@@ -117,6 +117,55 @@
pass
-def make_storage(url: str, existing: bool = False) -> IStorage:
- return FSStorage(url, YAMLSerializer(), existing)
+ISimpleStorable = Union[Dict, List, int, str, None, bool]
+
+
+class Storage:
+ """interface for storage"""
+ def __init__(self, storage: ISimpleStorage, serializer: ISerializer):
+ self.storage = storage
+ self.serializer = serializer
+
+ def __setitem__(self, path: str, value: IStorable) -> None:
+ self.storage[path] = self.serializer.pack(value)
+
+ @abc.abstractmethod
+ def __getitem__(self, path: str) -> ISimpleStorable:
+ return self.serializer.unpack(self.storage[path])
+
+ @abc.abstractmethod
+ def __contains__(self, path: str) -> bool:
+ return path in self.storage
+
+ @abc.abstractmethod
+ def list(self, path: str) -> Iterable[Tuple[bool, str]]:
+ return self.storage.list(path)
+
+ @abc.abstractmethod
+ def load(self, path: str, obj_class: Type[ObjClass]) -> ObjClass:
+ raw_val = self[path]
+ if obj_class in (int, str, dict, list, None):
+ if not isinstance(raw_val, obj_class):
+ raise ValueError("Can't load path {!r} into type {}. Real type is {}"
+ .format(path, obj_class, type(raw_val)))
+ return raw_val
+
+ if not isinstance(raw_val, dict):
+ raise ValueError("Can't load path {!r} into python type. Raw value not dict".format(path))
+
+ if not all(isinstance(str, key) for key in raw_val.keys):
+ raise ValueError("Can't load path {!r} into python type.".format(path) +
+ "Raw not all keys in raw value is strings")
+
+ obj = ObjClass.__new__(ObjClass)
+ obj.__dict__.update(raw_val)
+ return obj
+
+ @abc.abstractmethod
+ def get_stream(self, path: str) -> IO:
+ return self.storage.get_stream(path)
+
+
+def make_storage(url: str, existing: bool = False) -> Storage:
+ return Storage(FSStorage(url, existing), YAMLSerializer())
diff --git a/wally/suits/itest.py b/wally/suits/itest.py
index 00492c9..7004b8e 100644
--- a/wally/suits/itest.py
+++ b/wally/suits/itest.py
@@ -9,9 +9,8 @@
from ..utils import Barrier, StopTestError
from ..statistic import data_property
-from ..ssh_utils import copy_paths
from ..inode import INode
-
+from ..storage import Storage
logger = logging.getLogger("wally")
@@ -31,15 +30,15 @@
def __init__(self,
test_type: str,
params: Dict[str, Any],
- test_uuid: str,
+ run_uuid: str,
nodes: List[INode],
- log_directory: str,
+ storage: Storage,
remote_dir: str):
self.test_type = test_type
self.params = params
- self.test_uuid = test_uuid
- self.log_directory = log_directory
+ self.run_uuid = run_uuid
self.nodes = nodes
+ self.storage = storage
self.remote_dir = remote_dir
@@ -200,7 +199,7 @@
pass
@abc.abstractmethod
- def run(self) -> List[TestResults]:
+ def run(self):
pass
@abc.abstractmethod
diff --git a/wally/test_run_class.py b/wally/test_run_class.py
index e937300..cac893c 100644
--- a/wally/test_run_class.py
+++ b/wally/test_run_class.py
@@ -3,35 +3,37 @@
from .timeseries import SensorDatastore
-from . import inode
-from .start_vms import OSCreds
-from .storage import IStorage
+from .node_interfaces import NodeInfo, IRPCNode, RPCBeforeConnCallback
+from .start_vms import OSCreds, NovaClient, CinderClient
+from .storage import Storage
from .config import Config
class TestRun:
"""Test run information"""
- def __init__(self, config: Config, storage: IStorage):
+ def __init__(self, config: Config, storage: Storage):
# NodesInfo list
- self.nodes_info = [] # type: List[inode.NodeInfo]
+ self.nodes_info = [] # type: List[NodeInfo]
# Nodes list
- self.nodes = [] # type: List[inode.INode]
+ self.nodes = [] # type: List[IRPCNode]
self.build_meta = {} # type: Dict[str,Any]
self.clear_calls_stack = [] # type: List[Callable[['TestRun'], None]]
-
- # created openstack nodes
- self.openstack_nodes_ids = [] # type: List[str]
self.sensors_mon_q = None
# openstack credentials
self.fuel_openstack_creds = None # type: Optional[OSCreds]
+ self.os_creds = None # type: Optional[OSCreds]
+ self.nova_client = None # type: Optional[NovaClient]
+ self.cinder_client = None # type: Optional[CinderClient]
self.storage = storage
self.config = config
self.sensors_data = SensorDatastore()
+ self.before_conn_callback = None # type: RPCBeforeConnCallback
+
def get_pool(self):
return ThreadPoolExecutor(self.config.get('worker_pool_sz', 32))
diff --git a/wally/utils.py b/wally/utils.py
index d2b867e..14e3a6c 100644
--- a/wally/utils.py
+++ b/wally/utils.py
@@ -11,7 +11,7 @@
import subprocess
import collections
-from .interfaces import IRemoteNode
+from .node_interfaces import IRPCNode
from typing import Any, Tuple, Union, List, Iterator, Dict, Callable, Iterable, Optional, IO, Sequence
try:
@@ -412,15 +412,36 @@
class Timeout:
- def __init__(self, timeout: int, message: str = None) -> None:
+ def __init__(self, timeout: int, message: str = None, min_tick: int = 1, no_exc: bool = False) -> None:
self.etime = time.time() + timeout
self.message = message
+ self.min_tick = min_tick
+ self.prev_tick_at = time.time()
+ self.no_exc = no_exc
- def tick(self) -> None:
- if time.time() > self.etime:
+ def tick(self) -> bool:
+ ctime = time.time()
+ if ctime > self.etime:
if self.message:
msg = "Timeout: {}".format(self.message)
else:
msg = "Timeout"
- raise TimeoutError(msg)
\ No newline at end of file
+ if self.no_exc:
+ return False
+ raise TimeoutError(msg)
+
+ dtime = self.min_tick - (ctime - self.prev_tick_at)
+ if dtime > 0:
+ time.sleep(dtime)
+
+ self.prev_tick_at = time.time()
+ return True
+
+ def __iter__(self):
+ return self
+
+ def __next__(self) -> float:
+ if not self.tick():
+ raise StopIteration()
+ return self.etime - time.time()