import exec_helpers
import multiprocessing
import os
import yaml
from abc import abstractmethod

from concurrent import futures
from retry import retry
from paramiko.ssh_exception import NoValidConnectionsError

from si_tests import logger
from si_tests import settings
from si_tests.utils import utils, templates
from si_tests.deployments.utils import file_utils
from si_tests.managers.openstack_client_manager import OpenStackClientManager

LOG = logger.logger


class WorkloadBase:

    @abstractmethod
    def deploy(self):
        pass

    @abstractmethod
    def delete(self):
        pass

    @abstractmethod
    def check(self):
        pass

    @abstractmethod
    def collect_logs(self):
        pass


class PerNodeWorkload(WorkloadBase):
    def __init__(self, ostcm):
        self.ostcm = ostcm
        self.os_manager = self.ostcm.os_manager
        self.stack_prefix = "pernode-test"
        self._stacks_data = {}
        self.template_path = file_utils.join(
            os.path.dirname(os.path.abspath(__file__)),
            "../../tests/lcm/templates/heat_stack_host.yaml",
        )
        self.ssh_keys = utils.generate_keys()

    def _ensure_hugepages_flavor(self, hostname):
        LOG.info("Create a flavor with huge pages")
        flavor_name = "m1.tiny.dpdk-" + hostname
        flavors = self.ostcm.flavor.list(["--all", "--long"])
        current_flavor = [
            flavor for flavor in flavors if flavor.get("Name") == flavor_name
        ]
        if current_flavor:
            LOG.info(f"Flavor with name: {flavor_name} already exists")
        else:
            LOG.info(f"Creating new flavor with name: {flavor_name}")
            flavor_params = [
                "--ram",
                "1024",
                "--disk",
                "5",
                "--vcpus",
                "2",
                "--property",
                "hw:mem_page_size=2048",
                flavor_name,
            ]
            self.ostcm.flavor.create(flavor_params)
        return flavor_name

    def _ensure_stack_exists(self, stack_name, stack_data):
        if not self.ostcm.stack.exists(stack_name):
            LOG.info(f"Stack {stack_name} was not found, creating")
            hostname = stack_data["hostname"]
            node_labels = stack_data["node_labels"]
            az = stack_data["az"]
            if node_labels.get("openstack-compute-node-dpdk", "") == "enabled":
                flavor_name = self._ensure_hugepages_flavor(hostname)
                stack_params = {
                    "host_name": az + ":" + hostname,
                    "key_name": "key-" + hostname,
                    "image_name_1": self.ostcm.cirros_image_name,
                    "flavor1": flavor_name,
                }
            elif node_labels.get("openstack-compute-node", "") == "enabled":
                stack_params = {
                    "host_name": az + ":" + hostname,
                    "key_name": "key-" + hostname,
                    "image_name_1": self.ostcm.cirros_image_name,
                }
            self.ostcm.create_stack(
                '', stack_name, self.template_path, stack_params=stack_params, finalizer=False
            )
        else:
            LOG.info(f"Stack {stack_name} exists, skipping creation")

    def ssh_instance(self, stack_name, ip):
        ssh_access_key = self.ostcm.stack.output_show(
            [stack_name, "heat_key", "--column", "output_value"]
        )["output_value"]
        pkey = utils.get_rsa_key(ssh_access_key)

        @retry((TimeoutError, NoValidConnectionsError, ConnectionResetError), tries=5, delay=30)
        def _ssh_instance(ip):
            LOG.info(
                f"Attempt to connect to VM with ip: {ip} in stack {stack_name}"
            )
            auth = exec_helpers.SSHAuth(username="cirros", password="", key=pkey)
            ssh = exec_helpers.SSHClient(host=ip, port=22, auth=auth)
            ssh.logger.addHandler(logger.console)
            ssh.sudo_mode = True
            return ssh

        return _ssh_instance(ip)

    def _verify_stack_workload(self, stack_name):

        servers_to_verify = self.get_stack_servers_to_verify(stack_name)
        failed_results = {}

        # We should have consistent non-random paths and strings to verify
        # volume is available and contains a valid data after cluster upgrade
        volume_mount_file = '/mnt/test-volume/test-file'
        volume_mount_string = 'test volume expected string'

        server_with_attached_volume_exists = False
        for server_name, server_data in servers_to_verify.items():
            ip = server_data['fip']
            ssh = self.ssh_instance(stack_name, ip)
            expected_string = utils.gen_random_string(10)
            ssh.check_call(
                "echo {expected_string} > test.txt".format(
                    expected_string=expected_string
                )
            )
            res = ssh.check_call("cat test.txt")
            if res.stdout_brief != expected_string:
                failed_results[ip] = {
                    "actual": res.stdout_brief,
                    "expected": expected_string,
                }
                LOG.error(f"Instance with IP {ip} failed to check")
                continue

            volume = server_data.get('volume', {})
            if volume:
                server_with_attached_volume_exists = True
                # volume mount is executed and waited during stack creation,
                # but in case check is called after lcm operation we need to
                # check that mount point still exists
                res = ssh.check_call(
                    f"findmnt --source $(sudo cat /root/test_volume_{volume})",
                    verbose=True
                )
                if res.exit_code != 0:
                    failed_results[server_name] = {
                        'device_check': {'actual': res.stdout_brief, 'error': res.stderr_brief}}
                    LOG.error(f"Instance {stack_name} failed to check volume mountpoint")
                    continue
                # write file to a mounted dir of an attached volume
                res = ssh.check_call(f"echo {volume_mount_string} > {volume_mount_file}")
                if res.exit_code != 0:
                    failed_results[server_name] = {
                        'device_check': {'actual': res.stdout_brief, 'error': res.stderr_brief}}
                    LOG.error(f"Instance {stack_name} failed to check volume accessibility")
                    continue
                # read and verify wrote file is readable and has expected content
                res = ssh.check_call(f"cat {volume_mount_file}")
                if res.exit_code != 0:
                    failed_results[server_name] = {
                        'device_check': {'actual': res.stdout_brief, 'error': res.stderr_brief}}
                    LOG.error(f"Instance {stack_name} failed to check volume consistency")
                    continue
                if res.stdout_brief != volume_mount_string:
                    failed_results[server_name] = {
                        'device_check': {'actual': res.stdout_brief, 'expected': volume_mount_string}}
                    LOG.error(f"Instance {server_name} failed to check volume consistency")

        assert server_with_attached_volume_exists, \
            f"Test has not found any server with non-bootable attached volume for stack: {stack_name}"
        assert (
            not failed_results
        ), f"Some servers return wrong response in stack {stack_name}:\n{yaml.dump(failed_results)}"
        LOG.info("All instances checks passed")

    def get_stack_servers(self, stack_name):
        servers = [
            x["resource_name"]
            for x in self.ostcm.stack.resource_list(
                [stack_name, "--filter", "type=OS::Nova::Server"], yaml_output=True
            )
        ]
        return servers

    def get_stack_servers_to_verify(self, stack_name):
        servers = self.get_stack_servers(stack_name)
        assert servers, f"No servers existed for stack {stack_name}"

        def _get_fip_and_volume_attachment(server, volumes):
            # Use own instance as kubernetesclient is not thread safe
            ocm = OpenStackClientManager(self.ostcm.os_manager.kubeconfig)
            server_info = ocm.stack.resource_show([stack_name, server])
            result = {server: {'fip': None, 'volume': None}}
            network = list(server_info.get("attributes").get("addresses").keys())[0]
            for address in server_info.get('attributes').get('addresses').get(network):
                addr_type = address['OS-EXT-IPS:type']
                if addr_type == 'floating':
                    LOG.info("Verify instance {0} with floating ip {1}".format(
                        server_info.get('attributes').get('OS-EXT-SRV-ATTR:hostname'), address['addr']))
                    result[server]['fip'] = address['addr']
                    break
            attached_volumes = server_info.get("attributes").get("os-extended-volumes:volumes_attached")
            if attached_volumes:
                # iterate over all attached volumes and find those which are not
                # bootable ones
                for volume_attach in attached_volumes:
                    for volume in volumes:
                        if volume['ID'] == volume_attach['id']:
                            # do not include instances with boot from volume where
                            # volume is also attached but is used as a root disk
                            if volume['Bootable'] == 'true':
                                break
                            result[server]['volume'] = volume['ID']
                            LOG.info(f"Verify instance {server} with volume {volume['ID']}")
                            break

            return result

        servers_to_verify = {}
        volumes = self.ostcm.volume.list(["--long"])
        with futures.ThreadPoolExecutor() as executor:
            jobs = {
                executor.submit(_get_fip_and_volume_attachment, server, volumes): server for server in servers
            }
            for future in futures.as_completed(jobs):
                server = jobs[future]
                try:
                    data = future.result()
                    servers_to_verify.update(data)
                except Exception as e:
                    raise ValueError(
                        f"Getting server {server} IP and attached volume finished with exception: {e}"
                    )

        return servers_to_verify

    def get_servers_reboot_count(self, stack_name):
        def _count_reboots(acpid_stdout):
            count = 0
            for x in acpid_stdout:
                if x.decode().startswith('acpid: PWRF'):
                    count += 1
            return count

        reboot_data = {}
        for server_id, server in self.get_stack_servers_to_verify(stack_name).items():
            ssh = self.ssh_instance(stack_name, server["fip"])
            acpid_res = ssh.check_call("cat /var/log/acpid.log")
            reboot_data[server_id] = _count_reboots(acpid_res.stdout)
        return reboot_data

    def _reboot_stack_instance_with_volume(self, stack_name):
        servers = self.get_stack_servers(stack_name)
        assert servers, f"No servers existed for stack {stack_name}"

        for server_to_reboot in servers:
            server_info = self.ostcm.stack.resource_show([stack_name, server_to_reboot])
            if server_info.get("attributes").get("os-extended-volumes:volumes_attached"):
                server_id = server_info.get("attributes").get("id")
                break
        assert server_id, "Test hasn't found server with attached volume"
        LOG.info(f"Rebooting instance {server_id} from stack {stack_name}")
        self.ostcm.server_reboot(server_id)

        LOG.info("Stack instance reboot check passed")

    def _execute_stacks_tasks(self, tasks_map):
        process_data = {}
        for stack_name, task in tasks_map.items():
            process_data[stack_name] = {
                "process": multiprocessing.Process(
                    target=task["target"], args=task["args"], name=task["name"]
                ),
                "exit_code": None,
            }
            process_data[stack_name]["process"].start()
        # Waiting for all processes finished
        for k, v in process_data.items():
            v["process"].join()
            v["exit_code"] = v["process"].exitcode

        # Check processes exit codes
        assert all([v["exit_code"] == 0 for k, v in process_data.items()]), (
            "Next tasks are FAILED:\n",
            "\n".join(
                [
                    f"{v['process'].name} for stack {k}"
                    for k, v in process_data.items()
                    if v["exit_code"] != 0
                ]
            ),
        )

    @retry(Exception, delay=10, tries=3, jitter=1, logger=LOG)
    def _collect_stack_logs(self, stack_name, store_path):
        stacks = self.ostcm.stack.resource_list(
                [stack_name, "--filter", "type=OS::Nova::Server"], yaml_output=True
            )
        if stacks:
            servers = {
                x["resource_name"]: x["physical_resource_id"]
                for x in stacks
            }
        else:
            raise Exception(f'No stack found with stack_name: {stack_name}')
        base_path = file_utils.join(store_path, stack_name)
        os.makedirs(base_path, exist_ok=True)
        stack_info = self.ostcm.stack.show([stack_name])
        with open(os.path.join(base_path, "stack_info.yaml"), "w") as f:
            yaml.safe_dump(stack_info, f)

        # The CirrOS does not allow to use SFTP to download files from VM
        # We only collect console output for each VM
        for name, uuid in servers.items():
            server_path = file_utils.join(base_path, name)
            os.makedirs(server_path, exist_ok=True)
            self.ostcm.server_collect_console(uuid, arch_path=server_path)

    @property
    def stacks_data(self):
        if not self._stacks_data:
            for hs in self.ostcm.host.list([]):
                if hs.get("Service") == "compute" and "ironic" not in hs.get(
                    "Host Name"
                ):
                    hostname = hs["Host Name"]
                    node_labels = (
                        self.os_manager.api.nodes.get(name=hostname)
                        .read()
                        .to_dict()
                        .get("metadata", {})
                        .get("labels", {})
                    )
                    if (
                        node_labels.get("openstack-compute-node", "") == "enabled"
                        or node_labels.get("openstack-compute-node-dpdk", "")
                        == "enabled"
                    ):
                        stack_name = self.stack_prefix + hostname
                        self._stacks_data[stack_name] = {
                            "hostname": hostname,
                            "node_labels": node_labels,
                            "az": hs["Zone"],
                        }
        return self._stacks_data

    def deploy(self):
        tasks_map = {}
        if settings.OPENSTACK_ENCRYPTED_VOLUME:
            assert self.os_manager.is_encrypted_volume_type_described(), \
                "Encrypted volume type WAS NOT found in osdpl, " \
                "but settings.OPENSTACK_ENCRYPTED_VOLUME is enabled"
        LOG.info("Unset quotas for workloads")
        self.ostcm.quota.set(
            [
                "--volumes",
                "-1",
                "--instances",
                "-1",
                "--cores",
                "-1",
                "--routers",
                "-1",
                settings.MOSK_WORKLOAD_PROJECT_NAME
            ]
        )
        template = templates.render_template(
            self.template_path, options={"ssh_public_key": self.ssh_keys["public"],
                                         "ssh_private_key": self.ssh_keys["private"]})
        heat_template = yaml.safe_load(template)
        with open(self.template_path, 'w') as f:
            f.write(yaml.dump(heat_template))
        for stack_name, data in self.stacks_data.items():
            tasks_map[stack_name] = {
                "target": self._ensure_stack_exists,
                "args": (stack_name, data),
                "name": "ensure_stack_exists",
            }
        self._execute_stacks_tasks(tasks_map)

    def check(self):
        LOG.info("Checking Per node stacks workloads")
        tasks_map = {}
        for stack_name in self.stacks_data.keys():
            tasks_map[stack_name] = {
                "target": self.ostcm.check_stack,
                "args": (stack_name,),
                "name": "check_stack",
            }
        self._execute_stacks_tasks(tasks_map)

        if not self.os_manager.is_tf_enabled or (self.os_manager.is_tf_enabled and settings.RUN_ON_REMOTE):
            # Create and run processes for verifying stacks workloads
            for stack_name in self.stacks_data.keys():
                tasks_map[stack_name] = {
                    "target": self._verify_stack_workload,
                    "args": (stack_name,),
                    "name": "verify_stack",
                }
            self._execute_stacks_tasks(tasks_map)
        else:
            LOG.info("Skip checking stack workloads when TF without RUN_ON_REMOTE, due to fip")

    def delete(self):
        LOG.info("Deleting Per node stacks")
        tasks_map = {}
        for stack_name in self.stacks_data.keys():
            LOG.info(f"Removing stack {stack_name}")
            tasks_map[stack_name] = {
                "target": self.ostcm.stack.delete,
                "args": ([stack_name, "-y", "--wait"],),
                "name": "delete_stack",
            }
        self._execute_stacks_tasks(tasks_map)

    def collect_logs(self):
        LOG.info("Collecting logs from Per node stacks")
        tasks_map = {}
        archive_name = "per_node_stack_logs"
        store_path = file_utils.join(settings.ARTIFACTS_DIR, archive_name)
        for stack_name in self.stacks_data.keys():
            LOG.info(f"Collect logs from stack {stack_name}")
            tasks_map[stack_name] = {
                "target": self._collect_stack_logs,
                "args": (stack_name, store_path),
                "name": "collect_stack_logs",
            }
        self._execute_stacks_tasks(tasks_map)
        localsh = exec_helpers.Subprocess()
        localsh.check_call(f"tar -czf {store_path}.tar.gz -C {settings.ARTIFACTS_DIR} {archive_name}")
        localsh.check_call(f"rm -rf {store_path}")

    def reboot_instance_with_volume(self):
        """reboot function chooses random instance from stack and reboots it.

        It helps to understand that VM reboot won't be stuck due to
        any connectivity issues between nova and libvirt, libvirt and ceph
        for ceph-backend instances. We don't need rebooting all instances
        because it is sufficient to reboot one to verify libvirt has no issues.
        """
        LOG.info("Rebooting Per node stack instance with attached volume")
        tasks_map = {}
        for stack_name in self.stacks_data.keys():
            LOG.info(f"Rebooting instance with attached volume from stack {stack_name}")
            tasks_map[stack_name] = {
                "target": self._reboot_stack_instance_with_volume,
                "args": (stack_name,),
                "name": "reboot_stack_instance",
            }
        self._execute_stacks_tasks(tasks_map)
