aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/ywalk/graph.py
blob: 18556cdb71d2d24deab688863106f93bee72ec9b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import math

from collections import deque
from ywalk.types import Connection, Mode

class Graph:
    WEIGHTS = {
        'least-hop': lambda c: 1,
        'least-time': lambda c: c.time + c.mode.weight,
    }

    def __init__(self):
        self.connections = {}
        self.nodes = 0
        self.edges = 0
        self.predicates = []

        self.set_weight()

    def add_predicate(self, predicate):
        self.predicates.append(predicate)

    def set_weight(self, spec=None):
        self.weight_method = Graph.WEIGHTS[spec or 'least-time']

    def add_connection(self, conn: Connection):
        for place in [conn.origin, conn.destination]:
            if place not in self.connections:
                self.nodes = self.nodes + 1
                self.connections[place] = set()

        if conn not in self.connections[conn.origin]:
            self.edges = self.edges + 1
        self.connections[conn.origin].add(conn)

    def add_recall(self, destination):
        for origin in self.get_places():
            if origin == destination:
                continue
            self.add_connection(Connection(origin, destination, Mode.RECALL, 0))

    def __contains__(self, place):
        return place in self.connections.keys()

    def get_connections_from(self, place):
        for conn in self.connections[place]:
            yield conn

    def get_connections_to(self, place):
        for origin in self.get_places():
            yield from self.get_connections_between(origin, place)

    def get_connections_between(self, origin, destination):
        for conn in self.connections[origin]:
            if conn.destination == destination:
                yield conn

    def get_places(self):
        return self.connections.keys()

    def populate(self, gen):
        for conn in gen:
            self.add_connection(conn)

    def shortest_paths(self, origin, destination=None):
        tentative = list(self.get_places())
        weights = {place: math.inf for place in self.get_places()}
        hops = {}

        weights[origin] = 0

        while tentative:
            current = min(tentative, key=weights.get)

            tentative.remove(current)

            # Exit early if we were given a destination and have reached it
            if destination and current == destination:
                return weights, hops

            for conn in self.connections[current]:
                if conn.destination not in tentative:
                    continue

                next_hop = conn.destination

                # Make sure to filter out any unwanted connections by setting
                # their weight to infinity
                if all(f(conn) for f in self.predicates):
                    alt_weights = weights[current] + self.weight_method(conn)
                else:
                    alt_weights = math.inf

                if alt_weights < weights[next_hop]:
                    weights[next_hop] = alt_weights
                    hops[next_hop] = conn

        return weights, hops

    def find_path(self, *args):
        stops = deque(args)
        current_stop = stops.popleft()

        journey = []
        while stops:
            next_stop = stops.popleft()
            _, hops = self.shortest_paths(current_stop, next_stop)

            if not hops or next_stop not in hops:
                return None

            path = deque()
            last_hop = next_stop

            while last_hop in hops:
                conn = hops[last_hop]
                path.appendleft(conn)
                last_hop = conn.origin

            if path is None:
                return None

            journey.extend(path)
            current_stop = next_stop
        return journey