#    Copyright 2025 Mirantis, Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import cachetools.func as cachetools_func

from si_tests import settings
from si_tests import logger
from si_tests.utils import utils

from typing import Union

LOG = logger.logger

MachineProviders = Union[
    "AwsProviderMachine",
    "AzureProviderMachine",
    "EKSProviderMachine",
    "GCPProviderMachine",
    "Machine",
    "OpenStackProviderMachine",
    "VSphereProviderMachine",
    "Metal3Machines",
    "RemoteMachine"
]


class Machine(object):
    """Provider independent methods for the kaas machine"""

    def __init__(self, manager, clusterdeployment, capimachine):
        """
        manager: <Manager> instance
        cluster: <ClusterDeployment> instance
        capimachine: <Machine> instance
        """
        self.__manager = manager
        self.__clusterdeployment = clusterdeployment
        self.__capimachine = capimachine
        self.__uid = None
        self.__metadata = None
        self.__spec = None
        self.__machine_type = None
        self.__machine_types = None

        machine_spec = capimachine.spec
        self.capimachine_provider = machine_spec['infrastructureRef']['kind']
        self.capimachine_api_version = machine_spec['infrastructureRef']['apiVersion'].split("/")[-1]

    @property
    def provider(self):
        raise NotImplementedError("Abstract method")

    @property
    def internal_ip(self):
        raise NotImplementedError("Abstract method")

    @property
    def public_ip(self):
        raise NotImplementedError("Abstract method ")

    def power_off(self):
        raise NotImplementedError("Method is not implemented for current provider yet")

    def power_on(self):
        raise NotImplementedError("Method is not implemented for current provider yet")

    def get_power_status(self):
        raise NotImplementedError("Method is not implemented for current provider yet")

    @property
    def _manager(self):
        return self.__manager

    @property
    def _clusterdeployment(self):
        return self.__clusterdeployment

    @property
    @cachetools_func.ttl_cache(ttl=3600)
    def name(self):
        """Machine name"""
        return self.__capimachine.name

    @property
    @cachetools_func.ttl_cache(ttl=3600)
    def namespace(self):
        """Machine namespace"""
        return self.__capimachine.namespace

    @property
    def metadata(self):
        """Cached machine metadata"""
        if self.__metadata is None:
            self.__metadata = self.data['metadata']
        return self.__metadata

    @property
    def spec(self):
        """Cached machine spec"""
        if self.__spec is None:
            self.__spec = self.data['spec']
        return self.__spec

    @property
    def uid(self):
        """Machine uid"""
        if self.__uid is None:
            self.__uid = self.metadata['uid']
        return self.__uid

    @property
    def status(self):
        return self.data.get('status') or {}

    @property
    def machine_type(self):
        """Cached machine type
        :rtype string: One of ('control', 'worker', 'storage')
        """
        if self.__machine_type is None:
            self.__machine_type = \
                utils.get_type_by_labels(self.metadata['labels'])
        return self.__machine_type

    @property
    def data(self):
        """Returns dict of k8s object

        Data contains keys like api_version, kind,

        metadata, spec, status or items
        """

        return self.__capimachine.read().to_dict()

    def patch(self, *args, **kwargs):
        self.__capimachine.patch(*args, **kwargs)

    def replace(self, *args, **kwargs):
        self.__capimachine.replace(*args, **kwargs)

    @property
    def machine_types(self):
        """Cached machine types
        :rtype list: with any of valid machine types:
            'control', 'worker'
        """
        if self.__machine_types is None:
            self.__machine_types = \
                utils.get_types_by_labels(self.metadata['labels'])
        return self.__machine_types

    def is_machine_type(self, mtype):
        """
        :mtype string: One of ('control', 'worker')
        """
        if mtype in self.machine_types:
            return True
        return False

    @property
    def machine_phase(self):
        """Machine phase from capi

        Status object may disappear for a short periods, return None
        :rtype string:
        """
        machine_phase = self.__capimachine.phase
        return machine_phase or None

    def has_k8s_labels(self, labels) -> bool:
        """Check labels on K8sNode"""
        k8s_node = self.get_k8s_node()
        node_labels = k8s_node.data.get('metadata', {}).get('labels', {})
        for label, value in labels.items():
            if label not in node_labels or value != node_labels[label]:
                return False
        return True

    def get_k8s_node_name(self):
        return self.__capimachine.node_name

    def get_k8s_node(self):
        node_name = self.get_k8s_node_name()
        if node_name:
            return self._clusterdeployment.k8sclient.nodes.get(
                name=node_name
            )


