#    Copyright 2019 Mirantis, Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import base64
import enum
import exec_helpers
import ipaddress
import json
import os
import paramiko
import pytz
import random
import re
import requests
import secrets
import shutil
import string
import tabulate
import time
import traceback
import yaml
from contextlib import contextmanager
from datetime import datetime, timezone, timedelta
from functools import wraps
from io import StringIO
from kubernetes.client.rest import ApiException
from operator import itemgetter

from si_tests.clients.k8s.base import K8sNamespacedResource
from si_tests.utils import packaging_version as version
from pathlib import Path
from requests import exceptions as rexc
from retry import retry
from threading import Thread, Event
from typing import Iterable, List

from si_tests import logger
from si_tests.utils import templates as templates_utils
from si_tests.settings import ARTIFACTS_DIR, ENV_NAME
from si_tests import settings

LOG = logger.logger

MACHINE_TYPES = {
    'cluster.x-k8s.io/control-plane': 'control'
}

FUSE_FILENAME = '/var/lib/ksi-tests/.states/ksi_managed'


class LogTime(object):
    def __init__(self, name, file_path):
        self.name = name
        self.file_path = file_path
        self.start_time = 0

    def __enter__(self):
        self.start_time = time.time()
        return self

    def __exit__(self, x, y, z):
        end_time = time.time()
        total_time = end_time - self.start_time
        days = int(total_time // 86400)
        hours = int(total_time // 3600 % 24)
        minutes = int(total_time // 60 % 60)
        seconds = round(total_time % 60, 2)
        human_format = "{days}{hours}{minutes}{seconds}s".format(
            days=str(days) + "d " if days > 0 else "",
            hours=str(hours) + "h " if hours > 0 else "",
            minutes=str(minutes) + "m " if minutes > 0 else "",
            seconds=seconds)
        yaml_data = {
            "environment": ENV_NAME,
            "raw_duration": total_time,  # raw seconds
            "duration": human_format  # human readable format
        }
        with templates_utils.YamlEditor(
                file_path=self.file_path) as editor:
            current_content = editor.content
            current_content[self.name] = yaml_data
            editor.content = current_content


class Tail(Thread):
    """
    Tail thread allows to watch remote file continuously in parallel
    with running main thread. Typical use case is watching remote log
    file (from bootstrap.sh) while main thread runs ansible script.

    Tail will read all available lines from the file during each _read()
    iteration and will try to read everything else after stop() is called
    to ensure the entire file was read.

    NOTE: It is not possible to interrupt the thread while it is in _read().
          Entire file will be read to its end if no exception occurs.
    """

    def __init__(self, filename, *args, **kwargs):
        self.filename = filename
        self.prefix = kwargs.pop('prefix', '')
        self.remote = kwargs.pop('remote', None)
        self.logger = kwargs.pop('logger', LOG)
        self.interval = float(kwargs.pop('interval', 10))
        self._stop_event = Event()
        self._position = 0
        super(Tail, self).__init__(*args, **kwargs)

    def __enter__(self):
        self.start()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()
        self.join()

    def __str__(self):
        return "Tail (filename={}, interval={}, remote={})" \
            .format(self.filename, self.interval, self.remote is not None)

    def _open(self):
        if self.remote is None:
            return open(self.filename)
        else:
            return self.remote.open(self.filename)

    def _read(self):
        file_handler = None
        try:
            file_handler = self._open()
            file_handler.seek(self._position)

            for line in file_handler:
                self.logger.info(self.format(line.strip()))

            self._position = file_handler.tell()
        except OSError:
            self.logger.warning(
                self.format("Not found {}".format(self.filename)))
        except:  # noqa
            self.logger.error(self.format("Got Exception:",
                                          traceback.format_exc()))
            self.stop()
        finally:
            if file_handler:
                file_handler.close()

    def run(self):
        self.logger.info(self.format("Started {}".format(str(self))))
        try:
            while not self._stop_event.is_set():
                self._read()
                self._stop_event.wait(self.interval)
            self._read()
        finally:
            self.logger.info(self.format("Stopped {}".format(str(self))))

    def stop(self):
        self.logger.info(self.format("Stopping {} ...".format(str(self))))
        self._stop_event.set()

    def format(self, *args):
        return '{} {}'.format(self.prefix, '\n'.join(args))


@contextmanager
def pushd(path):
    current_dir = os.getcwd()
    try:
        os.chdir(os.path.expanduser(path))
        yield
    finally:
        os.chdir(current_dir)


def reduce_occurrences(items, text):
    """ Return string without items(substrings)
        Args:
            items: iterable of strings
            test: string
        Returns:
            string
        Raise:
            AssertionError if any substing not present in source text
    """
    for item in items:
        LOG.debug(
            "Verifying string {} is shown in "
            "\"\"\"\n{}\n\"\"\"".format(item, text))
        assert text.count(item) != 0
        text = text.replace(item, "", 1)
    return text


def generate_keys():
    file_obj = StringIO()
    key = paramiko.RSAKey.generate(1024)
    key.write_private_key(file_obj)
    public = key.get_base64()
    private = file_obj.getvalue()
    file_obj.close()
    return {'private': private,
            'public': public}


def load_keyfile(file_path):
    with open(file_path, 'r') as private_key_file:
        private = private_key_file.read()
    key = paramiko.RSAKey(file_obj=StringIO(private))
    public = key.get_base64()
    return {'private': private,
            'public': public}


def get_rsa_key(private_key):
    f = StringIO(private_key)
    return paramiko.rsakey.RSAKey.from_private_key(f)


def dump_keyfile(file_path, key):
    key = paramiko.RSAKey(file_obj=StringIO(key['private']))
    key.write_private_key_file(file_path)
    os.chmod(file_path, 0o644)


def clean_dir(dirpath):
    shutil.rmtree(dirpath)


def backup_file(filepath, postfix=None, remote=None):
    if not postfix:
        postfix = time.strftime("_%Y%m%d_%H%M%S") + ".bkp"
    new_file = filepath + postfix

    ops_area = "[local]"
    if remote:
        ops_area = "[remote]"
        remote.check_call(f"cp {filepath} {new_file}")
    else:
        shutil.copyfile(filepath, new_file)

    LOG.info(f"{ops_area}: '{filepath}' backup was saved as '{new_file}'")


def check_test_result(request, test_results):
    """Function to check whether test has expected result

    :param mark: pytest request object
    :param test_results: expected test results list
    :rtype: boolean
    """
    for test_result in test_results:
        if hasattr(request.node, 'rep_call') and getattr(request.node.rep_call, test_result):
            LOG.debug(f"Test result is {test_result}")
            return True
    return False


def extract_name_from_mark(mark, info='name'):
    """Simple function to extract name from pytest mark

    :param mark: pytest.mark.MarkInfo
    :param info: Kwarg with information
    :rtype: string or None
    """
    if mark:
        if len(mark.args) > 0:
            return mark.args[0]
        elif info in mark.kwargs:
            return mark.kwargs[info]
    return None


def get_top_fixtures_marks(request, mark_name):
    """Order marks according to fixtures order

    When a test use fixtures that depend on each other in some order,
    that fixtures can have the same pytest mark.

    This method extracts such marks from fixtures that are used in the
    current test and return the content of the marks ordered by the
    fixture dependences.
    If the test case have the same mark, than the content of this mark
    will be the first element in the resulting list.

    :param request: pytest 'request' fixture
    :param mark_name: name of the mark to search on the fixtures and the test

    :rtype list: marks content, from last to first executed.
    """

    fixtureinfo = request.session._fixturemanager.getfixtureinfo(
        request.node, request.function, request.cls)

    top_fixtures_names = []
    for _ in enumerate(fixtureinfo.name2fixturedefs):
        parent_fixtures = set()
        child_fixtures = set()
        for name in sorted(fixtureinfo.name2fixturedefs):
            if name in top_fixtures_names:
                continue
            parent_fixtures.add(name)
            child_fixtures.update(
                fixtureinfo.name2fixturedefs[name][0].argnames)
        top_fixtures_names.extend(list(parent_fixtures - child_fixtures))

    top_fixtures_marks = []

    if mark_name in request.function.func_dict:
        # The top priority is the 'revert_snapshot' mark on the test
        top_fixtures_marks.append(
            extract_name_from_mark(
                request.function.func_dict[mark_name]))

    for top_fixtures_name in top_fixtures_names:
        fd = fixtureinfo.name2fixturedefs[top_fixtures_name][0]
        if mark_name in fd.func.func_dict:
            fixture_mark = extract_name_from_mark(
                fd.func.func_dict[mark_name])
            # Append the snapshot names in the order that fixtures are called
            # starting from the last called fixture to the first one
            top_fixtures_marks.append(fixture_mark)

    LOG.debug("Fixtures ordered from last to first called: {0}"
              .format(top_fixtures_names))
    LOG.debug("Marks ordered from most to least preffered: {0}"
              .format(top_fixtures_marks))

    return top_fixtures_marks


def gen_random_string(size):
    """Generate a random string of fixed length """
    letters = string.ascii_lowercase
    return ''.join(random.choice(letters) for i in range(size))


def gen_random_password(size):
    """Generate a random password of fixed length """
    return secrets.token_urlsafe(size)


def convert_to_bytes(size_str):
    """
    Units are case-insensitive.
    :param size_str: can be 1KB, 2mb, 3.2gB, 4.5TB. If 100, 100B format
                     is used, then assume value already in bytes
    :return: int value in bytes
    """
    if not size_str:
        return 0
    units = {'TB': 2 ** 40, 'GB': 2 ** 30, 'MB': 2 ** 20, 'KB': 2 ** 10,
             'TI': 2 ** 40, 'GI': 2 ** 30, 'MI': 2 ** 20, 'KI': 2 ** 10}
    size_str = size_str.replace(" ", "")
    last_index = [x.end() for x in re.finditer(r'\d', size_str)][-1]
    value = float(size_str[:last_index])
    if value < 0:
        raise ValueError("Wrong size. Should be greater then 0")
    unit = size_str[last_index:].upper()
    if not unit or unit == 'B':
        return int(value)
    return int(value * units[unit])


def convert_to_gb(size_str: str) -> float:
    """
    Convert a size string like '15729372Ki' to gigabytes (GiB).

    Parameters
    ----------
    size_str : str
        Size string with an optional unit suffix (Ki, Mi, Gi, Ti, etc.).

    Returns
    -------
    float
        Size in GiB.
    """
    number = convert_to_bytes(size_str)
    # Convert to GiB with only one digit after decimal point
    gib = int(10 * (number / 2 ** 30)) / 10
    return gib


def print_pods_status(pods):
    """Print pods status in table format

    :param pods: list of K8sPod objects which produced
                 by kubectl_client.pods.list_all()
                 or kubectl_client.pods.list(namespace=...)
    """
    rows = [['NAMESPACE', 'NAME', 'READY', 'STATUS',
             'RESTARTS', 'AGE', 'IP', 'NODE', 'REASON']]
    for pod in pods:
        restarts = 0
        container_statuses = pod['status']['container_statuses'] or []
        total_containers = len(container_statuses)
        runnint_containers = sum([
            1 for x in container_statuses if x['ready'] is True
        ])
        restarts = sum([
            x['restart_count'] for x in container_statuses
        ])
        ready = "{0}/{1}".format(runnint_containers, total_containers)

        age = ''
        if pod['status']['start_time']:
            age_delta = datetime.now(timezone.utc) - \
                        pod['status']['start_time']
            m, s = divmod(age_delta.seconds, 60)
            h, m = divmod(m, 60)
            days = "{0}d ".format(age_delta.days) if age_delta.days else ''
            hours = "{0}h ".format(h) if h else ''
            minutes = "{0}m ".format(m) if (m and not days) else ''
            seconds = "{0}s".format(s) if (not days and not hours) else ''
            age = "{days}{hours}{minutes}{seconds}".format(
                days=days,
                hours=hours,
                minutes=minutes,
                seconds=seconds
            )

        rows.append([
            pod['metadata']['namespace'],
            pod['metadata']['name'],
            ready,
            pod['status']['phase'],
            str(restarts),
            str(age),
            pod['status']['pod_ip'] or '',
            pod['spec']['node_name'],
            pod['status']['reason'] or ''
        ])

    LOG.debug("Raw rows: {0}".format(rows))
    cols = zip(*rows)

    col_widths = [max(len(value or 'None') for value in col) for col in cols]

    format_str = '  '.join(['{{{0}!s:<{1}}}'.format(n, width)
                            for n, width in enumerate(col_widths)])

    LOG.debug("Format string for rows: {0}".format(format_str))
    LOG.info('\n' + '\n'.join([format_str.format(*row) for row in rows]))
    return rows


def is_divisible_by_3(number):
    digits = str(abs(number))
    digit_sum = sum(int(digit) for digit in digits)
    return digit_sum % 3 == 0


def make_export_env_strting(envs):
    envs_string = '; '.join(["export {}='{}'".format(k, envs[k]) for k in envs])
    return envs_string


def merge_dicts(src, added, path=None):
    """Deep merge two dictionaries.

    :param src: Initial dict to merge to.
    :param added: Dict which is going to be merged into src.
    :param path: Used for tracking the path through nested dictionaries during recursion (optional, defaults to []).
    :return: The merged source dictionary.
    """
    path = path or []
    for key in added:
        if key in src:
            if isinstance(src[key], dict) and isinstance(added[key], dict):
                merge_dicts(src[key], added[key], path + [str(key)])
            elif src[key] == added[key]:
                pass  # same leaf value
            else:
                src[key] = added[key]
        else:
            src[key] = added[key]
    return src


def is_list_has_overlap(a, b) -> bool:
    return any(item in b for item in a)


def log_method_time():
    def log(func):
        @wraps(func)
        def wrapped(*args, **kwargs):
            with LogTime(func.__name__,
                         file_path=ARTIFACTS_DIR + 'time_spent.yaml'):
                result = func(*args, **kwargs)
                return result

        return wrapped

    return log


def get_expected_pods(path, target_nss=[]):
    LOG.info("Fetching lists of expected pods")
    template = templates_utils.render_template(path)
    expected_pods_file = yaml.load(template, Loader=yaml.SafeLoader)['ucp']
    expected_pods = {}
    for ns, pod_dict in expected_pods_file.items():
        if target_nss and ns not in target_nss:
            continue
        # we are in trouble if we have identical pod names
        # in different ns
        expected_pods.update(pod_dict)
    return expected_pods


def generate_list_pods(kaas_manager, target_nss=[]):
    ns = kaas_manager.get_namespace(settings.TARGET_NAMESPACE)
    cluster = ns.get_cluster(settings.TARGET_CLUSTER)
    if not cluster.present():
        return {}
    LOG.info("Make sure we have all needed pods")
    cluster.check.check_actual_expected_pods()
    LOG.info("Generating list of pods for {0} cluster "
             "in {1} namespace".format(cluster.name, ns.name))
    kubectl_client = cluster.k8sclient
    ep = cluster.expected_pods
    expected_pods = {}
    for ns in ep:
        if not target_nss or ns in target_nss:
            pod_names = [x.split("/")[0] for x in ep[ns].keys()]
            for pod_name in sorted(pod_names, key=lambda x: len(x),
                                   reverse=True):
                pod_num = ep[ns][pod_name] if pod_name in ep[ns].keys() else ep[ns][f"{pod_name}/no_owner"]
                # order is guaranteed
                # https://stackoverflow.com/questions/39980323/are-dictionaries-ordered-in-python-3-6
                # https://docs.python.org/3.7/library/stdtypes.html#typesmapping
                expected_pods.update({pod_name: pod_num})
    actual_pods = []
    if target_nss:
        for ns in target_nss:
            actual_pods += kubectl_client.pods.list(
                namespace=ns,
                field_selector="status.phase=Running"
            )
    else:
        actual_pods = kubectl_client.pods.list_all(
            field_selector="status.phase=Running"
        )
    result = {}

    for pod_name, pod_num in expected_pods.items():
        pods = [x for x in actual_pods if x.name.startswith(pod_name)]
        # order is guaranteed
        result[pod_name] = (pod_name, pod_num, pods)
        if not target_nss:
            actual_pods = [x for x in actual_pods
                           if not x.name.startswith(pod_name)]
        else:
            actual_pods = [x for x in actual_pods if x.namespace in target_nss
                           and not x.name.startswith(pod_name)]

    LOG.debug("Generated list: {}".format(result))
    return result


@retry((rexc.Timeout, rexc.ConnectTimeout, rexc.ReadTimeout, rexc.ConnectionError), delay=15, tries=5, logger=LOG)
def get_latest_k8s_image_version(url, major_version):
    LOG.info("Fetching list of available k8s conformance "
             "images from {}".format(url))
    resp = requests.get(url, verify=False)
    if not resp.ok:
        raise rexc.ConnectionError("Request failed {}".format(resp.text))
    lst = []
    for txt in resp.text.split("\n"):
        if major_version in txt and 'att' not in txt:
            lst.append(
                txt.split('<a href="v')[1].split('/">v')[0])
    LOG.info("Available versions: {0}. "
             "Filtered by {1}".format(lst, major_version))
    return str(max([version.parse(x) for x in lst])).replace(".post", "-")


def get_docker_version(machine):
    cmd = "docker version --format '{{ json . }}'"
    try:
        result = machine.exec_pod_cmd(cmd, get_events=False, verbose=False)
        data = yaml.safe_load(result["logs"])
        versions = [f"Server Version: {data['Server']['Version']}"]
        if "Components" in data["Server"]:
            for component in data["Server"]["Components"]:
                versions.append(f"{component['Name']}: {component['Version']}")
        return '   '.join(versions)
    except Exception as e:
        msg = f"Unable to read docker data from {machine.name}: {e}"
        LOG.error(f"Unable to read docker data from {machine.name} , see debug log")
        LOG.debug(msg)
        return msg


def get_kernel_version(machine):
    cmd = "uname -rvi"
    try:
        result = machine.exec_pod_cmd(cmd, get_events=False, verbose=False)
        return result["logs"].strip()
    except Exception as e:
        msg = f"Unable to read kernel version from {machine.name}: {e}"
        LOG.error(f"Unable to read kernel version from {machine.name} , see debug log")
        LOG.debug(msg)
        return msg


def get_system_version(machine):
    cmd = ("set -a;"
           "cat /etc/redhat-release 2>/dev/null ||"
           "cat /etc/centos-release 2>/dev/null ||"
           " . /etc/os-release && echo $NAME $VERSION")
    try:
        result = machine.exec_pod_cmd(cmd, get_events=False, verbose=False)
        return result["logs"].strip()
    except Exception as e:
        msg = f"Unable to read system version from {machine.name}: {e}"
        LOG.error(f"Unable to read system version from {machine.name} , see debug log")
        LOG.debug(msg)
        return msg


def verify(expression, failure_msg, success_msg=None):
    """
    Function to perform soft assertions
    Args:
        :param expression: Expression that must be verified to be True
        :param failure_msg: Message to add failure if expression is False
        :param success_msg: Message to log if expression is True
    """
    if expression:
        if success_msg is not None:
            LOG.info(success_msg)
    else:
        assert False, failure_msg


def get_binary_path(binary):
    """
    Function to get binary full path. Searches in directory picked from
    configured as settings.SI_BINARIES_DIR, if not found fallback to
    directories specified in $PATH environment

    :param binary: The name of binary to look for
    :returns: full path to binary or None
    :raises Exception: when requested binary is not found
    """
    bin_path = os.path.join(settings.SI_BINARIES_DIR, binary)
    if os.path.isfile(bin_path):
        return bin_path
    result = exec_helpers.Subprocess().execute(
        f"which {binary}", verbose=False)
    if result.exit_code == 0:
        return result.stdout_str

    return '/binary/not/found'


def read_event(event, read_event_data=True):
    if read_event_data:
        LOG.debug(f"Read event {event.name} from API")
        try:
            return event.read(cached=True)
        except ApiException:
            return None
    else:
        # Assume that 'event' already contains all required fields
        return event


def parse_events(events, event_prefix=None, sort=True, read_event_data=True, filtered_events=False,
                 truncate_events_num=10):
    grouped = {}
    for event in events:

        if event_prefix and not event.name.startswith(event_prefix):
            continue

        # Init 'data' variable with an Event object
        data = read_event(event, read_event_data)
        if data is None:
            continue

        if data.first_timestamp:
            data_first_date = str(data.first_timestamp.date())
            data_first_time = str(data.first_timestamp.time())
        else:
            data_first_date = ''
            data_first_time = ''
        if data.last_timestamp:
            data_last_date = str(data.last_timestamp.date())
            data_last_time = str(data.last_timestamp.time())
            data_last_timestamp = str(data.last_timestamp.timestamp())
        else:
            data_last_date = ''
            data_last_time = ''
            data_last_timestamp = ''

        parsed_data = {
            'data': data,
            'namespace': data.metadata.namespace,
            'event_date': data_first_date,
            'event_start': data_first_time,
            'event_end_date': data_last_date,
            'event_end_time': data_last_time,
            'event_end_timestamp': data_last_timestamp,
            'message': data.message,
            'reason': data.reason,
            'event_type': data.type,
            'component': data.source.component,
            'host': data.source.host or '-',
            'object_kind': data.involved_object.kind,
            'object_name': data.involved_object.name,
            'object_namespace': data.involved_object.namespace,
            'object_uid': data.involved_object.uid,
        }

        object_uid = data.involved_object.uid
        grouped.setdefault(object_uid, []).append(parsed_data)

    result = {}
    for group, events_data in grouped.items():
        # events_data.sort(key=itemgetter('event_end_time'))
        events_data.sort(key=itemgetter('event_end_timestamp'))
        fevent = events_data[-1]
        group_msg = (f"{fevent['object_kind']} {fevent['object_namespace']}/{fevent['object_name']}"
                     f" [Events from: {fevent['namespace']}/{fevent['component']}/{fevent['host']}]"
                     f" [Object ID: {group}]")

        if filtered_events:
            # Leave events which latest status is not 'Normal'
            if events_data[-1]['event_type'] != 'Normal':
                if truncate_events_num:
                    # Leave only {truncate_events_num} latest events
                    result[group_msg] = events_data[-truncate_events_num:]
                else:
                    result[group_msg] = events_data
            else:
                event = events_data[-1]
                LOG.debug(f"### Skipping log for events for {event['object_kind']}"
                          f" {event['object_namespace']}/{event['object_name']}"
                          f" from {event['namespace']}/{event['component']}/{event['host']}")
        else:
            if truncate_events_num:
                # Leave only {truncate_events_num} latest events
                result[group_msg] = events_data[-truncate_events_num:]
            else:
                result[group_msg] = events_data

    return result


def create_events_msg(events, prefix="  Events:", header=""):
    """Print events

    :param events: list of dicts, prepared in self.get_events()
    :param prefix: str, the forst lines in the output
    :param header: format string to show it before each group of events
    """
    message = []
    for group_msg, events_data in sorted(events.items()):
        if header:
            msg = header.format(group_msg=group_msg, events_data=events_data)
        else:
            msg = f"|| {group_msg} ||"
            msg = f"\n{'-' * len(msg)}\n{msg}\n{'-' * len(msg)}"
        message.append(msg)
        for event in events_data:
            message.append(f"  >> {event['event_end_date']} {event['event_end_time']} "
                           f"{event['event_type']} {event['reason']} "
                           f"{event['message']}")

    return f"{prefix}" + "\n".join(message) + "\n"


def get_credential_type_by_provider(provider_name: str) -> str:
    return {
        settings.AWS_PROVIDER_NAME: 'awscredentials',
        settings.AZURE_PROVIDER_NAME: 'azurecredentials',
        settings.OPENSTACK_PROVIDER_NAME: 'openstackcredentials',
        settings.VSPHERE_PROVIDER_NAME: 'vspherecredentials',
    }.get(provider_name)


def is_valid_ip(chk_str):
    """Check if passed str is valid ip or not

    :param chk_str: String to determine is IP valid or not
    :return:
    """
    LOG.info(f"Check IP address '{chk_str}'")
    try:
        ipaddress.ip_address(chk_str)
        LOG.info("IP address is correct")
        return True
    except ipaddress.AddressValueError:
        LOG.info(f"value for check: {chk_str}")
        LOG.error("Incorrect IP address")
        return False
    except ipaddress.NetmaskValueError:
        LOG.info(f"value for check: {chk_str}")
        LOG.error("Incorrect netmask address")
        return False
    except ValueError:
        LOG.info(f"It's not ip, looks like hostname: {chk_str}")
        return False


def get_provider_specific_path(base: str, provider_name: str, extra_path: str = None) -> str:
    """
    Function returns a provider-specific path depending on its type.

    I.e. for `openstack` literal it will return `base/extra_path`,

    for all other provider literals it will return `base/provider_name/extra_path`.
    """
    ret = base \
        if provider_name == settings.OPENSTACK_PROVIDER_NAME \
        else os.path.join(base, provider_name)
    if extra_path:
        ret = os.path.join(ret, extra_path)
    return ret


def save_cluster_name_artifact(namespace, name):
    cluster_name_path = "{0}/cluster_name".format(settings.ARTIFACTS_DIR)
    LOG.info(f"Save cluster namespace/name to '{cluster_name_path}'")
    with open(cluster_name_path, 'w') as f:
        f.write(f"{namespace}/{name}")


def get_np_yaml_from_netconfigfiles(netconfigfiles):
    """Simple function to extract Netplan yaml from IPAM netconfigFiles list

    :netconfigfiles list: netconfigFiles field from IpamHost status
    :rtype: bytes
    """
    netplanPath = "/etc/netplan/60-kaas-lcm-netplan.yaml"
    encodedNetplan = ""
    for section in netconfigfiles:
        if section.get("path", "") == netplanPath:
            encodedNetplan = section.get("content", "")
            break
    return base64.b64decode(encodedNetplan)


def get_np_struct_from_netconfigfiles(netconfigfiles):
    """Simple function to extract Netplan struct from IPAM netconfigFiles list

    :netconfigfiles list: netconfigFiles field from IpamHost status
    :rtype: dict
    """
    netplanYaml = get_np_yaml_from_netconfigfiles(netconfigfiles)
    return yaml.safe_load(netplanYaml)


def remove_k8s_obj_annotations(si_clientobject, annotations_remove=None):
    """
    Remove annotation from k8s object.
    Example:
    >> remove_k8s_obj_annotations(bmh1, annotations_remove=['annotation1', 'annotation2'])

    :param K8sNamespacedResource si_clientobject: modified object
    :param list annotations_remove: list of annotations for remove
    """

    if not annotations_remove:
        LOG.warning("Annotations to remove are not set. Skipping")
        return

    assert isinstance(si_clientobject, K8sNamespacedResource), \
        f"si_clientobject must be instance of K8sNamespacedResource, type:{type(si_clientobject)}"
    assert isinstance(annotations_remove, list), "Please, use list of annotations keys"

    obj_typename = f"objectType:{type(si_clientobject)} with name {si_clientobject.name}"
    existing_annotations = si_clientobject.data.get('metadata', {}).get('annotations', {})
    annotations_patch = {}
    for annotation in annotations_remove:
        if annotation in existing_annotations:
            LOG.info(
                f"Annotation {annotation}: {existing_annotations[annotation]} will be removed from {obj_typename}"
            )
            annotations_patch[annotation] = None
        else:
            LOG.info(
                f"Annotation {annotation} not found in existing annotations for {obj_typename}. Skipping"
            )

    si_clientobject.patch({'metadata': {'annotations': annotations_patch}})


def generate_batch(lst, batch_size):
    """  Yields batch of specified size """
    if batch_size <= 0 or batch_size is None:
        LOG.warning(f"generate_batch: batch size is {batch_size}")
        return
    for i in range(0, len(lst), batch_size):
        yield lst[i: i + batch_size]


def get_datetime_utc(datetime_str, format_str="%Y-%m-%dT%H:%M:%SZ"):
    return datetime.strptime(datetime_str, format_str).astimezone(pytz.UTC)


def is_k8s_res_exist(k8s_res: K8sNamespacedResource):
    try:
        k8s_res.read()
    except ApiException as ex:
        try:
            body = json.loads(ex.body)
        except (TypeError, json.decoder.JSONDecodeError):
            body = {}
        if (str(ex.status) == "404" and ex.reason == "Not Found") or \
                (body.get('reason') == "NotFound" and str(body.get('code')) == "404"):
            LOG.info(f"Resource {k8s_res} is not found.")
        else:
            LOG.error(f"Got api error {ex.body} while finding {k8s_res}.")
        return False
    LOG.info(f"Resource {k8s_res} is present.")
    return True


def get_schedule_for_cronjob(time_delta):
    """Get schedule in Cron-like manner for future run.

    :param time_delta: Seconds between now() and planned time
    :return: Cron-like schedule for future
    """
    schedule = datetime.now() + timedelta(seconds=time_delta)
    return schedule.astimezone(pytz.utc).strftime("%M %H %d %m *")


def ssh_k8s_node(hostname):
    with open(settings.NODES_INFO) as f:
        dictionary = yaml.safe_load(f)
    private_key = dictionary.get(hostname, {}).get('ssh')['private_key']
    assert private_key, "Private key is not defined"
    pkey = get_rsa_key(private_key)
    node_ip = dictionary.get(hostname, {}).get('ip')['address']
    username = dictionary.get(hostname, {}).get('ssh')['username']
    auth = exec_helpers.SSHAuth(username=username, password='', key=pkey)
    ssh = exec_helpers.SSHClient(host=node_ip, port=22, auth=auth)
    ssh.logger.addHandler(logger.console)
    ssh.sudo_mode = True
    return ssh


@retry(paramiko.ssh_exception.NoValidConnectionsError, delay=10, tries=4, logger=LOG)
def basic_ssh_command(cmd, host, user, password='', private_key='', port=22, verbose=True, timeout=120):
    assert password or private_key, "Auth method is not defined!"
    if private_key:
        auth = exec_helpers.SSHAuth(username=user, password='', key=private_key)
    else:
        auth = exec_helpers.SSHAuth(username=user, password=password)
    ssh = exec_helpers.SSHClient(host=host, port=port, auth=auth)
    ssh.logger.addHandler(logger.console)
    result = ssh.execute(cmd, verbose=verbose, timeout=timeout)
    return result


def get_services_pids(machine, services, ssh_key):
    """Return dict with pids: service name for services

    :params
        machine: machine object on which services work
        services: list of services
        ssh_key: key for access on machine
    :return:
        dict [ service_name : pid ]
    """
    pids = {}
    result = machine._run_cmd(
        "ps -xaeo pid,command", verbose=False,
        ssh_key=ssh_key).stdout
    for line in result:
        line = line.decode().strip()
        for service in services:
            if service in line:
                pid = line.split()[0]
                if pid:
                    pids[service] = pid
    return pids


def get_local_executor():
    return exec_helpers.Subprocess()


def get_types_by_labels(labels):
    """
    labels and types in scope of MACHINE_TYPES
    terms.
    return: list MACHINE_TYPES.values() item/items
    """
    # For bm case we may have multiple types for machine
    multiple_types = []
    for label, machine_type in MACHINE_TYPES.items():
        # label value might be in ['machinetype', 'true', str()]
        if label in labels and (any([labels[label] == machine_type,
                                     labels[label] == 'true',
                                     len(str(labels[label])) >= 1])):
            multiple_types.append(MACHINE_TYPES[label])

    if not multiple_types:
        # Workaround for missing labels for node role on worker nodes
        multiple_types = ['worker']
        LOG.warning(f"Looks like passed 'machine' dont have "
                    f"expected 'labels' for kcm. "
                    f"Using by default: {multiple_types}")
    return sorted(set(multiple_types))


def get_type_by_labels(labels):
    """
    labels only in scope of MACHINE_TYPES
    terms.
    Types only for scope k8s node.(worker/control)
    """
    # Set default type to "worker" because worker nodes
    # don't have the related label in the openstack provider
    _machine_type = 'worker'
    # We have no storage type right now but it can appear later.
    multiple_types = get_types_by_labels(labels)
    if 'worker' in multiple_types and 'storage' in multiple_types:
        _machine_type = 'worker'
    elif multiple_types:
        _machine_type = multiple_types[0]
    else:
        LOG.warning(f"Using default machine type:{_machine_type}")
    return _machine_type


def get_timestamp():
    return time.strftime("%Y%m%d_%H_%M_%S")


def render_and_save_data(render_opts=None, template_path=None):
    """ Simple wrapper on top of render_template
    Purposes:
    - dump debug data to artifacts
    - doesnt pass all os_env by default
    """
    if render_opts is None:
        render_opts = dict()
    j2_extra_vars = dict()
    _data = dict()

    default_render_opts = {
        "target_cluster": 'RENDERED_cluster_name',
        "target_namespace": 'RENDERED_target_namespace',

    }
    # Hide secrets by mask
    sensitive_keys = [i for i in vars(settings).keys() if i.startswith('KSI_SECRET_')]
    for k in sensitive_keys:
        if isinstance(k, str):
            # we need to skip none-s
            _value = getattr(settings, k)
            if isinstance(_value, str):
                LOG.hide(_value)
                # we also need to hide base64 version,since some secrets encoded on apply\create
                LOG.hide(base64.b64encode(_value.encode("utf-8")).decode("utf-8"))
    render_opts_data = {**default_render_opts, **render_opts}
    LOG.debug(f"render_data resulted with render_opts_data:\n{render_opts_data}")
    _rendered_data = templates_utils.render_template(template_path,
                                                     options=render_opts_data,
                                                     extra_env_vars=j2_extra_vars)
    _data = yaml.safe_load(_rendered_data)[0]

    LOG.debug(f"Rendered data from file:\n"
              f"{template_path}:\n"
              f"{_data}")
    # save for debug purposes only
    with open(os.path.join(settings.ARTIFACTS_DIR, f'si_debug_rendered_data_{get_timestamp()}.yaml'),
              "w") as f:
        f.write(_rendered_data)
    return _data


def ksi_seed_check_fuse():
    if not os.path.exists(FUSE_FILENAME):
        raise (f'{FUSE_FILENAME} not exists! Attempt to run tests on '
               f'unexpected env node!')


def list_files_in_dir(dir_path: str,
                      patterns: Iterable[str] = ("*.yaml",),
                      absolute: bool = False) -> List[str]:
    """
    Return files from dir_path matched by the given glob pattern(s),
    sorted *lexically* like OpenSSH Include (ssh_config/sshd_config).
    Dotfiles won't match unless your pattern starts with '.'.

    Example: patterns=("* .conf") would mimic /etc/ssh/sshd_config.d/*.conf
    """
    p = Path(dir_path)
    if isinstance(patterns, (str, bytes)):
        patterns = (patterns,)

    # Expand all globs, keep only regular files (incl. symlinks to files)
    candidates = []
    for pat in patterns:
        candidates.extend(x for x in p.glob(pat) if x.is_file())

    files = sorted(candidates, key=lambda x: x.name)

    return [str(f if absolute else f.name) for f in files]


def check_expected_pods_template_exists(name, path_prefix='cluster/'):
    """Check that expected pods template exists.
    Simple check that given yaml file exist in path.

    :param name:
    :param path_prefix: in current state - cluster or kcm (depends on what we checking).
    other values does no make sense
    """
    path = settings.EXPECTED_PODS_TEMPLATES_DIR + path_prefix + name + '.yaml'
    if not settings.SKIP_EXPECTED_POD_CHECK and not os.path.exists(path):
        raise RuntimeError(f"SKIP_EXPECTED_POD_CHECK is False and template {path} does not exist. "
                           f"Please update templates or set SKIP_EXPECTED_POD_CHECK to True")


def discover_and_check_expected_pods_template_path(name, path_prefixes=None):
    """Check that expected pods template exists.

    :param name:
    :param path_prefixes: in current state - cluster or kcm (depends on what we checking).
    other values does no make sense
    """
    if path_prefixes is None:
        path_prefixes = ['cluster/', 'cluster-oot/']
    for path_prefix in path_prefixes:
        path = settings.EXPECTED_PODS_TEMPLATES_DIR + path_prefix + name + '.yaml'
        if os.path.exists(path):
            LOG.info(f"Expected pods template found in directory {settings.EXPECTED_PODS_TEMPLATES_DIR + path_prefix}")
            return path
        else:
            LOG.warning(f"Expected pods template not found in directory "
                        f"{settings.EXPECTED_PODS_TEMPLATES_DIR + path_prefix}")
    if not settings.SKIP_EXPECTED_POD_CHECK:
        raise RuntimeError(f"SKIP_EXPECTED_POD_CHECK is False and expected pods template {name}.yaml "
                           f"does not exist in any of paths: "
                           f"{[settings.EXPECTED_PODS_TEMPLATES_DIR + path for path in path_prefixes]}. "
                           f"Please update templates or set SKIP_EXPECTED_POD_CHECK to True")
    return None


def is_version_dev(version_str: str) -> bool:
    # dev versions like 0.1.0-rc1-g3f33c21, 0.1.0-1-g3f33c21, 0.1.0-rc43-g3f33c23, 0.1.0-rc43-2-g3f33c23
    rx_dev = r'^\d+\.\d+\.\d+-(rc\d+|\d+|rc\d+-\d)-[0-9a-z]{8,10}$'
    return bool(re.match(rx_dev, version_str))


def is_version_enterprise(name_str: str) -> bool:
    # note: actually checks by name, not by version
    # enterprise names looks like k0rdent-enterprise-1-0-0-rc3, k0rdent-enterprise-1-0-0
    _RX_ENTERPRISE = re.compile(r".*-enterprise-.*")
    return bool(_RX_ENTERPRISE.match(name_str))


def check_conditions(conditions, expected_fails=None, verbose=True):
    """Checks k0rdent-specific conditions structure

    :param conditions: list of conditions in k0rdent-specific format, like:
      - lastTransitionTime: "2025-11-13T10:33:48Z"
        message: Release cilium/cilium
        reason: Managing
        status: "True"
        type: cilium.cilium/SveltosHelmReleaseReady
    or new format implemented from kcm 1.3.0:
      - lastStateTransitionTime: "2025-11-13T20:39:13Z"
        name: cilium
        namespace: cilium
        state: Deployed
        template: cilium-0-1-0
        type: Helm
        version: cilium-0-1-0

    :param expected_fails: dict of conditions 'type' in keys, and substring to match in values.
                           In case of matching, the condition will be skipped from check.

    Returns a dict with three keys:
    - 'ready':     list of conditions 'type' where status is 'True'
    - 'not_ready': list of conditions 'type' where status is not 'True'
    - 'skipped':   list of conditions 'type' which found in expected_fails dict
    """
    expected_fails = expected_fails or {}
    result = {
        'ready': [],
        'not_ready': [],
        'skipped': [],
    }
    if all(['reason' in condition for condition in conditions]):
        # Looks like conditions in old format
        conditions_data = []
        for condition in conditions:
            if condition['status'] in ['True', 'False']:
                ready = eval(condition['status'])
            else:
                ready = False
            condition_type = condition.get('type')
            condition_message = condition.get('message')
            if (not ready and condition_type in expected_fails.keys() and
                    expected_fails[condition_type] in condition_message):
                result['skipped'].append(condition_type)
            elif ready:
                result['ready'].append(condition_type)
            else:
                result['not_ready'].append(condition_type)
            conditions_data.append({
                'type': condition_type,
                'status': condition['status'],
                'message': condition_message,
            })
        if verbose:
            headers = ["Type", "Status", "Message"]
            status_data = [[data["type"],
                            data["status"],
                            data["message"]]
                           for data in conditions_data]
            # Show Machines status and not ready conditions
            status_msg = tabulate.tabulate(status_data, tablefmt="presto", headers=headers)
            LOG.info(f"\n{status_msg}\n")
    elif all(['state' in condition for condition in conditions]):
        # Looks like conditions in new format
        conditions_data = []
        for condition in conditions:
            if condition.get('state') == 'Deployed':
                ready = True
            else:
                ready = False

            condition_type = f"{condition.get('namespace')}.{condition.get('name')}/{condition.get('type')}"
            if (not ready and condition_type in expected_fails.keys()):
                result['skipped'].append(condition_type)
            elif ready:
                result['ready'].append(condition_type)
            else:
                result['not_ready'].append(condition_type)
            conditions_data.append({
                'type': condition_type,
                'template': f"{condition.get('template')}/{condition.get('version')}",
                'state': condition.get('state'),
                'message': condition.get('failureMessage') or '',
            })
        if verbose:
            headers = ["Type", "Template", "State", "Failure Message"]
            status_data = [[data["type"],
                            data["template"],
                            data["state"],
                            data["message"]]
                           for data in conditions_data]
            # Show Machines status and not ready conditions
            status_msg = tabulate.tabulate(status_data, tablefmt="presto", headers=headers)
            LOG.info(f"\n{status_msg}\n")
    else:
        raise Exception(f"Conditions are in mixed or unknown format:\n{yaml.dump(conditions)}")

    return result


class Provider(enum.Enum):
    aws = ("aws", "v1beta2", "AWSMachine", "AWSCluster")
    metal3 = ("metal3", "v1beta1", "Metal3Machines", "Metal3Cluster")
    azure = ("azure", "v1beta1", "AzureMachine", "AzureCluster")
    vsphere = ("vsphere", "v1beta1", "VSphereMachine", "VSphereCluster")
    openstack = ("openstack", "v1beta1", "OpenStackMachine", "OpenStackCluster")
    gcp = ("gcp", "v1beta1", "GCPMachine", "GCPCluster")
    eks = ("eks", "v1beta2", "AWSMachine", "AWSManagedCluster")
    aks = ("aks", "v1alpha1", None, "AzureASOManagedCluster")
    gke = ("gke", "v1beta1", None, "GCPManagedCluster")
    remotessh = ("remotessh", "v1beta1", "RemoteMachine", "RemoteCluster")

    def __init__(self, provider_name, api_version, machine_spec, infrakind):
        self.provider_name = provider_name
        self.api_version = api_version
        self.machine_spec = machine_spec
        self.infrakind = infrakind

    @classmethod
    def get_all_machine_spec(cls):
        """
        Get all Machine Provider Spec

        Returns: List of Machine Provider Spec

        """
        return [e.machine_spec for e in cls]

    @classmethod
    def get_provider_by_name(cls, provider_name):
        """
        Get Provider object by Provider name
        Args:
            provider_name: provider name

        Returns: Provider object

        """
        return next((e for e in cls if e.provider_name == provider_name), None)

    @classmethod
    def get_provider_by_machine(cls, machine_spec, api_version='v1alpha1'):
        """
        Get Provider object by Machine Spec

        Args:
            machine_spec: Machine Spec
            api_version: API version

        Returns:

        """
        return next((e for e in cls if e.machine_spec == machine_spec
                     and e.api_version == api_version), None)

    def __str__(self):
        return self.provider_name

    @classmethod
    def get_provider_by_infrakind(cls, infrakind):
        """
        Get provider object by corresponding infrasturcture reference Kind
        :param infrakind: InfraRef Kind
        :return: Provider object
        """
        return next((e for e in cls if e.infrakind == infrakind))
