#    Copyright 2022 Mirantis, Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import base64
import copy
from datetime import datetime
from socket import socket

import idna
import pytest
from OpenSSL import SSL
from OpenSSL.SSL import SysCallError, Error
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from retry import retry

from si_tests import logger
from si_tests import settings
from si_tests.utils import waiters

LOG = logger.logger


# we may get an SSL error during renewing certificate inside the service,
# so we should retry before getting the updated data
@retry((SysCallError, Error, ConnectionError), delay=60, tries=3, logger=LOG)
def get_certificate_from_host(hostname, port=443):
    hostname_idna = idna.encode(hostname)
    sock = socket()

    sock.connect((hostname, port))
    ctx = SSL.Context(SSL.SSLv23_METHOD)
    ctx.check_hostname = False
    ctx.verify_mode = SSL.VERIFY_NONE

    sock_ssl = SSL.Connection(ctx, sock)
    sock_ssl.set_connect_state()
    sock_ssl.set_tlsext_host_name(hostname_idna)
    sock_ssl.do_handshake()
    cert = sock_ssl.get_peer_certificate()
    crypto_cert = cert.to_cryptography()
    sock_ssl.close()
    sock.close()

    return crypto_cert


def get_certificate_from_data(data):
    byte_data = data if isinstance(data, (bytes, bytearray)) else data.encode()
    return x509.load_pem_x509_certificate(byte_data, default_backend())


def get_certificate_from_internal_host(cluster, host, port='443'):
    """
    Some svc eg admission-controller does not have ingress with external host or IP
    On this case perform command inside internal pod (internal network)
    Returns:

    """
    cmd = f"true | openssl s_client -connect {host}:{port} 2>/dev/null | openssl x509"
    machine = cluster.get_machines()[0]
    cmd_result = machine.exec_pod_cmd(cmd)
    return get_certificate_from_data(cmd_result['logs'])


