aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/ywalk/graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'ywalk/graph.py')
-rw-r--r--ywalk/graph.py125
1 files changed, 125 insertions, 0 deletions
diff --git a/ywalk/graph.py b/ywalk/graph.py
new file mode 100644
index 0000000..18556cd
--- /dev/null
+++ b/ywalk/graph.py
@@ -0,0 +1,125 @@
+import math
+
+from collections import deque
+from ywalk.types import Connection, Mode
+
+class Graph:
+ WEIGHTS = {
+ 'least-hop': lambda c: 1,
+ 'least-time': lambda c: c.time + c.mode.weight,
+ }
+
+ def __init__(self):
+ self.connections = {}
+ self.nodes = 0
+ self.edges = 0
+ self.predicates = []
+
+ self.set_weight()
+
+ def add_predicate(self, predicate):
+ self.predicates.append(predicate)
+
+ def set_weight(self, spec=None):
+ self.weight_method = Graph.WEIGHTS[spec or 'least-time']
+
+ def add_connection(self, conn: Connection):
+ for place in [conn.origin, conn.destination]:
+ if place not in self.connections:
+ self.nodes = self.nodes + 1
+ self.connections[place] = set()
+
+ if conn not in self.connections[conn.origin]:
+ self.edges = self.edges + 1
+ self.connections[conn.origin].add(conn)
+
+ def add_recall(self, destination):
+ for origin in self.get_places():
+ if origin == destination:
+ continue
+ self.add_connection(Connection(origin, destination, Mode.RECALL, 0))
+
+ def __contains__(self, place):
+ return place in self.connections.keys()
+
+ def get_connections_from(self, place):
+ for conn in self.connections[place]:
+ yield conn
+
+ def get_connections_to(self, place):
+ for origin in self.get_places():
+ yield from self.get_connections_between(origin, place)
+
+ def get_connections_between(self, origin, destination):
+ for conn in self.connections[origin]:
+ if conn.destination == destination:
+ yield conn
+
+ def get_places(self):
+ return self.connections.keys()
+
+ def populate(self, gen):
+ for conn in gen:
+ self.add_connection(conn)
+
+ def shortest_paths(self, origin, destination=None):
+ tentative = list(self.get_places())
+ weights = {place: math.inf for place in self.get_places()}
+ hops = {}
+
+ weights[origin] = 0
+
+ while tentative:
+ current = min(tentative, key=weights.get)
+
+ tentative.remove(current)
+
+ # Exit early if we were given a destination and have reached it
+ if destination and current == destination:
+ return weights, hops
+
+ for conn in self.connections[current]:
+ if conn.destination not in tentative:
+ continue
+
+ next_hop = conn.destination
+
+ # Make sure to filter out any unwanted connections by setting
+ # their weight to infinity
+ if all(f(conn) for f in self.predicates):
+ alt_weights = weights[current] + self.weight_method(conn)
+ else:
+ alt_weights = math.inf
+
+ if alt_weights < weights[next_hop]:
+ weights[next_hop] = alt_weights
+ hops[next_hop] = conn
+
+ return weights, hops
+
+ def find_path(self, *args):
+ stops = deque(args)
+ current_stop = stops.popleft()
+
+ journey = []
+ while stops:
+ next_stop = stops.popleft()
+ _, hops = self.shortest_paths(current_stop, next_stop)
+
+ if not hops or next_stop not in hops:
+ return None
+
+ path = deque()
+ last_hop = next_stop
+
+ while last_hop in hops:
+ conn = hops[last_hop]
+ path.appendleft(conn)
+ last_hop = conn.origin
+
+ if path is None:
+ return None
+
+ journey.extend(path)
+ current_stop = next_stop
+ return journey