import sympy import itertools def solve(input): part2 = 0 for line in open(input): parts = line.split()[1:] idxs = [delist(p) for p in parts[:-1]] target = delist(parts[-1]) BTNS = len(idxs) N = len(target) matrix = [[int(i in row) for i in range(N)] for row in idxs] #print(matrix, target) M = sympy.Matrix(matrix + [target]).T #print(repr(M)) S, pivots = M.rref() if len(pivots) == BTNS: # solved presses = S.col(-1) print("*", target, presses.T, sum(presses)) part2 += sum(presses) else: # not solved # the system of equations is underdetermined, either because # we started with too few equations or because some of them were # not linearly independent. # the upshot is that we have one or more free variables # (in practice 1-3 free variables) so if we just iterate # through all legal values for those variables we should # be able to find a solution? # unfortunately i am still running into some unsolveable cases # and i'm not sure why. #if BTNS - len(pivots) >= 3: # print(repr(S)) #continue coords = [] limits = [] extra_rows = [] for p in range(BTNS): if p not in pivots: print(p, M.col(p)) row = [0]*(BTNS+1) row[p] = 1 coords.append((M.rows, M.cols-1)) extra_rows.append(row) limit = min(target[i] for i,x in enumerate(M.col(p)) if x) limits.append(limit) M = M.col_join(sympy.Matrix(extra_rows)) print(repr(M), limits) totals = [] for values in itertools.product(*[range(l+1) for l in limits]): for c,v in zip(coords,values): M[c] = v S, newpivot = M.rref() presses = S.col(-1) if all(int(x) >= 0 for x in presses): print(presses.T, flush=True) totals.append(sum(S.col(-1))) if totals: print("*", target, totals, min(totals)) part2 += min(totals) else: print("uhoh", target) print(part2) #print(less, greater) def delist(s): return [int(x) for x in s.strip("{}()").split(',')] #solve("sample") solve("input")