From c5d4f796e0aa81261284e622efad009bbf7d1f4a Mon Sep 17 00:00:00 2001 From: Andrew Ekstedt Date: Sun, 24 Dec 2023 03:24:35 +0000 Subject: [PATCH] day 21 part 2 solve! this is messy because i went in like three unsuccessful directions before hitting on an approach that worked --- day21/sol.py | 218 +++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 192 insertions(+), 26 deletions(-) diff --git a/day21/sol.py b/day21/sol.py index 4bc0c7e..efc4705 100644 --- a/day21/sol.py +++ b/day21/sol.py @@ -1,13 +1,16 @@ import sys -from collections import defaultdict +import time +import numpy map = [x.strip() for x in sys.stdin] -print(map) +if len(map) < 80: + print(map) Y = len(map) X = len(map[0]) assert all(len(row) == X for row in map) + grid = {} #start = None #for i, row in enumerate(map): @@ -22,11 +25,18 @@ grid = {} #print(grid) +def draw(mask, fill): + return + for i, row in enumerate(fill): + print("".join(".O#!"[f + 2*mask[i,j]] for j,f in enumerate(row))) + print() + +def new(): + return numpy.zeros((Y,X+1), dtype='uint8') def solve(grid): - import numpy - fill = numpy.zeros((Y,X+1), dtype='uint8') - mask = numpy.zeros((Y,X+1), dtype='uint8') + fill = numpy.zeros((Y,X), dtype='uint8') + mask = numpy.zeros((Y,X), dtype='uint8') for i, row in enumerate(map): for j, c in enumerate(row): if c == '#': @@ -34,31 +44,187 @@ def solve(grid): if c == 'S': start = (i,j) - def draw(fill): - for i, row in enumerate(fill): - print("".join(".O#"[f + 2*mask[i,j]] for j,f in enumerate(row))) + S = 11 + if X <= 11: + S = 15 + fill = numpy.c_[numpy.tile(fill,(S,S)), numpy.zeros(Y*S, dtype='uint8')] + mask = numpy.c_[numpy.tile(mask,(S,S)), numpy.zeros(Y*S, dtype='uint8')] - fill[start] = 1 + cache = {} # bitmap->k + maps = {} # k->bitmap + cycles = {} # k->[k'] + print(fill.shape, mask.shape) + + target_steps = 4064 + + + fill[S//2*Y + start[0], S//2*X + start[1]] = 1 + values, d2 = sequence(fill, mask, target_steps, X) + for n in 6,10,50,100,500,1000, 5000, 26501365: + print(n,extrapolate(n, values, X, d2)) + return + + + # find the tiles reachable in n steps from each possible start + # position + reachable = {} + breach = {} + center = slice(S//2*Y, (S//2+1)*Y), slice(S//2*X, (S//2+1)*X) + left = slice(S//2*Y, (S//2+1)*Y), slice((S//2-1)*X, S//2*X) + right = slice(S//2*Y, (S//2+1)*Y), slice((S//2+1)*X, (S//2+2)*X) + up = slice((S//2-1)*Y, (S//2 )*Y), slice(S//2*X, (S//2+1)*X) + down = slice((S//2+1)*Y, (S//2+2)*Y), slice(S//2*X, (S//2+1)*X) + dirs = [left,right,up,down] + for i in range(Y): + for j in range(X): + if i in (0,Y-1) or j in (0,X-1) or (i,j) == start: + fill = numpy.zeros((Y*S,X*S+1), dtype='uint8') + fill[S//2*Y + i, S//2*X + j] = 1 + reachable[i,j] = [fill] + found = [0,0,0,0] + while True: + s0 = fill + s1 = step(mask, s0) + s2 = step(mask, s1) + fill = s2 + n1 = len(reachable[i,j]) + n2 = len(reachable[i,j])+1 + reachable[i,j].append(s1) + reachable[i,j].append(s2) + for d in range(4): + if not found[d]: + if s1[dirs[d]].any(): + found[d] = n1 + elif s2[dirs[d]].any(): + found[d] = n2 + if (s0 == s2)[center].all(): + break + breach[i,j] = found + print(i,j,len(reachable[i,j]), found) + #draw(mask, fill) + + for (i,j),fills in reachable.items(): + fills = reachable[i,j] + when = breach[i,j] + for d in range(4): + if when[d]: + f = fills[when[d]][dirs[d]] + num_breach_points = sum(f.ravel()) + assert num_breach_points > 0 + print(i,j,d, num_breach_points == 1, num_breach_points) + if num_breach_points > 1: + draw(mask, fills[when[d]]) + + return + + # maximum number of steps for a tile to become completely reachable + M = max(len(r) for r in reachable.values()) + + states = {} + active = {} + reachable[start] + states = [((0,0),[(0,start)])] + for iters in range(target_steps): + assert fill.any() + + fill = step(mask, fill) + fill = step(mask, fill) + + #print("\033[2J") # clear screen + #draw(mask, fill) + for u in range(S): + for v in range(S): + small = fill[Y*u:Y*(u+1), X*v:X*(v+1)] + #if (u,v) == (1,1): print(small) + b = small.tobytes() + super[u,v] = cache.setdefault(b, len(cache)) + + print("========") + print(super) + print(flush=True) + #time.sleep(.1) + + + # look for symmetries + def look(): + for u in range(S): + for v in range(u,S): + if super[u,v] != super[v,u]: + return False + return True + + + #print(fill.tobytes()) + +def sequence(fill, mask, target_steps, period): + prev = [0] + c1, d1 = 0, 0 + c2, d2 = 0, 0 + prev2 = [] + for i in range(target_steps): + fill = step(mask, fill) + #fill = step(mask, fill) + n = int(fill.sum()) + #if len(prev) > period: + # c1, d1 = d1, n - prev[-period] + # c2, d2 = d2, d1 - c1 + # prev2.append(d2) + if len(prev) >= 2*period: + # find the second differece + d2 = (n - prev[-period]) - (prev[-period] - prev[-2*period]) + prev2.append(d2) + prev.append(n) + print(i,n,d1,d2,sep="\t",flush=True) + if len(prev2) > period and len(set(prev2[-period:])) == 1: + print("gotcha!") + return prev, prev2[-1] + break + if len(prev2) > period*2 and prev2[-period*2:-period] == prev2[-period:]: + print("gotcha!") + break + if len(prev2) > period*3 and prev2[-period*3:-period] == prev2[-period*2:]: + print("gotcha!") + break + if fill[0].any() or fill[-1].any() or fill[:,0].any() or fill[:,-1].any(): + draw(mask,fill) + break + return + +def extrapolate(n, values, period, d2): + if n < len(values): + return values[n] + quo, rem = divmod(n-len(values)+period, period) + x = len(values)-period+rem + y = values[-period+rem] + d1 = y - values[-2*period+rem] + while x < n: + d1 += d2 + y += d1 + x += period + assert x == n + return y + +def step(mask, old): # flood fill - for i in range(64): - draw(fill) - old = fill - fill = numpy.zeros(fill.shape, dtype='uint8') - for i in range(Y): - f = (old[i] == 1) - f = numpy.roll(f, 1) | numpy.roll(f, -1) - if i > 0: f |= (old[i-1] == 1) - if i < Y-1: f |= (old[i+1] == 1) - f[-1] = False - f &= (mask[i] == 0) - #print(old, f) - if f.any(): - fill[i, f] = 1 + # note that the provided map has a 1-tile border of empty spaces + # which we will use to our advantage + #draw(fill) + fill = numpy.zeros(old.shape, dtype='uint8') + y,x = fill.shape + for i in range(y): + f = (old[i] == 1) + f = numpy.roll(f, 1) | numpy.roll(f, -1) + if i > 0: f |= (old[i-1] == 1) + if i < y-1: f |= (old[i+1] == 1) + f[-1] = False + f &= (mask[i] == 0) + #print(old, f) + if f.any(): + fill[i, f] = 1 - draw(fill) - - print(numpy.sum(fill)) + assert fill.any() + return fill solve(grid)