# Copyright (c) 2017 Midokura SARL
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import netaddr
from oslo_config import cfg
import testtools

from neutron_lib.utils import test
from tempest.common import utils
from tempest.common import waiters
from tempest.lib.common import ssh
from tempest.lib.common.utils import data_utils
from tempest.lib import decorators

from neutron_tempest_plugin import config
from neutron_tempest_plugin.scenario import constants
from neutron_tempest_plugin.vpnaas.scenario import base_vpnaas as base


CONF = config.CONF

# NOTE(huntxu): This is a workaround due to a upstream bug [1].
# VPNaaS 4in6 and 6in4 is not working properly with LibreSwan 3.19+.
# In OpenStack zuul checks the base CentOS 7 node is using Libreswan 3.20 on
# CentOS 7.4. So we need to provide a way to skip the 4in6 and 6in4 test cases
# for zuul.
#
# Once the upstream bug gets fixed and the base node uses a newer version of
# Libreswan with that fix, we can remove this.
#
# [1] https://github.com/libreswan/libreswan/issues/175
CONF.register_opt(
    cfg.BoolOpt('skip_4in6_6in4_tests',
                default=False,
                help='Whether to skip 4in6 and 6in4 test cases.'),
    'neutron_vpnaas_plugin_options'
)

# The VPNaaS drivers for OVN don't support IPv6 VMs
CONF.register_opt(
    cfg.BoolOpt('skip_6in4_tests',
                default=False,
                help='Whether to skip 6in4 test cases.'),
    'neutron_vpnaas_plugin_options'
)
CONF.register_opt(
    cfg.BoolOpt('skip_6in6_tests',
                default=False,
                help='Whether to skip 6in6 test cases.'),
    'neutron_vpnaas_plugin_options'
)


