61 lines
1.9 KiB
Python

import sympy
import itertools
def solve(input):
part2 = 0
for line in open(input):
parts = line.split()[1:]
idxs = [delist(p) for p in parts[:-1]]
target = delist(parts[-1])
BTNS = len(idxs)
N = len(target)
matrix = [[int(i in row) for i in range(N)] for row in idxs]
#print(matrix, target)
M = sympy.Matrix(matrix + [target]).T
#print(repr(M))
S, pivots = M.rref()
if len(pivots) < BTNS:
# not solved
print(repr(S))
coords = []
limits = []
extra_rows = []
for p in range(BTNS):
if p not in pivots:
print(p, M.col(p))
row = [0]*(BTNS+1)
row[p] = 1
coords.append((M.rows, M.cols-1))
extra_rows.append(row)
limit = min(target[i] for i,x in enumerate(M.col(p)) if x)
limits.append(limit)
M = M.col_join(sympy.Matrix(extra_rows))
print(repr(M), limits)
totals = []
for values in itertools.product(*[range(l+1) for l in limits]):
for c,v in zip(coords,values):
M[c] = v
S, newpivot = M.rref()
presses = S.col(-1)
if all(int(x) >= 0 for x in presses):
print(presses.T, flush=True)
totals.append(sum(S.col(-1)))
if totals:
print("*", target, totals, min(totals))
part2 += min(totals)
else:
print("uhoh", target)
else:
presses = S.col(-1)
print("*", target, presses.T, sum(presses))
part2 += sum(presses)
print(part2)
#print(less, greater)
def delist(s):
return [int(x) for x in s.strip("{}()").split(',')]
#solve("sample")
solve("input")