diff --git a/pmap.go b/pmap.go index 4a105c5..3281e90 100644 --- a/pmap.go +++ b/pmap.go @@ -29,10 +29,18 @@ type pmap struct { // A Map implemented as a hashed trie type node struct { - child [nodeDegree]interface{} + child []interface{} bitmap uint32 } +func bitmask(shiftedHash uint32) uint32 { + return uint32(1) << (shiftedHash & nodeMask) +} + +func (n *node) index(mask uint32) int { + return bits.OnesCount32(n.bitmap & (mask - 1)) +} + type collision struct { hash uint32 leaf []leaf @@ -77,18 +85,28 @@ func (p pmap) Del(k Key) Map { return p } -func (m *node) getNode(shift, hash uint32, key Key) interface{} { - i := hash >> shift & nodeMask - return m.child[i] +func (n *node) check() { + if bits.OnesCount32(n.bitmap) != len(n.child) { + panic(fmt.Sprintf("pmap: corrupt bitmap b=%#b len=%d", n.bitmap, len(n.child))) + } } -func (m collision) getNode(hash uint32, key Key) interface{} { - if hash != m.hash { +func (n *node) getNode(shift, hash uint32, key Key) interface{} { + n.check() + m := bitmask(hash >> shift) + if n.bitmap&m == 0 { return nil } - for i := range m.leaf { - if key == m.leaf[i].k { - return m.leaf[i] + return n.child[n.index(m)] +} + +func (n collision) getNode(hash uint32, key Key) interface{} { + if hash != n.hash { + return nil + } + for i := range n.leaf { + if key == n.leaf[i].k { + return n.leaf[i] } } return nil @@ -123,9 +141,10 @@ func singleton(key Key, val Value, hash, shift uint32) *node { func newnode(child interface{}, hash, shift uint32) *node { n := &node{} - idx := hash >> shift & nodeMask - n.child[idx] = child - n.bitmap = 1 << idx + a := [1]interface{}{child} + n.child = a[:] + n.bitmap = bitmask(hash >> shift) + n.check() return n } @@ -155,8 +174,8 @@ func insert(n interface{}, hash uint32, key Key, val Value, hashFn HashFunc) (ne } // not a collision, so we must still have some hash bits left // split the trie - m := newnode(n, h, shift) - return _insert(m, shift) + x := newnode(n, h, shift) + return _insert(x, shift) } case *node: c := n.getNode(shift, hash, key) @@ -164,28 +183,42 @@ func insert(n interface{}, hash uint32, key Key, val Value, hashFn HashFunc) (ne // new node c = leaf{key, val} added = true + + m := bitmask(hash >> shift) + x := &node{bitmap: n.bitmap | m} + i := x.index(m) + x.child = make([]interface{}, len(n.child)+1) + copy(x.child[:i], n.child[:i]) + x.child[i] = c + copy(x.child[i+1:], n.child[i:]) + x.check() + return x } else { c = _insert(c, shift+nodeShift) + // TODO: short circuit if c unchanged? + m := bitmask(hash >> shift) + x := &node{bitmap: n.bitmap} + i := x.index(m) + x.child = make([]interface{}, len(n.child)) + copy(x.child, n.child) + x.child[i] = c + x.check() + return x } - idx := hash >> shift & nodeMask - x := &node{child: n.child, bitmap: n.bitmap} - x.child[idx] = c - x.bitmap |= 1 << idx - return x case collision: if n.hash != hash { // not a collision, so we must still have some hash bits left // split the trie - m := newnode(n, n.hash, shift) - return _insert(m, shift) + x := newnode(n, n.hash, shift) + return _insert(x, shift) } for i := range n.leaf { if key == n.leaf[i].k { // replace existing entry - l := make([]leaf, 1, len(n.leaf)) + l := make([]leaf, len(n.leaf)) l[0] = leaf{key, val} - l = append(l, n.leaf[:i]...) - l = append(l, n.leaf[i+1:]...) + copy(l[1:], n.leaf[:i]) + copy(l[1+i:], n.leaf[i+1:]) added = false return collision{hash, l} } @@ -230,11 +263,12 @@ func (p pmap) stats() stats { case *node: s.count++ s.nodeCount++ - s.emptySlots += bits.OnesCount32(^n.bitmap) - td += float64(bits.OnesCount32(n.bitmap)) for i := range n.child { if n.child[i] != nil { + td += 1.0 visit(n.child[i], h+1) + } else { + s.emptySlots++ } } case collision: diff --git a/pmap_test.go b/pmap_test.go index ca96b59..7e98cde 100644 --- a/pmap_test.go +++ b/pmap_test.go @@ -12,6 +12,8 @@ func hash(k Key) uint32 { return uint32(u + u>>32) } +// FIXME: collisions can cause allocations during Get + func TestPmap(t *testing.T) { p := New(hash) const numElems = 100 @@ -172,3 +174,18 @@ func benchmarkHmapSet(b *testing.B, numElems int) { delete(h, k) } } + +func TestBitmap(t *testing.T) { + n := node{bitmap: 0xff0a} + const x = -1 + want := []int{x, 0, x, 1, x, x, x, x, 2, 3, 4, 5, 6, 7, 8, 9} + for i, j := range want { + if j == x { + continue + } + m := bitmask(uint32(i)) + if got := n.index(m); got != j { + t.Errorf("got %v, want %v", got, j) + } + } +}