class Vpnaas(base.BaseTempestTestCase):
    """Test the following topology

    .. code-block:: HTML

              +-------------------+
              | public            |
              | network           |
              |                   |
              +-+---------------+-+
                |               |
                |               |
        +-------+-+           +-+-------+
        | LEFT    |           | RIGHT   |
        | router  | <--VPN--> | router  |
        |         |           |         |
        +----+----+           +----+----+
             |                     |
        +----+----+           +----+----+
        | LEFT    |           | RIGHT   |
        | network |           | network |
        |         |           |         |
        +---------+           +---------+
    """

    credentials = ['primary', 'admin']
    inner_ipv6 = False
    outer_ipv6 = False

    @classmethod
    @utils.requires_ext(extension="vpnaas", service="network")
    def resource_setup(cls):
        super(Vpnaas, cls).resource_setup()

        # common
        cls.keypair = cls.create_keypair()
        cls.secgroup = cls.os_primary.network_client.create_security_group(
            name=data_utils.rand_name('secgroup-'))['security_group']
        cls.security_groups.append(cls.secgroup)
        cls.create_loginable_secgroup_rule(secgroup_id=cls.secgroup['id'])
        cls.create_pingable_secgroup_rule(secgroup_id=cls.secgroup['id'])
        cls.ikepolicy = cls.create_ikepolicy(
            data_utils.rand_name("ike-policy-"), 'v2')
        cls.ipsecpolicy = cls.create_ipsecpolicy(
            data_utils.rand_name("ipsec-policy-"))

        cls.extra_subnet_attributes = {}
        if cls.inner_ipv6:
            cls.create_v6_pingable_secgroup_rule(
                secgroup_id=cls.secgroup['id'])
            cls.extra_subnet_attributes['ipv6_address_mode'] = 'slaac'
            cls.extra_subnet_attributes['ipv6_ra_mode'] = 'slaac'

        left_v4_cidr = netaddr.IPNetwork('10.220.0.0/24')
        left_v6_cidr = netaddr.IPNetwork('2001:db8:0:2::/64')
        cls.left_cidr = left_v6_cidr if cls.inner_ipv6 else left_v4_cidr
        right_v4_cidr = netaddr.IPNetwork('10.210.0.0/24')
        right_v6_cidr = netaddr.IPNetwork('2001:db8:0:1::/64')
        cls.right_cidr = right_v6_cidr if cls.inner_ipv6 else right_v4_cidr

        # LEFT
        cls.router = cls.create_router(
            data_utils.rand_name('left-router'),
            admin_state_up=True,
            external_network_id=CONF.network.public_network_id)
        cls.network = cls.create_network(network_name='left-network')
        ip_version = 6 if cls.inner_ipv6 else 4
        is_distributed = cls.os_admin.network_client.show_router(
            cls.router['id'])['router'].get('distributed')

        cls.subnet = cls.create_subnet(
            cls.network, ip_version=ip_version, cidr=cls.left_cidr,
            name='left-subnet', **cls.extra_subnet_attributes)
        cls.create_router_interface(cls.router['id'], cls.subnet['id'])
        if is_distributed:
            snat_port = cls.os_admin.network_client.list_ports(
                device_id=cls.router['id'],
                device_owner='network:router_centralized_snat')
            snat_ip = cls._get_ip_on_subnet_for_port(
                cls, snat_port['ports'][0], cls.subnet['id'])
            cls.os_admin.network_client.update_subnet(
                cls.subnet['id'], host_routes=[{"destination": cls.right_cidr,
                                                "nexthop": snat_ip}])
        # Gives an internal IPv4 subnet for floating IP to the left server,
        # we use it to ssh into the left server.
        if cls.inner_ipv6:
            v4_subnet = cls.create_subnet(
                cls.network, ip_version=4, name='left-v4-subnet')
            cls.create_router_interface(cls.router['id'], v4_subnet['id'])

        # RIGHT
        cls._right_network, cls._right_subnet, cls._right_router = \
            cls._create_right_network()

    @classmethod
    def create_v6_pingable_secgroup_rule(cls, secgroup_id=None, client=None):
        # NOTE(huntxu): This method should be moved into the base class, along
        # with the v4 version.
        """This rule is intended to permit inbound ping6"""

        rule_list = [{'protocol': 'ipv6-icmp',
                      'direction': 'ingress',
                      'port_range_min': 128,  # type
                      'port_range_max': 0,  # code
                      'ethertype': 'IPv6',
                      'remote_ip_prefix': '::/0'}]
        client = client or cls.os_primary.network_client
        cls.create_secgroup_rules(rule_list, client=client,
                                  secgroup_id=secgroup_id)

    @classmethod
    def _create_right_network(cls):
        router = cls.create_router(
            data_utils.rand_name('right-router'),
            admin_state_up=True,
            external_network_id=CONF.network.public_network_id)
        is_distributed = cls.os_admin.network_client.show_router(
            router['id'])['router'].get('distributed')
        network = cls.create_network(network_name='right-network')
        ip_version = 6 if cls.inner_ipv6 else 4
        subnet = cls.create_subnet(
            network, ip_version=ip_version, cidr=cls.right_cidr,
            name='right-subnet', **cls.extra_subnet_attributes)
        cls.create_router_interface(router['id'], subnet['id'])
        if is_distributed:
            snat_port = cls.os_admin.network_client.list_ports(
                device_id=router['id'],
                device_owner='network:router_centralized_snat')
            snat_ip = cls._get_ip_on_subnet_for_port(
                cls, snat_port['ports'][0], subnet['id'])
            cls.os_admin.network_client.update_subnet(
                subnet['id'], host_routes=[{"destination": cls.left_cidr,
                                            "nexthop": snat_ip}])
        return network, subnet, router

    def _create_server(self, create_floating_ip=True, network=None):
        if network is None:
            network = self.network
        port = self.create_port(network, security_groups=[self.secgroup['id']])
        if create_floating_ip:
            fip = self.create_and_associate_floatingip(port['id'])
        else:
            fip = None
        server = self.create_server(
            flavor_ref=CONF.compute.flavor_ref,
            image_ref=CONF.compute.image_ref,
            key_name=self.keypair['name'],
            networks=[{'port': port['id']}])['server']
        waiters.wait_for_server_status(self.os_primary.servers_client,
                                       server['id'],
                                       constants.SERVER_STATUS_ACTIVE)
        return {'port': port, 'fip': fip, 'server': server}

    def _setup_vpn(self):
        sites = [
            dict(name="left", network=self.network, subnet=self.subnet,
                 router=self.router),
            dict(name="right", network=self._right_network,
                 subnet=self._right_subnet, router=self._right_router),
        ]
        psk = data_utils.rand_name('mysecret')
        for i in range(0, 2):
            site = sites[i]
            site['vpnservice'] = self.create_vpnservice(
                site['subnet']['id'], site['router']['id'],
                name=data_utils.rand_name('%s-vpnservice' % site['name']))
        site_connections = []
        for i in range(0, 2):
            site = sites[i]
            vpnservice = site['vpnservice']
            peer = sites[1 - i]
            if self.outer_ipv6:
                peer_address = peer['vpnservice']['external_v6_ip']
                if not peer_address:
                    msg = "Public network must have an IPv6 subnet."
                    raise self.skipException(msg)
            else:
                peer_address = peer['vpnservice']['external_v4_ip']
            site_connection = self.create_ipsec_site_connection(
                self.ikepolicy['id'],
                self.ipsecpolicy['id'],
                vpnservice['id'],
                peer_address=peer_address,
                peer_id=peer_address,
                peer_cidrs=[peer['subnet']['cidr']],
                psk=psk,
                name=data_utils.rand_name(
                    '%s-ipsec-site-connection' % site['name']))
            site_connections.append(site_connection)
        for site_connection in site_connections:
            self.wait_ipsec_site_connection_status(site_connection['id'],
                                                   status="ACTIVE")

    def _get_ip_on_subnet_for_port(self, port, subnet_id):
        for fixed_ip in port['fixed_ips']:
            if fixed_ip['subnet_id'] == subnet_id:
                return fixed_ip['ip_address']
        msg = "Cannot get IP address on specified subnet %s for port %r." % (
            subnet_id, port)
        raise self.fail(msg)

    def _test_vpnaas(self):
        # RIGHT
        right_server = self._create_server(network=self._right_network,
            create_floating_ip=False)
        right_ip = self._get_ip_on_subnet_for_port(
            right_server['port'], self._right_subnet['id'])

        # LEFT
        left_server = self._create_server()
        ssh_client = ssh.Client(left_server['fip']['floating_ip_address'],
                                CONF.validation.image_ssh_user,
                                pkey=self.keypair['private_key'],
                                ssh_key_type=CONF.validation.ssh_key_type)

        # check LEFT -> RIGHT connectivity via VPN
        self.check_remote_connectivity(ssh_client, right_ip,
                                       should_succeed=False)
        self._setup_vpn()
        self.check_remote_connectivity(ssh_client, right_ip)

        # Test VPN traffic and floating IP traffic don't interfere each other.
        if not self.inner_ipv6:
            # Assign a floating-ip and check connectivity.
            # This is NOT via VPN.
            fip = self.create_and_associate_floatingip(
                right_server['port']['id'])
            self.check_remote_connectivity(ssh_client,
                                           fip['floating_ip_address'])

            # check LEFT -> RIGHT connectivity via VPN again, to ensure
            # the above floating-ip doesn't interfere the traffic.
            self.check_remote_connectivity(ssh_client, right_ip)


