def parse(s): ret = [] for pair in s.split(','): x,y = pair.split('-') ret.append((int(x),int(y))) return ret def solve(input): with open(input) as f: ranges = parse(f.read()) t = 0 for lo,hi in ranges: invalid = list(findinvalid(lo,hi,1)) #if len(invalid) == 0: # print(lo,hi,invalid) t += sum(invalid) print(t) t2 = 0 for lo,hi in ranges: invalid2 = set(findinvalid(lo,hi,2)) t2 += sum(invalid2) print(t2) def findinvalid(lo, hi, part): assert lo <= hi if log10(hi) != log10(lo): assert log10(hi) - log10(lo) == 1 yield from findinvalid(lo,exp10(log10(lo))-1, part) yield from findinvalid(exp10(log10(hi)-1),hi, part) elif part == 1: N = log10(hi) if N % 2 == 0: d = N // 2 yield from _find(lo,hi,d,N) elif part == 2: N = log10(hi) for d in range(1,N): if N%d == 0: yield from _find(lo,hi,d,N) def _find(lo, hi, d, N): B = exp10(d) BB = exp10(N-d) factor = sum(exp10(k) for k in range(0,N,d)) for i in range(lo//BB, hi//BB+1): if i < B//10: continue if i >= B: break id = i*factor if lo <= id <= hi: #assert(str(i)*(N//d) == str(id)), (i,id) #print(lo, hi, i,B,id) yield id def exp10(n): return 10 ** n def log10(n): return len(str(n)) solve("sample") solve("input")