day 8 optimization

This commit is contained in:
magical 2025-12-08 06:57:46 +00:00
parent af9dc5628a
commit c200e211c9

View File

@ -1,43 +1,55 @@
import math from math import floor, sqrt, dist
from math import floor, sqrt
def dist(p,q): def solve(input, limit, D=4000.0):
return sqrt(sum((x-y)**2 for x,y in zip(p,q))) # Parse
def solve(input, limit):
points = [] points = []
for line in open(input): for line in open(input):
x,y,z = map(int,line.split(',')) x,y,z = map(int,line.split(','))
points.append((x,y,z)) points.append((x,y,z))
def nearest(p):
mindist = float('inf')
minpoint = p
for q in points:
if q != p:
d = dist(p,q)
if d < mindist:
mindist = d
minpoint = q
return minpoint
#for p in points: # Optimization: Instead of computing the distance between
# q = shrink(p) # every pair of points (which is slow - it grows as n^2.
# print(p, nearest(p), list(chain(*[m.get(x,[]) for x in near(q)]))) # for n=1000 that's 1 million pairs), only compute the distance
# between pairs of "nearby" points, as controlled by the parameter D.
# We map each point to a cube in a DxDxD grid and only look at points
# in adjacent cubes. This is the core idea in Rabin and Lipton's algorithm
# for finding the closest pair of points. It works here because our points
# are roughly evenly distributed and none of the pairs we care about
# are more than about D units apart. (We end up only having to
# look at ~22,000 pairs, which is much quicker.)
def shrink(p):
x,y,z = p
return floor(x/D), floor(y/D), floor(z/D)
def neighbors(p):
x,y,z = shrink(p)
for dx in (0,-1,+1):
for dy in (0,-1,+1):
for dz in (0,-1,+1):
i = (x+dx,y+dy,z+dz)
if i in m:
yield from m[i]
m = {}
for p in points:
m.setdefault(shrink(p),[]).append(p)
pairs = [] pairs = []
for p in points: for p in points:
for q in points: for q in neighbors(p):
#q = nearest(p)
if p != q: if p != q:
d = dist(p,q) d = dist(p,q)
pairs.append((d,p,q)) pairs.append((d,p,q))
pairs.sort() pairs.sort()
print(len(pairs)) print("#pairs =", len(pairs))
print(*pairs[:10], sep='\n') print(*pairs[:10], sep='\n')
circuit = [] # Use a union-find (aka disjoint set) data structure to
# keep track of which circuit each point belongs to
#
direct = set() direct = set()
parent = {p:p for p in points} parent = {p:p for p in points}
size = {p:1 for p in points} size = {p:1 for p in points}
@ -69,7 +81,7 @@ def solve(input, limit):
if (p,q) in direct or (q,p) in direct: if (p,q) in direct or (q,p) in direct:
# already directly connected # already directly connected
continue continue
print("connecting", p, q) #print("connecting", p, q)
n += 1 n += 1
union(p,q) union(p,q)
direct.add((p,q)) direct.add((p,q))
@ -81,7 +93,7 @@ def solve(input, limit):
if r in seen: if r in seen:
continue continue
seen.add(r) seen.add(r)
print(r, size[r]) #print(r, size[r])
sizes.append(size[r]) sizes.append(size[r])
sizes.sort(reverse=True) sizes.sort(reverse=True)
@ -97,7 +109,7 @@ def solve(input, limit):
if (p,q) in direct or (q,p) in direct: if (p,q) in direct or (q,p) in direct:
# already directly connected # already directly connected
continue continue
print("connecting", p, q) #print("connecting", p, q)
n += 1 n += 1
union(p,q) union(p,q)
direct.add((p,q)) direct.add((p,q))
@ -106,8 +118,9 @@ def solve(input, limit):
print(p,q) print(p,q)
print(p[0]*q[0]) print(p[0]*q[0])
break break
else:
print("fail")
solve("sample", 10) solve("sample", 10, D=400.0)
solve("input", 1000) solve("input", 1000, D=10000.0)