From 94ccfd4e183a0e18bd89a76b4b4def7c741b5636 Mon Sep 17 00:00:00 2001 From: Andrew Ekstedt Date: Tue, 2 Dec 2025 08:06:03 +0000 Subject: [PATCH] day 2 refactor --- day02/sol.py | 82 ++++++++++++++++++++++++---------------------------- 1 file changed, 38 insertions(+), 44 deletions(-) diff --git a/day02/sol.py b/day02/sol.py index eed97da..7add81a 100644 --- a/day02/sol.py +++ b/day02/sol.py @@ -11,7 +11,7 @@ def solve(input): t = 0 for lo,hi in ranges: - invalid = list(findinvalid(lo,hi)) + invalid = list(findinvalid(lo,hi,1)) #if len(invalid) == 0: # print(lo,hi,invalid) t += sum(invalid) @@ -19,53 +19,47 @@ def solve(input): t2 = 0 for lo,hi in ranges: - invalid2 = set(findinvalid2(lo,hi)) + invalid2 = set(findinvalid(lo,hi,2)) t2 += sum(invalid2) print(t2) -def findinvalid(lo, hi): - assert len(str(hi)) - len(str(lo)) <= 1 +def findinvalid(lo, hi, part): assert lo <= hi - j = len(str(lo)) // 2 - if j < 1: - # (1,19) => [11] - j = 1 - for B in [10**j, 10**(j+1)]: - for i in range(lo//B, hi//B+1): - if i < B//10: - continue - if i >= B: - break - id = i*B + i - if lo <= id <= hi: - assert(str(i)+str(i) == str(id)), (i,id) - #print(lo, hi, i,B,id) - yield id + if log10(hi) != log10(lo): + assert log10(hi) - log10(lo) == 1 + yield from findinvalid(lo,exp10(log10(lo))-1, part) + yield from findinvalid(exp10(log10(hi)-1),hi, part) + elif part == 1: + N = log10(hi) + if N % 2 == 0: + d = N // 2 + yield from _find(lo,hi,d,N) + elif part == 2: + N = log10(hi) + for d in range(1,N): + if N%d == 0: + yield from _find(lo,hi,d,N) -def findinvalid2(lo, hi): - assert lo <= hi,(lo,hi) - if len(str(hi)) != len(str(lo)): - assert len(str(hi)) - len(str(lo)) == 1 - yield from findinvalid2(lo,10**len(str(lo))-1) - yield from findinvalid2(10**(len(str(hi))-1),hi) - return - N = len(str(hi)) - for d in range(1,N): - if N%d == 0: - B = 10**d - BB = 10**(N-d) - factor = sum(10**k for k in range(0,N,d)) - for i in range(lo//BB, hi//BB+1): - if i < B//10: - continue - if i >= B: - break - id = i*factor - #print(lo,hi,factor,i,id) - if lo <= id <= hi: - #assert(str(i)*(N//d) == str(id)), (i,id) - #print(lo, hi, i,B,id) - yield id +def _find(lo, hi, d, N): + B = exp10(d) + BB = exp10(N-d) + factor = sum(exp10(k) for k in range(0,N,d)) + for i in range(lo//BB, hi//BB+1): + if i < B//10: + continue + if i >= B: + break + id = i*factor + if lo <= id <= hi: + #assert(str(i)*(N//d) == str(id)), (i,id) + #print(lo, hi, i,B,id) + yield id -#solve("sample") +def exp10(n): + return 10 ** n + +def log10(n): + return len(str(n)) + +solve("sample") solve("input")