package main import ( "bufio" "container/heap" "fmt" "log" "math/bits" "os" "strconv" "strings" ) func die(err error) { log.Fatal(err) } func check(err error) { if err != nil { die(err) } } func main() { solve("sample") solve("input") } // Brilliant insight from a redditor: // suppose we only want to match the parity of the required joltages. // we can use our solution for part 1 to find the minimum number of button presses // to make that happen. now, consider what happens if we press a button twice: // it won't change the parity at all, only add 2 to some joltages. ok. // take the 2nd bit of each joltage. what's the minimum number of double-presses // required to match that pattern? and so on. func solve(filename string) { input, err := os.Open(filename) check(err) scanner := bufio.NewScanner(input) scanner.Split(bufio.ScanLines) total := 0 errors := 0 for scanner.Scan() { line := scanner.Text() parts := strings.Fields(line) parts = parts[1:] var buts []uint32 var target Jolts for _, p := range parts { j, err := parseJolt(p) check(err) if p[0] == '{' { target = j } else { buts = append(buts, j.uint32()) } } fmt.Printf("%v %b %v\n", target, buts, parts) n := solveJolts(target, buts) fmt.Printf("%s = %v\n", line, n) fmt.Println() if n >= 0 { total += n } else { errors += 1 } //fmt.Println(n) } check(scanner.Err()) fmt.Println(total) if errors > 0 { fmt.Printf("error! failed to solve %d instances\n", errors) } } type Jolts [10]int16 func parseJolt(s string) (Jolts, error) { if s[0] == '(' { return parseJoltByIndex(s) } else if s[0] == '{' { return parseJoltByValue(s) } return Jolts{}, fmt.Errorf("invalid jolts: %q", s) } func parseJoltByIndex(s string) (Jolts, error) { var j Jolts s = strings.Trim(s, "()") for _, part := range strings.Split(s, ",") { idx, err := strconv.ParseInt(part, 10, 16) if err != nil { return j, err } j[idx] = 1 } return j, nil } func parseJoltByValue(s string) (Jolts, error) { var j Jolts s = strings.Trim(s, "{}") for i, part := range strings.Split(s, ",") { val, err := strconv.ParseInt(part, 10, 16) if err != nil { return j, err } j[i] = int16(val) } return j, nil } func (j *Jolts) uint32() (u uint32) { for i, v := range j { if v != 0 { u |= uint32(1) << i } } return u } func (j Jolts) Add(m uint32, value int) Jolts { for i := range j { if m>>i&1 != 0 { j[i] += int16(value) } } return j } func (j Jolts) Sub(m uint32, value int) Jolts { return j.Add(m, -value) } func (j Jolts) Valid() bool { for i := range j { if j[i] < 0 { return false } } return true } func (j Jolts) Less(k Jolts) bool { for i := range j { if j[i] != k[i] { return j[i] < k[i] } } return false // equal } func (j Jolts) BitLength() int { var total uint32 for _, x := range j { total |= uint32(x) } return bits.Len32(total) } func (j Jolts) BitSlice(bit int) uint32 { var mask uint32 for i, x := range j { mask |= (uint32(x) >> uint(bit)) & 1 << uint(i) } return mask } // TODO: this isn't quite enough. we need to know not only how many button presses // a mask can be solved in, but also *the specific buttons* -- because even though // each button is pressed at most once, two buttons can be connected to the same // output so the joltages can be between like 0-10. this needs to be subtracted // from the target, and two different button patterns can have different effects // on the target. // // the good news is that we can cache and reuse the cost map computed in best(), // since it depends only on the button wiring. all that's left after that is a // graph search through the state space, using masks to prune each level. // suppose we have four buttons whose bitmasks are 110, 101, 011, and 100. // suppose our target jolts are [2 1 1], corresponding to bitmasks 100x2 and 011x1. // the shortest solution for the low bits (011) is to press button 3 once. but then we need // to press button 4 twice to get the 2nd bits. that's 3 total button presses. // a more efficient solution is to press button 1 once and button 2 once (110+101 = 211), // for 2 total button presses. // this is a potential issue whenever we have a set of buttons which aren't linearly independent // (e.g. button 1 ^ 2 ^ 3 = button 4). we know this to be the case in some of the inputs. func solveJolts(target Jolts, pool []uint32) int { var bitcost = new(state) bitcost.init(pool) for mask, vals := range bitcost.nodes { fmt.Printf("%b: %v\n", mask, vals) } bits := target.BitLength() qu := newHeap(target) var zero Jolts addstates := func(t, n Node, bit int) Node { t.cost += n.cost << bit t.mask += 1 for i := range t.jolts { t.jolts[i] -= n.jolts[i] << bit } return t } ntargets := len(target) for ntargets > 1 && target[ntargets-1] == 0 { ntargets-- } seen := make(map[Jolts]bool) for qu.Len() > 0 { this := heap.Pop(qu).(Node) log.Printf("sj: %v, len(q) = %v", this, qu.Len()+1) if this.jolts == zero { return int(this.cost) } seen[this.jolts] = true bit := int(this.mask) if bit >= int(bits) { continue } //mask := masks[this.mask] mask := this.jolts.BitSlice(bit) log.Printf("jolts = %v, bit = %d, mask = %0*b", this.jolts, bit, ntargets, mask) // We can't short circuit mask==0 // Consider this set of buttons: // (2) (0,3,4,5,6,7) (0,4,5,6,7) (1,2,3,5) (0,1,2,3,4,6,7) (0,2,3,5) {42,20,50,40,32,36,32,32} // with masks 100 11111001 11110001 101110 11011111 101101. // There is no way to combine the buttons that results in mask 00000101 // but! there is a way to press them that results in the zero mask 00000000 // and has joltage 2,2,2,2,2,2,2,2 // (press the third, fourth and fifth button) // if we use that solution for the low bits then for the next bits // we have to find 11111010 instead, which is possible (the second, fourth, and sixth buttons) if mask == 0 { seen[this.jolts] = false } for _, n := range bitcost.best(mask) { t := addstates(this, n, int(bit)) if t.jolts.Valid() && !seen[t.jolts] { qu.AddNode(t) } } //for i, p := range pool { // if mask>>i&1 != 0 { // target.Sub(p, 1) // } //} } log.Printf("%v: no solution found\n", target) return -1 } type Node struct { mask uint32 buttons uint32 cost int32 jolts Jolts } func (n Node) Add(mask uint32) Node { n.mask ^= mask n.cost += 1 n.jolts = n.jolts.Add(mask, 1) return n } //func (n *Node) cost() int { return bits.OnesCount(n.buttons) } type Heap struct { heap []Node idx map[Jolts]int } // these functions are for the heap.Interface interface // start func (h *Heap) Len() int { return len(h.heap) } func (h *Heap) Less(i, j int) bool { if h.heap[i].cost != h.heap[j].cost { return h.heap[i].cost < h.heap[j].cost } return h.heap[i].jolts.Less(h.heap[j].jolts) } func (h *Heap) Swap(i, j int) { h.heap[i], h.heap[j] = h.heap[j], h.heap[i] h.idx[h.heap[i].jolts] = i h.idx[h.heap[j].jolts] = j } func (h *Heap) Push(x any) { n := x.(Node) h.heap = append(h.heap, n) h.idx[n.jolts] = len(h.heap) - 1 } func (h *Heap) Pop() any { old := h.heap n := len(old) x := old[n-1] h.heap = old[0 : n-1] delete(h.idx, x.jolts) return x } // end func (h *Heap) AddNode(t Node) { defer func() { if e := recover(); e != nil { log.Println(h.heap) log.Println(h.idx) panic(e) } }() if i, ok := h.idx[t.jolts]; ok { old_cost := h.heap[i].cost if t.cost < old_cost { h.heap[i].cost = t.cost heap.Fix(h, i) } else { // nothing; don't add node with higher cost to the queue } } else { heap.Push(h, t) } } func newHeap(start Jolts) *Heap { var idx = make(map[Jolts]int) idx[start] = 0 return &Heap{[]Node{{jolts: start}}, idx} } type state struct { nodes map[uint32][]Node } func (s *state) init(pool []uint32) { qu := newHeap(Jolts{}) best := make(map[uint32][]Node) seen := make(map[uint32]bool) for qu.Len() > 0 { s := heap.Pop(qu).(Node) best[s.mask] = append(best[s.mask], s) // enumerate all the states that can be reached by toggling button p for b, p := range pool { if s.buttons>>b&1 == 0 { t := s.Add(p) t.buttons |= uint32(1) << b if !seen[t.buttons] { heap.Push(qu, t) seen[t.buttons] = true } //qu.AddNode(t) } } } s.nodes = best } func (s *state) best(target uint32) []Node { return s.nodes[target] }