import itertools def solve(file): grid = [l.strip() for l in file] ny = len(grid) nx = len(grid[0]) def get(p): x,y = p if 0 <= x < nx and 0 <= y < ny: return grid[y][x] return '.' def flood(initial_point): c = get(initial_point) stack = [initial_point] count = {} # set of points connected to x,y. value is the number of connected neighboring points while stack: p = stack.pop() if p not in count: count[p] = 0 x,y = p for n in (x+1,y), (x-1,y), (x,y+1), (x,y-1): if get(n) == c: stack.append(n) count[p] += 1 return count points = lambda: itertools.product(range(nx), range(ny)) done = set() # set of points which have been counted as part of a shape total = 0 for p in points(): if p in done: continue shape = flood(p) done |= shape.keys() area = len(shape) perim = 0 for p,n in shape.items(): perim += 4-n #print(get(p), area, perim) total += area * perim print(total) solve(open("sample4.in")) solve(open("input"))