def check_cluster_certificate(kaas_manager, target_cluster, show_step, test_service):
    service_name = test_service['service_name']
    namespace = test_service['namespace']
    svc_name = test_service['svc_name']
    secret_name = f"{test_service['service_name']}-tls-certs"
    secret_data_key = 'tls.crt'

    def _get_secret(target, name=secret_name, ns=namespace):
        return target.k8sclient.secrets.get(name=name, namespace=ns)

    svc = target_cluster.k8sclient.services.get(name=svc_name, namespace=namespace)
    service_host = svc.get_external_ip()
    https_service_port = [s for s in svc.get_ports() if s.name == 'https']
    service_port = https_service_port[0].port if https_service_port else svc.get_ports()[0].port

    show_step(1)
    if service_name == 'admission':
        internal_ip = svc.get_ip()
        service_host = 'admission-controller.kaas.svc'
        LOG.info(f"Start certificate reinitialization for {svc_name} service")
        LOG.info(f"Get current certificate for {internal_ip}:{service_port}")
        svc_old_certificate = get_certificate_from_internal_host(target_cluster, internal_ip, service_port)
    else:
        LOG.info(f"Start certificate reinitialization for {svc_name} service")
        LOG.info(f"Get current certificate for {service_host}:{service_port}")
        svc_old_certificate = get_certificate_from_host(service_host, port=service_port)

    LOG.info(f"Get current certificate secret {secret_name}")
    old_secret = _get_secret(target_cluster, name=secret_name, ns=namespace)
    old_secret_data = base64.b64decode(old_secret.read().to_dict()['data'][secret_data_key]).decode('utf-8')
    LOG.info("Get current CertificateConfiguration")
    certificate_configuration = target_cluster.get_certificate_configuration(name=service_name, namespace=namespace)
    certificate_configuration_revision = certificate_configuration.data['status']['revision']

    show_step(2)
    current_time = datetime.utcnow()
    LOG.info(f"Delete current certificate secret {secret_name}")
    old_secret.delete()

    certificate_request_name_prefix = f"{target_cluster.namespace}-{target_cluster.name}-{service_name}"
    LOG.info(f"CertificateRequests prefix name: {certificate_request_name_prefix}")

    show_step(3)
    LOG.info("Wait for the CertificateRequests to be created")
    waiters.wait(
        lambda: kaas_manager.get_certificate_requests(name_prefix=certificate_request_name_prefix, namespace='kaas'),
        timeout=900, interval=30,
        timeout_msg="Timeout for waiting CertificateRequest for cluster")
    certificate_request = kaas_manager.get_certificate_requests(name_prefix=certificate_request_name_prefix,
                                                                namespace='kaas')
    certificate_request_data = copy.deepcopy(certificate_request[0].data)
    LOG.info(f"Found CertificateRequests '{certificate_request_data['metadata']['name']}'")
    certificate_request_ca = base64.b64decode(certificate_request_data['status']['ca']).decode('utf-8')
    certificate_request_certificate = base64.b64decode(certificate_request_data['status']['certificate']) \
        .decode('utf-8')

    LOG.info("Wait for the CertificateRequests to be deleted")
    waiters.wait(lambda: not kaas_manager.get_certificate_requests(name_prefix=certificate_request_name_prefix,
                                                                   namespace='kaas'),
                 timeout=900, interval=30,
                 timeout_msg="Timeout for waiting CertificateRequest removing for cluster")

    show_step(4)
    mcc_service_name = f"{service_name}-{certificate_configuration_revision + 1}"
    waiters.wait(
        lambda: target_cluster.get_mcc_certificate_requests(name_prefix=mcc_service_name, namespace=namespace),
        timeout=300, interval=30,
        timeout_msg="Timeout for waiting MCCCertificateRequest for cluster")
    mcc_certificate_request = target_cluster.get_mcc_certificate_request(name=mcc_service_name,
                                                                         namespace=namespace)
    LOG.info(f"Select MCCCertificateRequest '{mcc_certificate_request.name}'")

    mcc_certificate_request_certificate = base64.b64decode(mcc_certificate_request.data['status']['certificate']) \
        .decode('utf-8')

    show_step(5)
    LOG.info("Wait for a new secret to be created and updated data")
    waiters.wait(lambda: _get_secret(target_cluster, name=secret_name, ns=namespace).read().to_dict().get('data').get(
        secret_data_key),
        timeout_msg=f'Secret {secret_name} not present')
    LOG.info(f"Found secret '{secret_name}'")
    new_secret = _get_secret(target_cluster, name=secret_name, ns=namespace)
    new_secret_data = base64.b64decode(new_secret.read().to_dict()['data'][secret_data_key]).decode('utf-8')
    certificate_secret_data = get_certificate_from_data(new_secret_data)
    certificate_secret_data_not_after = certificate_secret_data.not_valid_after
    certificate_secret_data_not_before = certificate_secret_data.not_valid_before
    certificate_secret_data_cn = certificate_secret_data.subject.rfc4514_string().split('=')[-1]

    show_step(6)
    waiters.wait(
        lambda: target_cluster.get_certificate_configuration(name=service_name, namespace=namespace).data['status'][
            'certificate'] == new_secret_data,
        timeout=300, interval=30,
        timeout_msg='Timeout for waiting update certificate in CertificateConfiguration')
    certificate_configuration = target_cluster.get_certificate_configuration(name=service_name, namespace=namespace)
    LOG.info(f"Found CertificateConfiguration '{certificate_configuration.name}'")
    certificate_configuration_certificate = certificate_configuration.data['status']['certificate']
    certificate_configuration_hostname = certificate_configuration.data['status']['hostname']
    certificate_configuration_not_after = certificate_configuration.data['status']['notAfter']
    certificate_configuration_not_before = certificate_configuration.data['status']['notBefore']

    show_step(7)
    # compare secrets
    assert old_secret_data != certificate_secret_data, "Secrets before and after the update should not be the same"
    assert mcc_certificate_request_certificate == certificate_configuration_certificate, \
        "Certificate data in MCCCertificateRequest and CertificateConfiguration should be the same"
    assert certificate_configuration_certificate == certificate_request_certificate + certificate_request_ca, \
        "Certificate data in CertificateConfiguration and CertificateRequests should be the same"
    assert certificate_configuration_certificate == new_secret_data, \
        "Certificate data in CertificateConfiguration and secret should be the same"
    LOG.info("All secrets updated and identical")

    show_step(8)
    # compare expiration
    assert certificate_secret_data_not_after == datetime.fromisoformat(certificate_configuration_not_after[:-1]), \
        "Certificate expiration 'after' time in secret and CertificateConfiguration should be the same"
    assert certificate_secret_data_not_before == datetime.fromisoformat(certificate_configuration_not_before[:-1]), \
        "Certificate expiration 'before' time in secret and CertificateConfiguration should be the same"
    LOG.info("All secrets contains correct expiration date")

    # compare cert valid time
    assert certificate_secret_data_not_before > current_time, \
        "Certificate 'before' time should be later than the reinitialization process started"
    LOG.info("New certificate contains correct start date after reinitialization")

    show_step(9)
    # compare hostname
    assert certificate_secret_data_cn == service_host, "Certificate CN in secret and svc host should be the same"
    assert certificate_configuration_hostname == certificate_secret_data_cn, \
        "Certificate CN in CertificateConfiguration and secret should be the same"
    LOG.info("Hostname is correct")

    show_step(10)
    # check cluster expiration
    LOG.info(f"Get cluster TLS status for {service_name} service")
    iso_format = certificate_secret_data_not_after.strftime("%Y-%m-%dT%H:%M:%SZ")
    waiters.wait(
        lambda: target_cluster.get_tls_status(service_name).get('expirationTime') == iso_format,
        timeout=300, interval=30,
        timeout_msg='Timeout for waiting update expiration time in cluster data')
    cluster_tls_status = target_cluster.get_tls_status(service_name)
    LOG.info("Cluster TLS status contains correct expiration date")
    assert certificate_secret_data_cn == cluster_tls_status['hostname'], \
        "Certificate CN in secret and cluster status should be the same"
    LOG.info("Cluster TLS status contains correct hostname")

    show_step(11)
    # check applied certificate
    if service_name == 'admission':
        waiters.wait(
            lambda: get_certificate_from_internal_host(target_cluster, internal_ip,
                                                       port=service_port) != svc_old_certificate,
            interval=10, timeout=300,
            timeout_msg="SVC applied certificate before and after reinitialization should not be the same")

        waiters.wait(lambda: get_certificate_from_internal_host(target_cluster, internal_ip,
                                                                port=service_port) == certificate_secret_data,
                     interval=10, timeout=300,
                     timeout_msg="SVC applied certificate and stored in secret should be the same")
        LOG.info(f"Certificate for service {svc_name} applied correctly")
    else:
        waiters.wait(lambda: get_certificate_from_host(service_host, port=service_port) != svc_old_certificate,
                     interval=10, timeout=300,
                     timeout_msg="SVC applied certificate before and after reinitialization should not be the same")

        waiters.wait(lambda: get_certificate_from_host(service_host, port=service_port) == certificate_secret_data,
                     interval=10, timeout=300,
                     timeout_msg="SVC applied certificate and stored in secret should be the same")
        LOG.info(f"Certificate for service {svc_name} applied correctly")


