from datetime import datetime, timedelta
from contextlib import contextmanager
from functools import partial
import pytest
import pytz
import os
import tempfile
import yaml

from si_tests import settings
from si_tests.deployments.utils import commons, kubectl_utils, file_utils
from si_tests.deployments.utils.namespace import NAMESPACE
from si_tests.managers.openstack_client_manager import OpenStackClientManager
from si_tests.utils import waiters


class TestMariadbBackup(object):
    backup_job_name = "mariadb-phy-backup"
    restore_job_name = "mariadb-phy-restore"
    vms = ["test-vm-0", "test-vm-1"]
    os_namespace = NAMESPACE.openstack
    osdpl_applying_time = 600
    os_client = OpenStackClientManager()
    os_manager = os_client.os_manager
    osdpl = os_manager.get_osdpl_deployment()
    backend = osdpl.data["spec"]["features"].get("database", {}).get("backup", {}).get("backend", "pvc")
    sync_remote = os_manager.is_db_backup_sync_remote_enabled()

    if sync_remote:
        remote_backend = (list(osdpl.data["spec"]["features"]["database"]["backup"]["sync_remote"]["remotes"]))[0]
        bucket = osdpl.data["spec"]["features"]["database"]["backup"]["sync_remote"]["remotes"][remote_backend]["path"]
        images = os_manager.get_os_helmbundle_images()
        helperImage = images["openstack-database"]["openstack-mariadb"]["mariadb"]

    @classmethod
    def setup_class(cls):
        if cls.sync_remote:
            cls.cleanup_remote_backups()

    @classmethod
    def teardown_class(cls):
        for vm_name in cls.vms:
            commons.LOG.info("Delete vm: %s", vm_name)
            cls.os_client.server.delete([vm_name])
        if cls.sync_remote:
            cls.cleanup_remote_backups()

    def assert_clean_env(self):
        def assert_no_job(name, log_str):
            commons.LOG.info("Assert there is no %s", log_str)
            assert len(self.os_manager.api.jobs.list(
                namespace=self.os_namespace, name_prefix=name)) == 0

        checks = [
            (self.backup_job_name, "backup job"),
            (self.restore_job_name, "restore job"),
        ]

        for check in checks:
            assert_no_job(*check)

        commons.LOG.info("Assert there is created cronjob %s in suspended state",
                         self.backup_job_name)

        cronjob = self.os_manager.api.cronjobs.get(self.backup_job_name,
                                                   namespace=self.os_namespace)
        assert cronjob.read().spec.suspend is True

    def wait_for_new_job(self, job_name, filter_func=None, filter_args=None):
        filter_args = filter_args or []
        jobs = self.os_manager.api.jobs.list(
            namespace=self.os_namespace, name_prefix=job_name)
        if filter_func:
            jobs = list(filter(lambda job: filter_func(job, *filter_args), jobs))
        assert len(jobs) == 1
        return jobs[0]

    def wait_job_complete(self, run_job):
        job = run_job.read()
        if job.status.failed:
            raise ValueError("Job {} failed to complete".format(run_job.name))
        assert job.status.succeeded == 1

    def wait_for_job(self, job_name, filter_func=None, filter_args=None,
                     trigger_timeout=600, complete_timeout=600):
        commons.LOG.info("Wait for %s job to trigger", job_name)

        waiter = partial(waiters.wait_pass, expected=AssertionError, interval=30)
        run_job = waiter(
            self.wait_for_new_job, timeout=trigger_timeout,
            predicate_args=(job_name, filter_func, filter_args))

        commons.LOG.info("Wait for %s job to succeed", job_name)
        waiter(self.wait_job_complete, timeout=complete_timeout,
               predicate_args=(run_job,))

        commons.LOG.info("%s: \n%s\n", run_job.name, run_job.read().status)

        return run_job

    def get_backup_dict(self, **backup):
        return {"spec": {"features": {"database": {"backup": backup}}}}

    def create_backups(self):
        with self.os_client.provide_vm(self.vms[0]):
            commons.LOG.info("Perform base backup")
            run_job = self.make_backup([])
            all_vms = self.os_client.server.list([])
            commons.LOG.info("Vms: %s", all_vms)

            with self.os_client.provide_vm(self.vms[1]):
                commons.LOG.info("Perform incremental backup")
                self.make_backup([run_job.name])
                all_vms = self.os_client.server.list([])
                commons.LOG.info("Vms: %s", all_vms)

    def make_backup(self, known):
        def get_time_for_cron(time_delta):
            schedule = datetime.now() + timedelta(seconds=time_delta)
            return schedule.astimezone(pytz.utc).strftime(
                "%M %H %d %m *")

        def filter_known(job, known):
            return job.name not in known

        time_delta = self.osdpl_applying_time + 30
        trigger_timeout = time_delta + 30
        self.osdpl.patch(self.get_backup_dict(backup_type="incremental", enabled=True,
                                              schedule_time=get_time_for_cron(time_delta)))
        return self.wait_for_job(
            self.backup_job_name, filter_func=filter_known,
            filter_args=(known,), trigger_timeout=trigger_timeout,
            complete_timeout=settings.MARIADB_BACKUP_JOB_TIMEOUT)

    @classmethod
    @contextmanager
    def spawn_service_pod(cls):
        kubectl = kubectl_utils.Kubectl()
        basepath = os.path.dirname(os.path.abspath(__file__))
        pod_yaml = file_utils.join(
            basepath, "templates/mariadb_check_pod.yaml")

        with open(pod_yaml, "r") as f:
            pod_data = yaml.safe_load(f)

        if cls.backend == "hostpath":
            pod_data["spec"]["volumes"].extend([
                {
                    "name": "mysql-backup",
                    "hostPath": {"path": "/var/lib/openstack-helm/mariadb-backup"},
                },
                {
                    "name": "mysql-data",
                    "persistentVolumeClaim": {"claimName": "mysql-data-mariadb-server-2"},
                }
            ])
        else:
            pod_data["spec"]["volumes"].append(
                {
                    "name": "mysql-backup",
                    "persistentVolumeClaim": {"claimName": "mariadb-phy-backup-data"},
                }
            )

        if cls.sync_remote:
            pod_data["spec"]["containers"][0]["image"] = cls.helperImage
            pod_data["spec"]["containers"][0]["volumeMounts"].append(
                {
                    "name": "mariadb-secrets",
                    "mountPath": "/tmp/.rclone.conf",
                    "subPath": "rclone.conf",
                    "readOnly": True,
                }
            )

        with tempfile.NamedTemporaryFile(mode="w") as f:
            yaml.dump(pod_data, f)
            kubectl.apply(f.name)
        try:

            def _wait_pod():
                pod = cls.os_manager.api.pods.get(
                    "si-tests-helper", namespace=cls.os_namespace)
                assert pod.read().status.phase == "Running"
                return pod

            commons.LOG.info("Wait for si-tests-helper pod to become Running")
            check_pod = waiters.wait_pass(
                _wait_pod, expected=Exception, timeout=120)
            yield check_pod
        finally:
            kubectl.delete("", "-f {}".format(pod_yaml), cls.os_namespace, cwd=".")

    def get_backup_names_by_location(self, location="local"):
        commons.LOG.info("Spawn pod to get backup name")

        def check_datetime(name):
            try:
                datetime.strptime(name, "%Y-%m-%d_%H-%M-%S")
            except ValueError as e:
                commons.LOG.exception("Backup name %s is not valid", name)
                raise e

        with self.spawn_service_pod() as check_pod:

            def get_files(path):
                if location == "remote":
                    raw_list = (
                        check_pod.exec(
                            [
                                "rclone",
                                "--no-check-certificate",
                                "--config=/tmp/.rclone.conf",
                                "lsf",
                                "{}:{}/{}".format(self.remote_backend, self.bucket, path),
                            ]
                        )
                        .strip()
                        .split("\n")
                    )
                    return [x.rstrip("/") for x in raw_list]
                elif location == "local":
                    return (
                        check_pod.exec(["ls", os.path.join("/var/backup/", path)])
                        .strip()
                        .split("\n")
                    )
                else:
                    raise Exception("Unsupported location")

            commons.LOG.info("Get base backup name")
            backup_name = get_files("base")[0]

            check_datetime(backup_name)

            commons.LOG.info("Found backup %s", backup_name)

            back_files = get_files(
                "base/{}".format(backup_name))

            back_files = list(map(lambda x: x.strip(), back_files))

            assert "backup.stream.gz" in back_files
            assert "backup.successful" in back_files

            commons.LOG.info("Get incremental backups")
            incr_backups_folders = get_files("incr")

            map(check_datetime, incr_backups_folders)

            assert backup_name in incr_backups_folders

            incr_backups = get_files("incr/{}".format(
                backup_name))

            commons.LOG.info("Incremental backups are %s", incr_backups)

            for incr_backup in incr_backups:
                check_datetime(incr_backup)
                files = get_files("incr/{}/{}".format(
                    backup_name, incr_backup))
                assert "backup.stream.gz" in files
                assert "backup.successful" in files

        return backup_name, incr_backups[0]

    def get_restore_dict(self, **kwargs):
        return {"spec": {"services": {"database": {"mariadb": {
            "values": kwargs}}}}}

    def check_restore_deleted(self):
        assert len(self.os_manager.api.jobs.list(
            namespace=self.os_namespace, name_prefix=self.restore_job_name)) == 0

    def restore_backup(self, backup_name):
        commons.LOG.info("Perform DB restore")
        self.osdpl.patch(self.get_restore_dict(
            conf={"phy_restore": {"backup_name": backup_name}},
            manifests={"job_mariadb_phy_restore": True}))

        self.wait_for_job(
            self.restore_job_name,
            trigger_timeout=self.osdpl_applying_time,
            complete_timeout=settings.MARIADB_RESTORE_JOB_TIMEOUT,
        )

        commons.LOG.info("Turn off the restore")
        self.osdpl.patch(self.get_restore_dict(
            manifests={"job_mariadb_phy_restore": False}))

        commons.LOG.info("Wait for restore job to be deleted")
        waiters.wait_pass(self.check_restore_deleted, expected=AssertionError,
                          interval=30, timeout=self.osdpl_applying_time)
        # Restore operation leads to shutdown of all mariadb server pods
        # we need to make sure all services are ready after mariadb pods
        # are restored.
        self.os_manager.wait_openstackdeployment_health_status()

    def assert_vms(self, present=None, absent=None):
        present = present or []
        absent = absent or []
        all_vms = [x["Name"] for x in self.os_client.server.list([])]

        for vm in present:
            assert vm in all_vms
        for vm in absent:
            assert vm not in all_vms

    def break_local_backup(self, backups, backup_type="base"):
        commons.LOG.info("Corrupt {} backup archives on the local storage.".format(backup_type))
        with self.spawn_service_pod() as work_pod:
            if backup_type == "base":
                backup_path = os.path.join("/var/backup/base/", backups["base"], "backup.stream.gz")
            elif backup_type == "incremental":
                backup_path = os.path.join("/var/backup/incr/", backups["base"], backups["incr"], "backup.stream.gz")
            else:
                raise Exception("Unknown backup type \"{}\".".format(backup_type))
            work_pod.exec(["dd", "if=/dev/urandom", "of={}".format(backup_path), "conv=notrunc", "bs=1M", "count=1"])

    @classmethod
    def cleanup_remote_backups(cls):
        commons.LOG.info("Clean up backup on the remote storage.")
        with cls.spawn_service_pod() as work_pod:
            for folder in ["base", "incr"]:
                work_pod.exec(
                    [
                            "rclone",
                            "--no-check-certificate",
                            "--config=/tmp/.rclone.conf",
                            "purge",
                            "{}:{}/{}".format(cls.remote_backend, cls.bucket, folder),
                    ]
                )

    def clean_env(self):
        kubectl = kubectl_utils.Kubectl()
        commons.LOG.info("Remove all existing backup jobs")
        for job in self.os_manager.api.jobs.list(namespace=self.os_namespace, name_prefix=self.backup_job_name):
            kubectl.delete("job", job.name, namespace=self.os_namespace)

    @pytest.mark.usefixtures("mos_workload_downtime_report")
    @pytest.mark.usefixtures("mos_per_node_workload_check_after_test")
    @pytest.mark.dependency(name="backup_basic")
    def test_backup_basic(self):
        """ Test scenario:
            1. Check heck feasibility of the test processing
            2. Create VM.
            3. Create base backup.
            4. Create second VM.
            5. Create incremental backup.
            6. Remove both VMs.
            7. Verify if local backups exist
        """
        self.assert_clean_env()
        self.create_backups()
        base_backup, incr_backup = self.get_backup_names_by_location(location="local")
        assert base_backup is not None and base_backup != '', "Full backup is not found."
        assert incr_backup is not None and incr_backup != '', "Incremental backup is not found."

    @pytest.mark.usefixtures("mos_workload_downtime_report")
    @pytest.mark.usefixtures("mos_per_node_workload_check_after_test")
    @pytest.mark.usefixtures('skip_by_mariadb_backup_remote_sync')
    @pytest.mark.mariadb_backup_remote_sync('enabled')
    @pytest.mark.dependency(name="corrupt_backup_basic", depends=["backup_basic"])
    def test_backup_with_remote_fix_corrupted_full(self):
        """ Test scenario:
            1. Corrupt base backup archives on the local storage
            2. Create incremental backup. It is expected that after backup job is triggered
             - corrupted backup is fixed by downloading correct backup from remote and restore succeeds.
            The test case 'test_restore_incremental' will verify that restore job succeeds.
        """
        self.clean_env()
        base_backup, incr_backup = self.get_backup_names_by_location(location="local")
        self.break_local_backup({"base": base_backup, "incr": incr_backup})
        self.make_backup([])

    @pytest.mark.usefixtures('skip_by_mariadb_backup_remote_sync')
    @pytest.mark.mariadb_backup_remote_sync('enabled')
    @pytest.mark.dependency(name="verify_remote_backup", depends=["corrupt_backup_basic"])
    def test_backup_remote_storage(self):
        """ Test scenario:
            1. Get list of backup files with their md5summ from remote storage
            2. Get list of backup files with their md5summ from local filesystem
            3. Check if both lists are equal each other
        """
        def make_list(raw_list, prefix=""):
            result = []
            for line in raw_list:
                pair = line.split()
                result.append({"hash": pair[0], "path": pair[1].replace(prefix, "")})
            return result

        def is_list_eq(list1, list2):
            if len(list1) != len(list2):
                return False
            diff = [i for i in list1 if i not in list2]
            return (diff == [])

        with self.spawn_service_pod() as work_pod:
            raw_list = (
                work_pod.exec(
                    [
                            "rclone",
                            "--no-check-certificate",
                            "--config=/tmp/.rclone.conf",
                            "md5sum",
                            "{}:{}".format(self.remote_backend, self.bucket),
                    ]
                )
                .strip()
                .split("\n")
            )
            remote_backups = make_list(raw_list)
            raw_list = (
                work_pod.exec(
                    [
                            "find",
                            "/var/backup/",
                            "-type",
                            "f",
                            "-exec",
                            "md5sum",
                            "{}",
                            ";",
                    ]
                )
                .strip()
                .split("\n")
            )
            local_backups = make_list(raw_list, "/var/backup/")
            commons.LOG.info("Assert if the remote backup matches the local one.")
            assert is_list_eq(local_backups, remote_backups)

    @pytest.mark.usefixtures("mos_workload_downtime_report")
    @pytest.mark.usefixtures("mos_per_node_workload_check_after_test")
    @pytest.mark.dependency(name="restore_full", depends=["backup_basic"])
    def test_restore_full(self):
        """ Test scenario:
            1. Check if both test VMs are absent
            2. Get names of base backups
            3. Restore base backup
            4. Check if only first test VM is present
        """
        commons.LOG.info("Check VMs before restore.")
        self.assert_vms(absent=self.vms)
        base_backup, _ = self.get_backup_names_by_location(location="local")
        commons.LOG.info("Process full backup restoration.")
        self.restore_backup(base_backup)
        commons.LOG.info("Check if only first VM is present.")
        self.assert_vms(present=self.vms[:1], absent=self.vms[1:])

    @pytest.mark.usefixtures("mos_workload_downtime_report")
    @pytest.mark.usefixtures("mos_per_node_workload_check_after_test")
    @pytest.mark.dependency(name="restore_incremental", depends=["restore_full"])
    def test_restore_incremental(self):
        """ Test scenario:
            1. Get names of full and incremental backups
            2. Restore incremental backup
            3. Check if both test VMs are present
        """
        base_backup, incr_backup = self.get_backup_names_by_location(location="local")
        commons.LOG.info("Process incrementall backup restoration.")
        self.restore_backup("{}/{}".format(base_backup, incr_backup))
        commons.LOG.info("Check if both test VMs are present.")
        self.assert_vms(present=self.vms)

    @pytest.mark.usefixtures("mos_workload_downtime_report")
    @pytest.mark.usefixtures("mos_per_node_workload_check_after_test")
    @pytest.mark.dependency(name="restore_remote_full", depends=["restore_incremental", "verify_remote_backup"])
    def test_restore_from_remote_local_storage_corrupted_full(self):
        """ Test scenario:
            1. Get names of base and incremental backups
            2. Corrupt base backup archives on the local storage
            3. Restore base backup
            4. Check if only first test VM is in the base
        """
        base_backup, incr_backup = self.get_backup_names_by_location(location="remote")
        self.break_local_backup({"base": base_backup, "incr": incr_backup})
        commons.LOG.info("Process full backup restoraton.")
        self.restore_backup(base_backup)
        commons.LOG.info("Check if only first VM is present.")
        self.assert_vms(present=self.vms[:1], absent=self.vms[1:])

    @pytest.mark.usefixtures("mos_workload_downtime_report")
    @pytest.mark.usefixtures("mos_per_node_workload_check_after_test")
    @pytest.mark.dependency(depends=["restore_remote_full"])
    def test_restore_from_remote_local_storage_corrupted_incremental(self):
        """ Test scenario:
            1. Get names of base and incremental backups
            2. Corrupt incremental backup archives on the local storage
            2. Restore incremental backup
            3. Check if both test VMs are in the base
        """
        base_backup, incr_backup = self.get_backup_names_by_location(location="remote")
        self.break_local_backup({"base": base_backup, "incr": incr_backup}, backup_type="incremental")
        commons.LOG.info("Process incremental backup restoration.")
        self.restore_backup("{}/{}".format(base_backup, incr_backup))
        commons.LOG.info("Check if both test VMs are present.")
        self.assert_vms(present=self.vms)
