#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import base64
import copy
import operator as py_operator
import os
from contextlib import contextmanager
from io import StringIO

import exec_helpers
import jinja2
import paramiko
import yaml
from jinja2_ansible_filters import AnsibleCoreFiltersExtension

from si_tests import logger
from si_tests import settings, settings_func
from si_tests.utils import packaging_version as ver

LOG = logger.logger


class YamlEditor(object):
    """Manipulations with local or remote .yaml files.

    Usage:

    with YamlEditor("tasks.yaml") as editor:
        editor.content[key] = "value"

    with YamlEditor("astute.yaml", ip=self.admin_ip) as editor:
        editor.content[key] = "value"
    """

    def __init__(self, file_path, host=None, port=None,
                 username=None, password=None, private_keys=None, remote=None,
                 document_id=0,
                 default_flow_style=False, default_style=None):
        self.__file_path = file_path
        self.remote = remote
        self.host = host
        self.port = port or 22
        self.username = username
        self.__password = password
        self.__private_keys = private_keys or []
        self.__content = None
        self.__documents = [{}, ]
        self.__document_id = document_id
        self.__original_content = None
        self.default_flow_style = default_flow_style
        self.default_style = default_style

    @property
    def file_path(self):
        """Open file path

        :rtype: str
        """
        return self.__file_path

    @property
    def content(self):
        if self.__content is None:
            self.__content = self.get_content()
        return self.__content

    @content.setter
    def content(self, new_content):
        self.__content = new_content

    @contextmanager
    def open(self, mode="r"):
        file = None
        try:
            if self.remote:
                file = self.remote.open(self.__file_path, mode=mode)
            elif self.host:
                keys = map(paramiko.RSAKey.from_private_key,
                           map(StringIO, self.__private_keys))

                remote = exec_helpers.SSHClient(
                    host=self.host,
                    port=self.port,
                    auth=exec_helpers.SSHAuth(
                        username=self.username,
                        password=self.__password,
                        keys=list(keys)
                    )
                )

                file = remote.open(self.__file_path, mode=mode)
            else:
                file = open(self.__file_path, mode=mode)

            yield file
        finally:
            if file:
                file.close()

    def get_content(self):
        """Return a single document from YAML"""

        def multi_constructor(loader, tag_suffix, node):
            """Stores all unknown tags content into a dict

            Original yaml:
            !unknown_tag
            - some content

            Python object:
            {"!unknown_tag": ["some content", ]}
            """
            if type(node.value) is list:
                if type(node.value[0]) is tuple:
                    return {node.tag: loader.construct_mapping(node)}
                else:
                    return {node.tag: loader.construct_sequence(node)}
            else:
                return {node.tag: loader.construct_scalar(node)}

        yaml.add_multi_constructor("!", multi_constructor)
        with self.open('a+') as file_obj:
            file_obj.seek(0)
            self.__documents = [x for x in yaml.load_all(
                file_obj, Loader=yaml.SafeLoader)] or [{}, ]
            return self.__documents[self.__document_id]

    def write_content(self, content=None):
        if content:
            self.content = content
        self.__documents[self.__document_id] = self.content

        def representer(dumper, data):
            """Represents a dict key started with '!' as a YAML tag

            Assumes that there is only one !tag in the dict at the
            current indent.

            Python object:
            {"!unknown_tag": ["some content", ]}

            Resulting yaml:
            !unknown_tag
            - some content
            """
            key = data.keys()[0]
            if key.startswith("!"):
                value = data[key]
                if type(value) is dict:
                    node = dumper.represent_mapping(key, value)
                elif type(value) is list:
                    node = dumper.represent_sequence(key, value)
                else:
                    node = dumper.represent_scalar(key, value)
            else:
                node = dumper.represent_mapping(u'tag:yaml.org,2002:map', data)
            return node

        # FIXME: need to debug, why it not work with latest PyYaml ?
        # yaml.add_representer(dict, representer)
        with self.open('w') as file_obj:
            yaml.dump_all(self.__documents, file_obj,
                          default_flow_style=self.default_flow_style,
                          default_style=self.default_style)

    def __enter__(self):
        self.__content = self.get_content()
        self.__original_content = copy.deepcopy(self.content)
        return self

    def __exit__(self, x, y, z):
        if self.content == self.__original_content:
            return
        self.write_content()


