# Copyright 2012 OpenStack Foundation
# Copyright 2013 IBM Corp.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import os

import netaddr
from oslo_log import log

from tempest.common import compute
from tempest.common import utils
from tempest.common import waiters
from tempest import config
from tempest.lib.common.utils import data_utils
from tempest.lib.common.utils import test_utils
from tempest.lib import exceptions as lib_exc
from tempest.scenario import manager


CONF = config.CONF

LOG = log.getLogger(__name__)

NET_A = 'A'
NET_A_BIS = 'A-Bis'
NET_B = 'B'
NET_C = 'C'

if "SUBNETPOOL_PREFIX_V4" in os.environ:
    subnet_base = netaddr.IPNetwork(os.environ['SUBNETPOOL_PREFIX_V4'])
    if subnet_base.prefixlen > 21:
        raise Exception("if SUBNETPOOL_PREFIX_V4 is set, it needs to offer "
                        "space for at least 8 /24 subnets")
else:
    subnet_base = netaddr.IPNetwork("10.100.0.0/16")


def assign_24(idx):
    # how many addresses in a /24:
    range_size = 2 ** (32 - 24)
    return netaddr.cidr_merge(
        subnet_base[range_size * idx:range_size * (idx + 1)])[0]


S1A = assign_24(1)
S2A = assign_24(2)
S1B = assign_24(4)
S2B = assign_24(6)
S1C = assign_24(6)
NET_A_S1 = str(S1A)
NET_A_S2 = str(S2A)
NET_B_S1 = str(S1B)
NET_B_S2 = str(S2B)
NET_C_S1 = str(S1C)
IP_A_S1_1 = str(S1A[10])
IP_B_S1_1 = str(S1B[20])
IP_C_S1_1 = str(S1C[30])
IP_A_S1_2 = str(S1A[30])
IP_B_S1_2 = str(S1B[40])
IP_A_S1_3 = str(S1A[50])
IP_B_S1_3 = str(S1B[60])
IP_A_S2_1 = str(S2A[50])
IP_B_S2_1 = str(S2B[60])
IP_A_BIS_S1_1 = IP_A_S1_1
IP_A_BIS_S1_2 = IP_A_S1_2
IP_A_BIS_S1_3 = IP_A_S1_3
IP_A_BIS_S2_1 = IP_A_S2_1


class ScenarioTest(manager.NetworkScenarioTest):
    """Base class for scenario tests. Uses tempest own clients. """

    credentials = ['primary']


