From 0fb4fc20559c9a5de5a10c74c1247635a1523255 Mon Sep 17 00:00:00 2001 From: Wolfgang Müller Date: Sun, 14 Nov 2021 18:55:52 +0100 Subject: Initial commit --- ywalk/graph.py | 125 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 ywalk/graph.py (limited to 'ywalk/graph.py') diff --git a/ywalk/graph.py b/ywalk/graph.py new file mode 100644 index 0000000..18556cd --- /dev/null +++ b/ywalk/graph.py @@ -0,0 +1,125 @@ +import math + +from collections import deque +from ywalk.types import Connection, Mode + +class Graph: + WEIGHTS = { + 'least-hop': lambda c: 1, + 'least-time': lambda c: c.time + c.mode.weight, + } + + def __init__(self): + self.connections = {} + self.nodes = 0 + self.edges = 0 + self.predicates = [] + + self.set_weight() + + def add_predicate(self, predicate): + self.predicates.append(predicate) + + def set_weight(self, spec=None): + self.weight_method = Graph.WEIGHTS[spec or 'least-time'] + + def add_connection(self, conn: Connection): + for place in [conn.origin, conn.destination]: + if place not in self.connections: + self.nodes = self.nodes + 1 + self.connections[place] = set() + + if conn not in self.connections[conn.origin]: + self.edges = self.edges + 1 + self.connections[conn.origin].add(conn) + + def add_recall(self, destination): + for origin in self.get_places(): + if origin == destination: + continue + self.add_connection(Connection(origin, destination, Mode.RECALL, 0)) + + def __contains__(self, place): + return place in self.connections.keys() + + def get_connections_from(self, place): + for conn in self.connections[place]: + yield conn + + def get_connections_to(self, place): + for origin in self.get_places(): + yield from self.get_connections_between(origin, place) + + def get_connections_between(self, origin, destination): + for conn in self.connections[origin]: + if conn.destination == destination: + yield conn + + def get_places(self): + return self.connections.keys() + + def populate(self, gen): + for conn in gen: + self.add_connection(conn) + + def shortest_paths(self, origin, destination=None): + tentative = list(self.get_places()) + weights = {place: math.inf for place in self.get_places()} + hops = {} + + weights[origin] = 0 + + while tentative: + current = min(tentative, key=weights.get) + + tentative.remove(current) + + # Exit early if we were given a destination and have reached it + if destination and current == destination: + return weights, hops + + for conn in self.connections[current]: + if conn.destination not in tentative: + continue + + next_hop = conn.destination + + # Make sure to filter out any unwanted connections by setting + # their weight to infinity + if all(f(conn) for f in self.predicates): + alt_weights = weights[current] + self.weight_method(conn) + else: + alt_weights = math.inf + + if alt_weights < weights[next_hop]: + weights[next_hop] = alt_weights + hops[next_hop] = conn + + return weights, hops + + def find_path(self, *args): + stops = deque(args) + current_stop = stops.popleft() + + journey = [] + while stops: + next_stop = stops.popleft() + _, hops = self.shortest_paths(current_stop, next_stop) + + if not hops or next_stop not in hops: + return None + + path = deque() + last_hop = next_stop + + while last_hop in hops: + conn = hops[last_hop] + path.appendleft(conn) + last_hop = conn.origin + + if path is None: + return None + + journey.extend(path) + current_stop = next_stop + return journey -- cgit v1.2.3-2-gb3c3