diff --git a/day08/sol.py b/day08/sol.py index 4e5f2c1..41ac7ec 100644 --- a/day08/sol.py +++ b/day08/sol.py @@ -1,43 +1,55 @@ -import math -from math import floor, sqrt +from math import floor, sqrt, dist -def dist(p,q): - return sqrt(sum((x-y)**2 for x,y in zip(p,q))) - -def solve(input, limit): +def solve(input, limit, D=4000.0): + # Parse points = [] for line in open(input): x,y,z = map(int,line.split(',')) points.append((x,y,z)) - def nearest(p): - mindist = float('inf') - minpoint = p - for q in points: - if q != p: - d = dist(p,q) - if d < mindist: - mindist = d - minpoint = q - return minpoint - #for p in points: - # q = shrink(p) - # print(p, nearest(p), list(chain(*[m.get(x,[]) for x in near(q)]))) + # Optimization: Instead of computing the distance between + # every pair of points (which is slow - it grows as n^2. + # for n=1000 that's 1 million pairs), only compute the distance + # between pairs of "nearby" points, as controlled by the parameter D. + # We map each point to a cube in a DxDxD grid and only look at points + # in adjacent cubes. This is the core idea in Rabin and Lipton's algorithm + # for finding the closest pair of points. It works here because our points + # are roughly evenly distributed and none of the pairs we care about + # are more than about D units apart. (We end up only having to + # look at ~22,000 pairs, which is much quicker.) + + def shrink(p): + x,y,z = p + return floor(x/D), floor(y/D), floor(z/D) + + def neighbors(p): + x,y,z = shrink(p) + for dx in (0,-1,+1): + for dy in (0,-1,+1): + for dz in (0,-1,+1): + i = (x+dx,y+dy,z+dz) + if i in m: + yield from m[i] + + m = {} + for p in points: + m.setdefault(shrink(p),[]).append(p) pairs = [] for p in points: - for q in points: - #q = nearest(p) + for q in neighbors(p): if p != q: d = dist(p,q) pairs.append((d,p,q)) pairs.sort() - print(len(pairs)) + print("#pairs =", len(pairs)) print(*pairs[:10], sep='\n') - circuit = [] + # Use a union-find (aka disjoint set) data structure to + # keep track of which circuit each point belongs to + # direct = set() parent = {p:p for p in points} size = {p:1 for p in points} @@ -69,7 +81,7 @@ def solve(input, limit): if (p,q) in direct or (q,p) in direct: # already directly connected continue - print("connecting", p, q) + #print("connecting", p, q) n += 1 union(p,q) direct.add((p,q)) @@ -81,7 +93,7 @@ def solve(input, limit): if r in seen: continue seen.add(r) - print(r, size[r]) + #print(r, size[r]) sizes.append(size[r]) sizes.sort(reverse=True) @@ -97,7 +109,7 @@ def solve(input, limit): if (p,q) in direct or (q,p) in direct: # already directly connected continue - print("connecting", p, q) + #print("connecting", p, q) n += 1 union(p,q) direct.add((p,q)) @@ -106,8 +118,9 @@ def solve(input, limit): print(p,q) print(p[0]*q[0]) break + else: + print("fail") -solve("sample", 10) -solve("input", 1000) - +solve("sample", 10, D=400.0) +solve("input", 1000, D=10000.0)