day 10 attempt with sympy
This commit is contained in:
parent
e21df10751
commit
deec2b4981
@ -1,32 +1,60 @@
|
||||
import numpy, numpy.linalg
|
||||
import sympy
|
||||
import itertools
|
||||
def solve(input):
|
||||
less = 0
|
||||
greater = 0
|
||||
part2 = 0
|
||||
for line in open(input):
|
||||
parts = line.split()[1:]
|
||||
idxs = [delist(p) for p in parts[:-1]]
|
||||
target = delist(parts[-1])
|
||||
if len(idxs) != len(target) and 0:
|
||||
if len(idxs) > len(target):
|
||||
greater += 1
|
||||
if len(idxs) < len(target):
|
||||
less += 1
|
||||
#print(line.strip())
|
||||
else:
|
||||
#assert len(idxs) == len(target)
|
||||
N = len(target)
|
||||
matrix = [[int(i in row) for i in range(N)] for row in idxs]
|
||||
|
||||
#print(matrix, target)
|
||||
arr = numpy.array(matrix)
|
||||
print(arr)
|
||||
try:
|
||||
print(target, numpy.linalg.lstsq(arr.T, target))
|
||||
except numpy.linalg.LinAlgError as ex:
|
||||
print(target, ex)
|
||||
print(less, greater)
|
||||
BTNS = len(idxs)
|
||||
N = len(target)
|
||||
matrix = [[int(i in row) for i in range(N)] for row in idxs]
|
||||
|
||||
#print(matrix, target)
|
||||
M = sympy.Matrix(matrix + [target]).T
|
||||
#print(repr(M))
|
||||
S, pivots = M.rref()
|
||||
if len(pivots) < BTNS:
|
||||
# not solved
|
||||
print(repr(S))
|
||||
coords = []
|
||||
limits = []
|
||||
extra_rows = []
|
||||
for p in range(BTNS):
|
||||
if p not in pivots:
|
||||
print(p, M.col(p))
|
||||
row = [0]*(BTNS+1)
|
||||
row[p] = 1
|
||||
coords.append((M.rows, M.cols-1))
|
||||
extra_rows.append(row)
|
||||
limit = min(target[i] for i,x in enumerate(M.col(p)) if x)
|
||||
limits.append(limit)
|
||||
M = M.col_join(sympy.Matrix(extra_rows))
|
||||
print(repr(M), limits)
|
||||
totals = []
|
||||
for values in itertools.product(*[range(l+1) for l in limits]):
|
||||
for c,v in zip(coords,values):
|
||||
M[c] = v
|
||||
S, newpivot = M.rref()
|
||||
presses = S.col(-1)
|
||||
if all(int(x) >= 0 for x in presses):
|
||||
print(presses.T, flush=True)
|
||||
totals.append(sum(S.col(-1)))
|
||||
if totals:
|
||||
print("*", target, totals, min(totals))
|
||||
part2 += min(totals)
|
||||
else:
|
||||
print("uhoh", target)
|
||||
else:
|
||||
presses = S.col(-1)
|
||||
print("*", target, presses.T, sum(presses))
|
||||
part2 += sum(presses)
|
||||
print(part2)
|
||||
#print(less, greater)
|
||||
|
||||
def delist(s):
|
||||
return [int(x) for x in s.strip("{}()").split(',')]
|
||||
|
||||
#solve("sample")
|
||||
solve("input")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user