from math import floor, sqrt, dist def solve(input, limit, D=4000.0): # Parse points = [] for line in open(input): x,y,z = map(int,line.split(',')) points.append((x,y,z)) # Optimization: Instead of computing the distance between # every pair of points (which is slow - it grows as n^2. # for n=1000 that's 1 million pairs), only compute the distance # between pairs of "nearby" points, as controlled by the parameter D. # We map each point to a cube in a DxDxD grid and only look at points # in adjacent cubes. This is the core idea in Rabin and Lipton's algorithm # for finding the closest pair of points. It works here because our points # are roughly evenly distributed and none of the pairs we care about # are more than about D units apart. (We end up only having to # look at ~22,000 pairs, which is much quicker.) def shrink(p): x,y,z = p return floor(x/D), floor(y/D), floor(z/D) def neighbors(p): x,y,z = shrink(p) for dx in (0,-1,+1): for dy in (0,-1,+1): for dz in (0,-1,+1): i = (x+dx,y+dy,z+dz) if i in m: yield from m[i] m = {} for p in points: m.setdefault(shrink(p),[]).append(p) pairs = [] for p in points: for q in neighbors(p): if p != q: d = dist(p,q) pairs.append((d,p,q)) pairs.sort() print("#pairs =", len(pairs)) print(*pairs[:10], sep='\n') # Use a union-find (aka disjoint set) data structure to # keep track of which circuit each point belongs to # direct = set() parent = {p:p for p in points} size = {p:1 for p in points} def find(x): root = x while parent[root] != root: root = parent[root] while parent[x] != root: x, parent[x] = parent[x], root return root def union(p,q): p = find(p) q = find(q) if p == q: return if size[p] < size[q]: p,q = q,p parent[q] = p size[p] += size[q] # part 1 n = 0 for _,p,q in pairs: if n >= limit: break if (p,q) in direct or (q,p) in direct: # already directly connected continue #print("connecting", p, q) n += 1 union(p,q) direct.add((p,q)) seen = set() sizes = [] for p in points: r = find(p) if r in seen: continue seen.add(r) #print(r, size[r]) sizes.append(size[r]) sizes.sort(reverse=True) print(sizes[:3]) t = 1 for x in sizes[:3]: t *= x print(t) # part 2 (continued) for _,p,q in pairs: if (p,q) in direct or (q,p) in direct: # already directly connected continue #print("connecting", p, q) n += 1 union(p,q) direct.add((p,q)) r = find(p) if size[r] == len(points): print(p,q) print(p[0]*q[0]) break else: print("fail") solve("sample", 10, D=400.0) solve("input", 1000, D=10000.0)