aboutsummaryrefslogblamecommitdiffstatshomepage
path: root/ywalk/graph.py
blob: 18556cdb71d2d24deab688863106f93bee72ec9b (plain) (tree)




























































































































                                                                                
import math

from collections import deque
from ywalk.types import Connection, Mode

class Graph:
    WEIGHTS = {
        'least-hop': lambda c: 1,
        'least-time': lambda c: c.time + c.mode.weight,
    }

    def __init__(self):
        self.connections = {}
        self.nodes = 0
        self.edges = 0
        self.predicates = []

        self.set_weight()

    def add_predicate(self, predicate):
        self.predicates.append(predicate)

    def set_weight(self, spec=None):
        self.weight_method = Graph.WEIGHTS[spec or 'least-time']

    def add_connection(self, conn: Connection):
        for place in [conn.origin, conn.destination]:
            if place not in self.connections:
                self.nodes = self.nodes + 1
                self.connections[place] = set()

        if conn not in self.connections[conn.origin]:
            self.edges = self.edges + 1
        self.connections[conn.origin].add(conn)

    def add_recall(self, destination):
        for origin in self.get_places():
            if origin == destination:
                continue
            self.add_connection(Connection(origin, destination, Mode.RECALL, 0))

    def __contains__(self, place):
        return place in self.connections.keys()

    def get_connections_from(self, place):
        for conn in self.connections[place]:
            yield conn

    def get_connections_to(self, place):
        for origin in self.get_places():
            yield from self.get_connections_between(origin, place)

    def get_connections_between(self, origin, destination):
        for conn in self.connections[origin]:
            if conn.destination == destination:
                yield conn

    def get_places(self):
        return self.connections.keys()

    def populate(self, gen):
        for conn in gen:
            self.add_connection(conn)

    def shortest_paths(self, origin, destination=None):
        tentative = list(self.get_places())
        weights = {place: math.inf for place in self.get_places()}
        hops = {}

        weights[origin] = 0

        while tentative:
            current = min(tentative, key=weights.get)

            tentative.remove(current)

            # Exit early if we were given a destination and have reached it
            if destination and current == destination:
                return weights, hops

            for conn in self.connections[current]:
                if conn.destination not in tentative:
                    continue

                next_hop = conn.destination

                # Make sure to filter out any unwanted connections by setting
                # their weight to infinity
                if all(f(conn) for f in self.predicates):
                    alt_weights = weights[current] + self.weight_method(conn)
                else:
                    alt_weights = math.inf

                if alt_weights < weights[next_hop]:
                    weights[next_hop] = alt_weights
                    hops[next_hop] = conn

        return weights, hops

    def find_path(self, *args):
        stops = deque(args)
        current_stop = stops.popleft()

        journey = []
        while stops:
            next_stop = stops.popleft()
            _, hops = self.shortest_paths(current_stop, next_stop)

            if not hops or next_stop not in hops:
                return None

            path = deque()
            last_hop = next_stop

            while last_hop in hops:
                conn = hops[last_hop]
                path.appendleft(conn)
                last_hop = conn.origin

            if path is None:
                return None

            journey.extend(path)
            current_stop = next_stop
        return journey