blob: 13cd675e8a0220a173892b9a06d356993ab41858 [file] [log] [blame]
import re
import os
import abc
import sys
import math
import time
import uuid
import socket
import logging
import ipaddress
import threading
import contextlib
import subprocess
import collections
from .node_interfaces import IRPCNode
from typing import (Any, Tuple, Union, List, Iterator, Dict, Iterable, Optional,
IO, Sequence, NamedTuple, cast, TypeVar)
try:
import psutil
except ImportError:
psutil = None
try:
from petname import Generate as pet_generate
except ImportError:
def pet_generate(x: str, y: str) -> str:
return str(uuid.uuid4())
logger = logging.getLogger("wally")
TNumber = TypeVar('TNumber', int, float)
Number = Union[int, float]
class StopTestError(RuntimeError):
pass
class LogError:
def __init__(self, message: str, exc_logger: logging.Logger = None) -> None:
self.message = message
self.exc_logger = exc_logger
def __enter__(self) -> 'LogError':
return self
def __exit__(self, tp: type, value: Exception, traceback: Any) -> bool:
if value is None or isinstance(value, StopTestError):
return False
if self.exc_logger is None:
exc_logger = sys._getframe(1).f_globals.get('logger', logger)
else:
exc_logger = self.exc_logger
exc_logger.exception(self.message, exc_info=(tp, value, traceback))
raise StopTestError(self.message) from value
class IStorable(metaclass=abc.ABCMeta):
"""Interface for type, which can be stored"""
@abc.abstractmethod
def raw(self) -> Dict[str, Any]:
pass
@abc.abstractclassmethod
def fromraw(cls, data: Dict[str, Any]) -> 'IStorable':
pass
Basic = Union[int, str, bytes, bool, None]
Storable = Union[IStorable, Dict[str, Any], List[Any], int, str, bytes, bool, None]
class TaskFinished(Exception):
pass
class Barrier:
def __init__(self, count: int) -> None:
self.count = count
self.curr_count = 0
self.cond = threading.Condition()
self.exited = False
def wait(self, timeout: int=None) -> bool:
with self.cond:
if self.exited:
raise TaskFinished()
self.curr_count += 1
if self.curr_count == self.count:
self.curr_count = 0
self.cond.notify_all()
return True
else:
self.cond.wait(timeout=timeout)
return False
def exit(self) -> None:
with self.cond:
self.exited = True
class Timeout(Iterable[float]):
def __init__(self, timeout: int, message: str = None, min_tick: int = 1, no_exc: bool = False) -> None:
self.end_time = time.time() + timeout
self.message = message
self.min_tick = min_tick
self.prev_tick_at = time.time()
self.no_exc = no_exc
def tick(self) -> bool:
current_time = time.time()
if current_time > self.end_time:
if self.message:
msg = "Timeout: {}".format(self.message)
else:
msg = "Timeout"
if self.no_exc:
return False
raise TimeoutError(msg)
sleep_time = self.min_tick - (current_time - self.prev_tick_at)
if sleep_time > 0:
time.sleep(sleep_time)
self.prev_tick_at = time.time()
else:
self.prev_tick_at = current_time
return True
def __iter__(self) -> Iterator[float]:
return cast(Iterator[float], self)
def __next__(self) -> float:
if not self.tick():
raise StopIteration()
return self.end_time - time.time()
def greater_digit_pos(val: Number) -> int:
return int(math.floor(math.log10(val))) + 1
def round_digits(val: TNumber, num_digits: int = 3) -> TNumber:
pow = 10 ** (greater_digit_pos(val) - num_digits)
return type(val)(int(val / pow) * pow)
def is_ip(data: str) -> bool:
try:
ipaddress.ip_address(data)
return True
except ValueError:
return False
def log_block(message: str, exc_logger:logging.Logger = None) -> LogError:
logger.debug("Starts : " + message)
return LogError(message, exc_logger)
def check_input_param(is_ok: bool, message: str) -> None:
if not is_ok:
logger.error(message)
raise StopTestError(message)
def parse_creds(creds: str) -> Tuple[str, str, str]:
"""Parse simple credentials format user[:passwd]@host"""
user, passwd_host = creds.split(":", 1)
if '@' not in passwd_host:
passwd, host = passwd_host, None
else:
passwd, host = passwd_host.rsplit('@', 1)
return user, passwd, host
SMAP = dict(k=1024, m=1024 ** 2, g=1024 ** 3, t=1024 ** 4)
def ssize2b(ssize: Union[str, int]) -> int:
try:
if isinstance(ssize, int):
return ssize
ssize = ssize.lower()
if ssize[-1] in SMAP:
return int(ssize[:-1]) * SMAP[ssize[-1]]
return int(ssize)
except (ValueError, TypeError, AttributeError):
raise ValueError("Unknow size format {!r}".format(ssize))
RSMAP = [('K', 1024),
('M', 1024 ** 2),
('G', 1024 ** 3),
('T', 1024 ** 4)]
def b2ssize(size: int) -> str:
if size < 1024:
return str(size)
# make mypy happy
scale = 1
name = ""
for name, scale in RSMAP:
if size < 1024 * scale:
if size % scale == 0:
return "{} {}i".format(size // scale, name)
else:
return "{:.1f} {}i".format(float(size) / scale, name)
return "{}{}i".format(size // scale, name)
RSMAP_10 = [('k', 1000),
('m', 1000 ** 2),
('g', 1000 ** 3),
('t', 1000 ** 4)]
def b2ssize_10(size: int) -> str:
if size < 1000:
return str(size)
# make mypy happy
scale = 1
name = ""
for name, scale in RSMAP_10:
if size < 1000 * scale:
if size % scale == 0:
return "{} {}".format(size // scale, name)
else:
return "{:.1f} {}".format(float(size) / scale, name)
return "{}{}".format(size // scale, name)
def run_locally(cmd: Union[str, List[str]], input_data: str="", timeout:int =20) -> str:
if isinstance(cmd, str):
shell = True
cmd_str = cmd
else:
shell = False
cmd_str = " ".join(cmd)
proc = subprocess.Popen(cmd,
shell=shell,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
res = [] # type: List[Tuple[bytes, bytes]]
def thread_func() -> None:
rr = proc.communicate(input_data.encode("utf8"))
res.extend(rr)
thread = threading.Thread(target=thread_func,
name="Local cmd execution")
thread.daemon = True
thread.start()
thread.join(timeout)
if thread.is_alive():
if psutil is not None:
parent = psutil.Process(proc.pid)
for child in parent.children(recursive=True):
child.kill()
parent.kill()
else:
proc.kill()
thread.join()
raise RuntimeError("Local process timeout: " + cmd_str)
stdout_data, stderr_data = zip(*res) # type: List[bytes], List[bytes]
out = b"".join(stdout_data).decode("utf8")
err = b"".join(stderr_data).decode("utf8")
if 0 != proc.returncode:
raise subprocess.CalledProcessError(proc.returncode,
cmd_str, out + err)
return out
def get_ip_for_target(target_ip: str) -> str:
if not is_ip(target_ip):
target_ip = socket.gethostbyname(target_ip)
first_dig = map(int, target_ip.split("."))
if first_dig == 127:
return '127.0.0.1'
data = run_locally('ip route get to'.split(" ") + [target_ip])
rr1 = r'{0} via [.0-9]+ dev (?P<dev>.*?) src (?P<ip>[.0-9]+)$'
rr1 = rr1.replace(" ", r'\s+')
rr1 = rr1.format(target_ip.replace('.', r'\.'))
rr2 = r'{0} dev (?P<dev>.*?) src (?P<ip>[.0-9]+)$'
rr2 = rr2.replace(" ", r'\s+')
rr2 = rr2.format(target_ip.replace('.', r'\.'))
data_line = data.split("\n")[0].strip()
res1 = re.match(rr1, data_line)
res2 = re.match(rr2, data_line)
if res1 is not None:
return res1.group('ip')
if res2 is not None:
return res2.group('ip')
raise OSError("Can't define interface for {0}".format(target_ip))
def open_for_append_or_create(fname: str) -> IO[str]:
if not os.path.exists(fname):
return open(fname, "w")
fd = open(fname, 'r+')
fd.seek(0, os.SEEK_END)
return fd
def sec_to_str(seconds: int) -> str:
h = seconds // 3600
m = (seconds % 3600) // 60
s = seconds % 60
return "{}:{:02d}:{:02d}".format(h, m, s)
def yamable(data: Any) -> Any:
if isinstance(data, (tuple, list)):
return map(yamable, data)
if isinstance(data, dict):
res = {}
for k, v in data.items():
res[yamable(k)] = yamable(v)
return res
return data
def flatten(data: Iterable[Any]) -> List[Any]:
res = []
for i in data:
if isinstance(i, (list, tuple, set)):
res.extend(flatten(i))
else:
res.append(i)
return res
def get_creds_openrc(path: str) -> Tuple[str, str, str, str, bool]:
fc = open(path).read()
echo = 'echo "$OS_INSECURE:$OS_TENANT_NAME:$OS_USERNAME:$OS_PASSWORD@$OS_AUTH_URL"'
msg = "Failed to get creads from openrc file"
with LogError(msg):
data = run_locally(['/bin/bash'], input_data=fc + "\n" + echo)
msg = "Failed to get creads from openrc file: " + data
with LogError(msg):
data = data.strip()
insecure_str, user, tenant, passwd_auth_url = data.split(':', 3)
insecure = (insecure_str in ('1', 'True', 'true'))
passwd, auth_url = passwd_auth_url.rsplit("@", 1)
assert (auth_url.startswith("https://") or
auth_url.startswith("http://"))
return user, passwd, tenant, auth_url, insecure
OSRelease = NamedTuple("OSRelease",
[("distro", str),
("release", str),
("arch", str)])
def get_os(node: IRPCNode) -> OSRelease:
"""return os type, release and architecture for node.
"""
arch = node.run("arch", nolog=True).strip()
try:
node.run("ls -l /etc/redhat-release", nolog=True)
return OSRelease('redhat', None, arch)
except:
pass
try:
node.run("ls -l /etc/debian_version", nolog=True)
release = None
for line in node.run("lsb_release -a", nolog=True).split("\n"):
if ':' not in line:
continue
opt, val = line.split(":", 1)
if opt == 'Codename':
release = val.strip()
return OSRelease('ubuntu', release, arch)
except:
pass
raise RuntimeError("Unknown os")
@contextlib.contextmanager
def empty_ctx(val: Any = None) -> Iterator[Any]:
yield val
def log_nodes_statistic(nodes: Sequence[IRPCNode]) -> None:
logger.info("Found {0} nodes total".format(len(nodes)))
per_role = collections.defaultdict(int) # type: Dict[str, int]
for node in nodes:
for role in node.info.roles:
per_role[role] += 1
for role, count in sorted(per_role.items()):
logger.debug("Found {0} nodes with role {1}".format(count, role))
def which(program: str) -> Optional[str]:
def is_exe(fpath):
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
for path in os.environ["PATH"].split(os.pathsep):
path = path.strip('"')
exe_file = os.path.join(path, program)
if is_exe(exe_file):
return exe_file
return None
def get_uniq_path_uuid(path: str, max_iter: int = 10) -> Tuple[str, str]:
for i in range(max_iter):
run_uuid = pet_generate(2, "_")
results_dir = os.path.join(path, run_uuid)
if not os.path.exists(results_dir):
break
else:
run_uuid = str(uuid.uuid4())
results_dir = os.path.join(path, run_uuid)
return results_dir, run_uuid
def to_ip(host_or_ip: str) -> str:
# translate hostname to address
try:
ipaddress.ip_address(host_or_ip)
return host_or_ip
except ValueError:
ip_addr = socket.gethostbyname(host_or_ip)
logger.info("Will use ip_addr %r instead of hostname %r", ip_addr, host_or_ip)
return ip_addr