class Vpnaas4in4(Vpnaas):

    @decorators.idempotent_id('aa932ab2-63aa-49cf-a2a0-8ae71ac2bc24')
    @decorators.attr(type='smoke')
    def test_vpnaas(self):
        self._test_vpnaas()


class Vpnaas4in6(Vpnaas):
    outer_ipv6 = True

    @decorators.idempotent_id('2d5f18dc-6186-4deb-842b-051325bd0466')
    @testtools.skipUnless(CONF.network_feature_enabled.ipv6,
                          'IPv6 tests are disabled.')
    @testtools.skipIf(
        CONF.neutron_vpnaas_plugin_options.skip_4in6_6in4_tests,
        'VPNaaS 4in6 test is skipped.')
    @test.unstable_test("bug 1882220")
    def test_vpnaas_4in6(self):
        self._test_vpnaas()


class Vpnaas6in4(Vpnaas):
    inner_ipv6 = True

    @decorators.idempotent_id('10febf33-c5b7-48af-aa13-94b4fb585a55')
    @testtools.skipUnless(CONF.network_feature_enabled.ipv6,
                          'IPv6 tests are disabled.')
    @testtools.skipIf(
        CONF.neutron_vpnaas_plugin_options.skip_6in4_tests,
        'VPNaaS 6in4 test is skipped.')
    @test.unstable_test("bug 1882220")
    def test_vpnaas_6in4(self):
        self._test_vpnaas()