class TemplateFile:
    def __init__(self, templates_dir, path):
        self.templates_dir = templates_dir
        self.path = os.path.normpath(path)
        self.override = None

        override_dir = self.templates_dir.override_path
        if override_dir:
            override_file = os.path.join(override_dir, self.path)
            if os.path.exists(os.path.join(self.templates_dir.path,
                                           override_file)):
                LOG.info("Override file exists ({})".format(override_file))
                self.override = override_file
            else:
                LOG.info("Override file does not exist ({})"
                         .format(override_file))

    def get_template_file(self):
        if self.override:
            LOG.info("Using override file '{}' as '{}'".format(self.override,
                                                               self.path))
            return os.path.join(self.templates_dir.path, self.override)
        else:
            LOG.info("Using file '{}'".format(self.path))
            return os.path.join(self.templates_dir.path, self.path)

    def render(self, **kwargs):
        """
        Render the template with respect to `.overrides` directory.

        :return: Rendered content of a template
        """
        return render_template(self.get_template_file(), **kwargs)


class TemplatesDir(list):
    def __init__(self, path, env_name=None, test_mode=False):
        super(TemplatesDir, self).__init__()
        self.path = path
        self.env_name = env_name
        self._override_path = None

        base_files = list(self.file_list(self.path))

        override_files = []
        if test_mode:
            LOG.info("Skipping overrides dir logic in test mode")
        elif self.env_name:
            path = os.path.join('.override', self.env_name)
            abspath = os.path.join(self.path, path)
            if os.path.exists(abspath):
                LOG.info("Overrides dir exists ({})".format(path))
                self._override_path = path
                override_files = list(self.file_list(abspath))
            else:
                LOG.info("Overrides dir does not exist ({})".format(path))

        LOG.info("Loading templates from '{}', env name '{}'"
                 .format(self.path, self.env_name))
        LOG.info("Base files:\n{}".format('\n'.join(base_files)))
        LOG.info("Override files:\n{}".format('\n'.join(override_files)))

        for file_path in sorted(set(base_files + override_files)):
            self.append(TemplateFile(templates_dir=self, path=file_path))

        LOG.info("{} template(s) found".format(len(self)))

    def file_list(self, path):
        for dirpath, dirs, files in os.walk(path):
            if '.override' in dirs:
                dirs.remove('.override')

            relpath = os.path.relpath(dirpath, path)
            for filename in files:
                if filename == '.defaults':
                    continue

                yield os.path.join(relpath, filename)

    @property
    def override_path(self):
        return self._override_path

    @property
    def filenames(self):
        return sorted([x.path for x in self])


