day 8 optimization
This commit is contained in:
parent
af9dc5628a
commit
c200e211c9
71
day08/sol.py
71
day08/sol.py
@ -1,43 +1,55 @@
|
|||||||
import math
|
from math import floor, sqrt, dist
|
||||||
from math import floor, sqrt
|
|
||||||
|
|
||||||
def dist(p,q):
|
def solve(input, limit, D=4000.0):
|
||||||
return sqrt(sum((x-y)**2 for x,y in zip(p,q)))
|
# Parse
|
||||||
|
|
||||||
def solve(input, limit):
|
|
||||||
points = []
|
points = []
|
||||||
for line in open(input):
|
for line in open(input):
|
||||||
x,y,z = map(int,line.split(','))
|
x,y,z = map(int,line.split(','))
|
||||||
points.append((x,y,z))
|
points.append((x,y,z))
|
||||||
|
|
||||||
def nearest(p):
|
|
||||||
mindist = float('inf')
|
|
||||||
minpoint = p
|
|
||||||
for q in points:
|
|
||||||
if q != p:
|
|
||||||
d = dist(p,q)
|
|
||||||
if d < mindist:
|
|
||||||
mindist = d
|
|
||||||
minpoint = q
|
|
||||||
return minpoint
|
|
||||||
|
|
||||||
#for p in points:
|
# Optimization: Instead of computing the distance between
|
||||||
# q = shrink(p)
|
# every pair of points (which is slow - it grows as n^2.
|
||||||
# print(p, nearest(p), list(chain(*[m.get(x,[]) for x in near(q)])))
|
# for n=1000 that's 1 million pairs), only compute the distance
|
||||||
|
# between pairs of "nearby" points, as controlled by the parameter D.
|
||||||
|
# We map each point to a cube in a DxDxD grid and only look at points
|
||||||
|
# in adjacent cubes. This is the core idea in Rabin and Lipton's algorithm
|
||||||
|
# for finding the closest pair of points. It works here because our points
|
||||||
|
# are roughly evenly distributed and none of the pairs we care about
|
||||||
|
# are more than about D units apart. (We end up only having to
|
||||||
|
# look at ~22,000 pairs, which is much quicker.)
|
||||||
|
|
||||||
|
def shrink(p):
|
||||||
|
x,y,z = p
|
||||||
|
return floor(x/D), floor(y/D), floor(z/D)
|
||||||
|
|
||||||
|
def neighbors(p):
|
||||||
|
x,y,z = shrink(p)
|
||||||
|
for dx in (0,-1,+1):
|
||||||
|
for dy in (0,-1,+1):
|
||||||
|
for dz in (0,-1,+1):
|
||||||
|
i = (x+dx,y+dy,z+dz)
|
||||||
|
if i in m:
|
||||||
|
yield from m[i]
|
||||||
|
|
||||||
|
m = {}
|
||||||
|
for p in points:
|
||||||
|
m.setdefault(shrink(p),[]).append(p)
|
||||||
|
|
||||||
pairs = []
|
pairs = []
|
||||||
for p in points:
|
for p in points:
|
||||||
for q in points:
|
for q in neighbors(p):
|
||||||
#q = nearest(p)
|
|
||||||
if p != q:
|
if p != q:
|
||||||
d = dist(p,q)
|
d = dist(p,q)
|
||||||
pairs.append((d,p,q))
|
pairs.append((d,p,q))
|
||||||
|
|
||||||
pairs.sort()
|
pairs.sort()
|
||||||
print(len(pairs))
|
print("#pairs =", len(pairs))
|
||||||
print(*pairs[:10], sep='\n')
|
print(*pairs[:10], sep='\n')
|
||||||
|
|
||||||
circuit = []
|
# Use a union-find (aka disjoint set) data structure to
|
||||||
|
# keep track of which circuit each point belongs to
|
||||||
|
#
|
||||||
direct = set()
|
direct = set()
|
||||||
parent = {p:p for p in points}
|
parent = {p:p for p in points}
|
||||||
size = {p:1 for p in points}
|
size = {p:1 for p in points}
|
||||||
@ -69,7 +81,7 @@ def solve(input, limit):
|
|||||||
if (p,q) in direct or (q,p) in direct:
|
if (p,q) in direct or (q,p) in direct:
|
||||||
# already directly connected
|
# already directly connected
|
||||||
continue
|
continue
|
||||||
print("connecting", p, q)
|
#print("connecting", p, q)
|
||||||
n += 1
|
n += 1
|
||||||
union(p,q)
|
union(p,q)
|
||||||
direct.add((p,q))
|
direct.add((p,q))
|
||||||
@ -81,7 +93,7 @@ def solve(input, limit):
|
|||||||
if r in seen:
|
if r in seen:
|
||||||
continue
|
continue
|
||||||
seen.add(r)
|
seen.add(r)
|
||||||
print(r, size[r])
|
#print(r, size[r])
|
||||||
sizes.append(size[r])
|
sizes.append(size[r])
|
||||||
|
|
||||||
sizes.sort(reverse=True)
|
sizes.sort(reverse=True)
|
||||||
@ -97,7 +109,7 @@ def solve(input, limit):
|
|||||||
if (p,q) in direct or (q,p) in direct:
|
if (p,q) in direct or (q,p) in direct:
|
||||||
# already directly connected
|
# already directly connected
|
||||||
continue
|
continue
|
||||||
print("connecting", p, q)
|
#print("connecting", p, q)
|
||||||
n += 1
|
n += 1
|
||||||
union(p,q)
|
union(p,q)
|
||||||
direct.add((p,q))
|
direct.add((p,q))
|
||||||
@ -106,8 +118,9 @@ def solve(input, limit):
|
|||||||
print(p,q)
|
print(p,q)
|
||||||
print(p[0]*q[0])
|
print(p[0]*q[0])
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
print("fail")
|
||||||
|
|
||||||
|
|
||||||
solve("sample", 10)
|
solve("sample", 10, D=400.0)
|
||||||
solve("input", 1000)
|
solve("input", 1000, D=10000.0)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user