day 2 refactor

This commit is contained in:
magical 2025-12-02 08:06:03 +00:00
parent a0ff488cf1
commit 94ccfd4e18

View File

@ -11,7 +11,7 @@ def solve(input):
t = 0 t = 0
for lo,hi in ranges: for lo,hi in ranges:
invalid = list(findinvalid(lo,hi)) invalid = list(findinvalid(lo,hi,1))
#if len(invalid) == 0: #if len(invalid) == 0:
# print(lo,hi,invalid) # print(lo,hi,invalid)
t += sum(invalid) t += sum(invalid)
@ -19,53 +19,47 @@ def solve(input):
t2 = 0 t2 = 0
for lo,hi in ranges: for lo,hi in ranges:
invalid2 = set(findinvalid2(lo,hi)) invalid2 = set(findinvalid(lo,hi,2))
t2 += sum(invalid2) t2 += sum(invalid2)
print(t2) print(t2)
def findinvalid(lo, hi): def findinvalid(lo, hi, part):
assert len(str(hi)) - len(str(lo)) <= 1
assert lo <= hi assert lo <= hi
j = len(str(lo)) // 2 if log10(hi) != log10(lo):
if j < 1: assert log10(hi) - log10(lo) == 1
# (1,19) => [11] yield from findinvalid(lo,exp10(log10(lo))-1, part)
j = 1 yield from findinvalid(exp10(log10(hi)-1),hi, part)
for B in [10**j, 10**(j+1)]: elif part == 1:
for i in range(lo//B, hi//B+1): N = log10(hi)
if i < B//10: if N % 2 == 0:
continue d = N // 2
if i >= B: yield from _find(lo,hi,d,N)
break elif part == 2:
id = i*B + i N = log10(hi)
if lo <= id <= hi: for d in range(1,N):
assert(str(i)+str(i) == str(id)), (i,id) if N%d == 0:
#print(lo, hi, i,B,id) yield from _find(lo,hi,d,N)
yield id
def findinvalid2(lo, hi): def _find(lo, hi, d, N):
assert lo <= hi,(lo,hi) B = exp10(d)
if len(str(hi)) != len(str(lo)): BB = exp10(N-d)
assert len(str(hi)) - len(str(lo)) == 1 factor = sum(exp10(k) for k in range(0,N,d))
yield from findinvalid2(lo,10**len(str(lo))-1) for i in range(lo//BB, hi//BB+1):
yield from findinvalid2(10**(len(str(hi))-1),hi) if i < B//10:
return continue
N = len(str(hi)) if i >= B:
for d in range(1,N): break
if N%d == 0: id = i*factor
B = 10**d if lo <= id <= hi:
BB = 10**(N-d) #assert(str(i)*(N//d) == str(id)), (i,id)
factor = sum(10**k for k in range(0,N,d)) #print(lo, hi, i,B,id)
for i in range(lo//BB, hi//BB+1): yield id
if i < B//10:
continue
if i >= B:
break
id = i*factor
#print(lo,hi,factor,i,id)
if lo <= id <= hi:
#assert(str(i)*(N//d) == str(id)), (i,id)
#print(lo, hi, i,B,id)
yield id
#solve("sample") def exp10(n):
return 10 ** n
def log10(n):
return len(str(n))
solve("sample")
solve("input") solve("input")