#    Author: Alex Savatieiev (osavatieiev@mirantis.com; a.savex@gmail.com)
#    Copyright 2019-2022 Mirantis, Inc.
import base64
import json
import os
import tarfile
import io
from time import sleep
from datetime import datetime

from cfg_checker.common import logger_cli, logger
from cfg_checker.common.exception import KubeException

from cfg_checker.helpers.console_utils import Progress
from cfg_checker.helpers.tgz import TGZFile
from cfg_checker.nodes import KubeNodes
from cfg_checker.reports import reporter


class CephInfo(object):
    def __init__(
        self,
        config
    ):
        self.env_config = config
        return

    def get_info_archive_filename(self, client, project):
        # prefill known data
        _tags = ["CephCollectData"]
        _tags.append(client)
        _tags.append(project)

        # generate date for tgz
        _file_datetime_fmt = "%Y-%m-%d"
        _dt = datetime.now().strftime(_file_datetime_fmt)
        _tags.append(_dt)

        # extension
        _tags.append("tar")
        _tags.append("gz")
        return ".".join(_tags)

    def get_transposed_latency_table(self):
        _table = {
            "<dev>": []
        }
        for _pfd in self.ceph_info['osd_latency_data']['data']['data']:
            _table["<dev>"].append({
                "formatted": " cL/aL ",
                "commit_latency_ms": "Commit, ms",
                "apply_latency_ms": "Apply, ms",
                "commit_latency_ns": "Commit, ns",
                "apply_latency_ns": "Apply, ns"
            })
            for _f in _pfd['osdstats']['osd_perf_infos']:
                _n = "osd_{}".format(_f['id'])
                if _n not in _table:
                    _table[_n] = []
                _table[_n].append({
                    "formatted": "{:>3}/{:<3}".format(
                        _f['perf_stats']['commit_latency_ms'],
                        _f['perf_stats']['apply_latency_ms'],
                    ),
                    "commit_latency_ms": _f['perf_stats']['commit_latency_ms'],
                    "apply_latency_ms": _f['perf_stats']['apply_latency_ms'],
                    "commit_latency_ns": _f['perf_stats']['commit_latency_ns'],
                    "apply_latency_ns": _f['perf_stats']['apply_latency_ns']
                })
        self.ceph_info['osd_latency_data']['table'] = _table
        return _table

    def get_latest_health_readout(self):
        _h = self.ceph_info['ceph_health']['data']
        self.ceph_info['ceph_health']['latest'] = {}
        for _n, _d in _h.items():
            if not _d:
                self.ceph_info['ceph_health']['latest'][_n] = {}
                continue
            else:
                # TODO: Consider filtering out or prepare data for the table
                _osd = _d.pop("osd_name") if "osd_name" in _d else "unknown"
                _node_name = _d.pop("node_name") \
                    if "node_name" in _d else "unknown"
                # Additional check for empty data
                if not _d:
                    self.ceph_info['ceph_health']['latest'][_n] = {}
                    continue
                _date = sorted(_d.keys(), reverse=True)[0]
                self.ceph_info['ceph_health']['date'] = _date
                self.ceph_info['ceph_health']['latest'][_n] = _d[_date]
                self.ceph_info['ceph_health']['latest'][_n]["osd_name"] = _osd
                self.ceph_info['ceph_health']['latest'][_n]["node_name"] = \
                    _node_name

        return self.ceph_info['ceph_health']['latest']

    def print_summary(self):
        logger_cli.info("\n# Ceph Cluster summary")
        # Health status
        _h = self.ceph_info['health_detail']['data']
        logger_cli.info("Cluster status: {}".format(_h['status']))
        for _chk, _d in _h['checks'].items():
            logger_cli.info(
                "+ {}: {}\n\tSummary: {}".format(
                    _chk,
                    _d['severity'],
                    _d['summary']['message']
                )
            )
            logger_cli.info("\tDetails:")
            for _item in _d['detail']:
                logger_cli.info("\t  '{}".format(_item['message']))

        # OSD health metrics
        logger_cli.info("\n# Device health metrics:")
        _fmt = " {:45}  {:^14} {:^9} {:^6} {:^6}"
        logger_cli.info(
            _fmt.format(
                "Device Name",
                "Info",
                "Speed",
                "SMART",
                "Tempr."
            )
        )
        _latest = self.get_latest_health_readout()
        for _n, _d in _latest.items():
            if not _d:
                logger_cli.info("{:45} {:<10}".format(_n, "<empty>"))
                continue

            _status = _d['smart_status']['passed']
            if "interface_speed" in _d:
                _speed = _d['interface_speed']['current']['string']
            else:
                _speed = "-"

            _status = 'passed' if _status else 'failed'
            logger_cli.info(
                _fmt.format(
                    _n,
                    _d['device']['info_name'],
                    _speed,
                    _status,
                    _d['temperature']['current']
                )
            )

        # Latency table
        logger_cli.info(
            "\n# OSD Latency data ({} iterations, {} sec delay), "
            "table items 'osd_dev: N:cL/aL'\n"
            "  'Commit Latency' -> 'cL', 'Apply Latency' -> 'aL'\n".format(
                self.ceph_info['osd_latency_data']['data']['total'],
                self.ceph_info['osd_latency_data']['data']['delay']
            )
        )
        _strs = self.get_transposed_latency_table()
        for _osd, _list in _strs.items():
            _row = [c["formatted"] for c in _list]
            logger_cli.info(
                "  {:8}: {}".format(
                    _osd,
                    "  ".join(_row)
                )
            )
        logger_cli.info("\n")

        # critical config values
        # TODO: print/calculate config values

        return

    def dump_info(self):
        with open('cephdump.json', 'wt') as _f:
            _f.write(json.dumps(self.ceph_info, indent=2))

    def load_info(self):
        with open('cephdump.json', 'rt') as _f:
            self.ceph_info = json.load(_f)

    def generate_archive(self, tgzfilename):
        def _ensure_fname(ext):
            return key + ext if _fname is None else _fname

        if not self.ceph_info:
            logger_cli.warning(
                "WARNING: Ceph Info Data not detected. "
                "Consider check for errors in log."
            )
        else:
            # Create Archive
            logger_cli.info("-> Generating archive '{}'".format(tgzfilename))
            _tgz = TGZFile(
                tgzfilename,
                label="MCP Checker: Generated Ceph Information"
            )
            # Iterate every key and write data to tar file
            for key, d in self.ceph_info.items():
                _fname = None
                # Cast buf to a proper type
                _buf = None
                if "filename" in d:
                    _fname = d["filename"]
                if isinstance(d["data"], dict) or isinstance(d["data"], list):
                    _buf = json.dumps(d["data"], indent=2)
                    # _filename = key+".json" if _fname is not None else _fname
                    _filename = _ensure_fname(".json")
                elif isinstance(d["data"], str):
                    _buf = d["data"]
                    # _filename = key+".txt"
                    _filename = _ensure_fname(".txt")
                else:
                    _buf = str(d["data"])
                    # _filename = key+".txt"
                    _filename = _ensure_fname(".txt")
                logger_cli.debug("... writing '{}'".format(_filename))
                _tgz.add_file(_filename, buf=_buf, replace=True)

        return

    def create_html_report(self, filename):
        """
        Create static html showing ceph info report

        :return: none
        """
        logger_cli.info("### Generating report to '{}'".format(filename))
        _report = reporter.ReportToFile(
            reporter.HTMLCephInfo(self),
            filename
        )
        _report(
            {
                "info": self.ceph_info,
                "cluster": self.cluster_info,
                "nodes": self.nodes,
                "ceph_version": self.ceph_version,
            }
        )
        logger_cli.info("-> Done")

        return