class AwsProviderMachine(Machine):
    """AWS provider specific methods"""

    __internal_ip = None
    __public_ip = None

    @property
    def provider(self):
        return settings.AWS_PROVIDER_NAME

    def get_k8s_node_name(self):
        machine_status = self.data.get('status') or {}
        provider_id = self.data.get('spec', {}).get('providerID', '')
        if machine_status.get('nodeRef', {}).get('name', ''):
            return machine_status['nodeRef']['name']
        elif provider_id:
            nodes = [x for x in
                     self.__clusterdeployment.k8sclient.nodes.list_raw().items
                     if provider_id in x.spec.provider_id]
            if nodes:
                return nodes[0].metadata.name
        else:
            LOG.info('Node name for machine not available')
            return None

    @property
    @cachetools_func.ttl_cache(ttl=3600)
    def internal_ip(self):
        if self.__internal_ip is not None:
            return self.__internal_ip
        addresses = self.data.get('status', {}).get('addresses', [])
        if addresses:
            for address in addresses:
                if address.get('type', '') == 'InternalIP':
                    self.__internal_ip = address['address']
                    return self.__internal_ip
        else:
            return None

    @property
    def public_ip(self):
        LOG.debug("public_ip() is not implemented for AWS provider")
        return None

    # TODO(va4st): Add SI provider resource manager and align it with k0rdent resources
    # def power_off(self):
    #     aws_instance_name = \
    #         self.data['metadata']['annotations']['kaas.mirantis.com/uid']
    #     client = self.__clusterdeployment.provider_resources.client
    #     power_state = client.get_instance_state_by_name(
    #         instance_name=aws_instance_name)
    #     if power_state == 'running':
    #         client.instance_power_action(instance_name=aws_instance_name,
    #                                      action='stop')
    #         waiters.wait(lambda: client.get_instance_state_by_name(
    #             instance_name=aws_instance_name) == 'stopped', timeout=300,
    #                      interval=30)
    #         power_state = client.get_instance_state_by_name(
    #             aws_instance_name)
    #         LOG.info(f"state after power_off instance: {power_state}\n")
    #     else:
    #         raise Exception(f"Machine with name {aws_instance_name} "
    #                         f"in status {power_state}, expected <running>")
    #
    # def power_on(self):
    #     aws_instance_name = \
    #         self.data['metadata']['annotations']['kaas.mirantis.com/uid']
    #     client = self.__clusterdeployment.provider_resources.client
    #     power_state = client.get_instance_state_by_name(aws_instance_name)
    #     if power_state == 'stopped':
    #         client.instance_power_action(aws_instance_name, action='start')
    #         waiters.wait(lambda: client.get_instance_state_by_name(
    #             instance_name=aws_instance_name) == 'running', timeout=300,
    #                      interval=30)
    #         power_state = client.get_instance_state_by_name(
    #             aws_instance_name)
    #         LOG.info(f"state after power_on instance: {power_state}\n")
    #     else:
    #         raise Exception(f"Machine with name {aws_instance_name} "
    #                         f"in status {power_state}, expected <stopped>")


# (va4st): From kcm perspective - EKS machine == AWS machine. But from ksi perspective it's required to have it's
# own class for provider to make possible explicitly determine provider classes by machine. We will inherit EKS
# machine class from AWS for consistency but will not add any new methods here.
class EKSProviderMachine(AwsProviderMachine):
    """EKS provider specific methods"""

    __internal_ip = None
    __public_ip = None

    @property
    def provider(self):
        return settings.EKS_PROVIDER_NAME


class AzureProviderMachine(Machine):
    """Azure provider specific methods"""

    __internal_ip = None
    __public_ip = None

    @property
    def provider(self):
        return settings.AZURE_PROVIDER_NAME

    def get_k8s_node_name(self):
        machine_status = self.data.get('status') or {}
        provider_id = self.data.get('spec', {}).get('providerID', '')
        if machine_status.get('nodeRef', {}).get('name', ''):
            return machine_status['nodeRef']['name']
        elif provider_id:
            nodes = [x for x in
                     self.__clusterdeployment.k8sclient.nodes.list_raw().items
                     if provider_id in x.spec.provider_id]
            if nodes:
                return nodes[0].metadata.name
        else:
            LOG.info('Node name for machine not available')
            return None

    @property
    @cachetools_func.ttl_cache(ttl=3600)
    def internal_ip(self):
        if self.__internal_ip is not None:
            return self.__internal_ip
        addresses = self.data.get('status', {}).get('addresses', [])
        if addresses:
            for address in addresses:
                if address.get('type', '') == 'InternalIP':
                    self.__internal_ip = address['address']
                    return self.__internal_ip
        else:
            return None


class VSphereProviderMachine(Machine):
    """vSphere provider specific methods"""

    __internal_ip = None
    __public_ip = None

    @property
    def provider(self):
        return settings.VSPHERE_PROVIDER_NAME

    def get_k8s_node_name(self):
        machine_status = self.data.get('status') or {}
        provider_id = self.data.get('spec', {}).get('providerID', '')
        if machine_status.get('nodeRef', {}).get('name', ''):
            return machine_status['nodeRef']['name']
        elif provider_id:
            nodes = [x for x in
                     self.__clusterdeployment.k8sclient.nodes.list_raw().items
                     if provider_id in x.spec.provider_id]
            if nodes:
                return nodes[0].metadata.name
        else:
            LOG.info('Node name for machine not available')
            return None

    @property
    @cachetools_func.ttl_cache(ttl=3600)
    def internal_ip(self):
        if self.__internal_ip is not None:
            return self.__internal_ip
        addresses = self.data.get('status', {}).get('addresses', [])
        if addresses:
            for address in addresses:
                if address.get('type', '') == 'InternalIP':
                    self.__internal_ip = address['address']
                    return self.__internal_ip
        else:
            return None


