#    Copyright 2025 Mirantis, Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

from tabulate import tabulate

from si_tests import logger
from si_tests.utils import utils, waiters

LOG = logger.logger


class InferenceManager(object):
    """GCore inference manager"""

    def __init__(self, box_api_client):
        self.box_api_client = box_api_client

    def get_capacity(self):
        """Get box cluster capacity: how many instances can be deployed with each flavor"""
        capacity = self.box_api_client.capacity.v1_list_capacities()
        return capacity

    def show_capacity(self):
        """Show box cluster capacity"""
        capacity = self.get_capacity()
        LOG.debug(f"Capacity: {capacity}")
        headers = ["Flavor", "Region", "Capacity"]
        capacity_data = [[k, region, v]
                         for region, cap in capacity.items() for k, v in cap.items()]
        status_msg = tabulate(capacity_data, tablefmt="presto", headers=headers)
        LOG.info(f"Capacity:\n{status_msg}\n")

    def get_gpu_nodes(self):
        """Get nodes with GPU that are available to use for inference workloads"""
        gpu_nodes = self.box_api_client.nodes.v1_list_gpu_nodes()
        return gpu_nodes.results

    def show_gpu_nodes(self):
        """Show GPU nodes details"""
        gpu_nodes = self.get_gpu_nodes()
        LOG.debug(f"GPU nodes: {gpu_nodes}")
        headers = ["GPU node name",
                   "Region",
                   "Node group",
                   "GPU",
                   "Available CPU",
                   "Available memory",
                   "Available storage"]
