diff --git a/day12/alt.py b/day12/alt.py new file mode 100644 index 0000000..5615e7c --- /dev/null +++ b/day12/alt.py @@ -0,0 +1,53 @@ +import string + +import sys, os; sys.path.append(os.path.join(os.path.dirname(__file__), "../lib")) +import astar + +elevation = {c:i for i, c in enumerate(string.ascii_lowercase)} +elevation['E'] = elevation['z'] +elevation['S'] = elevation['a'] + +data = open("input").read().splitlines(False) +#print(*data, sep="\n") + +extra = [] +for y, row in enumerate(data): + for x, c in enumerate(row): + if c == 'S': + start = (x,y) + elif c == 'E': + goal = (x,y) + elif c == 'a': + # for part 2 + extra.append((x,y)) + +def get(x, y): + return elevation[data[y][x]] + +def solve(part): + def neighbors(src): + x,y = src + here = get(x,y) + n = [] + def push(x,y): + if 0 <= y < len(data): + if 0 <= x < len(data[y]): + if get(x,y) <= here+1: + n.append((1, (x,y))) + push(x-1, y) + push(x+1,y) + push(x, y-1) + push(x, y+1) + return n + + def heuristic(n): + return abs(goal[0] - n[0]) + abs(goal[1] - n[1]) + + s = [start] + if part == 2: + s += extra + + print(part, "=", astar.search(s, goal, neighbors, heuristic)) + +solve(1) +solve(2) diff --git a/lib/astar.py b/lib/astar.py new file mode 100644 index 0000000..55d3634 --- /dev/null +++ b/lib/astar.py @@ -0,0 +1,90 @@ +from heapq import heappush, heappop + +def search(start, goal, neighbors, heuristic=None): + if heuristic == None: + def heuristic(x): + return 0 + # TODO: callable goal + if not isinstance(start, list): + start = [start] + i = 0 # tiebreaker + q = [] + for s in start: + i += 1 + heappush(q, (heuristic(s), i, s, None)) + done = {} + while q: + z, _, this, prev = heappop(q) + if this in done: + continue + cost_so_far = z - heuristic(this) + #print(this,z, cost_so_far) + + done[this] = (cost_so_far, prev) + + if this == goal: + print("astar: visited", len(done), "nodes") + # reconsruct the path + n = this + path = [] + while n is not None: + path.append(n) + _, n = done[n] + path.reverse() + return cost_so_far, this, path + + for c, n in neighbors(this): + if n not in done: + # calculate the "reduced cost" + c = cost_so_far + c + heuristic(n) + i += 1 + heappush(q, (c, i, n, this)) + + return float('inf'), None + +def test(): + data = [ + "aabqponm", + "abcryxxl", + "accszzxk", + "acctuvwj", + "abdefghi", + ] + #data = [ + # "aabbaaba", + # "abcbaaba", + # "accabcbb", + # "accbbbba", + # "abbbabab", + #] + start = (0,0) + goal = (5,2) + + def get(x,y): + return ord(data[y][x]) + + def neighbors(src): + x,y = src + here = get(x,y) + n = [] + def push(x,y): + if 0 <= y < len(data): + if 0 <= x < len(data[y]): + if get(x,y) <= here+1: + n.append((1, (x,y))) + push(x-1, y) + push(x+1,y) + push(x, y-1) + push(x, y+1) + return n + + def heuristic(n): + return abs(goal[0] - n[0]) + abs(goal[1] - n[1]) + + c, _, path = search(start, goal, neighbors, heuristic) + print(*data, sep="\n") + print(*path, sep="\n") + assert c == 31, c + +if __name__ == '__main__': + test()