@pytest.mark.parametrize("test_service", [
    {'service_name': 'keycloak', 'namespace': 'kaas', 'svc_name': 'iam-keycloak-http'},
    {'service_name': 'ui', 'namespace': 'kaas', 'svc_name': 'kaas-kaas-ui'},
    {'service_name': 'cache', 'namespace': 'kaas', 'svc_name': 'mcc-cache'},
    {'service_name': 'admission', 'namespace': 'kaas', 'svc_name': 'admission-controller'},
    {'service_name': 'telemeter-server', 'namespace': 'stacklight', 'svc_name': 'telemeter-server-external'},
    {'service_name': 'iam-proxy-alerta', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-alerta'},
    {'service_name': 'iam-proxy-alertmanager', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-alertmanager'},
    {'service_name': 'iam-proxy-grafana', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-grafana'},
    {'service_name': 'iam-proxy-kibana', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-kibana'},
    {'service_name': 'iam-proxy-prometheus', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-prometheus'}
], ids=lambda name: name['svc_name'])
@pytest.mark.usefixtures('log_method_time')
def test_auto_renewal_mgmt_certificate(kaas_manager, show_step, test_service):
    """Update license test for mgmt cluster.

    Scenario:
        1. Get current tls secret data
        2. Delete current tls secret
        3. Get CertificateRequests data
        4. Get MCCCertificateRequest data
        5. Get new tls secret data
        6. Get CertificateConfiguration data
        7. Verify secret data
        8. Verify expiration date
        9. Verify hostname
        10. Verify cluster expiration
        11. Verify applied certificate


    """
    managed_ns = kaas_manager.get_namespace(settings.TARGET_NAMESPACE)
    cluster = managed_ns.get_cluster(settings.TARGET_CLUSTER)
    LOG.info(f"Start certificate verification for MGMT cluster: {cluster.name}")
    check_cluster_certificate(kaas_manager, cluster, show_step, test_service)


@pytest.mark.parametrize("test_service", [
    {'service_name': 'iam-proxy-alerta', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-alerta'},
    {'service_name': 'iam-proxy-alertmanager', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-alertmanager'},
    {'service_name': 'iam-proxy-grafana', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-grafana'},
    {'service_name': 'iam-proxy-prometheus', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-prometheus'}
], ids=lambda name: name['svc_name'])
def test_auto_renewal_child_certificate(kaas_manager, show_step, test_service):
    """Update license test for child cluster.

    Scenario:
        1. Get current tls secret data
        2. Delete current tls secret
        3. Get CertificateRequests data
        4. Get MCCCertificateRequest data
        5. Get new tls secret data
        6. Get CertificateConfiguration data
        7. Verify secret data
        8. Verify expiration date
        9. Verify hostname
        10. Verify cluster expiration
        11. Verify applied certificate


    """
    child_clusters = kaas_manager.get_child_clusters()

    for child_cluster in child_clusters:
        LOG.info(f"Start certificate verification for child cluster: {child_cluster.name}")
        check_cluster_certificate(kaas_manager, child_cluster, show_step, test_service)


@pytest.mark.parametrize("test_service", [
    {'service_name': 'iam-proxy-alerta', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-alerta'},
    {'service_name': 'iam-proxy-alertmanager', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-alertmanager'},
    {'service_name': 'iam-proxy-grafana', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-grafana'},
    {'service_name': 'iam-proxy-prometheus', 'namespace': 'stacklight', 'svc_name': 'iam-proxy-prometheus'},
    {'service_name': 'cache', 'namespace': 'kaas', 'svc_name': 'mcc-cache'},
    {'service_name': 'telemeter-server', 'namespace': 'stacklight', 'svc_name': 'telemeter-server-external'}
], ids=lambda name: name['svc_name'])
def test_auto_renewal_region_certificate(kaas_manager, show_step, test_service):
    """Update license test for region cluster.

    Scenario:
        1. Get current tls secret data
        2. Delete current tls secret
        3. Get CertificateRequests data
        4. Get MCCCertificateRequest data
        5. Get new tls secret data
        6. Get CertificateConfiguration data
        7. Verify secret data
        8. Verify expiration date
        9. Verify hostname
        10. Verify cluster expiration
        11. Verify applied certificate


    """
    regional_clusters = kaas_manager.get_regional_clusters()

    for regional_cluster in regional_clusters:
        LOG.info(f"Start certificate verification for regional cluster: {regional_cluster.name}")
        check_cluster_certificate(kaas_manager, regional_cluster, show_step, test_service)