class OpenStackProviderMachine(Machine):
    """OpenStack provider specific methods"""

    __internal_ip = None
    __public_ip = None

    @property
    def provider(self):
        return settings.OPENSTACK_PROVIDER_NAME

    def get_k8s_node_name(self):
        machine_status = self.data.get('status') or {}
        provider_id = self.data.get('spec', {}).get('providerID', '')
        if machine_status.get('nodeRef', {}).get('name', ''):
            return machine_status['nodeRef']['name']
        elif provider_id:
            nodes = [x for x in
                     self.__clusterdeployment.k8sclient.nodes.list_raw().items
                     if provider_id in x.spec.provider_id]
            if nodes:
                return nodes[0].metadata.name
        else:
            LOG.info('Node name for machine not available')
            return None

    @property
    @cachetools_func.ttl_cache(ttl=3600)
    def internal_ip(self):
        if self.__internal_ip is not None:
            return self.__internal_ip
        addresses = self.data.get('status', {}).get('addresses', [])
        if addresses:
            for address in addresses:
                if address.get('type', '') == 'InternalIP':
                    self.__internal_ip = address['address']
                    return self.__internal_ip
        else:
            return None


class GCPProviderMachine(Machine):
    """GCP provider specific methods"""

    __internal_ip = None
    __public_ip = None

    @property
    def provider(self):
        return settings.GCP_PROVIDER_NAME

    def get_k8s_node_name(self):
        machine_status = self.data.get('status') or {}
        provider_id = self.data.get('spec', {}).get('providerID', '')
        if machine_status.get('nodeRef', {}).get('name', ''):
            return machine_status['nodeRef']['name']
        elif provider_id:
            nodes = [x for x in
                     self.__clusterdeployment.k8sclient.nodes.list_raw().items
                     if provider_id in x.spec.provider_id]
            if nodes:
                return nodes[0].metadata.name
        else:
            LOG.info('Node name for machine not available')
            return None

    @property
    @cachetools_func.ttl_cache(ttl=3600)
    def internal_ip(self):
        if self.__internal_ip is not None:
            return self.__internal_ip
        addresses = self.data.get('status', {}).get('addresses', [])
        if addresses:
            for address in addresses:
                if address.get('type', '') == 'InternalIP':
                    self.__internal_ip = address['address']
                    return self.__internal_ip
        else:
            return None


class Metal3Machines(Machine):
    """metal3 provider specific methods"""

    __internal_ip = None
    __public_ip = None

    @property
    def provider(self):
        return settings.KSI_METAL3_PROVIDER_NAME

    def get_k8s_node_name(self):
        machine_status = self.data.get('status') or {}
        provider_id = self.data.get('spec', {}).get('providerID', '')
        if machine_status.get('nodeRef', {}).get('name', ''):
            return machine_status['nodeRef']['name']
        elif provider_id:
            nodes = [x for x in
                     self.__clusterdeployment.k8sclient.nodes.list_raw().items
                     if provider_id in x.spec.provider_id]
            if nodes:
                return nodes[0].metadata.name
        else:
            LOG.info('Node name for machine not available')
            return None

    @property
    @cachetools_func.ttl_cache(ttl=3600)
    def internal_ip(self):
        if self.__internal_ip is not None:
            return self.__internal_ip
        addresses = self.data.get('status', {}).get('addresses', [])
        if addresses:
            for address in addresses:
                if address.get('type', '') == 'InternalIP':
                    self.__internal_ip = address['address']
                    return self.__internal_ip
        else:
            return None


class RemoteMachine (Machine):
    """remote/ssh provider specific methods"""

    __internal_ip = None
    __public_ip = None

    @property
    def provider(self):
        return settings.KSI_REMOTE_PROVIDER_NAME

    def get_k8s_node_name(self):
        machine_status = self.data.get('status') or {}
        provider_id = self.data.get('spec', {}).get('providerID', '')
        if machine_status.get('nodeRef', {}).get('name', ''):
            return machine_status['nodeRef']['name']
        elif provider_id:
            nodes = [x for x in
                     self.__clusterdeployment.k8sclient.nodes.list_raw().items
                     if provider_id in x.spec.provider_id]
            if nodes:
                return nodes[0].metadata.name
        else:
            LOG.info('Node name for machine not available')
            return None

    @property
    @cachetools_func.ttl_cache(ttl=3600)
    def internal_ip(self):
        if self.__internal_ip is not None:
            return self.__internal_ip
        addresses = self.data.get('status', {}).get('addresses', [])
        if addresses:
            for address in addresses:
                if address.get('type', '') == 'InternalIP':
                    self.__internal_ip = address['address']
                    return self.__internal_ip
        else:
            return None