class SaltCephInfo(CephInfo):
    def __init__(
        self,
        config
    ):
        logger_cli.warning("\nWARNING: Not impelented for Salt environment!\n")

        # self.master = SaltNodes(config)
        super(SaltCephInfo, self).__init__(config)
        return


class KubeCephInfo(CephInfo):
    ceph_ns = "rook-ceph"
    ceph_app_label = "rook-ceph-tools"
    ceph_group = "ceph.rook.io"
    ceph_apiversion = "v1"
    ceph_plural = "cephclusters"
    ceph_version = "unknown"

    def __init__(self, config):
        self.master = KubeNodes(config)
        super(KubeCephInfo, self).__init__(config)
        # Init ceph tools pod
        self.pod_name = self._get_tools_pod_name()
        self.ceph_info = {}
        self.cluster_info = {}
        self.ceph_version = self.get_ceph_cluster_config()

    def _safe_tools_cmd(self, cmd_str, expect_output=True):
        _r = self.master.exec_cmd_on_target_pod(
            self.pod_name,
            self.ceph_ns,
            cmd_str
        )
        if expect_output and not _r:
            logger.debug("... got empty output for '{}'".format(cmd_str))
        elif not expect_output and _r:
            logger.warning(
                "WARNING: Unexpected output for '{}':\n"
                "===== Start\n{}\n===== End".format(cmd_str, _r)
            )
        return _r

    def _safe_tools_cmd_zipped_output(self, cmd_str):
        # temp file
        _tmp_path = "/tmp"
        _filename = "checker_cmd_output"
        _tar_path = os.path.join(_tmp_path, "checker_cmd.tgz")
        _path = os.path.join(_tmp_path, _filename)

        # Run original cmd with redirect
        _cmd = [cmd_str, "-o", _path]
        self._safe_tools_cmd(" ".join(_cmd), expect_output=False)
        # zip it and base64 encode
        _cmd = ["tar", "-zcvf", _tar_path, _path]
        self._safe_tools_cmd(" ".join(_cmd))
        _b64 = self._safe_tools_cmd("base64 " + _tar_path)
        # decode and decompress
        _io = io.BytesIO(base64.standard_b64decode(_b64))
        _json = ""
        with tarfile.open(fileobj=_io) as _tar:
            _tar_item = _tar.extractfile(_tar.getmembers()[0])
            _json = _tar_item.read()
        # cleanup
        self._safe_tools_cmd("rm -f " + _path)
        self._safe_tools_cmd("rm -f " + _tar_path)
        return _json

    @staticmethod
    def _as_json(buf):
        try:
            return json.loads(buf)
        except ValueError as e:
            _out = ""
            if len(buf) > 512:
                _out = buf[:512]
                _out += "..."
            else:
                _out = buf
            logger_cli.error(
                "\nERROR: failed to parse json: '{}'. Data: '{}'".format(
                    e,
                    _out
                )
            )
            return buf

    def _safe_get_cmd_output_as_json(self, cmd, zipped=False):
        if zipped:
            _buf = self._safe_tools_cmd_zipped_output(cmd)
        else:
            _buf = self._safe_tools_cmd(cmd)
        return self._as_json(_buf)

    def _get_tools_pod_name(self):
        # get ceph pod
        _pods = self.master.kube.get_pods_by_partial_name(
            self.ceph_app_label,
            self.ceph_ns
        )
        # _names = self.master.kube.get_pod_names_by_partial_name(
        #     self.ceph_app_label,
        #     self.ceph_ns
        # )
        if not _pods:
            raise KubeException(
                "Failed to find pod using '{}'".format(self.ceph_app_label)
            )
        elif len(_pods) > 1:
            logger_cli.warning(
                "WARNING: Environment has more than one pod "
                "with '{}' app: {}".format(
                    self.ceph_app_label,
                    ", ".join([p.metadata.name for p in _pods])
                )
            )
        else:
            logger_cli.debug("... found '{}'".format(_pods[0].metadata.name))
        self.ceph_pod = _pods[0]
        return _pods[0].metadata.name

    def _add_ceph_info_item(self, key, title, data, filename=None):
        # handle data
        if key in self.ceph_info:
            self.ceph_info[key]["title"] = title
            self.ceph_info[key]["data"] = data
        else:
            self.ceph_info[key] = {
                "title": title,
                "data": data
            }
        if filename:
            self.ceph_info[key]["filename"] = filename

    def _parse_dev_classes(self, deviceClasses):
        _devClasses = []
        for _i in deviceClasses:
            _devClasses += list(_i.values())
        return set(_devClasses)

    def get_ceph_cluster_config(self):
        # get cephclusters resource
        logger_cli.info("# Loading '{}' object of type '{}/{}'".format(
            self.ceph_plural,
            self.ceph_group,
            self.ceph_apiversion
        ))
        _r = self.master.kube.get_custom_resource(
            self.ceph_group,
            self.ceph_apiversion,
            self.ceph_plural,
        )
        # find cluster
        _cluster = None
        if len(_r['items']) < 1:
            logger_cli.warning(
                "WARNING: Failed to find '{}' ({}/{})".format(
                    self.ceph_plural,
                    self.ceph_group,
                    self.ceph_apiversion
                )
            )
            return 'uknown'
        elif len(_r['items']) > 1:
            logger_cli.warning(
                "WARNING: Multiple clusters found '{}' ({}/{})".format(
                    self.ceph_plural,
                    self.ceph_group,
                    self.ceph_apiversion
                )
            )
        _cluster = _r['items'][0]
        _s = _cluster['status']
        self.cluster_info.update({
            'image': _s['version']['image'],
            'version': _s['version']['version'],
            'device_classes': self._parse_dev_classes(
                _s['storage'].get('deviceClasses', [])
            ),
            'phase': _s['phase'],
            'state': _s['state'],
            'health': _s['ceph'].get('health', {}),
            'previousHealth': _s['ceph'].get('previousHealth', {}),
            'lastChanged': _s['ceph'].get('lastChanged', ""),
            'lastChecked': _s['ceph'].get('lastChecked', ""),
            'mon_count': _cluster['spec']['mon']['count']
        })
        self.nodes = _cluster['spec']['storage']['nodes'],
        logger_cli.info("-> Found Ceph cluster: {} ({})".format(
            self.cluster_info['version'],
            self.cluster_info['image']
        ))
        return self.cluster_info['version']

    def get_cluster_status(self):
        return self._safe_get_cmd_output_as_json("ceph -s -f json")

    def get_health_detail(self):
        return self._safe_get_cmd_output_as_json("ceph -f json health detail")

    def get_ceph_df(self):
        return self._safe_get_cmd_output_as_json("ceph df -f json")

    def get_ceph_pg_dump(self):
        return self._safe_get_cmd_output_as_json(
            "ceph pg dump -f json",
            zipped=True
        )

    def get_ceph_osd_df(self):
        return self._safe_get_cmd_output_as_json("ceph osd df -f json")

    def gather_info(self):
        logger_cli.info("# Gathering Ceph cluster info")
        # Collect info
        _c = self._safe_tools_cmd
        _cj = self._safe_get_cmd_output_as_json
        # Crush Map
        logger_cli.info("-> Collecting CRUSH map")
        _cmap_tmp_path = "/tmp/crushmap.bin"
        _r = _c(
            "ceph osd getcrushmap -o " + _cmap_tmp_path,
            expect_output=False
        )
        # TODO: Handle errors in _r
        logger_cli.debug("... 'getcrushmap' return value is: '{}'".format(_r))

        # Get Crush map as json and text
        self._add_ceph_info_item(
            "crushmap_json",
            "Crush Map (json)",
            _cj("crushtool -i " + _cmap_tmp_path + " --dump"),
            filename="crushmap.json"
        )
        # _crushmap = _cj("crushtool -i " + _cmap_tmp_path + " --dump")
        self._add_ceph_info_item(
            "crushmap_text",
            "Crush Map (text)",
            _c("crushtool -d " + _cmap_tmp_path),
            filename="crushmap.json"
        )

        logger_cli.info("-> Collecting ceph osd crush dump")
        self._add_ceph_info_item(
            "osd_crushdump",
            "Crush dump (osd)",
            _cj("ceph osd crush dump")
        )

        logger_cli.info("-> Collecting cluster status")
        self._add_ceph_info_item(
            "cluster_status",
            "Cluster status",
            self.get_cluster_status()
        )

        logger_cli.info("-> Collecting health detail")
        self._add_ceph_info_item(
            "health_detail",
            "Health details",
            self.get_health_detail()
        )

        logger_cli.info("-> Collecting monmap")
        self._add_ceph_info_item(
            "monmap",
            "Ceph Mon map",
            _cj("ceph mon dump -f json")
        )

        logger_cli.info("-> Collecting ceph df")
        self._add_ceph_info_item(
            "ceph_df",
            "Ceph DF",
            self.get_ceph_df()
        )

        logger_cli.info("-> Collecting ceph osd df")
        self._add_ceph_info_item(
            "ceph_osd_df",
            "Ceph OSD DF",
            self.get_ceph_osd_df()
        )

        logger_cli.info("-> Collecting ceph osd dump")
        self._add_ceph_info_item(
            "ceph_osd_dump",
            "Ceph OSD dump",
            _cj("ceph osd dump -f json")
        )

        logger_cli.info("-> Collecting rados df")
        self._add_ceph_info_item(
            "rados_df",
            "Rados DF",
            _cj("rados df -f json")
        )

        logger_cli.info("-> Collecting ceph report")
        self._add_ceph_info_item(
            "ceph_report",
            "Ceph Report",
            _cj("ceph report")
        )

        logger_cli.info("-> Collecting auth data anonymized")
        _auth_data = _cj("ceph auth list -f json")
        # Anonymize data
        # _cj("ceph auth list -f json | sed 's/AQ[^=]*==/KEY/g'")
        for item in _auth_data["auth_dump"]:
            if "key" in item:
                item['key'] = "key-data-redacted"
        self._add_ceph_info_item(
            "ceph_auth_ls",
            "Ceph Auth Data (anonymized)",
            _auth_data
        )

        logger_cli.info("-> Collecting ceph pg dump")
        self._add_ceph_info_item(
            "ceph_pg_dump",
            "Ceph PG dump",
            self.get_ceph_pg_dump()
        )

        logger_cli.info("-> Collecting ceph running configuration")
        self._add_ceph_info_item(
            "ceph_config_dump",
            "Ceph Configuration Dump",
            _cj("ceph config dump -f json")
        )

        logger_cli.info("-> Collecting health metrics")
        _health_metrics = {}
        _devices = _c("ceph device ls")
        _devices = _devices.splitlines()
        cmd_list = []
        for device in _devices:
            _t = device.split()
            _dev = _t[0]
            _node = _t[1] if len(_t) > 1 else "unknown"
            _osd = _t[2] if len(_t) > 2 else "unknown"

            if _dev == "DEVICE":
                continue
            # _metric = _cj("ceph device get-health-metrics {}".format(_dev))
            _cmd = "ceph device get-health-metrics {}".format(_dev)
            cmd_list.append(_cmd)
            _dev_name = "{}_{}".format(_osd, _dev)
            _health_metrics[_dev_name] = {}
            _health_metrics[_dev_name]['node_name'] = _node
            _health_metrics[_dev_name]['osd_name'] = _osd
            _health_metrics[_dev_name]['cmd'] = _cmd

        results = self.master.exec_cmds_on_pod(
            self.ceph_pod,
            cmd_list
        )

        logger_cli.info("-> Processing results")
        for _r in results:
            _cmd = _r[3]
            _j = self._as_json(_r[2])
            for _dev_name in _health_metrics.keys():
                if "cmd" in _health_metrics[_dev_name] and \
                  _health_metrics[_dev_name]["cmd"] == _cmd:
                    _health_metrics[_dev_name].update(_j)
                    _health_metrics[_dev_name].pop("cmd")
                    break

        self._add_ceph_info_item(
            "ceph_health",
            "Ceph Health Metrics",
            _health_metrics
        )

        # Latency values
        # config const for set
        _latency_count = 10
        _latency_delay = 4
        logger_cli.info(
            "-> Collecting ceph osd latency data "
            "({} total, {} sec delay)".format(
                _latency_count,
                _latency_delay
            )
        )
        _osd_lat = {
            "total": _latency_count,
            "delay": _latency_delay,
            "data": []
        }
        _progress = Progress(_latency_count)
        _index = 1
        while _index <= _latency_count:
            _progress.write_progress(_index)
            _osd_lat["data"].append(_cj("ceph osd perf -f json"))
            sleep(_latency_delay)
            _index += 1
        _progress.end()
        self._add_ceph_info_item(
            "osd_latency_data",
            "OSD Latency metrics",
            _osd_lat
        )

        return

    def gather_osd_configs(self):
        _total_osd = len(self.ceph_info["ceph_osd_df"]["data"]["nodes"])
        logger_cli.info(
            "-> Gathering OSD configuration ({})".format(_total_osd)
        )
        cmds = {}
        cmd_list = []
        for _osd in self.ceph_info["ceph_osd_df"]["data"]["nodes"]:
            _cmd = "ceph config show-with-defaults -f json {}".format(
                _osd["name"]
            )
            cmd_list.append(_cmd)
            cmds[_osd["name"]] = _cmd

        results = self.master.exec_cmds_on_pod(
            self.ceph_pod,
            cmd_list
        )

        logger_cli.info("-> Processing results")
        _cfgs = {}
        for _r in results:
            _cmd = _r[3]
            _j = self._as_json(_r[2])
            for _osd_name in cmds.keys():
                if cmds[_osd_name] == _cmd:
                    _cfgs[_osd_name] = _j
                    break

        # Process configs
        _base = {}
        _uniq = {}
        logger_cli.info("-> Filtering config values")
        _progress = Progress(_total_osd)
        _idx = 1
        for _osd, _data in _cfgs.items():
            _progress.write_progress(_idx, note=_osd)
            for _o in _data:
                _name = _o.pop("name")
                if not _o["value"]:
                    _o["value"] = "-"
                if _name not in _base:
                    _base[_name] = _o
                elif _base[_name]["value"] != _o["value"]:
                    _progress.clearline()
                    logger_cli.info(
                        "...specific value for {} (src: '{}'): {}={}".format(
                            _osd,
                            _o["source"],
                            _name,
                            _o["value"]
                        )
                    )
                    _uniq[_osd] = {
                        _name: _o
                    }
            _idx += 1
        _progress.end()
        self._add_ceph_info_item(
            "osd_config_data",
            "OSD Configuration values",
            {
                "common": _base,
                "uniq": _uniq
            }
        )
        return
