package pmap import ( "fmt" "math/bits" ) const ( nodeDegree = 16 // branch factor of nodes nodeShift = 4 nodeMask = 0b1111 ) type Key = int type Value = int type Map interface { Get(Key) (Value, bool) Set(Key, Value) Map Del(Key) Map Len() int } type pmap struct { root interface{} len int hash func(Key) uint32 } // A Map implemented as a hashed trie type node struct { child []interface{} bitmap uint32 } func bitmask(shiftedHash uint32) uint32 { return uint32(1) << (shiftedHash & nodeMask) } func (n *node) index(mask uint32) int { return bits.OnesCount32(n.bitmap & (mask - 1)) } type collision struct { hash uint32 leaf []leaf } type leaf struct { k Key v Value } type HashFunc = func(Key) uint32 func New(hash HashFunc) Map { if hash == nil { panic("pmap.New: nil hash") } return pmap{hash: hash} } func (p pmap) Len() int { return p.len } func (p pmap) Get(k Key) (Value, bool) { var zero Value h := p.hash(k) return lookup(p.root, h, k, zero) } func (p pmap) Set(k Key, v Value) Map { h := p.hash(k) root, added := insert(p.root, h, k, v, p.hash) p.root = root if added { p.len++ } //pretty.Println(p) return p } func (p pmap) Del(k Key) Map { return p } func (n *node) check() { if bits.OnesCount32(n.bitmap) != len(n.child) { panic(fmt.Sprintf("pmap: corrupt bitmap b=%#b len=%d", n.bitmap, len(n.child))) } } func (n *node) getNode(shift, hash uint32, key Key) interface{} { n.check() m := bitmask(hash >> shift) if n.bitmap&m == 0 { return nil } return n.child[n.index(m)] } func (n *collision) getNode(hash uint32, key Key) interface{} { if hash != n.hash { return nil } for i := range n.leaf { if key == n.leaf[i].k { return n.leaf[i] } } return nil } func lookup(root interface{}, hash uint32, key Key, zero Value) (Value, bool) { var shift uint32 cur := root for { switch n := cur.(type) { case nil: return zero, false case leaf: if n.k == key { return n.v, true } else { return zero, false } case *node: cur = n.getNode(shift, hash, key) shift += nodeShift case *collision: cur = n.getNode(hash, key) default: panic("pmap: unhandled case in lookup") } } } func singleton(key Key, val Value, hash, shift uint32) *node { return newnode(leaf{key, val}, hash, shift) } func newnode(child interface{}, hash, shift uint32) *node { n := &node{} a := [1]interface{}{child} n.child = a[:] n.bitmap = bitmask(hash >> shift) n.check() return n } func insert(n interface{}, hash uint32, key Key, val Value, hashFn HashFunc) (newNode interface{}, added bool) { if n == nil { return leaf{key, val}, true } var _insert func(n interface{}, shift uint32) interface{} _insert = func(n interface{}, shift uint32) interface{} { //fmt.Printf("insert %v %x %#v\n", shift, hash, n) switch n := n.(type) { //case nil: // added = true // return leaf{key, val} case leaf: if n.k == key { // replace existing entry added = false return leaf{key, val} } else if h := hashFn(n.k); h == hash { // collision added = true return &collision{hash, []leaf{{key, val}, n}} } else { if h>>shift == hash>>shift { panic("pmap: infinite loop in insert") } // not a collision, so we must still have some hash bits left // split the trie x := newnode(n, h, shift) return _insert(x, shift) } case *node: c := n.getNode(shift, hash, key) if c == nil { // new node c = leaf{key, val} added = true m := bitmask(hash >> shift) x := &node{bitmap: n.bitmap | m} i := x.index(m) x.child = make([]interface{}, len(n.child)+1) copy(x.child[:i], n.child[:i]) x.child[i] = c copy(x.child[i+1:], n.child[i:]) x.check() return x } else { c = _insert(c, shift+nodeShift) // TODO: short circuit if c unchanged? m := bitmask(hash >> shift) x := &node{bitmap: n.bitmap} i := x.index(m) x.child = make([]interface{}, len(n.child)) copy(x.child, n.child) x.child[i] = c x.check() return x } case *collision: if n.hash != hash { // not a collision, so we must still have some hash bits left // split the trie x := newnode(n, n.hash, shift) return _insert(x, shift) } for i := range n.leaf { if key == n.leaf[i].k { // replace existing entry l := make([]leaf, len(n.leaf)) l[0] = leaf{key, val} copy(l[1:], n.leaf[:i]) copy(l[1+i:], n.leaf[i+1:]) added = false return &collision{hash, l} } } // new collision added = true return &collision{hash, append([]leaf{{key, val}}, n.leaf...)} default: panic("pmap: unhandled case in insert") } } newNode = _insert(n, 0) return } type stats struct { count int maxHeight int avgHeight float64 avgDegree float64 nodeCount int leafCount int collisionCount int collidedKeys int emptySlots int } func (p pmap) stats() stats { var s stats var th float64 var td float64 var visit func(n interface{}, h int) visit = func(n interface{}, h int) { switch n := n.(type) { case leaf: s.count++ s.leafCount++ th += float64(h) if s.maxHeight < h { s.maxHeight = h } case *node: s.count++ s.nodeCount++ for i := range n.child { if n.child[i] != nil { td += 1.0 visit(n.child[i], h+1) } else { s.emptySlots++ } } case *collision: s.count++ s.collisionCount++ s.collidedKeys += len(n.leaf) default: panic("pmap: unhandled case in stats") } } visit(p.root, 1) if s.leafCount > 0 { s.avgHeight = th / float64(s.leafCount) } if s.nodeCount > 0 { s.avgDegree = td / float64(s.nodeCount) } return s } func (s stats) String() string { return fmt.Sprintf( "count = %d\n"+ "maxHeight = %d\n"+ "avgHeight = %g\n"+ "avgDegree = %g\n"+ "nodeCount = %d\n"+ "leafCount = %d\n"+ "collisionCount = %d\n"+ "collidedKeys = %d\n"+ "emptySlots = %d\n", s.count, s.maxHeight, s.avgHeight, s.avgDegree, s.nodeCount, s.leafCount, s.collisionCount, s.collidedKeys, s.emptySlots, ) }