import itertools from collections import Counter def read(file): return [line.strip() for line in file] def solve(file): grid = read(file) ny = len(grid) nx = len(grid[0]) points = lambda: itertools.product(range(nx), range(ny)) def get(x,y): if 0 <= x < nx and 0 <= y < ny: return grid[y][x] return '.' def visit(x,y,h,n): c = get(x,y) if c == h: assert (x,y) not in done queue.append((x,y)) count[x,y] += n trailheads = [(x,y) for x,y in points() if get(x,y) == '0'] t1 = 0 t2 = 0 for p in trailheads: queue = [p] done = set() count = Counter() count[p] = 1 while queue: x,y = queue.pop(0) #print(x,y,queue) if (x,y) in done: continue done.add((x,y)) n = count[x,y] c = get(x,y) if c != '.' and c != '9': h = str(int(c)+1) visit(x+1,y,h,n) visit(x-1,y,h,n) visit(x,y+1,h,n) visit(x,y-1,h,n) for x,y in points(): if get(x,y) == '9' and (x,y) in done: t1 += 1 t2 += count[x,y] print(t1) print(t2) solve(open("sample5.in")) solve(open("input"))