day 10 attempt with sympy
This commit is contained in:
parent
e21df10751
commit
deec2b4981
@ -1,32 +1,60 @@
|
|||||||
import numpy, numpy.linalg
|
import sympy
|
||||||
|
import itertools
|
||||||
def solve(input):
|
def solve(input):
|
||||||
less = 0
|
part2 = 0
|
||||||
greater = 0
|
|
||||||
for line in open(input):
|
for line in open(input):
|
||||||
parts = line.split()[1:]
|
parts = line.split()[1:]
|
||||||
idxs = [delist(p) for p in parts[:-1]]
|
idxs = [delist(p) for p in parts[:-1]]
|
||||||
target = delist(parts[-1])
|
target = delist(parts[-1])
|
||||||
if len(idxs) != len(target) and 0:
|
|
||||||
if len(idxs) > len(target):
|
|
||||||
greater += 1
|
|
||||||
if len(idxs) < len(target):
|
|
||||||
less += 1
|
|
||||||
#print(line.strip())
|
|
||||||
else:
|
|
||||||
#assert len(idxs) == len(target)
|
|
||||||
N = len(target)
|
|
||||||
matrix = [[int(i in row) for i in range(N)] for row in idxs]
|
|
||||||
|
|
||||||
#print(matrix, target)
|
BTNS = len(idxs)
|
||||||
arr = numpy.array(matrix)
|
N = len(target)
|
||||||
print(arr)
|
matrix = [[int(i in row) for i in range(N)] for row in idxs]
|
||||||
try:
|
|
||||||
print(target, numpy.linalg.lstsq(arr.T, target))
|
#print(matrix, target)
|
||||||
except numpy.linalg.LinAlgError as ex:
|
M = sympy.Matrix(matrix + [target]).T
|
||||||
print(target, ex)
|
#print(repr(M))
|
||||||
print(less, greater)
|
S, pivots = M.rref()
|
||||||
|
if len(pivots) < BTNS:
|
||||||
|
# not solved
|
||||||
|
print(repr(S))
|
||||||
|
coords = []
|
||||||
|
limits = []
|
||||||
|
extra_rows = []
|
||||||
|
for p in range(BTNS):
|
||||||
|
if p not in pivots:
|
||||||
|
print(p, M.col(p))
|
||||||
|
row = [0]*(BTNS+1)
|
||||||
|
row[p] = 1
|
||||||
|
coords.append((M.rows, M.cols-1))
|
||||||
|
extra_rows.append(row)
|
||||||
|
limit = min(target[i] for i,x in enumerate(M.col(p)) if x)
|
||||||
|
limits.append(limit)
|
||||||
|
M = M.col_join(sympy.Matrix(extra_rows))
|
||||||
|
print(repr(M), limits)
|
||||||
|
totals = []
|
||||||
|
for values in itertools.product(*[range(l+1) for l in limits]):
|
||||||
|
for c,v in zip(coords,values):
|
||||||
|
M[c] = v
|
||||||
|
S, newpivot = M.rref()
|
||||||
|
presses = S.col(-1)
|
||||||
|
if all(int(x) >= 0 for x in presses):
|
||||||
|
print(presses.T, flush=True)
|
||||||
|
totals.append(sum(S.col(-1)))
|
||||||
|
if totals:
|
||||||
|
print("*", target, totals, min(totals))
|
||||||
|
part2 += min(totals)
|
||||||
|
else:
|
||||||
|
print("uhoh", target)
|
||||||
|
else:
|
||||||
|
presses = S.col(-1)
|
||||||
|
print("*", target, presses.T, sum(presses))
|
||||||
|
part2 += sum(presses)
|
||||||
|
print(part2)
|
||||||
|
#print(less, greater)
|
||||||
|
|
||||||
def delist(s):
|
def delist(s):
|
||||||
return [int(x) for x in s.strip("{}()").split(',')]
|
return [int(x) for x in s.strip("{}()").split(',')]
|
||||||
|
|
||||||
|
#solve("sample")
|
||||||
solve("input")
|
solve("input")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user