75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
import sympy
|
|
import itertools
|
|
def solve(input):
|
|
part2 = 0
|
|
for line in open(input):
|
|
parts = line.split()[1:]
|
|
idxs = [delist(p) for p in parts[:-1]]
|
|
target = delist(parts[-1])
|
|
|
|
BTNS = len(idxs)
|
|
N = len(target)
|
|
matrix = [[int(i in row) for i in range(N)] for row in idxs]
|
|
|
|
#print(matrix, target)
|
|
M = sympy.Matrix(matrix + [target]).T
|
|
#print(repr(M))
|
|
S, pivots = M.rref()
|
|
if len(pivots) == BTNS:
|
|
# solved
|
|
presses = S.col(-1)
|
|
print("*", target, presses.T, sum(presses))
|
|
part2 += sum(presses)
|
|
else:
|
|
# not solved
|
|
# the system of equations is underdetermined, either because
|
|
# we started with too few equations or because some of them were
|
|
# not linearly independent.
|
|
# the upshot is that we have one or more free variables
|
|
# (in practice 1-3 free variables) so if we just iterate
|
|
# through all legal values for those variables we should
|
|
# be able to find a solution?
|
|
# unfortunately i am still running into some unsolveable cases
|
|
# and i'm not sure why.
|
|
|
|
#if BTNS - len(pivots) >= 3:
|
|
# print(repr(S))
|
|
#continue
|
|
|
|
coords = []
|
|
limits = []
|
|
extra_rows = []
|
|
for p in range(BTNS):
|
|
if p not in pivots:
|
|
print(p, M.col(p))
|
|
row = [0]*(BTNS+1)
|
|
row[p] = 1
|
|
coords.append((M.rows, M.cols-1))
|
|
extra_rows.append(row)
|
|
limit = min(target[i] for i,x in enumerate(M.col(p)) if x)
|
|
limits.append(limit)
|
|
M = M.col_join(sympy.Matrix(extra_rows))
|
|
print(repr(M), limits)
|
|
totals = []
|
|
for values in itertools.product(*[range(l+1) for l in limits]):
|
|
for c,v in zip(coords,values):
|
|
M[c] = v
|
|
S, newpivot = M.rref()
|
|
presses = S.col(-1)
|
|
if all(int(x) >= 0 for x in presses):
|
|
print(presses.T, flush=True)
|
|
totals.append(sum(S.col(-1)))
|
|
if totals:
|
|
print("*", target, totals, min(totals))
|
|
part2 += min(totals)
|
|
else:
|
|
print("uhoh", target)
|
|
print(part2)
|
|
#print(less, greater)
|
|
|
|
def delist(s):
|
|
return [int(x) for x in s.strip("{}()").split(',')]
|
|
|
|
#solve("sample")
|
|
solve("input")
|