#        nodes_data = [[data['name'],
#                       data['region_name'],
#                       data["node_group"],
#                       data["gpu_model"],
#                       f"{data['available_resources']['cpu']} / {data['total_resources']['cpu']}",
#                       (f"{utils.convert_to_gb(data['available_resources']['memory'])} GiB / "
#                        f"{utils.convert_to_gb(data['total_resources']['memory'])} GiB"),
#                       (f"{utils.convert_to_gb(data['available_resources']['ephemeral-storage'])} GiB / "
#                        f"{utils.convert_to_gb(data['total_resources']['ephemeral-storage'])} GiB")
#                       ]
#                      for data in gpu_nodes]
        nodes_data = [[data.name,
                       data.region_name,
                       data.node_group,
                       data.gpu_model,
                       f"{data.available_resources['cpu']} / {data.total_resources['cpu']}",
                       (f"{utils.convert_to_gb(data.available_resources['memory'])} GiB / "
                        f"{utils.convert_to_gb(data.total_resources['memory'])} GiB"),
                       (f"{utils.convert_to_gb(data.available_resources['ephemeral-storage'])} GiB / "
                        f"{utils.convert_to_gb(data.total_resources['ephemeral-storage'])} GiB")
                       ]
                      for data in gpu_nodes]
        status_msg = tabulate(nodes_data, tablefmt="presto", headers=headers)
        LOG.info(f"GPU Nodes:\n{status_msg}\n")

    def get_flavors(self):
        """Get available flavors for inference workloads"""
        flavors = self.box_api_client.flavors.v1_list_flavors()
        LOG.debug(f"Flavors: {flavors}")
        return flavors.results

    def show_flavors(self):
        """Show available flavors"""
        flavors = self.get_flavors()
        headers = ["Flavor name", "CPU", "Memory", "GPU"]
        flavors_data = [[data.name,
                         data.cpu,
                         data.memory,
                         f"{data.gpu} x {data.gpu_model} , {data.gpu_memory}" if int(data.gpu) else "-",
                         ]
                        for data in flavors]
        status_msg = tabulate(flavors_data, tablefmt="presto", headers=headers)
        LOG.info(f"Flavors:\n{status_msg}\n")

    def get_projects(self):
        """Get box projects (special namespaces in kubernetes)"""
        projects = self.box_api_client.projects.v1_list_projects()
        LOG.debug(f"Projects: {projects}")
        return projects.results

    def get_or_create_project(self, project_name):
        """Create a new project only if not exists"""
        projects = self.get_projects()
        match_projects = [p for p in projects if project_name == p.name]
        if not match_projects:
            LOG.info(f"Project {project_name} not found, creating")
            create_project_request = self.box_api_client.projects.V1CreateProjectRequest(name=project_name)
            project = self.box_api_client.projects.v1_create_project(create_project_request)
        else:
            project = match_projects[0]
        return project

    def show_projects(self):
        """Show box projects"""
        projects = self.get_projects()
        headers = ["Project name"]
        projects_data = [[data.name,
                          ]
                         for data in projects]
        status_msg = tabulate(projects_data, tablefmt="presto", headers=headers)
        LOG.info(f"Projects:\n{status_msg}\n")

    def create_inference_request(self, inference_data):
        """Prepare a special object V1CreateInferenceRequest to create a new inference"""
        try:
            inference_request = self.box_api_client.inferences.V1CreateInferenceRequest.from_dict(inference_data)
        except Exception as e:
            # Swagger generated client can miss some attributes in the '.from_dict()' method.
            # Catch such errors and show details.
            if hasattr(e, 'errors'):
                raise Exception(f"Error creating 'CreateInferenceRequest' object: {e.errors()}")
            else:
                raise e
        return inference_request

    def create_inference(self, project_name, inference_data, dry_run):
        """Create a new inference workload
        project_name - str, name of the project where to create the inference workload
        inference_data - dict, all the parameters for inference object
        """
        LOG.info(f"Create inference '{inference_data.get('name')}' in project '{project_name}'")
        inference_request = self.create_inference_request(inference_data)
        inference = self.box_api_client.inferences.v1_create_inference(project_name, inference_request, dry_run=dry_run)
        return inference

    def get_inferences(self, project_name):
        """Get list of inference objects"""
        inferences = self.box_api_client.inferences.v1_list_inferences(project_name=project_name)
        LOG.debug(f"Inferences: {inferences}")
        return inferences.results

    def get_inference(self, project_name, inference_name):
        """Get inference object from the list of inferences"""
        # Use v1_list_inferences instead of v1_get_inference because of bug in swagger client for 'statuses' field
        inferences = self.get_inferences(project_name)
        inference = [inference for inference in inferences if inference.name == inference_name]
        if not inference:
            raise Exception(f"Inference {inference_name} not found")
        LOG.debug(f"Inference: {inference[0]}")
        return inference[0]

    def is_inference_exists(self, project_name, inference_name):
        """Check if the inference with the specified name exists"""
        inferences = self.get_inferences(project_name)
        return any(inference.name == inference_name for inference in inferences)

    def delete_inference(self, project_name, inference_name, timeout=1200, interval=15):
        """Delete the specified inference and wait until it is removed from the inferences list"""
        def _is_inference_deleted():
            exists = self.is_inference_exists(project_name, inference_name)
            self.show_inferences(project_name)
            return not exists

        LOG.info(f"Delete inference '{inference_name}' from the project '{project_name}'")
        self.box_api_client.inferences.v1_delete_inference(project_name=project_name, inference_name=inference_name)
        waiters.wait(_is_inference_deleted,
                     timeout=timeout,
                     interval=interval,
                     timeout_msg=f"Timeout deleting inference '{inference_name}' after {timeout} sec.")
        LOG.info(f"Inference '{inference_name}' is successfully deleted from the project '{project_name}'")

    def get_inference_addresses(self, project_name, inference_name):
        """Get the specified inferece addresses per region
        return: dict(region1=address1, region2=address2, ...)
        """
        inference = self.get_inference(project_name, inference_name)
        addresses = {region: status.address for region, status in (inference.statuses or {}).items()}
        return addresses

    def show_inferences(self, project_name):
        """Show inferences in the specified project"""
        inferences = self.get_inferences(project_name)
        headers = ["Inference name", "Image", "Regions", "Listening port", "Flavor", "Status", "Addresses"]
        inferences_data = [[data.name,
                            data.image,
                            data.regions,
                            data.listening_port,
                            data.flavor.name,
                            data.status,
                            '\n'.join([f"[{region}] {status.address} - {status.status}"
                                       for region, status in (data.statuses or {}).items()]),
                            ]
                           for data in inferences]
        status_msg = tabulate(inferences_data, tablefmt="presto", headers=headers)
        LOG.info(f"Inferences:\n{status_msg}\n")

    def wait_inference_ready(self, project_name, inference_name, timeout=1800, interval=15):
        """Wait until the specified inference status becomes 'Active' and all it's addresses become 'Ready'"""
        def _get_inference_readiness():
            inference = self.get_inference(project_name, inference_name)
            self.show_inferences(project_name)
            if not inference.status == 'Active':
                return False
            endpoints_ready = all(status.status == 'Ready'
                                  for region, status in (inference.statuses or {}).items())
            return endpoints_ready

        waiters.wait(_get_inference_readiness,
                     timeout=timeout,
                     interval=interval,
                     timeout_msg=f"Timeout for waiting inference '{inference_name}' readiness after {timeout} sec.")
        LOG.info(f"Inference '{inference_name}' is 'Active' and all it's addresses are 'Ready'")
