refactoring is on the way
diff --git a/wally/common_types.py b/wally/common_types.py
new file mode 100644
index 0000000..884cd44
--- /dev/null
+++ b/wally/common_types.py
@@ -0,0 +1,4 @@
+from typing import NamedTuple
+
+IP = str
+IPAddr = NamedTuple("IPAddr", [("host", IP), ("port", int)])
diff --git a/wally/discover/ceph.py b/wally/discover/ceph.py
index ae06cbc..4a72bfb 100644
--- a/wally/discover/ceph.py
+++ b/wally/discover/ceph.py
@@ -1,64 +1,91 @@
""" Collect data about ceph nodes"""
import json
import logging
-from typing import Iterable
+from typing import List, Set, Dict
-from ..node import NodeInfo, Node
-
+from ..node_interfaces import NodeInfo, IRPCNode
+from ..ssh_utils import ConnCreds
+from ..common_types import IP
logger = logging.getLogger("wally.discover")
-def discover_ceph_nodes(node: Node) -> Iterable[NodeInfo]:
+def discover_ceph_nodes(node: IRPCNode,
+ cluster: str = "ceph",
+ conf: str = None,
+ key: str = None) -> List[NodeInfo]:
"""Return list of ceph's nodes NodeInfo"""
- ips = {}
- osd_ips = get_osds_ips(node, get_osds_list(node))
- mon_ips = get_mons_or_mds_ips(node, "mon")
- mds_ips = get_mons_or_mds_ips(node, "mds")
+ if conf is None:
+ conf = "/etc/ceph/{}.conf".format(cluster)
+ if key is None:
+ key = "/etc/ceph/{}.client.admin.keyring".format(cluster)
+
+ try:
+ osd_ips = get_osds_ips(node, conf, key)
+ except Exception as exc:
+ logger.error("OSD discovery failed: %s", exc)
+ osd_ips = set()
+
+ try:
+ mon_ips = get_mons_ips(node, conf, key)
+ except Exception as exc:
+ logger.error("MON discovery failed: %s", exc)
+ mon_ips = set()
+
+ ips = {} # type: Dict[str, List[str]]
for ip in osd_ips:
- url = "ssh://%s" % ip
- ips.setdefault(url, []).append("ceph-osd")
+ ips.setdefault(ip, []).append("ceph-osd")
for ip in mon_ips:
- url = "ssh://%s" % ip
- ips.setdefault(url, []).append("ceph-mon")
+ ips.setdefault(ip, []).append("ceph-mon")
- for ip in mds_ips:
- url = "ssh://%s" % ip
- ips.setdefault(url, []).append("ceph-mds")
-
- return [NodeInfo(url, set(roles)) for url, roles in ips.items()]
+ ssh_key = node.get_file_content("~/.ssh/id_rsa")
+ return [NodeInfo(ConnCreds(host=ip, user="root", key=ssh_key), set(roles)) for ip, roles in ips.items()]
-def get_osds_list(node: Node) -> Iterable[str]:
- """Get list of osd's id"""
- return filter(None, node.run("ceph osd ls").split("\n"))
+def get_osds_ips(node: IRPCNode, conf: str, key: str) -> Set[IP]:
+ """Get set of osd's ip"""
-
-def get_mons_or_mds_ips(node: Node, who: str) -> Iterable[str]:
- """Return mon ip list. who - mon/mds"""
- assert who in ("mon", "mds"), \
- "{!r} in get_mons_or_mds_ips instead of mon/mds".format(who)
-
- line_res = node.run("ceph {0} dump".format(who)).split("\n")
-
- ips = set()
- for line in line_res:
- fields = line.split()
- if len(fields) > 2 and who in fields[2]:
- ips.add(fields[1].split(":")[0])
-
+ data = node.run("ceph -c {} -k {} --format json osd dump".format(conf, key))
+ jdata = json.loads(data)
+ ips = set() # type: Set[IP]
+ first_error = True
+ for osd_data in jdata["osds"]:
+ if "public_addr" not in osd_data:
+ if first_error:
+ osd_id = osd_data.get("osd", "<OSD_ID_MISSED>")
+ logger.warning("No 'public_addr' field in 'ceph osd dump' output for osd %s" +
+ "(all subsequent errors omitted)", osd_id)
+ first_error = False
+ else:
+ ip_port = osd_data["public_addr"]
+ if '/' in ip_port:
+ ip_port = ip_port.split("/", 1)[0]
+ ips.add(IP(ip_port.split(":")[0]))
return ips
-def get_osds_ips(node: Node, osd_list: Iterable[str]) -> Iterable[str]:
- """Get osd's ips. osd_list - list of osd names from osd ls command"""
- ips = set()
- for osd_id in osd_list:
- out = node.run("ceph osd find {0}".format(osd_id))
- ip = json.loads(out)["ip"]
- ips.add(str(ip.split(":")[0]))
+def get_mons_ips(node: IRPCNode, conf: str, key: str) -> Set[IP]:
+ """Return mon ip set"""
+
+ data = node.run("ceph -c {} -k {} --format json mon_status".format(conf, key))
+ jdata = json.loads(data)
+ ips = set() # type: Set[IP]
+ first_error = True
+ for mon_data in jdata["monmap"]["mons"]:
+ if "addr" not in mon_data:
+ if first_error:
+ mon_name = mon_data.get("name", "<MON_NAME_MISSED>")
+ logger.warning("No 'addr' field in 'ceph mon_status' output for mon %s" +
+ "(all subsequent errors omitted)", mon_name)
+ first_error = False
+ else:
+ ip_port = mon_data["addr"]
+ if '/' in ip_port:
+ ip_port = ip_port.split("/", 1)[0]
+ ips.add(IP(ip_port.split(":")[0]))
+
return ips
diff --git a/wally/discover/discover.py b/wally/discover/discover.py
index d1eb9ac..e0d7dab 100644
--- a/wally/discover/discover.py
+++ b/wally/discover/discover.py
@@ -1,14 +1,18 @@
import os.path
import logging
+from typing import Dict, NamedTuple, List, Optional, cast
-from paramiko import AuthenticationException
+from paramiko.ssh_exception import AuthenticationException
from . import ceph
from . import fuel
from . import openstack
from ..utils import parse_creds, StopTestError
-from ..test_run_class import TestRun
-from ..node import Node
+from ..config import ConfigBlock
+from ..start_vms import OSCreds
+from ..node_interfaces import NodeInfo
+from ..node import connect, setup_rpc
+from ..ssh_utils import parse_ssh_uri
logger = logging.getLogger("wally.discover")
@@ -32,95 +36,81 @@
"""
-def discover(testrun: TestRun, discover_cfg, clusters_info, var_dir, discover_nodes=True):
+DiscoveryResult = NamedTuple("DiscoveryResult", [("os_creds", Optional[OSCreds]), ("nodes", List[NodeInfo])])
+
+
+def discover(discover_list: List[str], clusters_info: ConfigBlock, discover_nodes: bool = True) -> DiscoveryResult:
"""Discover nodes in clusters"""
- nodes_to_run = []
- clean_data = None
- for cluster in discover_cfg:
- if cluster == "openstack" and not discover_nodes:
- logger.warning("Skip openstack cluster discovery")
- elif cluster == "openstack" and discover_nodes:
- cluster_info = clusters_info["openstack"]
- conn = cluster_info['connection']
- user, passwd, tenant = parse_creds(conn['creds'])
+ new_nodes = [] # type: List[NodeInfo]
+ os_creds = None # type: Optional[OSCreds]
- auth_data = dict(
- auth_url=conn['auth_url'],
- username=user,
- api_key=passwd,
- project_id=tenant)
-
- if not conn:
- logger.error("No connection provided for %s. Skipping"
- % cluster)
+ for cluster in discover_list:
+ if cluster == "openstack":
+ if not discover_nodes:
+ logger.warning("Skip openstack cluster discovery")
continue
- logger.debug("Discovering openstack nodes "
- "with connection details: %r" %
- conn)
+ cluster_info = clusters_info["openstack"] # type: ConfigBlock
- os_nodes = openstack.discover_openstack_nodes(auth_data,
- cluster_info)
- nodes_to_run.extend(os_nodes)
+ conn = cluster_info['connection'] # type: ConfigBlock
+ if not conn:
+ logger.error("No connection provided for %s. Skipping", cluster)
+ continue
+
+ user, passwd, tenant = parse_creds(conn['creds'])
+
+ auth_data = dict(auth_url=conn['auth_url'],
+ username=user,
+ api_key=passwd,
+ project_id=tenant) # type: Dict[str, str]
+
+ logger.debug("Discovering openstack nodes with connection details: %r", conn)
+ new_nodes.extend(openstack.discover_openstack_nodes(auth_data, cluster_info))
elif cluster == "fuel" or cluster == "fuel_openrc_only":
if cluster == "fuel_openrc_only":
discover_nodes = False
- ssh_creds = clusters_info['fuel']['ssh_creds']
- fuel_node = Node(NodeInfo(ssh_creds, {'fuel_master'}))
-
+ fuel_node_info = NodeInfo(parse_ssh_uri(clusters_info['fuel']['ssh_creds']), {'fuel_master'})
try:
- fuel_node.connect_ssh()
+ fuel_rpc_conn = setup_rpc(connect(fuel_node_info))
except AuthenticationException:
raise StopTestError("Wrong fuel credentials")
except Exception:
logger.exception("While connection to FUEL")
raise StopTestError("Failed to connect to FUEL")
- fuel_node.connect_rpc()
+ with fuel_rpc_conn:
+ nodes, fuel_info = fuel.discover_fuel_nodes(fuel_rpc_conn, clusters_info['fuel'], discover_nodes)
+ new_nodes.extend(nodes)
- res = fuel.discover_fuel_nodes(fuel_node,
- clusters_info['fuel'],
- discover_nodes)
- nodes, clean_data, openrc_dict, version = res
+ if fuel_info.openrc:
+ auth_url = cast(str, fuel_info.openrc['os_auth_url'])
+ if fuel_info.version >= [8, 0] and auth_url.startswith("https://"):
+ logger.warning("Fixing FUEL 8.0 AUTH url - replace https://->http://")
+ auth_url = auth_url.replace("https", "http", 1)
- if openrc_dict:
- if version >= [8, 0] and openrc_dict['os_auth_url'].startswith("https://"):
- logger.warning("Fixing FUEL 8.0 AUTH url - replace https://->http://")
- openrc_dict['os_auth_url'] = "http" + openrc_dict['os_auth_url'][5:]
-
- testrun.fuel_openstack_creds = {
- 'name': openrc_dict['username'],
- 'passwd': openrc_dict['password'],
- 'tenant': openrc_dict['tenant_name'],
- 'auth_url': openrc_dict['os_auth_url'],
- 'insecure': openrc_dict['insecure']}
-
- env_name = clusters_info['fuel']['openstack_env']
- env_f_name = env_name
- for char in "-+ {}()[]":
- env_f_name = env_f_name.replace(char, '_')
-
- fuel_openrc_fname = os.path.join(var_dir,
- env_f_name + "_openrc")
-
- if testrun.fuel_openstack_creds is not None:
- with open(fuel_openrc_fname, "w") as fd:
- fd.write(openrc_templ.format(**testrun.fuel_openstack_creds))
- msg = "Openrc for cluster {0} saves into {1}"
- logger.info(msg.format(env_name, fuel_openrc_fname))
- nodes_to_run.extend(nodes)
+ os_creds = OSCreds(name=cast(str, fuel_info.openrc['username']),
+ passwd=cast(str, fuel_info.openrc['password']),
+ tenant=cast(str, fuel_info.openrc['tenant_name']),
+ auth_url=cast(str, auth_url),
+ insecure=cast(bool, fuel_info.openrc['insecure']))
elif cluster == "ceph":
if discover_nodes:
cluster_info = clusters_info["ceph"]
- nodes_to_run.extend(ceph.discover_ceph_nodes(cluster_info))
+ root_node_uri = cast(str, cluster_info["root_node"])
+ cluster = clusters_info["ceph"].get("cluster", "ceph")
+ conf = clusters_info["ceph"].get("conf")
+ key = clusters_info["ceph"].get("key")
+ info = NodeInfo(parse_ssh_uri(root_node_uri), set())
+ with setup_rpc(connect(info)) as ceph_root_conn:
+ new_nodes.extend(ceph.discover_ceph_nodes(ceph_root_conn, cluster=cluster, conf=conf, key=key))
else:
logger.warning("Skip ceph cluster discovery")
else:
- msg_templ = "Unknown cluster type in 'discover' parameter: {0!r}"
+ msg_templ = "Unknown cluster type in 'discover' parameter: {!r}"
raise ValueError(msg_templ.format(cluster))
- return nodes_to_run
+ return DiscoveryResult(os_creds, new_nodes)
diff --git a/wally/discover/fuel.py b/wally/discover/fuel.py
index 1443a9f..119cbd0 100644
--- a/wally/discover/fuel.py
+++ b/wally/discover/fuel.py
@@ -1,20 +1,25 @@
-import socket
import logging
-from typing import Dict, Any, Tuple, List
+import socket
+from typing import Dict, Any, Tuple, List, NamedTuple, Union
from urllib.parse import urlparse
-
from .. import fuel_rest_api
+from ..node_interfaces import NodeInfo, IRPCNode
+from ..ssh_utils import ConnCreds
from ..utils import parse_creds, check_input_param
-from ..node import NodeInfo, Node, FuelNodeInfo
-
logger = logging.getLogger("wally.discover")
-def discover_fuel_nodes(fuel_master_node: Node,
+FuelNodeInfo = NamedTuple("FuelNodeInfo",
+ [("version", List[int]),
+ ("fuel_ext_iface", str),
+ ("openrc", Dict[str, Union[str, bool]])])
+
+
+def discover_fuel_nodes(fuel_master_node: IRPCNode,
fuel_data: Dict[str, Any],
- discover_nodes: bool=True) -> Tuple[List[NodeInfo], FuelNodeInfo]:
+ discover_nodes: bool = True) -> Tuple[List[NodeInfo], FuelNodeInfo]:
"""Discover nodes in fuel cluster, get openrc for selected cluster"""
# parse FUEL REST credentials
@@ -51,17 +56,10 @@
logger.debug("Downloading fuel master key")
fuel_key = fuel_master_node.get_file_content('/root/.ssh/id_rsa')
- # forward ports of cluster nodes to FUEL master
- logger.info("Forwarding ssh ports from FUEL nodes to localhost")
- ips = [str(fuel_node.get_ip(network)) for fuel_node in fuel_nodes]
- port_fw = [fuel_master_node.forward_port(ip, 22) for ip in ips]
- listen_ip = fuel_master_node.get_ip()
-
nodes = []
- for port, fuel_node, ip in zip(port_fw, fuel_nodes, ips):
- logger.debug("SSH port forwarding {} => {}:{}".format(ip, listen_ip, port))
- conn_url = "ssh://root@{}:{}".format(listen_ip, port)
- nodes.append(NodeInfo(conn_url, fuel_node['roles'], listen_ip, fuel_key))
+ for fuel_node in fuel_nodes:
+ ip = str(fuel_node.get_ip(network))
+ nodes.append(NodeInfo(ConnCreds(ip, "root", key=fuel_key), roles=set(fuel_node.get_roles())))
logger.debug("Found {} fuel nodes for env {}".format(len(nodes), fuel_data['openstack_env']))
diff --git a/wally/discover/openstack.py b/wally/discover/openstack.py
index bb128c3..88e8656 100644
--- a/wally/discover/openstack.py
+++ b/wally/discover/openstack.py
@@ -1,18 +1,19 @@
import socket
import logging
-from typing import Iterable, Dict, Any
+from typing import Dict, Any, List
from novaclient.client import Client
-from ..node import NodeInfo
-from wally.utils import parse_creds
+from ..node_interfaces import NodeInfo
+from ..config import ConfigBlock
+from ..utils import parse_creds
logger = logging.getLogger("wally.discover")
-def get_floating_ip(vm) -> str:
+def get_floating_ip(vm: Any) -> str:
"""Get VM floating IP address"""
for net_name, ifaces in vm.addresses.items():
@@ -34,22 +35,21 @@
return "ssh://{}@{}::{}".format(user, ip, key)
-def discover_vms(client: Client, search_opts) -> Iterable[NodeInfo]:
+def discover_vms(client: Client, search_opts: Dict) -> List[NodeInfo]:
"""Discover virtual machines"""
user, password, key = parse_creds(search_opts.pop('auth'))
servers = client.servers.list(search_opts=search_opts)
logger.debug("Found %s openstack vms" % len(servers))
- nodes = []
+ nodes = [] # type: List[NodeInfo]
for server in servers:
ip = get_floating_ip(server)
- nodes.append(NodeInfo(get_ssh_url(user, password, ip, key), ["test_vm"]))
-
+ nodes.append(NodeInfo(get_ssh_url(user, password, ip, key), roles={"test_vm"}))
return nodes
-def discover_services(client: Client, opts: Dict[str, Any]) -> Iterable[NodeInfo]:
+def discover_services(client: Client, opts: Dict[str, Any]) -> List[NodeInfo]:
"""Discover openstack services for given cluster"""
user, password, key = parse_creds(opts.pop('auth'))
@@ -63,15 +63,16 @@
for s in opts['service']:
services.extend(client.services.list(binary=s))
- host_services_mapping = {}
+ host_services_mapping = {} # type: Dict[str, [str]]
for service in services:
ip = socket.gethostbyname(service.host)
- host_services_mapping[ip].append(service.binary)
+ host_services_mapping.get(ip, []).append(service.binary)
logger.debug("Found %s openstack service nodes" %
len(host_services_mapping))
- nodes = []
+
+ nodes = [] # type: List[NodeInfo]
for host, services in host_services_mapping.items():
ssh_url = get_ssh_url(user, password, host, key)
nodes.append(NodeInfo(ssh_url, services))
@@ -79,7 +80,7 @@
return nodes
-def discover_openstack_nodes(conn_details: Dict[str, str], conf: Dict[str, Any]) -> Iterable[NodeInfo]:
+def discover_openstack_nodes(conn_details: Dict[str, str], conf: ConfigBlock) -> List[NodeInfo]:
"""Discover vms running in openstack
conn_details - dict with openstack connection details -
auth_url, api_key (password), username
diff --git a/wally/fuel_rest_api.py b/wally/fuel_rest_api.py
index 03beb37..2addaf4 100644
--- a/wally/fuel_rest_api.py
+++ b/wally/fuel_rest_api.py
@@ -1,29 +1,38 @@
import re
+import abc
import json
-import time
import logging
import urllib.request
import urllib.parse
-from typing import Dict, Any, Iterator, Match
-from functools import partial, wraps
+from typing import Dict, Any, Iterator, Match, List, Callable
+from functools import partial
import netaddr
-
-from keystoneclient.v2_0 import Client as keystoneclient
from keystoneclient import exceptions
+from keystoneclient.v2_0 import Client as keystoneclient
logger = logging.getLogger("wally.fuel_api")
-class Urllib2HTTP:
+class Connection(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ def do(self, method: str, path: str, params: Dict = None) -> Dict:
+ pass
+
+ @abc.abstractmethod
+ def get(self, path: str, params: Dict = None) -> Dict:
+ pass
+
+
+class Urllib2HTTP(Connection):
"""
class for making HTTP requests
"""
allowed_methods = ('get', 'put', 'post', 'delete', 'patch', 'head')
- def __init__(self, root_url: str, headers: Dict[str, str]=None):
+ def __init__(self, root_url: str, headers: Dict[str, str] = None) -> None:
"""
"""
if root_url.endswith('/'):
@@ -31,12 +40,14 @@
else:
self.root_url = root_url
- self.headers = headers if headers is not None else {}
+ self.host = urllib.parse.urlparse(self.root_url).hostname
- def host(self) -> str:
- return self.root_url.split('/')[2]
+ if headers is None:
+ self.headers = {} # type: Dict[str, str]
+ else:
+ self.headers = headers
- def do(self, method: str, path: str, params: Dict[Any, Any]=None) -> Dict[str, Any]:
+ def do(self, method: str, path: str, params: Dict = None) -> Dict:
if path.startswith('/'):
url = self.root_url + path
else:
@@ -120,11 +131,11 @@
name = None
id = None
- def __init__(self, conn, **kwargs):
+ def __init__(self, conn, **kwargs) -> None:
self.__dict__.update(kwargs)
self.__connection__ = conn
- def __str__(self):
+ def __str__(self) -> str:
res = ["{0}({1}):".format(self.__class__.__name__, self.name)]
for k, v in sorted(self.__dict__.items()):
if k.startswith('__') or k.endswith('__'):
@@ -133,12 +144,12 @@
res.append(" {0}={1!r}".format(k, v))
return "\n".join(res)
- def __getitem__(self, item):
+ def __getitem__(self, item: str) -> Any:
return getattr(self, item)
-def make_call(method: str, url: str):
- def closure(obj, entire_obj=None, **data):
+def make_call(method: str, url: str) -> Callable[[Any, Any], Dict]:
+ def closure(obj: Any, entire_obj: Any = None, **data) -> Dict:
inline_params_vals = {}
for name in get_inline_param_list(url):
if name in data:
@@ -160,30 +171,10 @@
GET = partial(make_call, 'get')
DELETE = partial(make_call, 'delete')
-
-def with_timeout(tout, message):
- def closure(func):
- @wraps(func)
- def closure2(*dt, **mp):
- ctime = time.time()
- etime = ctime + tout
-
- while ctime < etime:
- if func(*dt, **mp):
- return
- sleep_time = ctime + 1 - time.time()
- if sleep_time > 0:
- time.sleep(sleep_time)
- ctime = time.time()
- raise RuntimeError("Timeout during " + message)
- return closure2
- return closure
-
-
# ------------------------------- ORM ----------------------------------------
-def get_fuel_info(url):
+def get_fuel_info(url: str) -> 'FuelInfo':
conn = Urllib2HTTP(url)
return FuelInfo(conn)
@@ -198,27 +189,27 @@
get_info = GET('api/releases')
@property
- def nodes(self):
+ def nodes(self) -> 'NodeList':
"""Get all fuel nodes"""
return NodeList([Node(self.__connection__, **node) for node
in self.get_nodes()])
@property
- def free_nodes(self):
+ def free_nodes(self) -> 'NodeList':
"""Get unallocated nodes"""
return NodeList([Node(self.__connection__, **node) for node in
self.get_nodes() if not node['cluster']])
@property
- def clusters(self):
+ def clusters(self) -> List['Cluster']:
"""List clusters in fuel"""
return [Cluster(self.__connection__, **cluster) for cluster
in self.get_clusters()]
- def get_version(self):
+ def get_version(self) -> List[int]:
for info in self.get_info():
vers = info['version'].split("-")[1].split('.')
- return map(int, vers)
+ return list(map(int, vers))
raise ValueError("No version found")
@@ -228,23 +219,18 @@
get_info = GET('/api/nodes/{id}')
get_interfaces = GET('/api/nodes/{id}/interfaces')
- def get_network_data(self):
+ def get_network_data(self) -> Dict:
"""Returns node network data"""
- node_info = self.get_info()
- return node_info.get('network_data')
+ return self.get_info().get('network_data')
- def get_roles(self, pending=False):
+ def get_roles(self) -> List[str]:
"""Get node roles
Returns: (roles, pending_roles)
"""
- node_info = self.get_info()
- if pending:
- return node_info.get('roles'), node_info.get('pending_roles')
- else:
- return node_info.get('roles')
+ return self.get_info().get('roles')
- def get_ip(self, network='public'):
+ def get_ip(self, network='public') -> netaddr.IPAddress:
"""Get node ip
:param network: network to pick
@@ -267,7 +253,7 @@
allowed_roles = ['controller', 'compute', 'cinder', 'ceph-osd', 'mongo',
'zabbix-server']
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> List[Node]:
if name in self.allowed_roles:
return [node for node in self if name in node.roles]
@@ -280,13 +266,13 @@
get_attributes = GET('api/clusters/{id}/attributes')
_get_nodes = GET('api/nodes?cluster_id={id}')
- def __init__(self, *dt, **mp):
+ def __init__(self, *dt, **mp) -> None:
super(Cluster, self).__init__(*dt, **mp)
self.nodes = NodeList([Node(self.__connection__, **node) for node in
self._get_nodes()])
self.network_roles = {}
- def check_exists(self):
+ def check_exists(self) -> bool:
"""Check if cluster exists"""
try:
self.get_status()
@@ -296,7 +282,7 @@
return False
raise
- def get_openrc(self):
+ def get_openrc(self) -> Dict[str, str]:
access = self.get_attributes()['editable']['access']
creds = {'username': access['user']['value'],
'password': access['password']['value'],
@@ -313,31 +299,31 @@
self.get_networks()['public_vip'])
return creds
- def get_nodes(self):
+ def get_nodes(self) -> Iterator[Node]:
for node_descr in self._get_nodes():
yield Node(self.__connection__, **node_descr)
-def reflect_cluster(conn, cluster_id):
+def reflect_cluster(conn: Connection, cluster_id: int) -> Cluster:
"""Returns cluster object by id"""
c = Cluster(conn, id=cluster_id)
c.nodes = NodeList(list(c.get_nodes()))
return c
-def get_all_nodes(conn):
+def get_all_nodes(conn: Connection) -> Iterator[Node]:
"""Get all nodes from Fuel"""
for node_desc in conn.get('api/nodes'):
yield Node(conn, **node_desc)
-def get_all_clusters(conn):
+def get_all_clusters(conn: Connection) -> Iterator[Cluster]:
"""Get all clusters"""
for cluster_desc in conn.get('api/clusters'):
yield Cluster(conn, **cluster_desc)
-def get_cluster_id(conn, name):
+def get_cluster_id(conn: Connection, name: str) -> int:
"""Get cluster id by name"""
for cluster in get_all_clusters(conn):
if cluster.name == name:
diff --git a/wally/hw_info.py b/wally/hw_info.py
index 812921e..764d126 100644
--- a/wally/hw_info.py
+++ b/wally/hw_info.py
@@ -1,7 +1,7 @@
import re
from typing import Dict, Iterable
import xml.etree.ElementTree as ET
-from typing import List, Tuple
+from typing import List, Tuple, cast, Optional
from . import utils
from .node_interfaces import IRPCNode
@@ -24,7 +24,7 @@
self.disks_raw_info = {} # type: Dict[str, str]
# name => (speed, is_full_diplex, ip_addresses)
- self.net_info = {} # type: Dict[str, Tuple[int, bool, str]]
+ self.net_info = {} # type: Dict[str, Tuple[Optional[int], Optional[bool], List[str]]]
self.ram_size = 0 # type: int
self.sys_name = None # type: str
@@ -107,11 +107,9 @@
def __init__(self) -> None:
self.partitions = None # type: str
self.kernel_version = None # type: str
- self.fio_version = None # type: str
self.libvirt_version = None # type: str
- self.kvm_version = None # type: str
self.qemu_version = None # type: str
- self.OS_version = None # type: str
+ self.OS_version = None # type: utils.OSRelease
self.ceph_version = None # type: str
@@ -119,11 +117,11 @@
res = SWInfo()
res.OS_version = utils.get_os(node)
- res.kernel_version = node.get_file_content('/proc/version')
- res.partitions = node.get_file_content('/etc/mtab')
- res.libvirt_version = node.run("virsh -v", nolog=True)
- res.qemu_version = node.run("qemu-system-x86_64 --version", nolog=True)
- res.ceph_version = node.run("ceph --version", nolog=True)
+ res.kernel_version = node.get_file_content('/proc/version').decode('utf8').strip()
+ res.partitions = node.get_file_content('/etc/mtab').decode('utf8').strip()
+ res.libvirt_version = node.run("virsh -v", nolog=True).strip()
+ res.qemu_version = node.run("qemu-system-x86_64 --version", nolog=True).strip()
+ res.ceph_version = node.run("ceph --version", nolog=True).strip()
return res
@@ -136,13 +134,14 @@
lshw_et = ET.fromstring(lshw_out)
try:
- res.hostname = lshw_et.find("node").attrib['id']
+ res.hostname = cast(str, lshw_et.find("node").attrib['id'])
except Exception:
pass
try:
- res.sys_name = (lshw_et.find("node/vendor").text + " " +
- lshw_et.find("node/product").text)
+
+ res.sys_name = cast(str, lshw_et.find("node/vendor").text) + " " + \
+ cast(str, lshw_et.find("node/product").text)
res.sys_name = res.sys_name.replace("(To be filled by O.E.M.)", "")
res.sys_name = res.sys_name.replace("(To be Filled by O.E.M.)", "")
except Exception:
@@ -150,17 +149,17 @@
core = lshw_et.find("node/node[@id='core']")
if core is None:
- return
+ return res
try:
- res.mb = " ".join(core.find(node).text
+ res.mb = " ".join(cast(str, core.find(node).text)
for node in ['vendor', 'product', 'version'])
except Exception:
pass
for cpu in core.findall("node[@class='processor']"):
try:
- model = cpu.find('product').text
+ model = cast(str, cpu.find('product').text)
threads_node = cpu.find("configuration/setting[@id='threads']")
if threads_node is None:
threads = 1
@@ -192,21 +191,22 @@
try:
link = net.find("configuration/setting[@id='link']")
if link.attrib['value'] == 'yes':
- name = net.find("logicalname").text
+ name = cast(str, net.find("logicalname").text)
speed_node = net.find("configuration/setting[@id='speed']")
if speed_node is None:
speed = None
else:
- speed = speed_node.attrib['value']
+ speed = int(speed_node.attrib['value'])
dup_node = net.find("configuration/setting[@id='duplex']")
if dup_node is None:
dup = None
else:
- dup = dup_node.attrib['value']
+ dup = cast(str, dup_node.attrib['value']).lower() == 'yes'
- res.net_info[name] = (speed, dup, [])
+ ips = [] # type: List[str]
+ res.net_info[name] = (speed, dup, ips)
except Exception:
pass
@@ -231,7 +231,7 @@
try:
lname_node = disk.find('logicalname')
if lname_node is not None:
- dev = lname_node.text.split('/')[-1]
+ dev = cast(str, lname_node.text).split('/')[-1]
if dev == "" or dev[-1].isdigit():
continue
@@ -250,7 +250,7 @@
full_descr = "{0} {1} {2} {3} {4}".format(
description, product, vendor, version, serial)
- businfo = disk.find('businfo').text
+ businfo = cast(str, disk.find('businfo').text)
res.disks_raw_info[businfo] = full_descr
except Exception:
pass
diff --git a/wally/keystone.py b/wally/keystone.py
deleted file mode 100644
index 358d128..0000000
--- a/wally/keystone.py
+++ /dev/null
@@ -1,90 +0,0 @@
-import json
-import urllib.request
-from functools import partial
-from typing import Dict, Any
-
-from keystoneclient import exceptions
-from keystoneclient.v2_0 import Client as keystoneclient
-
-
-class Urllib2HTTP:
- """
- class for making HTTP requests
- """
-
- allowed_methods = ('get', 'put', 'post', 'delete', 'patch', 'head')
-
- def __init__(self, root_url:str, headers:Dict[str, str]=None, echo: bool=False) -> None:
- """"""
- if root_url.endswith('/'):
- self.root_url = root_url[:-1]
- else:
- self.root_url = root_url
-
- self.headers = headers if headers is not None else {}
- self.echo = echo
-
- def do(self, method: str, path: str, params: Dict[str, str]=None) -> Any:
- if path.startswith('/'):
- url = self.root_url + path
- else:
- url = self.root_url + '/' + path
-
- if method == 'get':
- assert params == {} or params is None
- data_json = None
- else:
- data_json = json.dumps(params)
-
- request = urllib.request.Request(url,
- data=data_json,
- headers=self.headers)
- if data_json is not None:
- request.add_header('Content-Type', 'application/json')
-
- request.get_method = lambda: method.upper()
- response = urllib.request.urlopen(request)
-
- if response.code < 200 or response.code > 209:
- raise IndexError(url)
-
- content = response.read()
-
- if '' == content:
- return None
-
- return json.loads(content)
-
- def __getattr__(self, name):
- if name in self.allowed_methods:
- return partial(self.do, name)
- raise AttributeError(name)
-
-
-class KeystoneAuth(Urllib2HTTP):
- def __init__(self, root_url: str, creds: Dict[str, str], headers: Dict[str, str]=None, echo: bool=False,
- admin_node_ip: str=None):
- super(KeystoneAuth, self).__init__(root_url, headers, echo)
- self.keystone_url = "http://{0}:5000/v2.0".format(admin_node_ip)
- self.keystone = keystoneclient(
- auth_url=self.keystone_url, **creds)
- self.refresh_token()
-
- def refresh_token(self) -> None:
- """Get new token from keystone and update headers"""
- try:
- self.keystone.authenticate()
- self.headers['X-Auth-Token'] = self.keystone.auth_token
- except exceptions.AuthorizationFailure:
- raise
-
- def do(self, method: str, path: str, params: Dict[str, str]=None) -> Any:
- """Do request. If gets 401 refresh token"""
- try:
- return super(KeystoneAuth, self).do(method, path, params)
- except urllib.request.HTTPError as e:
- if e.code == 401:
- self.refresh_token()
- return super(KeystoneAuth, self).do(method, path, params)
- else:
- raise
diff --git a/wally/logger.py b/wally/logger.py
index a8cbf2a..e8c916d 100644
--- a/wally/logger.py
+++ b/wally/logger.py
@@ -1,5 +1,5 @@
import logging
-from typing import Callable, IO
+from typing import Callable, IO, Optional
def color_me(color: int) -> Callable[[str], str]:
@@ -49,22 +49,25 @@
def setup_loggers(def_level: int = logging.DEBUG, log_fname: str = None, log_fd: IO = None) -> None:
- logger = logging.getLogger('wally')
- logger.setLevel(logging.DEBUG)
- sh = logging.StreamHandler()
- sh.setLevel(def_level)
log_format = '%(asctime)s - %(levelname)s - %(name)-15s - %(message)s'
colored_formatter = ColoredFormatter(log_format, datefmt="%H:%M:%S")
+ sh = logging.StreamHandler()
+ sh.setLevel(def_level)
sh.setFormatter(colored_formatter)
+
+ logger = logging.getLogger('wally')
+ logger.setLevel(logging.DEBUG)
logger.addHandler(sh)
logger_api = logging.getLogger("wally.fuel_api")
+ logger_api.setLevel(logging.WARNING)
+ logger_api.addHandler(sh)
if log_fname or log_fd:
if log_fname:
- handler = logging.FileHandler(log_fname)
+ handler = logging.FileHandler(log_fname) # type: Optional[logging.Handler]
else:
handler = logging.StreamHandler(log_fd)
@@ -72,16 +75,8 @@
formatter = logging.Formatter(log_format, datefmt="%H:%M:%S")
handler.setFormatter(formatter)
handler.setLevel(logging.DEBUG)
+
logger.addHandler(handler)
logger_api.addHandler(handler)
- else:
- fh = None
- logger_api.addHandler(sh)
- logger_api.setLevel(logging.WARNING)
-
- logger = logging.getLogger('paramiko')
- logger.setLevel(logging.WARNING)
- # logger.addHandler(sh)
- if fh is not None:
- logger.addHandler(fh)
+ logging.getLogger('paramiko').setLevel(logging.WARNING)
diff --git a/wally/main.py b/wally/main.py
index bfb9af7..360506e 100644
--- a/wally/main.py
+++ b/wally/main.py
@@ -40,7 +40,7 @@
def list_results(path: str) -> List[Tuple[str, str, str, str]]:
- results = []
+ results = [] # type: List[Tuple[float, str, str, str, str]]
for dir_name in os.listdir(path):
full_path = os.path.join(path, dir_name)
@@ -50,14 +50,14 @@
except Exception as exc:
logger.warning("Can't load folder {}. Error {}".format(full_path, exc))
- comment = stor['info/comment']
- run_uuid = stor['info/run_uuid']
- run_time = stor['info/run_time']
+ comment = cast(str, stor['info/comment'])
+ run_uuid = cast(str, stor['info/run_uuid'])
+ run_time = cast(float, stor['info/run_time'])
test_types = ""
- results.append((time.ctime(run_time),
+ results.append((run_time,
run_uuid,
test_types,
- run_time,
+ time.ctime(run_time),
'-' if comment is None else comment))
results.sort()
@@ -147,8 +147,7 @@
with open(file_name) as fd:
config = Config(yaml_load(fd.read())) # type: ignore
- config.run_uuid = utils.get_uniq_path_uuid(config.results_dir)
- config.storage_url = os.path.join(config.results_dir, config.run_uuid)
+ config.storage_url, config.run_uuid = utils.get_uniq_path_uuid(config.results_dir)
config.comment = opts.comment
config.keep_vm = opts.keep_vm
config.no_tests = opts.no_tests
@@ -160,7 +159,7 @@
storage = make_storage(config.storage_url)
- storage['config'] = config
+ storage['config'] = config # type: ignore
stages.extend([
run_test.discover_stage,
@@ -170,11 +169,10 @@
run_test.connect_stage])
if config.get("collect_info", True):
- stages.append(run_test.collect_hw_info_stage)
+ stages.append(run_test.collect_info_stage)
stages.extend([
run_test.run_tests_stage,
- run_test.store_raw_results_stage,
])
elif opts.subparser_name == 'ls':
@@ -191,10 +189,10 @@
config.settings_dir = get_config_path(config, opts.settings_dir)
elif opts.subparser_name == 'compare':
- x = run_test.load_data_from_path(opts.data_path1)
- y = run_test.load_data_from_path(opts.data_path2)
- print(run_test.IOPerfTest.format_diff_for_console(
- [x['io'][0], y['io'][0]]))
+ # x = run_test.load_data_from_path(opts.data_path1)
+ # y = run_test.load_data_from_path(opts.data_path2)
+ # print(run_test.IOPerfTest.format_diff_for_console(
+ # [x['io'][0], y['io'][0]]))
return 0
if not getattr(opts, "no_report", False):
diff --git a/wally/meta_info.py b/wally/meta_info.py
index 3de33cb..e0e2b30 100644
--- a/wally/meta_info.py
+++ b/wally/meta_info.py
@@ -1,70 +1,54 @@
-from typing import Any, Dict
-from urllib.parse import urlparse
+from typing import Any, Dict, Union, List
+from .fuel_rest_api import KeystoneAuth, FuelInfo
-from .keystone import KeystoneAuth
+def total_lab_info(nodes: List[Dict[str, Any]]) -> Dict[str, int]:
+ lab_data = {'nodes_count': len(nodes),
+ 'total_memory': 0,
+ 'total_disk': 0,
+ 'processor_count': 0} # type: Dict[str, int]
-
-def total_lab_info(data: Dict[str, Any]) -> Dict[str, Any]:
- lab_data = {}
- lab_data['nodes_count'] = len(data['nodes'])
- lab_data['total_memory'] = 0
- lab_data['total_disk'] = 0
- lab_data['processor_count'] = 0
-
- for node in data['nodes']:
+ for node in nodes:
lab_data['total_memory'] += node['memory']['total']
lab_data['processor_count'] += len(node['processors'])
for disk in node['disks']:
lab_data['total_disk'] += disk['size']
- def to_gb(x):
- return x / (1024 ** 3)
+ lab_data['total_memory'] /= (1024 ** 3)
+ lab_data['total_disk'] /= (1024 ** 3)
- lab_data['total_memory'] = to_gb(lab_data['total_memory'])
- lab_data['total_disk'] = to_gb(lab_data['total_disk'])
return lab_data
-def collect_lab_data(url: str, cred: Dict[str, str]) -> Dict[str, Any]:
- u = urlparse(url)
- keystone = KeystoneAuth(root_url=url, creds=cred, admin_node_ip=u.hostname)
- lab_info = keystone.do(method='get', path="/api/nodes")
- fuel_version = keystone.do(method='get', path="/api/version/")
+def collect_lab_data(url: str, cred: Dict[str, str]) -> Dict[str, Union[List[Dict[str, str]], str]]:
+ finfo = FuelInfo(KeystoneAuth(url, cred))
- nodes = []
+ nodes = [] # type: List[Dict[str, str]]
result = {}
- for node in lab_info:
- # TODO(koder): give p,i,d,... vars meaningful names
- d = {}
- d['name'] = node['name']
- p = []
- i = []
- disks = []
- devices = []
+ for node in finfo.get_nodes():
+ node_info = {
+ 'name': node['name'],
+ 'processors': [],
+ 'interfaces': [],
+ 'disks': [],
+ 'devices': [],
+ 'memory': node['meta']['memory'].copy()
+ }
for processor in node['meta']['cpu']['spec']:
- p.append(processor)
+ node_info['processors'].append(processor)
for iface in node['meta']['interfaces']:
- i.append(iface)
-
- m = node['meta']['memory'].copy()
+ node_info['interfaces'].append(iface)
for disk in node['meta']['disks']:
- disks.append(disk)
+ node_info['disks'].append(disk)
- d['memory'] = m
- d['disks'] = disks
- d['devices'] = devices
- d['interfaces'] = i
- d['processors'] = p
-
- nodes.append(d)
+ nodes.append(node_info)
result['nodes'] = nodes
- result['fuel_version'] = fuel_version['release']
-
+ result['fuel_version'] = finfo.get_version()
+ result['total_info'] = total_lab_info(nodes)
return result
diff --git a/wally/node.py b/wally/node.py
index 3bc52fc..2b58571 100644
--- a/wally/node.py
+++ b/wally/node.py
@@ -4,34 +4,36 @@
import socket
import logging
import subprocess
-from typing import Callable
+from typing import Union, cast, Any
+
import agent
+import paramiko
-from .node_interfaces import IRPCNode, NodeInfo, ISSHHost, RPCBeforeConnCallback
-from .ssh_utils import parse_ssh_uri, ssh_connect
+
+from .node_interfaces import IRPCNode, NodeInfo, ISSHHost
+from .ssh import connect as ssh_connect
logger = logging.getLogger("wally")
class SSHHost(ISSHHost):
- def __init__(self, ssh_conn, node_name: str, ip: str) -> None:
- self.conn = ssh_conn
- self.node_name = node_name
- self.ip = ip
-
- def get_ip(self) -> str:
- return self.ip
+ def __init__(self, conn: paramiko.SSHClient, info: NodeInfo) -> None:
+ self.conn = conn
+ self.info = info
def __str__(self) -> str:
- return self.node_name
+ return self.info.node_id()
def put_to_file(self, path: str, content: bytes) -> None:
with self.conn.open_sftp() as sftp:
with sftp.open(path, "wb") as fd:
fd.write(content)
+ def disconnect(self):
+ self.conn.close()
+
def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
transport = self.conn.get_transport()
session = transport.open_session()
@@ -102,12 +104,16 @@
return stdout_data
+ def disconnect(self):
+ pass
-def connect(conn_url: str, conn_timeout: int = 60) -> ISSHHost:
- if conn_url == 'local':
+
+def connect(info: Union[str, NodeInfo], conn_timeout: int = 60) -> ISSHHost:
+ if info == 'local':
return LocalHost()
else:
- return SSHHost(*ssh_connect(parse_ssh_uri(conn_url), conn_timeout))
+ info_c = cast(NodeInfo, info)
+ return SSHHost(ssh_connect(info_c.ssh_creds, conn_timeout), info_c)
class RPCNode(IRPCNode):
@@ -117,46 +123,50 @@
self.info = info
self.conn = conn
- # if self.ssh_conn_url is not None:
- # self.ssh_cred = parse_ssh_uri(self.ssh_conn_url)
- # self.node_id = "{0.host}:{0.port}".format(self.ssh_cred)
- # else:
- # self.ssh_cred = None
- # self.node_id = None
-
def __str__(self) -> str:
- return "<Node: url={!r} roles={!r} hops=/>".format(self.info.ssh_conn_url, ",".join(self.info.roles))
+ return "<Node: url={!s} roles={!r} hops=/>".format(self.info.ssh_creds, ",".join(self.info.roles))
def __repr__(self) -> str:
return str(self)
- def get_file_content(self, path: str) -> str:
+ def get_file_content(self, path: str) -> bytes:
raise NotImplementedError()
- def forward_port(self, ip: str, remote_port: int, local_port: int = None) -> int:
- raise NotImplementedError()
+ def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
+ raise NotImplemented()
+
+ def copy_file(self, local_path: str, remote_path: str = None) -> str:
+ raise NotImplemented()
+
+ def put_to_file(self, path: str, content: bytes) -> None:
+ raise NotImplemented()
+
+ def get_interface(self, ip: str) -> str:
+ raise NotImplemented()
+
+ def stat_file(self, path: str) -> Any:
+ raise NotImplemented()
+
+ def disconnect(self) -> str:
+ raise NotImplemented()
-def setup_rpc(node: ISSHHost, rpc_server_code: bytes, port: int = 0,
- rpc_conn_callback: RPCBeforeConnCallback = None) -> IRPCNode:
+def setup_rpc(node: ISSHHost, rpc_server_code: bytes, port: int = 0) -> IRPCNode:
code_file = node.run("mktemp").strip()
log_file = node.run("mktemp").strip()
node.put_to_file(code_file, rpc_server_code)
cmd = "python {code_file} server --listen-addr={listen_ip}:{port} --daemon " + \
"--show-settings --stdout-file={out_file}"
+
+ ip = node.info.ssh_creds.addr.host
+
params_js = node.run(cmd.format(code_file=code_file,
- listen_addr=node.get_ip(),
+ listen_addr=ip,
out_file=log_file,
port=port)).strip()
params = json.loads(params_js)
params['log_file'] = log_file
-
- if rpc_conn_callback:
- ip, port = rpc_conn_callback(node, port)
- else:
- ip = node.get_ip()
- port = int(params['addr'].split(":")[1])
-
+ port = int(params['addr'].split(":")[1])
rpc_conn = agent.connect((ip, port))
node.info.params.update(params)
return RPCNode(rpc_conn, node.info)
diff --git a/wally/node_interfaces.py b/wally/node_interfaces.py
index e0b56aa..e75c2a3 100644
--- a/wally/node_interfaces.py
+++ b/wally/node_interfaces.py
@@ -1,27 +1,27 @@
import abc
-from typing import Any, Set, Optional, List, Dict, Callable
+from typing import Any, Set, Optional, List, Dict, Callable, NamedTuple
+from .ssh_utils import ConnCreds
+from .common_types import IPAddr
+
+
+RPCCreds = NamedTuple("RPCCreds", [("addr", IPAddr), ("key_file", str), ("cert_file", str)])
class NodeInfo:
- """Node information object, result of dicovery process or config parsing"""
+ """Node information object, result of discovery process or config parsing"""
+ def __init__(self, ssh_creds: ConnCreds, roles: Set[str]) -> None:
- def __init__(self,
- ssh_conn_url: str,
- roles: Set[str],
- hops: List['NodeInfo'] = None,
- ssh_key: bytes = None) -> None:
-
- self.hops = [] # type: List[NodeInfo]
- if hops is not None:
- self.hops = hops
-
- self.ssh_conn_url = ssh_conn_url # type: str
- self.rpc_conn_url = None # type: str
- self.roles = roles # type: Set[str]
+ # ssh credentials
+ self.ssh_creds = ssh_creds
+ # credentials for RPC connection
+ self.rpc_creds = None # type: Optional[RPCCreds]
+ self.roles = roles
self.os_vm_id = None # type: Optional[int]
- self.ssh_key = ssh_key # type: Optional[bytes]
self.params = {} # type: Dict[str, Any]
+ def node_id(self) -> str:
+ return "{0.host}:{0.port}".format(self.ssh_creds.addr)
+
class ISSHHost(metaclass=abc.ABCMeta):
"""Minimal interface, required to setup RPC connection"""
@@ -32,17 +32,24 @@
pass
@abc.abstractmethod
- def get_ip(self) -> str:
+ def __str__(self) -> str:
pass
@abc.abstractmethod
- def __str__(self) -> str:
+ def disconnect(self) -> None:
pass
@abc.abstractmethod
def put_to_file(self, path: str, content: bytes) -> None:
pass
+ def __enter__(self) -> 'ISSHHost':
+ return self
+
+ def __exit__(self, x, y, z) -> bool:
+ self.disconnect()
+ return False
+
class IRPCNode(metaclass=abc.ABCMeta):
"""Remote filesystem interface"""
@@ -65,10 +72,6 @@
pass
@abc.abstractmethod
- def forward_port(self, ip: str, remote_port: int, local_port: int = None) -> int:
- pass
-
- @abc.abstractmethod
def get_interface(self, ip: str) -> str:
pass
@@ -77,14 +80,13 @@
pass
@abc.abstractmethod
- def node_id(self) -> str:
- pass
-
-
- @abc.abstractmethod
def disconnect(self) -> str:
pass
+ def __enter__(self) -> 'IRPCNode':
+ return self
+ def __exit__(self, x, y, z) -> bool:
+ self.disconnect()
+ return False
-RPCBeforeConnCallback = Callable[[NodeInfo, int], None]
\ No newline at end of file
diff --git a/wally/run_test.py b/wally/run_test.py
index e62cfb1..d8ef685 100755
--- a/wally/run_test.py
+++ b/wally/run_test.py
@@ -1,16 +1,14 @@
import os
-import time
import logging
-import functools
import contextlib
-import collections
-from typing import List, Dict, Iterable, Any, Iterator, Mapping, Callable, Tuple, Optional, Union, cast
+from typing import List, Dict, Iterable, Iterator, Tuple, Optional, Union, cast
from concurrent.futures import ThreadPoolExecutor, Future
from .node_interfaces import NodeInfo, IRPCNode
from .test_run_class import TestRun
from .discover import discover
from . import pretty_yaml, utils, report, ssh_utils, start_vms, hw_info
+from .node import setup_rpc, connect
from .config import ConfigBlock, Config
from .suits.mysql import MysqlTest
@@ -31,18 +29,16 @@
logger = logging.getLogger("wally")
-def connect_all(nodes_info: List[NodeInfo],
- pool: ThreadPoolExecutor,
- conn_timeout: int = 30,
- rpc_conn_callback: ssh_utils.RPCBeforeConnCallback = None) -> List[IRPCNode]:
+def connect_all(nodes_info: List[NodeInfo], pool: ThreadPoolExecutor, conn_timeout: int = 30) -> List[IRPCNode]:
"""Connect to all nodes, log errors"""
logger.info("Connecting to %s nodes", len(nodes_info))
def connect_ext(node_info: NodeInfo) -> Tuple[bool, Union[IRPCNode, NodeInfo]]:
try:
- ssh_node = ssh_utils.connect(node_info.ssh_conn_url, conn_timeout=conn_timeout)
- return True, ssh_utils.setup_rpc(ssh_node, rpc_conn_callback=rpc_conn_callback)
+ ssh_node = connect(node_info, conn_timeout=conn_timeout)
+ # TODO(koder): need to pass all required rpc bytes to this call
+ return True, setup_rpc(ssh_node, b"")
except Exception as exc:
logger.error("During connect to {}: {!s}".format(node, exc))
return False, node_info
@@ -77,16 +73,16 @@
return ready
-def collect_info_stage(ctx: TestRun, nodes: Iterable[IRPCNode]) -> None:
- futures = {} # type: Dict[str, Future]]
+def collect_info_stage(ctx: TestRun) -> None:
+ futures = {} # type: Dict[str, Future]
with ctx.get_pool() as pool:
- for node in nodes:
- hw_info_path = "hw_info/{}".format(node.node_id())
+ for node in ctx.nodes:
+ hw_info_path = "hw_info/{}".format(node.info.node_id())
if hw_info_path not in ctx.storage:
futures[hw_info_path] = pool.submit(hw_info.get_hw_info, node), node
- sw_info_path = "sw_info/{}".format(node.node_id())
+ sw_info_path = "sw_info/{}".format(node.info.node_id())
if sw_info_path not in ctx.storage:
futures[sw_info_path] = pool.submit(hw_info.get_sw_info, node)
@@ -95,7 +91,7 @@
@contextlib.contextmanager
-def suspend_vm_nodes_ctx(unused_nodes: List[IRPCNode]) -> Iterator[List[int]]:
+def suspend_vm_nodes_ctx(ctx: TestRun, unused_nodes: List[IRPCNode]) -> Iterator[List[int]]:
pausable_nodes_ids = [cast(int, node.info.os_vm_id)
for node in unused_nodes
@@ -108,14 +104,16 @@
if pausable_nodes_ids:
logger.debug("Try to pause {} unused nodes".format(len(pausable_nodes_ids)))
- start_vms.pause(pausable_nodes_ids)
+ with ctx.get_pool() as pool:
+ start_vms.pause(ctx.os_connection, pausable_nodes_ids, pool)
try:
yield pausable_nodes_ids
finally:
if pausable_nodes_ids:
logger.debug("Unpausing {} nodes".format(len(pausable_nodes_ids)))
- start_vms.unpause(pausable_nodes_ids)
+ with ctx.get_pool() as pool:
+ start_vms.unpause(ctx.os_connection, pausable_nodes_ids, pool)
def run_tests(ctx: TestRun, test_block: ConfigBlock, nodes: List[IRPCNode]) -> None:
@@ -133,7 +131,7 @@
# select test nodes
if vm_count is None:
curr_test_nodes = test_nodes
- unused_nodes = []
+ unused_nodes = [] # type: List[IRPCNode]
else:
curr_test_nodes = test_nodes[:vm_count]
unused_nodes = test_nodes[vm_count:]
@@ -147,23 +145,23 @@
# suspend all unused virtual nodes
if ctx.config.get('suspend_unused_vms', True):
- suspend_ctx = suspend_vm_nodes_ctx(unused_nodes)
+ suspend_ctx = suspend_vm_nodes_ctx(ctx, unused_nodes)
else:
suspend_ctx = utils.empty_ctx()
+ resumable_nodes_ids = [cast(int, node.info.os_vm_id)
+ for node in curr_test_nodes
+ if node.info.os_vm_id is not None]
+
+ if resumable_nodes_ids:
+ logger.debug("Check and unpause {} nodes".format(len(resumable_nodes_ids)))
+
+ with ctx.get_pool() as pool:
+ start_vms.unpause(ctx.os_connection, resumable_nodes_ids, pool)
+
with suspend_ctx:
- resumable_nodes_ids = [cast(int, node.info.os_vm_id)
- for node in curr_test_nodes
- if node.info.os_vm_id is not None]
-
- if resumable_nodes_ids:
- logger.debug("Check and unpause {} nodes".format(len(resumable_nodes_ids)))
- start_vms.unpause(resumable_nodes_ids)
-
test_cls = TOOL_TYPE_MAPPER[name]
-
remote_dir = ctx.config.default_test_local_folder.format(name=name, uuid=ctx.config.run_uuid)
-
test_cfg = TestConfig(test_cls.__name__,
params=params,
run_uuid=ctx.config.run_uuid,
@@ -178,60 +176,81 @@
ctx.clear_calls_stack.append(disconnect_stage)
with ctx.get_pool() as pool:
- ctx.nodes = connect_all(ctx.nodes_info, pool, rpc_conn_callback=ctx.before_conn_callback)
+ ctx.nodes = connect_all(ctx.nodes_info, pool)
def discover_stage(ctx: TestRun) -> None:
"""discover clusters and nodes stage"""
+ # TODO(koder): Properly store discovery info and check if it available to skip phase
+
discover_info = ctx.config.get('discover')
if discover_info:
- discover_objs = [i.strip() for i in discover_info.strip().split(",")]
+ if "discovered_nodes" in ctx.storage:
+ nodes = ctx.storage.load_list("discovered_nodes", NodeInfo)
+ ctx.fuel_openstack_creds = ctx.storage.load("fuel_openstack_creds", start_vms.OSCreds)
+ else:
+ discover_objs = [i.strip() for i in discover_info.strip().split(",")]
- nodes_info = discover.discover(ctx, discover_objs,
- ctx.config.clouds,
- ctx.storage,
- not ctx.config.dont_discover_nodes)
+ ctx.fuel_openstack_creds, nodes = discover.discover(discover_objs,
+ ctx.config.clouds,
+ not ctx.config.dont_discover_nodes)
- ctx.nodes_info.extend(nodes_info)
+ ctx.storage["fuel_openstack_creds"] = ctx.fuel_openstack_creds # type: ignore
+ ctx.storage["discovered_nodes"] = nodes # type: ignore
+ ctx.nodes_info.extend(nodes)
for url, roles in ctx.config.get('explicit_nodes', {}).items():
- ctx.nodes_info.append(NodeInfo(url, set(roles.split(","))))
+ creds = ssh_utils.parse_ssh_uri(url)
+ roles = set(roles.split(","))
+ ctx.nodes_info.append(NodeInfo(creds, roles))
def save_nodes_stage(ctx: TestRun) -> None:
"""Save nodes list to file"""
- ctx.storage['nodes'] = ctx.nodes_info
+ ctx.storage['nodes'] = ctx.nodes_info # type: ignore
+
+
+def ensure_connected_to_openstack(ctx: TestRun) -> None:
+ if not ctx.os_connection is None:
+ if ctx.os_creds is None:
+ ctx.os_creds = get_OS_credentials(ctx)
+ ctx.os_connection = start_vms.os_connect(ctx.os_creds)
def reuse_vms_stage(ctx: TestRun) -> None:
- vms_patterns = ctx.config.get('clouds/openstack/vms', [])
- private_key_path = get_vm_keypair(ctx.config)['keypair_file_private']
+ if "reused_nodes" in ctx.storage:
+ ctx.nodes_info.extend(ctx.storage.load_list("reused_nodes", NodeInfo))
+ else:
+ reused_nodes = []
+ vms_patterns = ctx.config.get('clouds/openstack/vms', [])
+ private_key_path = get_vm_keypair_path(ctx.config)[0]
- for creds in vms_patterns:
- user_name, vm_name_pattern = creds.split("@", 1)
- msg = "Vm like {} lookup failed".format(vm_name_pattern)
+ for creds in vms_patterns:
+ user_name, vm_name_pattern = creds.split("@", 1)
+ msg = "Vm like {} lookup failed".format(vm_name_pattern)
- with utils.LogError(msg):
- msg = "Looking for vm with name like {0}".format(vm_name_pattern)
- logger.debug(msg)
+ with utils.LogError(msg):
+ msg = "Looking for vm with name like {0}".format(vm_name_pattern)
+ logger.debug(msg)
- if not start_vms.is_connected():
- os_creds = get_OS_credentials(ctx)
- else:
- os_creds = None
+ ensure_connected_to_openstack(ctx)
- conn = start_vms.nova_connect(os_creds)
- for ip, vm_id in start_vms.find_vms(conn, vm_name_pattern):
- conn_url = "ssh://{user}@{ip}::{key}".format(user=user_name,
- ip=ip,
- key=private_key_path)
- node_info = NodeInfo(conn_url, ['testnode'])
- node_info.os_vm_id = vm_id
- ctx.nodes_info.append(node_info)
+ for ip, vm_id in start_vms.find_vms(ctx.os_connection, vm_name_pattern):
+ creds = ssh_utils.ConnCreds(host=ip, user=user_name, key_file=private_key_path)
+ node_info = NodeInfo(creds, {'testnode'})
+ node_info.os_vm_id = vm_id
+ reused_nodes.append(node_info)
+ ctx.nodes_info.append(node_info)
+
+ ctx.storage["reused_nodes"] = reused_nodes # type: ignore
-def get_OS_credentials(ctx: TestRun) -> None:
+def get_OS_credentials(ctx: TestRun) -> start_vms.OSCreds:
+
+ if "openstack_openrc" in ctx.storage:
+ return ctx.storage.load("openstack_openrc", start_vms.OSCreds)
+
creds = None
os_creds = None
force_insecure = False
@@ -245,7 +264,7 @@
os_creds = start_vms.OSCreds(*creds_tuple)
elif 'ENV' in os_cfg:
logger.info("Using OS credentials from shell environment")
- os_creds = start_vms.ostack_get_creds()
+ os_creds = start_vms.get_openstack_credentials()
elif 'OS_TENANT_NAME' in os_cfg:
logger.info("Using predefined credentials")
os_creds = start_vms.OSCreds(os_cfg['OS_USERNAME'].strip(),
@@ -257,11 +276,10 @@
elif 'OS_INSECURE' in os_cfg:
force_insecure = os_cfg.get('OS_INSECURE', False)
- if os_creds is None and 'fuel' in cfg.clouds and \
- 'openstack_env' in cfg.clouds['fuel'] and \
- ctx.fuel_openstack_creds is not None:
+ if os_creds is None and 'fuel' in cfg.clouds and 'openstack_env' in cfg.clouds['fuel'] and \
+ ctx.fuel_openstack_creds is not None:
logger.info("Using fuel creds")
- creds = start_vms.OSCreds(**ctx.fuel_openstack_creds)
+ creds = ctx.fuel_openstack_creds
elif os_creds is None:
logger.error("Can't found OS credentials")
raise utils.StopTestError("Can't found OS credentials", None)
@@ -279,10 +297,11 @@
logger.debug(("OS_CREDS: user={0.name} tenant={0.tenant} " +
"auth_url={0.auth_url} insecure={0.insecure}").format(creds))
+ ctx.storage["openstack_openrc"] = creds # type: ignore
return creds
-def get_vm_keypair(cfg: Config) -> Tuple[str, str]:
+def get_vm_keypair_path(cfg: Config) -> Tuple[str, str]:
key_name = cfg.vm_configs['keypair_name']
private_path = os.path.join(cfg.settings_dir, key_name + "_private.pem")
public_path = os.path.join(cfg.settings_dir, key_name + "_public.pub")
@@ -291,52 +310,54 @@
@contextlib.contextmanager
def create_vms_ctx(ctx: TestRun, vm_config: ConfigBlock, already_has_count: int = 0) -> Iterator[List[NodeInfo]]:
- if vm_config['count'].startswith('='):
- count = int(vm_config['count'][1:])
- if count <= already_has_count:
- logger.debug("Not need new vms")
- yield []
- return
- if not start_vms.is_connected():
- os_creds = get_OS_credentials(ctx)
- else:
- os_creds = None
+ if 'spawned_vm_ids' in ctx.storage:
+ os_nodes_ids = ctx.storage.get('spawned_vm_ids', []) # type: List[int]
+ new_nodes = [] # type: List[NodeInfo]
- nova = start_vms.nova_connect(os_creds)
-
- os_nodes_ids = ctx.storage.get('spawned_vm_ids', []) # # type: List[int]
- new_nodes = [] # type: List[IRPCNode]
-
- if not os_nodes_ids:
- params = ctx.config.vm_configs[vm_config['cfg_name']].copy()
- params.update(vm_config)
- params.update(get_vm_keypair(ctx.config))
- params['group_name'] = ctx.config.run_uuid
- params['keypair_name'] = ctx.config.vm_configs['keypair_name']
-
- if not vm_config.get('skip_preparation', False):
- logger.info("Preparing openstack")
- start_vms.prepare_os(nova, params, os_creds)
- else:
# TODO(koder): reconnect to old VM's
raise NotImplementedError("Reconnect to old vms is not implemented")
+ else:
+ os_nodes_ids = []
+ new_nodes = []
+ no_spawn = False
+ if vm_config['count'].startswith('='):
+ count = int(vm_config['count'][1:])
+ if count <= already_has_count:
+ logger.debug("Not need new vms")
+ no_spawn = True
- already_has_count += len(os_nodes_ids)
- old_nodes = ctx.nodes[:]
+ if not no_spawn:
+ ensure_connected_to_openstack(ctx)
+ params = ctx.config.vm_configs[vm_config['cfg_name']].copy()
+ params.update(vm_config)
+ params.update(get_vm_keypair_path(ctx.config))
+ params['group_name'] = ctx.config.run_uuid
+ params['keypair_name'] = ctx.config.vm_configs['keypair_name']
- for node_info, node_id in start_vms.launch_vms(nova, params, already_has_count):
- node_info.roles.append('testnode')
- os_nodes_ids.append(node_id)
- new_nodes.append(node_info)
- ctx.storage['spawned_vm_ids'] = os_nodes_ids
+ if not vm_config.get('skip_preparation', False):
+ logger.info("Preparing openstack")
+ start_vms.prepare_os(ctx.os_connection, params)
- yield new_nodes
+ with ctx.get_pool() as pool:
+ for node_info in start_vms.launch_vms(ctx.os_connection, params, pool, already_has_count):
+ node_info.roles.add('testnode')
+ os_nodes_ids.append(node_info.os_vm_id)
+ new_nodes.append(node_info)
- # keep nodes in case of error for future test restart
- if not ctx.config.keep_vm:
- shut_down_vms_stage(ctx, os_nodes_ids)
- ctx.storage['spawned_vm_ids'] = []
+ ctx.storage['spawned_vm_ids'] = os_nodes_ids # type: ignore
+ yield new_nodes
+
+ # keep nodes in case of error for future test restart
+ if not ctx.config.keep_vm:
+ shut_down_vms_stage(ctx, os_nodes_ids)
+
+ del ctx.storage['spawned_vm_ids']
+
+
+@contextlib.contextmanager
+def sensor_monitoring(ctx: TestRun, cfg: ConfigBlock, nodes: List[IRPCNode]) -> Iterator[None]:
+ yield
def run_tests_stage(ctx: TestRun) -> None:
@@ -362,15 +383,18 @@
vm_ctx = utils.empty_ctx([])
tests = [group]
- with vm_ctx as new_nodes: # type: List[NodeInfo]
+ # make mypy happy
+ new_nodes = [] # type: List[NodeInfo]
+
+ with vm_ctx as new_nodes:
if new_nodes:
with ctx.get_pool() as pool:
- new_rpc_nodes = connect_all(new_nodes, pool, rpc_conn_callback=ctx.before_conn_callback)
+ new_rpc_nodes = connect_all(new_nodes, pool)
test_nodes = ctx.nodes + new_rpc_nodes
if ctx.config.get('sensors'):
- sensor_ctx = sensor_monitoring(ctx.config.get('sensors'), test_nodes)
+ sensor_ctx = sensor_monitoring(ctx, ctx.config.get('sensors'), test_nodes)
else:
sensor_ctx = utils.empty_ctx([])
@@ -386,13 +410,13 @@
def shut_down_vms_stage(ctx: TestRun, nodes_ids: List[int]) -> None:
if nodes_ids:
logger.info("Removing nodes")
- start_vms.clear_nodes(nodes_ids)
+ start_vms.clear_nodes(ctx.os_connection, nodes_ids)
logger.info("Nodes has been removed")
def clear_enviroment(ctx: TestRun) -> None:
shut_down_vms_stage(ctx, ctx.storage.get('spawned_vm_ids', []))
- ctx.storage['spawned_vm_ids'] = []
+ ctx.storage['spawned_vm_ids'] = [] # type: ignore
def disconnect_stage(ctx: TestRun) -> None:
diff --git a/wally/ssh.py b/wally/ssh.py
new file mode 100644
index 0000000..f1786f8
--- /dev/null
+++ b/wally/ssh.py
@@ -0,0 +1,132 @@
+import time
+import errno
+import socket
+import logging
+import os.path
+import selectors
+from io import BytesIO
+from typing import cast, Dict, List, Set
+
+import paramiko
+
+from . import utils
+from .ssh_utils import ConnCreds, IPAddr
+
+logger = logging.getLogger("wally")
+NODE_KEYS = {} # type: Dict[IPAddr, paramiko.RSAKey]
+
+
+def set_key_for_node(host_port: IPAddr, key: bytes) -> None:
+ with BytesIO(key) as sio:
+ NODE_KEYS[host_port] = paramiko.RSAKey.from_private_key(sio)
+
+
+def connect(creds: ConnCreds,
+ conn_timeout: int = 60,
+ tcp_timeout: int = 15,
+ default_banner_timeout: int = 30) -> paramiko.SSHClient:
+
+ ssh = paramiko.SSHClient()
+ ssh.load_host_keys('/dev/null')
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ ssh.known_hosts = None
+
+ end_time = time.time() + conn_timeout # type: float
+
+ while True:
+ try:
+ time_left = end_time - time.time()
+ c_tcp_timeout = min(tcp_timeout, time_left)
+
+ banner_timeout_arg = {} # type: Dict[str, int]
+ if paramiko.__version_info__ >= (1, 15, 2):
+ banner_timeout_arg['banner_timeout'] = int(min(default_banner_timeout, time_left))
+
+ if creds.passwd is not None:
+ ssh.connect(creds.addr.host,
+ timeout=c_tcp_timeout,
+ username=creds.user,
+ password=cast(str, creds.passwd),
+ port=creds.addr.port,
+ allow_agent=False,
+ look_for_keys=False,
+ **banner_timeout_arg)
+ elif creds.key_file is not None:
+ ssh.connect(creds.addr.host,
+ username=creds.user,
+ timeout=c_tcp_timeout,
+ key_filename=cast(str, creds.key_file),
+ look_for_keys=False,
+ port=creds.addr.port,
+ **banner_timeout_arg)
+ elif creds.key is not None:
+ with BytesIO(creds.key) as sio:
+ ssh.connect(creds.addr.host,
+ username=creds.user,
+ timeout=c_tcp_timeout,
+ pkey=paramiko.RSAKey.from_private_key(sio),
+ look_for_keys=False,
+ port=creds.addr.port,
+ **banner_timeout_arg)
+ elif (creds.addr.host, creds.addr.port) in NODE_KEYS:
+ ssh.connect(creds.addr.host,
+ username=creds.user,
+ timeout=c_tcp_timeout,
+ pkey=NODE_KEYS[creds.addr],
+ look_for_keys=False,
+ port=creds.addr.port,
+ **banner_timeout_arg)
+ else:
+ key_file = os.path.expanduser('~/.ssh/id_rsa')
+ ssh.connect(creds.addr.host,
+ username=creds.user,
+ timeout=c_tcp_timeout,
+ key_filename=key_file,
+ look_for_keys=False,
+ port=creds.addr.port,
+ **banner_timeout_arg)
+ return ssh
+ except paramiko.PasswordRequiredException:
+ raise
+ except (socket.error, paramiko.SSHException):
+ if time.time() > end_time:
+ raise
+ time.sleep(1)
+
+
+def wait_ssh_available(addrs: List[IPAddr],
+ timeout: int = 300,
+ tcp_timeout: float = 1.0) -> None:
+
+ addrs_set = set(addrs) # type: Set[IPAddr]
+
+ for _ in utils.Timeout(timeout):
+ selector = selectors.DefaultSelector() # type: selectors.BaseSelector
+ with selector:
+ for addr in addrs_set:
+ sock = socket.socket()
+ sock.setblocking(False)
+ try:
+ sock.connect(addr)
+ except BlockingIOError:
+ pass
+ selector.register(sock, selectors.EVENT_READ, data=addr)
+
+ etime = time.time() + tcp_timeout
+ ltime = etime - time.time()
+ while ltime > 0:
+ # convert to greater or equal integer
+ for key, _ in selector.select(timeout=int(ltime + 0.99999)):
+ selector.unregister(key.fileobj)
+ try:
+ key.fileobj.getpeername() # type: ignore
+ addrs_set.remove(key.data)
+ except OSError as exc:
+ if exc.errno == errno.ENOTCONN:
+ pass
+ ltime = etime - time.time()
+
+ if not addrs_set:
+ break
+
+
diff --git a/wally/ssh_utils.py b/wally/ssh_utils.py
index 7728dfd..43ba44a 100644
--- a/wally/ssh_utils.py
+++ b/wally/ssh_utils.py
@@ -1,21 +1,9 @@
import re
-import time
-import errno
-import socket
-import logging
-import os.path
import getpass
-import selectors
-from io import BytesIO
-from typing import Union, Optional, cast, Dict, List, Tuple
-
-import paramiko
-
-from . import utils
+from typing import List
-logger = logging.getLogger("wally")
-IPAddr = Tuple[str, int]
+from .common_types import IPAddr
class URIsNamespace:
@@ -54,17 +42,16 @@
class ConnCreds:
- conn_uri_attrs = ("user", "passwd", "host", "port", "key_file")
-
- def __init__(self, host: str, user: str, passwd: str = None, port: int = 22, key_file: str = None) -> None:
+ def __init__(self, host: str, user: str, passwd: str = None, port: int = 22,
+ key_file: str = None, key: bytes = None) -> None:
self.user = user
self.passwd = passwd
- self.host = host
- self.port = port
+ self.addr = IPAddr(host, port)
self.key_file = key_file
+ self.key = key
def __str__(self) -> str:
- return str(self.__dict__)
+ return "{}@{}:{}".format(self.user, self.addr.host, self.addr.port)
def parse_ssh_uri(uri: str) -> ConnCreds:
@@ -87,107 +74,3 @@
raise ValueError("Can't parse {0!r} as ssh uri value".format(uri))
-NODE_KEYS = {} # type: Dict[IPAddr, paramiko.RSAKey]
-
-
-def set_key_for_node(host_port: IPAddr, key: bytes) -> None:
- with BytesIO(key) as sio:
- NODE_KEYS[host_port] = paramiko.RSAKey.from_private_key(sio)
-
-
-def ssh_connect(creds: ConnCreds,
- conn_timeout: int = 60,
- tcp_timeout: int = 15,
- default_banner_timeout: int = 30) -> Tuple[paramiko.SSHClient, str, str]:
-
- ssh = paramiko.SSHClient()
- ssh.load_host_keys('/dev/null')
- ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- ssh.known_hosts = None
-
- end_time = time.time() + conn_timeout # type: float
-
- while True:
- try:
- time_left = end_time - time.time()
- c_tcp_timeout = min(tcp_timeout, time_left)
-
- banner_timeout_arg = {} # type: Dict[str, int]
- if paramiko.__version_info__ >= (1, 15, 2):
- banner_timeout_arg['banner_timeout'] = int(min(default_banner_timeout, time_left))
-
- if creds.passwd is not None:
- ssh.connect(creds.host,
- timeout=c_tcp_timeout,
- username=creds.user,
- password=cast(str, creds.passwd),
- port=creds.port,
- allow_agent=False,
- look_for_keys=False,
- **banner_timeout_arg)
- elif creds.key_file is not None:
- ssh.connect(creds.host,
- username=creds.user,
- timeout=c_tcp_timeout,
- key_filename=cast(str, creds.key_file),
- look_for_keys=False,
- port=creds.port,
- **banner_timeout_arg)
- elif (creds.host, creds.port) in NODE_KEYS:
- ssh.connect(creds.host,
- username=creds.user,
- timeout=c_tcp_timeout,
- pkey=NODE_KEYS[(creds.host, creds.port)],
- look_for_keys=False,
- port=creds.port,
- **banner_timeout_arg)
- else:
- key_file = os.path.expanduser('~/.ssh/id_rsa')
- ssh.connect(creds.host,
- username=creds.user,
- timeout=c_tcp_timeout,
- key_filename=key_file,
- look_for_keys=False,
- port=creds.port,
- **banner_timeout_arg)
- return ssh, "{0.host}:{0.port}".format(creds), creds.host
- except paramiko.PasswordRequiredException:
- raise
- except (socket.error, paramiko.SSHException):
- if time.time() > end_time:
- raise
- time.sleep(1)
-
-
-def wait_ssh_available(addrs: List[IPAddr],
- timeout: int = 300,
- tcp_timeout: float = 1.0) -> None:
- addrs = set(addrs)
- for _ in utils.Timeout(timeout):
- with selectors.DefaultSelector() as selector: # type: selectors.BaseSelector
- for addr in addrs:
- sock = socket.socket()
- sock.setblocking(False)
- try:
- sock.connect(addr)
- except BlockingIOError:
- pass
- selector.register(sock, selectors.EVENT_READ, data=addr)
-
- etime = time.time() + tcp_timeout
- ltime = etime - time.time()
- while ltime > 0:
- for key, _ in selector.select(timeout=ltime):
- selector.unregister(key.fileobj)
- try:
- key.fileobj.getpeername()
- addrs.remove(key.data)
- except OSError as exc:
- if exc.errno == errno.ENOTCONN:
- pass
- ltime = etime - time.time()
-
- if not addrs:
- break
-
-
diff --git a/wally/start_vms.py b/wally/start_vms.py
index af81463..a55fdbf 100644
--- a/wally/start_vms.py
+++ b/wally/start_vms.py
@@ -16,22 +16,21 @@
from cinderclient.client import Client as CinderClient
from glanceclient import Client as GlanceClient
-
from .utils import Timeout
from .node_interfaces import NodeInfo
+from .storage import IStorable
__doc__ = """
Module used to reliably spawn set of VM's, evenly distributed across
compute servers in openstack cluster. Main functions:
- get_OS_credentials - extract openstack credentials from different sources
- nova_connect - connect to nova api
- cinder_connect - connect to cinder api
- find - find VM with given prefix in name
- prepare_OS - prepare tenant for usage
+ get_openstack_credentials - extract openstack credentials from different sources
+ os_connect - connect to nova, cinder and glance API
+ find_vms - find VM's with given prefix in name
+ prepare_os - prepare tenant for usage
launch_vms - reliably start set of VM in parallel with volumes and floating IP
- clear_all - clear VM and volumes
+ clear_nodes - clear VM and volumes
"""
@@ -79,14 +78,20 @@
return OSConnection(nova, cinder, glance)
-def find_vms(conn: OSConnection, name_prefix: str) -> Iterable[str, int]:
+def find_vms(conn: OSConnection, name_prefix: str) -> Iterable[Tuple[str, int]]:
for srv in conn.nova.servers.list():
if srv.name.startswith(name_prefix):
+
+ # need to exit after found server first external IP
+ # so have to rollout two cycles to avoid using exceptions
+ all_ip = [] # type: List[Any]
for ips in srv.addresses.values():
- for ip in ips:
- if ip.get("OS-EXT-IPS:type", None) == 'floating':
- yield ip['addr'], srv.id
- break
+ all_ip.extend(ips)
+
+ for ip in all_ip:
+ if ip.get("OS-EXT-IPS:type", None) == 'floating':
+ yield ip['addr'], srv.id
+ break
def pause(conn: OSConnection, ids: Iterable[int], executor: ThreadPoolExecutor) -> None:
@@ -326,7 +331,7 @@
return vol
-def wait_for_server_active(conn: OSConnection, server: Any, timeout: int = 300)-> None:
+def wait_for_server_active(conn: OSConnection, server: Any, timeout: int = 300) -> bool:
"""waiting till server became active
parameters:
diff --git a/wally/storage.py b/wally/storage.py
index 02de173..349995f 100644
--- a/wally/storage.py
+++ b/wally/storage.py
@@ -4,7 +4,7 @@
import os
import abc
-from typing import Any, Iterable, TypeVar, Type, IO, Tuple, Union, Dict, List
+from typing import Any, Iterable, TypeVar, Type, IO, Tuple, cast, List
class IStorable(metaclass=abc.ABCMeta):
@@ -46,11 +46,15 @@
pass
@abc.abstractmethod
+ def __delitem__(self, path: str) -> None:
+ pass
+
+ @abc.abstractmethod
def __contains__(self, path: str) -> bool:
pass
@abc.abstractmethod
- def list(self, path: str) -> Iterable[str]:
+ def list(self, path: str) -> Iterable[Tuple[bool, str]]:
pass
@abc.abstractmethod
@@ -78,35 +82,33 @@
if not os.path.isdir(self.root_path):
raise ValueError("No storage found at {!r}".format(root_path))
- def ensure_dir(self, path):
- os.makedirs(path, exist_ok=True)
-
- @abc.abstractmethod
def __setitem__(self, path: str, value: bytes) -> None:
path = os.path.join(self.root_path, path)
- self.ensure_dir(os.path.dirname(path))
+ os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as fd:
fd.write(value)
- @abc.abstractmethod
+ def __delitem__(self, path: str) -> None:
+ try:
+ os.unlink(path)
+ except FileNotFoundError:
+ pass
+
def __getitem__(self, path: str) -> bytes:
path = os.path.join(self.root_path, path)
with open(path, "rb") as fd:
return fd.read()
- @abc.abstractmethod
def __contains__(self, path: str) -> bool:
path = os.path.join(self.root_path, path)
return os.path.exists(path)
- @abc.abstractmethod
def list(self, path: str) -> Iterable[Tuple[bool, str]]:
path = os.path.join(self.root_path, path)
for entry in os.scandir(path):
if not entry.name in ('..', '.'):
yield entry.is_file(), entry.name
- @abc.abstractmethod
def get_stream(self, path: str, mode: str = "rb") -> IO:
path = os.path.join(self.root_path, path)
return open(path, mode)
@@ -114,41 +116,40 @@
class YAMLSerializer(ISerializer):
"""Serialize data to yaml"""
- pass
+ def pack(self, value: IStorable) -> bytes:
+ raise NotImplementedError()
-
-ISimpleStorable = Union[Dict, List, int, str, None, bool]
+ def unpack(self, data: bytes) -> IStorable:
+ raise NotImplementedError()
class Storage:
"""interface for storage"""
- def __init__(self, storage: ISimpleStorage, serializer: ISerializer):
+ def __init__(self, storage: ISimpleStorage, serializer: ISerializer) -> None:
self.storage = storage
self.serializer = serializer
def __setitem__(self, path: str, value: IStorable) -> None:
self.storage[path] = self.serializer.pack(value)
- @abc.abstractmethod
- def __getitem__(self, path: str) -> ISimpleStorable:
+ def __getitem__(self, path: str) -> IStorable:
return self.serializer.unpack(self.storage[path])
- @abc.abstractmethod
+ def __delitem__(self, path: str) -> None:
+ del self.storage[path]
+
def __contains__(self, path: str) -> bool:
return path in self.storage
- @abc.abstractmethod
def list(self, path: str) -> Iterable[Tuple[bool, str]]:
return self.storage.list(path)
- @abc.abstractmethod
- def load(self, path: str, obj_class: Type[ObjClass]) -> ObjClass:
- raw_val = self[path]
+ def construct(self, path: str, raw_val: IStorable, obj_class: Type[ObjClass]) -> ObjClass:
if obj_class in (int, str, dict, list, None):
if not isinstance(raw_val, obj_class):
raise ValueError("Can't load path {!r} into type {}. Real type is {}"
.format(path, obj_class, type(raw_val)))
- return raw_val
+ return cast(ObjClass, raw_val)
if not isinstance(raw_val, dict):
raise ValueError("Can't load path {!r} into python type. Raw value not dict".format(path))
@@ -157,14 +158,27 @@
raise ValueError("Can't load path {!r} into python type.".format(path) +
"Raw not all keys in raw value is strings")
- obj = ObjClass.__new__(ObjClass)
+ obj = obj_class.__new__(obj_class) # type: ObjClass
obj.__dict__.update(raw_val)
return obj
- @abc.abstractmethod
+ def load_list(self, path: str, obj_class: Type[ObjClass]) -> List[ObjClass]:
+ raw_val = self[path]
+ assert isinstance(raw_val, list)
+ return [self.construct(path, val, obj_class) for val in cast(list, raw_val)]
+
+ def load(self, path: str, obj_class: Type[ObjClass]) -> ObjClass:
+ return self.construct(path, self[path], obj_class)
+
def get_stream(self, path: str) -> IO:
return self.storage.get_stream(path)
+ def get(self, path: str, default: Any = None) -> Any:
+ try:
+ return self[path]
+ except KeyError:
+ return default
+
def make_storage(url: str, existing: bool = False) -> Storage:
return Storage(FSStorage(url, existing), YAMLSerializer())
diff --git a/wally/storage_structure.txt b/wally/storage_structure.txt
new file mode 100644
index 0000000..8ba89c1
--- /dev/null
+++ b/wally/storage_structure.txt
@@ -0,0 +1,11 @@
+config: Config - full configuration
+nodes: List[NodeInfo] - all nodes
+fuel_openstack_creds: OSCreds - openstack creds, discovered from fuel (or None)
+openstack_openrc: OSCreds - openrc used for openstack cluster
+discovered_nodes: List[NodeInfo] - list of discovered nodes
+reused_nodes: List[NodeInfo] - list of reused nodes from cluster
+spawned_vm_ids: List[int] - list of openstack VM id's, spawned for test
+__types__ = type of data in keys
+info/comment : str - run comment
+info/run_uuid : str - run uuid
+info/run_time : float - run unix time
\ No newline at end of file
diff --git a/wally/test_run_class.py b/wally/test_run_class.py
index cac893c..ecdfa4f 100644
--- a/wally/test_run_class.py
+++ b/wally/test_run_class.py
@@ -3,8 +3,8 @@
from .timeseries import SensorDatastore
-from .node_interfaces import NodeInfo, IRPCNode, RPCBeforeConnCallback
-from .start_vms import OSCreds, NovaClient, CinderClient
+from .node_interfaces import NodeInfo, IRPCNode
+from .start_vms import OSCreds, OSConnection
from .storage import Storage
from .config import Config
@@ -25,15 +25,12 @@
# openstack credentials
self.fuel_openstack_creds = None # type: Optional[OSCreds]
self.os_creds = None # type: Optional[OSCreds]
- self.nova_client = None # type: Optional[NovaClient]
- self.cinder_client = None # type: Optional[CinderClient]
+ self.os_connection = None # type: Optional[OSConnection]
self.storage = storage
self.config = config
self.sensors_data = SensorDatastore()
- self.before_conn_callback = None # type: RPCBeforeConnCallback
-
def get_pool(self):
return ThreadPoolExecutor(self.config.get('worker_pool_sz', 32))
diff --git a/wally/utils.py b/wally/utils.py
index 14e3a6c..b27b1ba 100644
--- a/wally/utils.py
+++ b/wally/utils.py
@@ -12,7 +12,8 @@
import collections
from .node_interfaces import IRPCNode
-from typing import Any, Tuple, Union, List, Iterator, Dict, Callable, Iterable, Optional, IO, Sequence
+from typing import (Any, Tuple, Union, List, Iterator, Dict, Callable, Iterable, Optional,
+ IO, Sequence, NamedTuple, cast)
try:
import psutil
@@ -184,6 +185,7 @@
shell = True
cmd_str = cmd
else:
+ shell = False
cmd_str = " ".join(cmd)
proc = subprocess.Popen(cmd,
@@ -330,17 +332,20 @@
return user, passwd, tenant, auth_url, insecure
-os_release = collections.namedtuple("Distro", ["distro", "release", "arch"])
+OSRelease = NamedTuple("OSRelease",
+ [("distro", str),
+ ("release", str),
+ ("arch", str)])
-def get_os(node: IRemoteNode) -> os_release:
+def get_os(node: IRPCNode) -> OSRelease:
"""return os type, release and architecture for node.
"""
arch = node.run("arch", nolog=True).strip()
try:
node.run("ls -l /etc/redhat-release", nolog=True)
- return os_release('redhat', None, arch)
+ return OSRelease('redhat', None, arch)
except:
pass
@@ -356,7 +361,7 @@
if opt == 'Codename':
release = val.strip()
- return os_release('ubuntu', release, arch)
+ return OSRelease('ubuntu', release, arch)
except:
pass
@@ -364,21 +369,16 @@
@contextlib.contextmanager
-def empty_ctx(val: Any=None) -> Iterator[Any]:
+def empty_ctx(val: Any = None) -> Iterator[Any]:
yield val
-def mkdirs_if_unxists(path: str) -> None:
- if not os.path.exists(path):
- os.makedirs(path)
-
-
-def log_nodes_statistic(nodes: Sequence[IRemoteNode]) -> None:
+def log_nodes_statistic(nodes: Sequence[IRPCNode]) -> None:
logger.info("Found {0} nodes total".format(len(nodes)))
per_role = collections.defaultdict(int) # type: Dict[str, int]
for node in nodes:
- for role in node.roles:
+ for role in node.info.roles:
per_role[role] += 1
for role, count in sorted(per_role.items()):
@@ -411,17 +411,18 @@
return results_dir, run_uuid
-class Timeout:
+class Timeout(Iterable[float]):
def __init__(self, timeout: int, message: str = None, min_tick: int = 1, no_exc: bool = False) -> None:
- self.etime = time.time() + timeout
+ self.end_time = time.time() + timeout
self.message = message
self.min_tick = min_tick
self.prev_tick_at = time.time()
self.no_exc = no_exc
def tick(self) -> bool:
- ctime = time.time()
- if ctime > self.etime:
+ current_time = time.time()
+
+ if current_time > self.end_time:
if self.message:
msg = "Timeout: {}".format(self.message)
else:
@@ -429,19 +430,22 @@
if self.no_exc:
return False
+
raise TimeoutError(msg)
- dtime = self.min_tick - (ctime - self.prev_tick_at)
- if dtime > 0:
- time.sleep(dtime)
+ sleep_time = self.min_tick - (current_time - self.prev_tick_at)
+ if sleep_time > 0:
+ time.sleep(sleep_time)
+ self.prev_tick_at = time.time()
+ else:
+ self.prev_tick_at = current_time
- self.prev_tick_at = time.time()
return True
- def __iter__(self):
- return self
+ def __iter__(self) -> Iterator[float]:
+ return cast(Iterator[float], self)
def __next__(self) -> float:
if not self.tick():
raise StopIteration()
- return self.etime - time.time()
+ return self.end_time - time.time()