import math from collections import deque from ywalk.types import Connection, Mode class Graph: WEIGHTS = { 'least-hop': lambda c: 1, 'least-time': lambda c: c.time + c.mode.weight, } def __init__(self): self.connections = {} self.nodes = 0 self.edges = 0 self.predicates = [] self.set_weight() def add_predicate(self, predicate): self.predicates.append(predicate) def set_weight(self, spec=None): self.weight_method = Graph.WEIGHTS[spec or 'least-time'] def add_connection(self, conn: Connection): for place in [conn.origin, conn.destination]: if place not in self.connections: self.nodes = self.nodes + 1 self.connections[place] = set() if conn not in self.connections[conn.origin]: self.edges = self.edges + 1 self.connections[conn.origin].add(conn) def add_recall(self, destination): for origin in self.get_places(): if origin == destination: continue self.add_connection(Connection(origin, destination, Mode.RECALL, 0)) def __contains__(self, place): return place in self.connections.keys() def get_connections_from(self, place): for conn in self.connections[place]: yield conn def get_connections_to(self, place): for origin in self.get_places(): yield from self.get_connections_between(origin, place) def get_connections_between(self, origin, destination): for conn in self.connections[origin]: if conn.destination == destination: yield conn def get_places(self): return self.connections.keys() def populate(self, gen): for conn in gen: self.add_connection(conn) def shortest_paths(self, origin, destination=None): tentative = list(self.get_places()) weights = {place: math.inf for place in self.get_places()} hops = {} weights[origin] = 0 while tentative: current = min(tentative, key=weights.get) tentative.remove(current) # Exit early if we were given a destination and have reached it if destination and current == destination: return weights, hops for conn in self.connections[current]: if conn.destination not in tentative: continue next_hop = conn.destination # Make sure to filter out any unwanted connections by setting # their weight to infinity if all(f(conn) for f in self.predicates): alt_weights = weights[current] + self.weight_method(conn) else: alt_weights = math.inf if alt_weights < weights[next_hop]: weights[next_hop] = alt_weights hops[next_hop] = conn return weights, hops def find_path(self, *args): stops = deque(args) current_stop = stops.popleft() journey = [] while stops: next_stop = stops.popleft() _, hops = self.shortest_paths(current_stop, next_stop) if not hops or next_stop not in hops: return None path = deque() last_hop = next_stop while last_hop in hops: conn = hops[last_hop] path.appendleft(conn) last_hop = conn.origin if path is None: return None journey.extend(path) current_stop = next_stop return journey