Compare commits

...

2 Commits

Author SHA1 Message Date
c200e211c9 day 8 optimization 2025-12-08 06:57:46 +00:00
af9dc5628a day 8 2025-12-08 06:27:09 +00:00
5 changed files with 1235 additions and 0 deletions

24
day08/est.py Normal file
View File

@ -0,0 +1,24 @@
points = []
for line in open("input"):
x,y,z = map(int,line.split(','))
points.append((x,y,z))
import math
def dist(p,q):
return math.sqrt(sum((x-y)**2 for x,y in zip(p,q)))
import random
n = len(points)
min_dist = float('inf')
#for _ in range(n):
# a = random.choice(points)
for a in points:
b = random.choice(points)
d = dist(a,b)
if d == 0:
continue
min_dist = min(d,min_dist)
print(min_dist)

1000
day08/input Normal file

File diff suppressed because it is too large Load Diff

65
day08/old.py Normal file
View File

@ -0,0 +1,65 @@
import math
from math import floor, sqrt
from itertools import chain
D = 4000.0
import math
def dist(p,q):
return sqrt(sum((x-y)**2 for x,y in zip(p,q)))
def solve(input, limit, D=D):
points = []
for line in open(input):
x,y,z = map(int,line.split(','))
points.append((x,y,z))
def shrink(p):
x,y,z = p
return floor(x/D), floor(y/D), floor(z/D)
m = {}
for p in points:
q = shrink(p)
m.setdefault(q,[]).append(p)
def nearest(p):
mindist = float('inf')
minpoint = p
for r in near(shrink(p)):
for n in m.get(r,[]):
if n != p:
d = dist(p,n)
if d < mindist:
mindist = d
minpoint = n
return minpoint
#for p in points:
# q = shrink(p)
# print(p, nearest(p), list(chain(*[m.get(x,[]) for x in near(q)])))
pairs = []
for p in points:
q = nearest(p)
if p != q:
d = dist(p,q)
pairs.append((d,p,q))
pairs.sort()
print(len(pairs))
print(*pairs[:10], sep='\n')
def near(q):
x,y,z = q
for dx in (-1,0,+1):
for dy in (-1,0,+1):
for dz in (-1,0,+1):
#if not (dx==dy==dz==0):
yield x+dx,y+dy,z+dz
solve("sample", 10, D=350)
solve("input", 10, D=8000)

20
day08/sample Normal file
View File

@ -0,0 +1,20 @@
162,817,812
57,618,57
906,360,560
592,479,940
352,342,300
466,668,158
542,29,236
431,825,988
739,650,466
52,470,668
216,146,977
819,987,18
117,168,530
805,96,715
346,949,466
970,615,88
941,993,340
862,61,35
984,92,344
425,690,689

126
day08/sol.py Normal file
View File

@ -0,0 +1,126 @@
from math import floor, sqrt, dist
def solve(input, limit, D=4000.0):
# Parse
points = []
for line in open(input):
x,y,z = map(int,line.split(','))
points.append((x,y,z))
# Optimization: Instead of computing the distance between
# every pair of points (which is slow - it grows as n^2.
# for n=1000 that's 1 million pairs), only compute the distance
# between pairs of "nearby" points, as controlled by the parameter D.
# We map each point to a cube in a DxDxD grid and only look at points
# in adjacent cubes. This is the core idea in Rabin and Lipton's algorithm
# for finding the closest pair of points. It works here because our points
# are roughly evenly distributed and none of the pairs we care about
# are more than about D units apart. (We end up only having to
# look at ~22,000 pairs, which is much quicker.)
def shrink(p):
x,y,z = p
return floor(x/D), floor(y/D), floor(z/D)
def neighbors(p):
x,y,z = shrink(p)
for dx in (0,-1,+1):
for dy in (0,-1,+1):
for dz in (0,-1,+1):
i = (x+dx,y+dy,z+dz)
if i in m:
yield from m[i]
m = {}
for p in points:
m.setdefault(shrink(p),[]).append(p)
pairs = []
for p in points:
for q in neighbors(p):
if p != q:
d = dist(p,q)
pairs.append((d,p,q))
pairs.sort()
print("#pairs =", len(pairs))
print(*pairs[:10], sep='\n')
# Use a union-find (aka disjoint set) data structure to
# keep track of which circuit each point belongs to
#
direct = set()
parent = {p:p for p in points}
size = {p:1 for p in points}
def find(x):
root = x
while parent[root] != root:
root = parent[root]
while parent[x] != root:
x, parent[x] = parent[x], root
return root
def union(p,q):
p = find(p)
q = find(q)
if p == q:
return
if size[p] < size[q]:
p,q = q,p
parent[q] = p
size[p] += size[q]
# part 1
n = 0
for _,p,q in pairs:
if n >= limit:
break
if (p,q) in direct or (q,p) in direct:
# already directly connected
continue
#print("connecting", p, q)
n += 1
union(p,q)
direct.add((p,q))
seen = set()
sizes = []
for p in points:
r = find(p)
if r in seen:
continue
seen.add(r)
#print(r, size[r])
sizes.append(size[r])
sizes.sort(reverse=True)
print(sizes[:3])
t = 1
for x in sizes[:3]:
t *= x
print(t)
# part 2 (continued)
for _,p,q in pairs:
if (p,q) in direct or (q,p) in direct:
# already directly connected
continue
#print("connecting", p, q)
n += 1
union(p,q)
direct.add((p,q))
r = find(p)
if size[r] == len(points):
print(p,q)
print(p[0]*q[0])
break
else:
print("fail")
solve("sample", 10, D=400.0)
solve("input", 1000, D=10000.0)