class NetworkScenarioTest(ScenarioTest):
    """Base class for network scenario tests.

    This class provide helpers for network scenario tests, using the neutron
    API. Helpers from ancestor which use the nova network API are overridden
    with the neutron API.

    This Class also enforces using Neutron instead of novanetwork.
    Subclassed tests will be skipped if Neutron is not enabled

    """

    credentials = ['primary', 'admin']

    @classmethod
    def skip_checks(cls):
        super(NetworkScenarioTest, cls).skip_checks()
        if not CONF.service_available.neutron:
            raise cls.skipException('Neutron not available')
        if not utils.is_extension_enabled('bgpvpn', 'network'):
            msg = "Bgpvpn extension not enabled."
            raise cls.skipException(msg)

    def _check_remote_connectivity(self, source, dest, should_succeed=True,
                                   nic=None):
        """check ping server via source ssh connection

        :param source: RemoteClient: an ssh connection from which to ping
        :param dest: and IP to ping against
        :param should_succeed: boolean should ping succeed or not
        :param nic: specific network interface to ping from
        :returns: boolean -- should_succeed == ping
        :returns: ping is false if ping failed
        """
        def ping_remote():
            try:
                source.ping_host(dest, nic=nic)
            except lib_exc.SSHExecCommandFailed:
                LOG.warning('Failed to ping IP: %s via a ssh connection '
                            'from: %s.', dest, source.ssh_client.host)
                return not should_succeed
            return should_succeed

        return test_utils.call_until_true(ping_remote,
                                          CONF.validation.ping_timeout,
                                          1)

    def create_loginable_secgroup_rule(self, security_group_rules_client=None,
                                       secgroup=None,
                                       security_groups_client=None):
        """Create loginable security group rule

        This function will create:
        1. egress and ingress tcp port 22 allow rule in order to allow ssh
        access for ipv4.
        2. egress and ingress tcp port 80 allow rule in order to allow http
        access for ipv4.
        3. egress and ingress ipv6 icmp allow rule, in order to allow icmpv6.
        4. egress and ingress ipv4 icmp allow rule, in order to allow icmpv4.
        """

        if security_group_rules_client is None:
            security_group_rules_client = self.security_group_rules_client
        if security_groups_client is None:
            security_groups_client = self.security_groups_client
        rules = []
        rulesets = [
            dict(
                # ssh
                protocol='tcp',
                port_range_min=22,
                port_range_max=22,
            ),
            dict(
                # http
                protocol='tcp',
                port_range_min=80,
                port_range_max=80,
            ),
            dict(
                # ping
                protocol='icmp',
            ),
            dict(
                # ipv6-icmp for ping6
                protocol='icmp',
                ethertype='IPv6',
            )
        ]
        sec_group_rules_client = security_group_rules_client
        for ruleset in rulesets:
            for r_direction in ['ingress', 'egress']:
                ruleset['direction'] = r_direction
                try:
                    sg_rule = self.create_security_group_rule(
                        sec_group_rules_client=sec_group_rules_client,
                        secgroup=secgroup,
                        security_groups_client=security_groups_client,
                        **ruleset)
                except lib_exc.Conflict as ex:
                    # if rule already exist - skip rule and continue
                    msg = 'Security group rule already exists'
                    if msg not in ex._error_string:
                        raise ex
                else:
                    self.assertEqual(r_direction, sg_rule['direction'])
                    rules.append(sg_rule)

        return rules

    def _create_router(self, client=None, tenant_id=None,
                       namestart='router-smoke'):
        if not client:
            client = self.routers_client
        if not tenant_id:
            tenant_id = client.tenant_id
        name = data_utils.rand_name(namestart)
        result = client.create_router(name=name,
                                      admin_state_up=True,
                                      tenant_id=tenant_id)
        router = result['router']
        self.assertEqual(router['name'], name)
        self.addCleanup(test_utils.call_and_ignore_notfound_exc,
                        client.delete_router,
                        router['id'])
        return router

    def _create_security_group_for_test(self):
        self.security_group = self.create_security_group(
            project_id=self.bgpvpn_client.project_id)

    def _create_networks_and_subnets(self, names=None, subnet_cidrs=None,
                                     port_security=True):
        if not names:
            names = [NET_A, NET_B, NET_C]
        if not subnet_cidrs:
            subnet_cidrs = [[NET_A_S1], [NET_B_S1], [NET_C_S1]]
        for (name, subnet_cidrs) in zip(names, subnet_cidrs):
            network = super(NetworkScenarioTest, self).create_network(
                namestart=name,
                port_security_enabled=port_security)
            self.networks[name] = network
            self.subnets[name] = []
            for (j, cidr) in enumerate(subnet_cidrs):
                sub_name = "subnet-%s-%d" % (name, j + 1)
                subnet = self._create_subnet_with_cidr(network,
                                                       namestart=sub_name,
                                                       cidr=cidr,
                                                       ip_version=4)
                self.subnets[name].append(subnet)

    def _create_subnet_with_cidr(self, network, subnets_client=None,
                                 namestart='subnet-smoke', **kwargs):
        if not subnets_client:
            subnets_client = self.subnets_client
        tenant_cidr = kwargs.get('cidr')
        # reserving pool for Neutron service ports
        net = netaddr.IPNetwork(tenant_cidr)
        allocation_pools = [{'start': str(net[2]), 'end': str(net[9])}]
        subnet = dict(
            name=data_utils.rand_name(namestart),
            network_id=network['id'],
            tenant_id=network['tenant_id'],
            allocation_pools=allocation_pools,
            **kwargs)
        result = subnets_client.create_subnet(**subnet)
        self.assertIsNotNone(result, 'Unable to allocate tenant network')
        subnet = result['subnet']
        self.assertEqual(subnet['cidr'], tenant_cidr)
        self.addCleanup(test_utils.call_and_ignore_notfound_exc,
                        subnets_client.delete_subnet, subnet['id'])
        return subnet

    def _create_fip_router(self, client=None, public_network_id=None,
                           subnet_id=None):
        router = self._create_router(client, namestart='router-')
        router_id = router['id']
        if public_network_id is None:
            public_network_id = CONF.network.public_network_id
        if client is None:
            client = self.routers_client
        kwargs = {'external_gateway_info': {'network_id': public_network_id}}
        router = client.update_router(router_id, **kwargs)['router']
        if subnet_id is not None:
            client.add_router_interface(router_id, subnet_id=subnet_id)
            self.addCleanup(test_utils.call_and_ignore_notfound_exc,
                            client.remove_router_interface, router_id,
                            subnet_id=subnet_id)
        return router

    def _associate_fip(self, server_index):
        server = self.servers[server_index]
        fip = self.create_floating_ip(
            server, external_network_id=CONF.network.public_network_id,
            port_id=self.ports[server['id']]['id'])
        self.server_fips[server['id']] = fip
        return fip

    def _create_router_and_associate_fip(self, server_index, subnet):
        router = self._create_fip_router(subnet_id=subnet['id'])
        self._associate_fip(server_index)
        return router

    def _create_server(self, name, keypair, network, ip_address,
                       security_group_ids, clients, port_security):
        security_groups = []
        if port_security:
            security_groups = security_group_ids
        create_port_body = {'fixed_ips': [{'ip_address': ip_address}],
                            'namestart': 'port-smoke',
                            'security_groups': security_groups}

        port = super(NetworkScenarioTest, self).create_port(
            network_id=network['id'],
            client=clients.ports_client,
            **create_port_body)

        create_server_kwargs = {
            'key_name': keypair['name'],
            'networks': [{'uuid': network['id'], 'port': port['id']}]
        }
        body, servers = compute.create_test_server(
            clients, wait_until='ACTIVE', name=name, **create_server_kwargs)
        self.addCleanup(waiters.wait_for_server_termination,
                        clients.servers_client, body['id'])
        self.addCleanup(test_utils.call_and_ignore_notfound_exc,
                        clients.servers_client.delete_server, body['id'])
        server = clients.servers_client.show_server(body['id'])['server']
        LOG.debug('Created server: %s with status: %s', server['id'],
                  server['status'])
        self.ports[server['id']] = port
        return server

    def _create_servers(self, ports_config=None, port_security=True):
        keypair = self.create_keypair()
        security_group_ids = [self.security_group['id']]
        if not ports_config:
            ports_config = [[self.networks[NET_A], IP_A_S1_1],
                            [self.networks[NET_B], IP_B_S1_1]]

        for (i, port_config) in enumerate(ports_config):
            network = port_config[0]
            server = self._create_server(
                'server-' + str(i + 1), keypair, network, port_config[1],
                security_group_ids, self.os_primary, port_security)
            self.servers.append(server)
            self.servers_keypairs[server['id']] = keypair
            self.server_fixed_ips[server['id']] = (
                server['addresses'][network['name']][0]['addr'])
            self.assertTrue(self.servers_keypairs)

    def _create_l3_bgpvpn(self, name='test-l3-bgpvpn', rts=None,
                          import_rts=None, export_rts=None):
        if rts is None and import_rts is None and export_rts is None:
            rts = [self.RT1]
        import_rts = import_rts or []
        export_rts = export_rts or []
        self.bgpvpn = self.create_bgpvpn(
            self.bgpvpn_admin_client, tenant_id=self.bgpvpn_client.tenant_id,
            name=name, route_targets=rts, export_targets=export_rts,
            import_targets=import_rts)
        self.addCleanup(test_utils.call_and_ignore_notfound_exc,
                        self.bgpvpn_admin_client.delete_bgpvpn,
                        self.bgpvpn['id'])
        return self.bgpvpn

    def _update_l3_bgpvpn(self, rts=None, import_rts=None, export_rts=None,
                          bgpvpn=None):
        bgpvpn = bgpvpn or self.bgpvpn
        if rts is None:
            rts = [self.RT1]
        import_rts = import_rts or []
        export_rts = export_rts or []
        LOG.debug('Updating targets in BGPVPN %s', bgpvpn['id'])
        self.bgpvpn_admin_client.update_bgpvpn(bgpvpn['id'],
                                               route_targets=rts,
                                               export_targets=export_rts,
                                               import_targets=import_rts)

    def _associate_all_nets_to_bgpvpn(self, bgpvpn=None):
        bgpvpn = bgpvpn or self.bgpvpn
        for network in self.networks.values():
            self.bgpvpn_client.create_network_association(
                bgpvpn['id'], network['id'])
        LOG.debug('BGPVPN network associations completed')

    def _setup_ssh_client(self, server):
        server_fip = self.server_fips[server['id']][
            'floating_ip_address']
        private_key = self.servers_keypairs[server['id']][
            'private_key']
        ssh_client = self.get_remote_client(server_fip,
                                            private_key=private_key,
                                            server=server)
        return ssh_client

    def _setup_http_server(self, server_index):
        server = self.servers[server_index]
        ssh_client = self._setup_ssh_client(server)
        ssh_client.exec_command("sudo nc -kl -p 80 -e echo '%s:%s' &"
                                % (server['name'], server['id']))

    def _setup_ip_forwarding(self, server_index):
        server = self.servers[server_index]
        ssh_client = self._setup_ssh_client(server)
        ssh_client.exec_command("sudo sysctl -w net.ipv4.ip_forward=1")

    def _setup_ip_address(self, server_index, cidr, device=None):
        self._setup_range_ip_address(server_index, [cidr], device=None)

    def _setup_range_ip_address(self, server_index, cidrs, device=None):
        MAX_CIDRS = 50
        if device is None:
            device = 'lo'
        server = self.servers[server_index]
        ssh_client = self._setup_ssh_client(server)
        for i in range(0, len(cidrs), MAX_CIDRS):
            ips = ' '.join(cidrs[i:i + MAX_CIDRS])
            ssh_client.exec_command(
                ("for ip in {ips}; do sudo ip addr add $ip "
                 "dev {dev}; done").format(ips=ips, dev=device))

    def _check_l3_bgpvpn(self, from_server=None, to_server=None,
                         should_succeed=True, validate_server=False):
        to_server = to_server or self.servers[1]
        destination_srv = None
        if validate_server:
            destination_srv = '%s:%s' % (to_server['name'], to_server['id'])
        destination_ip = self.server_fixed_ips[to_server['id']]
        self._check_l3_bgpvpn_by_specific_ip(from_server=from_server,
                                             to_server_ip=destination_ip,
                                             should_succeed=should_succeed,
                                             validate_server=destination_srv)

    def _check_l3_bgpvpn_by_specific_ip(self, from_server=None,
                                        to_server_ip=None,
                                        should_succeed=True,
                                        validate_server=None,
                                        repeat_validate_server=10):
        from_server = from_server or self.servers[0]
        from_server_ip = self.server_fips[from_server['id']][
            'floating_ip_address']
        if to_server_ip is None:
            to_server_ip = self.server_fixed_ips[self.servers[1]['id']]
        ssh_client = self._setup_ssh_client(from_server)
        check_reachable = should_succeed or validate_server
        msg = ""
        if check_reachable:
            msg = "Timed out waiting for {ip} to become reachable".format(
                ip=to_server_ip)
        else:
            msg = ("Unexpected ping response from VM with IP address "
                   "{dest} originated from VM with IP address "
                   "{src}").format(dest=to_server_ip, src=from_server_ip)
        try:
            result = self._check_remote_connectivity(ssh_client,
                                                     to_server_ip,
                                                     check_reachable)
            # if a negative connectivity check was unsuccessful (unexpected
            # ping reply) then try to know more:
            if not check_reachable and not result:
                try:
                    content = ssh_client.exec_command(
                        "nc %s 80" % to_server_ip).strip()
                    LOG.warning("Can connect to %s: %s", to_server_ip, content)
                except Exception:
                    LOG.warning("Could ping %s, but no http", to_server_ip)

            self.assertTrue(result, msg)

            if validate_server and result:
                # repeating multiple times gives increased odds of avoiding
                # false positives in the case where the dataplane does
                # equal-cost multipath
                for i in range(0, repeat_validate_server):
                    real_dest = ssh_client.exec_command(
                        "nc %s 80" % to_server_ip).strip()
                    result = real_dest == validate_server
                    self.assertTrue(
                        should_succeed == result,
                        ("Destination server name is '%s', expected is '%s'" %
                         (real_dest, validate_server)))
                    LOG.info("nc server name check %d successful", i)
        except Exception:
            LOG.exception("Error validating connectivity to %s "
                          "from VM with IP address %s: %s",
                          to_server_ip, from_server_ip, msg)
            raise

    def _associate_fip_and_check_l3_bgpvpn(self, subnet=None,
                                           should_succeed=True):
        if not subnet:
            subnet = self.subnets[NET_A][0]
        else:
            subnet = self.subnets[subnet][0]

        self.router = self._create_router_and_associate_fip(0, subnet)
        self._check_l3_bgpvpn(should_succeed=should_succeed)

    def _live_migrate(self, server_id, target_host, state,
                      volume_backed=False):
        # If target_host is None,
        # check whether source host is different with
        # the new host after migration.
        if target_host is None:
            source_host = self.get_host_for_server(server_id)
        self._migrate_server_to(server_id, target_host, volume_backed)
        waiters.wait_for_server_status(self.servers_client, server_id, state)
        migration_list = (self.admin_migration_client.list_migrations()
                          ['migrations'])
        msg = ("Live Migration failed. Migrations list for Instance "
               "%s: [" % server_id)
        for live_migration in migration_list:
            if (live_migration['instance_uuid'] == server_id):
                msg += "\n%s" % live_migration
        msg += "]"
        if target_host is None:
            self.assertNotEqual(source_host,
                                self.get_host_for_server(server_id), msg)
        else:
            self.assertEqual(target_host, self.get_host_for_server(server_id),
                             msg)

    def _migrate_server_to(self, server_id, dest_host, volume_backed=False):
        kwargs = dict()
        block_migration = getattr(self, 'block_migration', None)
        if self.block_migration is None:
            block_migration = (CONF.compute_feature_enabled.
                               block_migration_for_live_migration and
                               not volume_backed)
        self.admin_servers_client.live_migrate_server(
            server_id, host=dest_host, block_migration=block_migration,
            **kwargs)
