import sympy import itertools def solve(input): part2 = 0 for line in open(input): parts = line.split()[1:] idxs = [delist(p) for p in parts[:-1]] target = delist(parts[-1]) BTNS = len(idxs) N = len(target) matrix = [[int(i in row) for i in range(N)] for row in idxs] #print(matrix, target) M = sympy.Matrix(matrix + [target]).T #print(repr(M)) S, pivots = M.rref() if len(pivots) < BTNS: # not solved print(repr(S)) coords = [] limits = [] extra_rows = [] for p in range(BTNS): if p not in pivots: print(p, M.col(p)) row = [0]*(BTNS+1) row[p] = 1 coords.append((M.rows, M.cols-1)) extra_rows.append(row) limit = min(target[i] for i,x in enumerate(M.col(p)) if x) limits.append(limit) M = M.col_join(sympy.Matrix(extra_rows)) print(repr(M), limits) totals = [] for values in itertools.product(*[range(l+1) for l in limits]): for c,v in zip(coords,values): M[c] = v S, newpivot = M.rref() presses = S.col(-1) if all(int(x) >= 0 for x in presses): print(presses.T, flush=True) totals.append(sum(S.col(-1))) if totals: print("*", target, totals, min(totals)) part2 += min(totals) else: print("uhoh", target) else: presses = S.col(-1) print("*", target, presses.T, sum(presses)) part2 += sum(presses) print(part2) #print(less, greater) def delist(s): return [int(x) for x in s.strip("{}()").split(',')] #solve("sample") solve("input")