From a940dff8306a89d00c40de058432af0937ea5e79 Mon Sep 17 00:00:00 2001 From: Andrew Ekstedt Date: Sun, 24 Dec 2023 04:51:38 +0000 Subject: [PATCH] day 21 cleanup --- day21/sol.py | 283 ++++++++++++++++++++++----------------------------- 1 file changed, 122 insertions(+), 161 deletions(-) diff --git a/day21/sol.py b/day21/sol.py index efc4705..a5ce7c9 100644 --- a/day21/sol.py +++ b/day21/sol.py @@ -1,40 +1,58 @@ import sys -import time import numpy -map = [x.strip() for x in sys.stdin] -if len(map) < 80: - print(map) +def read_map(input): + map = [x.strip() for x in input] -Y = len(map) -X = len(map[0]) -assert all(len(row) == X for row in map) + if len(map) < 80: + print(map) + Y = len(map) + X = len(map[0]) + assert all(len(row) == X for row in map) + assert X == Y -grid = {} -#start = None -#for i, row in enumerate(map): -# for j,c in enumerate(row): -# if c in "#.": -# pass -# elif c == "S": -# start = (i,j) -# else: -# raise Exception("invalid char %s at (%s,%s)" % (c, i, j)) -# grid[i,j] = d - -#print(grid) + return map def draw(mask, fill): - return - for i, row in enumerate(fill): - print("".join(".O#!"[f + 2*mask[i,j]] for j,f in enumerate(row))) + if len(fill) < 50: + for i, row in enumerate(fill): + print("".join(".O#!"[f + 2*mask[i,j]] for j,f in enumerate(row))) print() -def new(): - return numpy.zeros((Y,X+1), dtype='uint8') -def solve(grid): +def solve(map): + Y = len(map) + X = len(map[0]) + + fill = numpy.zeros((Y,X+1), dtype='uint8') + mask = numpy.zeros((Y,X+1), dtype='uint8') + for i, row in enumerate(map): + for j, c in enumerate(row): + if c == '#': + mask[i,j] = 1 + if c == 'S': + start = (i,j) + + # add a border on the side of the map (to prevent wraparound) + mask[:, -1] = 1 + + # start at the start + fill[start] = 1 + + for i in range(64): + fill = step(mask, fill) + #draw(mask, fill) + if i == 6-1: # sample + print(fill.sum()) + + print("part1 = ", fill.sum()) + + +def solve2(map): + Y = len(map) + X = len(map[0]) + fill = numpy.zeros((Y,X), dtype='uint8') mask = numpy.zeros((Y,X), dtype='uint8') for i, row in enumerate(map): @@ -44,152 +62,96 @@ def solve(grid): if c == 'S': start = (i,j) - S = 11 - if X <= 11: - S = 15 - fill = numpy.c_[numpy.tile(fill,(S,S)), numpy.zeros(Y*S, dtype='uint8')] - mask = numpy.c_[numpy.tile(mask,(S,S)), numpy.zeros(Y*S, dtype='uint8')] - - cache = {} # bitmap->k - maps = {} # k->bitmap - cycles = {} # k->[k'] - print(fill.shape, mask.shape) - target_steps = 4064 + fill[start] = 1 + values, period, d2 = simulate(fill, mask, max_steps=500) - - fill[S//2*Y + start[0], S//2*X + start[1]] = 1 - values, d2 = sequence(fill, mask, target_steps, X) for n in 6,10,50,100,500,1000, 5000, 26501365: - print(n,extrapolate(n, values, X, d2)) - return + print(n, extrapolate(n, values, period, d2)) - # find the tiles reachable in n steps from each possible start - # position - reachable = {} - breach = {} - center = slice(S//2*Y, (S//2+1)*Y), slice(S//2*X, (S//2+1)*X) - left = slice(S//2*Y, (S//2+1)*Y), slice((S//2-1)*X, S//2*X) - right = slice(S//2*Y, (S//2+1)*Y), slice((S//2+1)*X, (S//2+2)*X) - up = slice((S//2-1)*Y, (S//2 )*Y), slice(S//2*X, (S//2+1)*X) - down = slice((S//2+1)*Y, (S//2+2)*Y), slice(S//2*X, (S//2+1)*X) - dirs = [left,right,up,down] - for i in range(Y): - for j in range(X): - if i in (0,Y-1) or j in (0,X-1) or (i,j) == start: - fill = numpy.zeros((Y*S,X*S+1), dtype='uint8') - fill[S//2*Y + i, S//2*X + j] = 1 - reachable[i,j] = [fill] - found = [0,0,0,0] - while True: - s0 = fill - s1 = step(mask, s0) - s2 = step(mask, s1) - fill = s2 - n1 = len(reachable[i,j]) - n2 = len(reachable[i,j])+1 - reachable[i,j].append(s1) - reachable[i,j].append(s2) - for d in range(4): - if not found[d]: - if s1[dirs[d]].any(): - found[d] = n1 - elif s2[dirs[d]].any(): - found[d] = n2 - if (s0 == s2)[center].all(): - break - breach[i,j] = found - print(i,j,len(reachable[i,j]), found) - #draw(mask, fill) +def simulate(fill, mask, max_steps): + Y,X = mask.shape + assert Y == X, "need a square matrix" + period = X + tiles = 1 - for (i,j),fills in reachable.items(): - fills = reachable[i,j] - when = breach[i,j] - for d in range(4): - if when[d]: - f = fills[when[d]][dirs[d]] - num_breach_points = sum(f.ravel()) - assert num_breach_points > 0 - print(i,j,d, num_breach_points == 1, num_breach_points) - if num_breach_points > 1: - draw(mask, fills[when[d]]) - - return - - # maximum number of steps for a tile to become completely reachable - M = max(len(r) for r in reachable.values()) - - states = {} - active = {} - reachable[start] - states = [((0,0),[(0,start)])] - for iters in range(target_steps): - assert fill.any() - - fill = step(mask, fill) - fill = step(mask, fill) - - #print("\033[2J") # clear screen - #draw(mask, fill) - for u in range(S): - for v in range(S): - small = fill[Y*u:Y*(u+1), X*v:X*(v+1)] - #if (u,v) == (1,1): print(small) - b = small.tobytes() - super[u,v] = cache.setdefault(b, len(cache)) - - print("========") - print(super) - print(flush=True) - #time.sleep(.1) - - - # look for symmetries - def look(): - for u in range(S): - for v in range(u,S): - if super[u,v] != super[v,u]: - return False - return True - - - #print(fill.tobytes()) - -def sequence(fill, mask, target_steps, period): + original_mask = mask prev = [0] - c1, d1 = 0, 0 - c2, d2 = 0, 0 - prev2 = [] - for i in range(target_steps): + diffs = [] + for i in range(max_steps): + # expand the map if necessary + if fill[0].any() or fill[-1].any() or fill[:,0].any() or fill[:,-1].any(): + draw(mask,fill) + print("expanding...") + tiles += 2 + mask = numpy.tile(original_mask, (tiles,tiles)) + fill = numpy.pad(fill, [(Y,Y), (X,X)]) + + # take one more step and count the number of reachable squares fill = step(mask, fill) - #fill = step(mask, fill) n = int(fill.sum()) - #if len(prev) > period: - # c1, d1 = d1, n - prev[-period] - # c2, d2 = d2, d1 - c1 - # prev2.append(d2) + + # look for patterns + # although the number of squares is somewhat unpreditable from step + # to step (due to the maze-like structure of the mask), the fact that + # the mask repeats in tiles (and the fact that there are unobstructed + # pathways in the orthogonal and diagonal directions) means that, in + # the long run, the flood fill will even out and -importantly- it should + # have some recognizable pattern every N steps (where N is the period + # of the tiling - 11 in the sample, 131 in the input). this is because + # we enter new tiles every N steps, and although it's hard to predict + # exactly how many of the squares in each tile we'll have visited, + # it should be the same number in every tile (or rather, there should + # be some small set of repeated tile-states, and we should be able to + # predict how many of each tile-state there will be). + # + # SO the first step is to get the difference between the current + # number of reachable squares and the number N steps ago. + # + # d1(i) = f(i) - f(i - N) + # + # we then have to discover some pattern in that sequence. + # we know the number of reachable squares will grow roughly as the + # square of the number of steps (because the map is 2D) so + # we should be looking for a quadratic relation. + # we can use second-order differences (see day 9) to do that. + # d1 is already a first-order difference, so take the difference + # between d1s to get a second-order difference. if our assumption is + # correct, then d1 should be a linear sequence and d2 should be a + # constant. + # + # d2(i) = d1(i) - d1(i-N) + # = (f(i) - f(i-N)) - (f(i-N) - f(i-N-N)) + # + # + # there may be some unstability at the beginning of the simulation + # so we need to wait until the d2 values for every step in the period + # all agree. + # + # once it settles down, this gives us a set of N (11, 131, whatever) + # equations we can use to predict the number of squares after any + # future number of steps + # + d2 = 0 if len(prev) >= 2*period: # find the second differece d2 = (n - prev[-period]) - (prev[-period] - prev[-2*period]) - prev2.append(d2) + diffs.append(d2) prev.append(n) - print(i,n,d1,d2,sep="\t",flush=True) - if len(prev2) > period and len(set(prev2[-period:])) == 1: + print(i, n, d2, sep="\t", flush=True) + if len(diffs) > period and all_same_value(diffs[-period:]): print("gotcha!") - return prev, prev2[-1] - break - if len(prev2) > period*2 and prev2[-period*2:-period] == prev2[-period:]: - print("gotcha!") - break - if len(prev2) > period*3 and prev2[-period*3:-period] == prev2[-period*2:]: - print("gotcha!") - break - if fill[0].any() or fill[-1].any() or fill[:,0].any() or fill[:,-1].any(): - draw(mask,fill) - break - return + return prev, period, diffs[-1] + assert False, "failed to find a stable pattern" + return prev, period, None + +def all_same_value(list): + if len(list) < 1: + return False + x = list[0] + return all(x == y for y in list) def extrapolate(n, values, period, d2): if n < len(values): @@ -207,8 +169,6 @@ def extrapolate(n, values, period, d2): def step(mask, old): # flood fill - # note that the provided map has a 1-tile border of empty spaces - # which we will use to our advantage #draw(fill) fill = numpy.zeros(old.shape, dtype='uint8') y,x = fill.shape @@ -217,7 +177,6 @@ def step(mask, old): f = numpy.roll(f, 1) | numpy.roll(f, -1) if i > 0: f |= (old[i-1] == 1) if i < y-1: f |= (old[i+1] == 1) - f[-1] = False f &= (mask[i] == 0) #print(old, f) if f.any(): @@ -227,4 +186,6 @@ def step(mask, old): return fill -solve(grid) +map = read_map(sys.stdin) +solve(map) +solve2(map)