import itertools from collections import Counter, deque def read(file): return [line.strip() for line in file] def solve(file): grid = read(file) ny = len(grid) nx = len(grid[0]) points = lambda: itertools.product(range(nx), range(ny)) def get(x,y): if 0 <= x < nx and 0 <= y < ny: return grid[y][x] return '.' def visit(x,y,h,n): c = get(x,y) if c == h: assert (x,y) not in done if (x,y) not in count: queue.append((x,y)) count[x,y] += n trailheads = [(x,y) for x,y in points() if get(x,y) == '0'] summits = [(x,y) for x,y in points() if get(x,y) == '9'] t1 = 0 t2 = 0 for p in trailheads: queue = deque([p]) done = set() count = Counter() count[p] = 1 while queue: x,y = queue.popleft() #print(x,y,queue) if (x,y) in done: continue done.add((x,y)) n = count[x,y] c = get(x,y) if c != '.' and c != '9': h = str(int(c)+1) visit(x+1,y,h,n) visit(x-1,y,h,n) visit(x,y+1,h,n) visit(x,y-1,h,n) for x,y in summits: if (x,y) in done: t1 += 1 t2 += count[x,y] print(t1) print(t2) solve(open("sample5.in")) solve(open("input"))