| import netaddr |
| import itertools |
| |
| |
| class Route(object): |
| def __init__(self, interface, router, cidr, next_hop): |
| self.interface = interface |
| self.router = router |
| self.cidr = cidr |
| self.next_hop = next_hop |
| |
| def __str__(self): |
| return "%s via %s dev %s src %s # router: %s" % ( |
| self.cidr, |
| self.next_hop, |
| self.interface.name, |
| self.interface.cidr.ip, |
| self.router.name) |
| |
| |
| class Interface(object): |
| def __init__(self, name, cidr): |
| self.name = name |
| self.cidr = cidr |
| |
| def __str__(self): |
| return self.name |
| |
| @staticmethod |
| def from_dict(data, default_name=None): |
| try: |
| cidr = netaddr.IPNetwork("%(address)s/%(netmask)s" % data) |
| except KeyError: |
| cidr = None |
| return Interface( |
| name = data.get('name', default_name), |
| cidr = cidr) |
| |
| |
| class Router(object): |
| def __init__(self, |
| name, |
| addresses=None, |
| networks=None, |
| implicit_routes=True): |
| self.name = name |
| self.addresses = list(addresses or []) |
| self.networks = list(networks or []) |
| if implicit_routes: |
| self.networks.extend(addr.cidr for addr in self.addresses) |
| |
| if any(addr.ip == addr.network for addr in self.addresses): |
| raise ValueError("Invalid router address") |
| if any(cidr.ip != cidr.network for cidr in self.networks): |
| raise ValueError("Invalid destination CIDR") |
| |
| def __str__(self): |
| return self.name |
| |
| @staticmethod |
| def from_dict(data, default_name=None): |
| options = data.get('options', {}) |
| return Router( |
| name = data.get('name', default_name), |
| addresses = [ |
| netaddr.IPNetwork(router_addr) |
| for router_addr in data.get('addresses', []) |
| ], |
| networks = [ |
| netaddr.IPNetwork(dest_cidr).cidr |
| for dest_cidr in data.get('networks', []) |
| ], |
| implicit_routes = options.get('implicit_routes', True) |
| ) |
| |
| def get_routes(self, interface): |
| next_hops = [addr.ip for addr in self.addresses |
| if addr.ip in interface.cidr] |
| |
| return [ |
| Route(interface, self, dest_cidr, next_hop) |
| for next_hop in next_hops |
| for dest_cidr in self.networks if next_hop not in dest_cidr |
| ] |
| |
| |
| def get_routes(interfaces, routers): |
| def sort_key_cidr(route): |
| return route.cidr.sort_key() |
| |
| def sort_key_ifname(route): |
| return (route.interface.name, route.cidr.sort_key()) |
| |
| interfaces = [Interface.from_dict(data, key) |
| for key, data in interfaces.items()] |
| routers = [Router.from_dict(data, key) |
| for key, data in routers.items()] |
| routes = [route |
| for interface in interfaces if interface.cidr |
| for router in routers |
| for route in router.get_routes(interface)] |
| |
| routes_by_cidr = {k: list(g) for k, g in itertools.groupby( |
| sorted(routes, key=sort_key_cidr), |
| lambda route: route.cidr) |
| } |
| |
| for cidr, duplicates in routes_by_cidr.items(): |
| if len(duplicates) > 1: |
| raise ValueError("Found multiple routes for %s: %s" % (cidr, |
| ["[via %s dev %s]" % (route.next_hop, route.interface) |
| for route in duplicates])) |
| |
| routes_by_ifname = {k: list(g) for k, g in itertools.groupby( |
| sorted(routes, key=sort_key_ifname), |
| lambda route: route.interface.name) |
| } |
| |
| return routes_by_ifname |