import sys from collections import defaultdict map = [x.strip() for x in sys.stdin] print(map) Y = len(map) X = len(map[0]) assert all(len(row) == X for row in map) grid = {} #start = None #for i, row in enumerate(map): # for j,c in enumerate(row): # if c in "#.": # pass # elif c == "S": # start = (i,j) # else: # raise Exception("invalid char %s at (%s,%s)" % (c, i, j)) # grid[i,j] = d #print(grid) def solve(grid): import numpy fill = numpy.zeros((Y,X+1), dtype='uint8') mask = numpy.zeros((Y,X+1), dtype='uint8') for i, row in enumerate(map): for j, c in enumerate(row): if c == '#': mask[i,j] = 1 if c == 'S': start = (i,j) def draw(fill): for i, row in enumerate(fill): print("".join(".O#"[f + 2*mask[i,j]] for j,f in enumerate(row))) fill[start] = 1 # flood fill for i in range(64): draw(fill) old = fill fill = numpy.zeros(fill.shape, dtype='uint8') for i in range(Y): f = (old[i] == 1) f = numpy.roll(f, 1) | numpy.roll(f, -1) if i > 0: f |= (old[i-1] == 1) if i < Y-1: f |= (old[i+1] == 1) f[-1] = False f &= (mask[i] == 0) #print(old, f) if f.any(): fill[i, f] = 1 draw(fill) print(numpy.sum(fill)) solve(grid)