def render_template(file_path, options=None, log_env_vars=True,
                    log_template=True, extra_env_vars=None, jinja2_kwargs=None):
    """Render a Jinja2 template file

    Extra options:
      {{ os_env(SOME_ENV_NAME, "default_value") }} : Get environment variable
      {{ os_env_bool(SOME_ENV_NAME, "1") }} : Get bool environment variable
      {{ feature_flags.enabled("aflag") }}

    Extra filters:
      {{ "aaa/bbb/ccc" | basename }} : get basename of the path
      {{ "aaa/bbb/ccc" | dirname }} : get dirname of the path

    :param log_env_vars:
    :param log_template:
    :param extra_env_vars: dict, extra k-v set of variables, analog global os_env
    :param jinja2_kwargs: dict, extra k-v set of variables to pass as arguments for jinja2.Environment()
    :param file_path: str, path to the jinja2 template
    :param options: dict, extra objects to use in Jinja code blocks
    :log_env_vars: bool, log the environment variables used in the
                   template
    """
    required_env_vars = set()
    optional_env_vars = dict()
    if not extra_env_vars:
        extra_env_vars = dict()
    if not jinja2_kwargs:
        jinja2_kwargs = dict()

    if log_template:
        _LOG = LOG.info
    else:
        _LOG = LOG.debug

    def extra_env(var_name, default=None, env_type=None):
        return os_env(var_name, default=default, env_type=env_type, env_source='extra_env_vars')

    def os_env(var_name: str, default=None, env_type=None, env_source='os'):
        """
        :param env_source: enum: os, env_config
                os -  means os.environ.get()
                env_config - means from dict, extra_env_vars
        :type default: object
        :type var_name: object
        """
        if env_source == 'os':
            requested_var = os.environ.get(var_name) or default
        elif env_source == 'extra_env_vars':
            requested_var = extra_env_vars.get(var_name, default)
        else:
            raise Exception('Wrong invocation')

        if requested_var is None:
            raise Exception("Environment variable '{0}' is undefined!"
                            .format(var_name))

        if default is None:
            required_env_vars.add(var_name)
        else:
            optional_env_vars[var_name] = requested_var

        if env_type == "bool":
            requested_var = settings_func.get_var_as_bool(var_name, default)

        return requested_var

    def os_env_bool(var_name, default=None):
        return os_env(var_name, default, env_type='bool')

    def basename(path):
        return os.path.basename(path)

    def dirname(path):
        return os.path.dirname(path)

    def base64_encode(value):
        return base64.b64encode(value.encode('utf-8')).decode('utf-8')

    def base64_decode(value):
        return base64.b64decode(value.encode('utf-8')).decode('utf-8')

    def version_compare(value, version, operator='eq'):
        """
         Perform a version comparison on a value
         Partial-copy-paste from Ansible 2.10.7
        """
        op_map = {
            '==': 'eq', '=': 'eq', 'eq': 'eq',
            '<': 'lt', 'lt': 'lt',
            '<=': 'le', 'le': 'le',
            '>': 'gt', 'gt': 'gt',
            '>=': 'ge', 'ge': 'ge',
            '!=': 'ne', '<>': 'ne', 'ne': 'ne'
        }
        if operator in op_map:
            operator = op_map[operator]
        else:
            raise Exception('Invalid operator type')

        try:
            method = getattr(py_operator, operator)
            return method(ver.parse(str(value)), ver.parse(str(version)))
        except Exception as e:
            raise Exception('Version comparison: %s' % e)

    render_options = {
        'os_env': os_env,
        'os_env_bool': os_env_bool,
        'extra_env': extra_env,
        'settings': settings,
    }
    if options:
        render_options.update(options)

    _LOG(f"Reading template file '{file_path}'")

    path, filename = os.path.split(file_path)
    environment = jinja2.Environment(
        extensions=[AnsibleCoreFiltersExtension],
        loader=jinja2.FileSystemLoader([path, os.path.dirname(path)],
                                       followlinks=True),
        **jinja2_kwargs)
    environment.filters['basename'] = basename
    environment.filters['dirname'] = dirname
    environment.filters['version_compare'] = version_compare
    environment.filters['base64_encode'] = base64_encode
    environment.filters['base64_decode'] = base64_decode
    LOG.debug("Attempt to render template with ops:\n{0}".format(
        yaml.dump(render_options, indent=4)))

    template = environment.get_template(filename).render(render_options)

    if required_env_vars and log_env_vars:
        LOG.info("Required environment variables:")
        for var in required_env_vars:
            LOG.info("    {0}".format(var))
    if optional_env_vars and log_env_vars:
        LOG.info("Optional environment variables:")
        for var, default in sorted(optional_env_vars.items()):
            LOG.info("    {0} , value = {1}".format(var, default))
    return template


class NoAliasDumper(yaml.SafeDumper):
    def ignore_aliases(self, data):
        return True
