From deec2b4981bd7c2a2a84fa8775fd5ac3f6e987a4 Mon Sep 17 00:00:00 2001 From: Andrew Ekstedt Date: Thu, 11 Dec 2025 00:00:49 +0000 Subject: [PATCH] day 10 attempt with sympy --- day10/sol2.py | 70 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 21 deletions(-) diff --git a/day10/sol2.py b/day10/sol2.py index 7d7095f..d886ce6 100644 --- a/day10/sol2.py +++ b/day10/sol2.py @@ -1,32 +1,60 @@ -import numpy, numpy.linalg +import sympy +import itertools def solve(input): - less = 0 - greater = 0 + part2 = 0 for line in open(input): parts = line.split()[1:] idxs = [delist(p) for p in parts[:-1]] target = delist(parts[-1]) - if len(idxs) != len(target) and 0: - if len(idxs) > len(target): - greater += 1 - if len(idxs) < len(target): - less += 1 - #print(line.strip()) - else: - #assert len(idxs) == len(target) - N = len(target) - matrix = [[int(i in row) for i in range(N)] for row in idxs] - #print(matrix, target) - arr = numpy.array(matrix) - print(arr) - try: - print(target, numpy.linalg.lstsq(arr.T, target)) - except numpy.linalg.LinAlgError as ex: - print(target, ex) - print(less, greater) + BTNS = len(idxs) + N = len(target) + matrix = [[int(i in row) for i in range(N)] for row in idxs] + + #print(matrix, target) + M = sympy.Matrix(matrix + [target]).T + #print(repr(M)) + S, pivots = M.rref() + if len(pivots) < BTNS: + # not solved + print(repr(S)) + coords = [] + limits = [] + extra_rows = [] + for p in range(BTNS): + if p not in pivots: + print(p, M.col(p)) + row = [0]*(BTNS+1) + row[p] = 1 + coords.append((M.rows, M.cols-1)) + extra_rows.append(row) + limit = min(target[i] for i,x in enumerate(M.col(p)) if x) + limits.append(limit) + M = M.col_join(sympy.Matrix(extra_rows)) + print(repr(M), limits) + totals = [] + for values in itertools.product(*[range(l+1) for l in limits]): + for c,v in zip(coords,values): + M[c] = v + S, newpivot = M.rref() + presses = S.col(-1) + if all(int(x) >= 0 for x in presses): + print(presses.T, flush=True) + totals.append(sum(S.col(-1))) + if totals: + print("*", target, totals, min(totals)) + part2 += min(totals) + else: + print("uhoh", target) + else: + presses = S.col(-1) + print("*", target, presses.T, sum(presses)) + part2 += sum(presses) + print(part2) + #print(less, greater) def delist(s): return [int(x) for x in s.strip("{}()").split(',')] +#solve("sample") solve("input")