class Vpnaas6in6(Vpnaas):
    inner_ipv6 = True
    outer_ipv6 = True

    @decorators.idempotent_id('8b503ffc-aeb0-4938-8dba-73c7323e276d')
    @testtools.skipUnless(CONF.network_feature_enabled.ipv6,
                          'IPv6 tests are disabled.')
    @testtools.skipIf(
        CONF.neutron_vpnaas_plugin_options.skip_6in6_tests,
        'VPNaaS 6in6 test is skipped.')
    @test.unstable_test("bug 1882220")
    def test_vpnaas_6in6(self):
        self._test_vpnaas()


class VpnaasWithEndpointGroupBase(Vpnaas):
    """Test the following topology

    .. code-block:: HTML

                   +-------------------+
                   | public            |
                   | network           |
                   |                   |
                   ++-----------------++
                    |                 |
                    |                 |
            +-------+-+             +-+-------+
            | LEFT    |             | RIGHT   |
            | router  |  <--VPN-->  | router  |
            |         |             |         |
            +-+-----+-+             +-+-----+-+
              |     |                 |     |
      +-------+-+ +-+-------+ +-------+-+ +-+-------+
      | LEFT    | | LEFT    | | RIGHT   | | RIGHT   |
      | network | | network2| | network | | network2|
      |         | |         | |         | |         |
      +---------+ +---------+ +---------+ +---------+
    """

    check_failures = ""

    @classmethod
    @utils.requires_ext(extension="vpnaas", service="network")
    def resource_setup(cls):
        super(VpnaasWithEndpointGroupBase, cls).resource_setup()

        left_v4_cidr2 = netaddr.IPNetwork('10.220.2.0/24')
        left_v6_cidr2 = netaddr.IPNetwork('2001:db8:0:3::/64')
        right_v4_cidr2 = netaddr.IPNetwork('10.210.2.0/24')
        right_v6_cidr2 = netaddr.IPNetwork('2001:db8:0:4::/64')

        cls.left_cidr2 = left_v6_cidr2 if cls.inner_ipv6 else left_v4_cidr2
        cls.right_cidr2 = right_v6_cidr2 if cls.inner_ipv6 else right_v4_cidr2

        cls.network2, cls.subnet2 = cls._add_network(
            'left2',
            cls.router,
            cls.left_cidr2
        )

        cls._right_network2, cls._right_subnet2 = cls._add_network(
            'right2',
            cls._right_router,
            cls.right_cidr2
        )

        # Update subnets in case of distributed routers
        cls._update_host_routes()

        cls.left_ep_group_subnet = cls.create_endpoint_group(
            name=data_utils.rand_name("left-endpoint-group-subnet-"),
            type="subnet",
            endpoints=[cls.subnet['id'], cls.subnet2['id']])
        cls.left_ep_group_cidr = cls.create_endpoint_group(
            name=data_utils.rand_name("left-endpoint-group-cidr-"),
            type="cidr",
            endpoints=[cls.left_cidr, cls.left_cidr2])
        cls.right_ep_group_subnet = cls.create_endpoint_group(
            name=data_utils.rand_name("right-endpoint-group-subnet-"),
            type="subnet",
            endpoints=[cls._right_subnet['id'], cls._right_subnet2['id']])
        cls.right_ep_group_cidr = cls.create_endpoint_group(
            name=data_utils.rand_name("right-endpoint-group-cidr-"),
            type="cidr",
            endpoints=[cls.right_cidr, cls.right_cidr2])

    @classmethod
    def _add_network(cls, prefix, router, cidr):
        network = cls.create_network(network_name=f"{prefix}-network")
        ip_version = 6 if cls.inner_ipv6 else 4
        subnet = cls.create_subnet(
            network, ip_version=ip_version, cidr=cidr,
            name=f"{prefix}-subnet", **cls.extra_subnet_attributes)
        cls.create_router_interface(router['id'], subnet['id'])

        return network, subnet

    @classmethod
    def _update_host_routes(cls):
        is_left_router_distributed = cls.os_admin.network_client.show_router(
            cls.router['id'])['router'].get('distributed')
        if is_left_router_distributed:
            snat_port = cls.os_admin.network_client.list_ports(
                device_id=cls.router['id'],
                device_owner='network:router_centralized_snat')
            for subnet in [cls.subnet, cls.subnet2]:
                snat_ip = cls._get_snat_ip(cls, snat_port['ports'], subnet)
                host_routes = [
                    {"destination": cls.right_cidr, "nexthop": snat_ip},
                    {"destination": cls.right_cidr2, "nexthop": snat_ip}
                ]
                cls.os_admin.network_client.update_subnet(
                    subnet['id'], host_routes=host_routes)
        is_right_router_distributed = cls.os_admin.network_client.show_router(
            cls._right_router['id'])['router'].get('distributed')
        if is_right_router_distributed:
            snat_port = cls.os_admin.network_client.list_ports(
                device_id=cls._right_router['id'],
                device_owner='network:router_centralized_snat')
            for subnet in [cls._right_subnet, cls._right_subnet2]:
                snat_ip = cls._get_snat_ip(cls, snat_port['ports'], subnet)
                host_routes = [
                    {"destination": cls.left_cidr, "nexthop": snat_ip},
                    {"destination": cls.left_cidr2, "nexthop": snat_ip}
                ]
                cls.os_admin.network_client.update_subnet(
                    subnet['id'], host_routes=host_routes)

    def _get_snat_ip(self, ports, subnet):
        snat_ip = None
        for port in ports:
            if subnet['network_id'] == port['network_id']:
                snat_ip = self._get_ip_on_subnet_for_port(
                    self, port, subnet['id'])
                break
        return snat_ip

    def _setup_vpn(self):
        sites = [
            dict(
                name="left",
                local_ep_group_id=self.left_ep_group_subnet['id'],
                peer_ep_group_id=self.right_ep_group_cidr['id'],
                router=self.router,
            ),
            dict(
                name="right",
                local_ep_group_id=self.right_ep_group_subnet['id'],
                peer_ep_group_id=self.left_ep_group_cidr['id'],
                router=self._right_router,
            ),
        ]
        psk = data_utils.rand_name('mysecret')
        for i in range(0, 2):
            site = sites[i]
            site['vpnservice'] = self.create_vpnservice_no_subnet(
                site['router']['id'])
        site_connections = []
        for i in range(0, 2):
            site = sites[i]
            vpnservice = site['vpnservice']
            peer = sites[1 - i]
            if self.outer_ipv6:
                peer_address = peer['vpnservice']['external_v6_ip']
                if not peer_address:
                    msg = "Public network must have an IPv6 subnet."
                    raise self.skipException(msg)
            else:
                peer_address = peer['vpnservice']['external_v4_ip']
            site_connection = self.create_ipsec_site_connection(
                self.ikepolicy['id'],
                self.ipsecpolicy['id'],
                vpnservice['id'],
                peer_address=peer_address,
                peer_id=peer_address,
                local_ep_group_id=site['local_ep_group_id'],
                peer_ep_group_id=site['peer_ep_group_id'],
                psk=psk,
                name=data_utils.rand_name(
                    '%s-ipsec-site-connection' % site['name']))
            site_connections.append(site_connection)
        for site_connection in site_connections:
            self.wait_ipsec_site_connection_status(site_connection['id'],
                                                   status="ACTIVE")

    def _union_tests(description):
        def decorator(func):
            def f(self, *args, **kwargs):
                try:
                    func(self, *args, **kwargs)
                except Exception:
                    self.check_failures += (
                        f"\nTest connection {description} '{args[0].host} " +
                        f"--> {args[1]}' failed"
                    )
            return f
        return decorator

    @_union_tests('without VPN')
    def _test_check_connectivity_fail(self, fip, dst_ip):
        self.check_remote_connectivity(fip, dst_ip, should_succeed=False)

    @_union_tests('via VPN')
    def _test_check_connectivity_with_vpn(self, fip, dst_ip):
        self.check_remote_connectivity(fip, dst_ip)

    @_union_tests('via floating IP')
    def _test_check_connectivity_with_fip(self, fip, dst_ip):
        self.check_remote_connectivity(fip, dst_ip)

    def _test_vpnaas(self, right_servers_fip=False):
        # RIGHT
        self.right_server_A = self._create_server(network=self._right_network,
            create_floating_ip=right_servers_fip)
        self.right_ip_A = self._get_ip_on_subnet_for_port(
            self.right_server_A['port'], self._right_subnet['id'])
        self.right_server_B = self._create_server(network=self._right_network2,
            create_floating_ip=right_servers_fip)
        self.right_ip_B = self._get_ip_on_subnet_for_port(
            self.right_server_B['port'], self._right_subnet2['id'])

        # LEFT
        left_server_A = self._create_server()
        self.ssh_client_A = ssh.Client(
            left_server_A['fip']['floating_ip_address'],
            CONF.validation.image_ssh_user,
            pkey=self.keypair['private_key'],
            ssh_key_type=CONF.validation.ssh_key_type
        )
        left_server_B = self._create_server(network=self.network2)
        self.ssh_client_B = ssh.Client(
            left_server_B['fip']['floating_ip_address'],
            CONF.validation.image_ssh_user,
            pkey=self.keypair['private_key'],
            ssh_key_type=CONF.validation.ssh_key_type
        )

        # check LEFT -> RIGHT connectivity without VPN
        self._test_check_connectivity_fail(self.ssh_client_A, self.right_ip_A)
        self._test_check_connectivity_fail(self.ssh_client_B, self.right_ip_A)
        self._test_check_connectivity_fail(self.ssh_client_A, self.right_ip_B)
        self._test_check_connectivity_fail(self.ssh_client_B, self.right_ip_B)
        self.assertEmpty(self.check_failures, f"{self.check_failures}")

        # check LEFT -> RIGHT connectivity via VPN
        self._setup_vpn()
        self._test_check_connectivity_with_vpn(
            self.ssh_client_A, self.right_ip_A
        )
        self._test_check_connectivity_with_vpn(
            self.ssh_client_B, self.right_ip_A
        )
        self._test_check_connectivity_with_vpn(
            self.ssh_client_A, self.right_ip_B
        )
        self._test_check_connectivity_with_vpn(
            self.ssh_client_B, self.right_ip_B
        )
        self.assertEmpty(self.check_failures, f"{self.check_failures}")

    def _test_vpnaas_with_fip(self):
        self._test_vpnaas(right_servers_fip=True)
        # check LEFT -> RIGHT connectivity via floating IP
        self._test_check_connectivity_with_fip(
            self.ssh_client_A,
            self.right_server_A['fip']['floating_ip_address']
        )
        self._test_check_connectivity_with_fip(
            self.ssh_client_B,
            self.right_server_A['fip']['floating_ip_address']
        )
        self._test_check_connectivity_with_fip(
            self.ssh_client_A,
            self.right_server_B['fip']['floating_ip_address']
        )
        self._test_check_connectivity_with_fip(
            self.ssh_client_B,
            self.right_server_B['fip']['floating_ip_address']
        )
        self.assertEmpty(self.check_failures, f"{self.check_failures}")


class VpnaasEG4in4(VpnaasWithEndpointGroupBase):

    @decorators.idempotent_id('26dad126-665f-4f59-ba2d-e7e27a9675de')
    def test_vpnaas_with_fip(self):
        self._test_vpnaas_with_fip()
