pmap/pmap.go

315 lines
6.1 KiB
Go

package pmap
import (
"fmt"
"math/bits"
)
const (
nodeDegree = 16 // branch factor of nodes
nodeShift = 4
nodeMask = 0b1111
)
type Key = int
type Value = int
type Map interface {
Get(Key) (Value, bool)
Set(Key, Value) Map
Del(Key) Map
Len() int
}
type pmap struct {
root interface{}
len int
hash func(Key) uint32
}
// A Map implemented as a hashed trie
type node struct {
child []interface{}
bitmap uint32
}
func bitmask(shiftedHash uint32) uint32 {
return uint32(1) << (shiftedHash & nodeMask)
}
func (n *node) index(mask uint32) int {
return bits.OnesCount32(n.bitmap & (mask - 1))
}
type collision struct {
hash uint32
leaf []leaf
}
type leaf struct {
k Key
v Value
}
type HashFunc = func(Key) uint32
func New(hash HashFunc) Map {
if hash == nil {
panic("pmap.New: nil hash")
}
return pmap{hash: hash}
}
func (p pmap) Len() int {
return p.len
}
func (p pmap) Get(k Key) (Value, bool) {
var zero Value
h := p.hash(k)
return lookup(p.root, h, k, zero)
}
func (p pmap) Set(k Key, v Value) Map {
h := p.hash(k)
root, added := insert(p.root, h, k, v, p.hash)
p.root = root
if added {
p.len++
}
//pretty.Println(p)
return p
}
func (p pmap) Del(k Key) Map {
return p
}
func (n *node) check() {
if bits.OnesCount32(n.bitmap) != len(n.child) {
panic(fmt.Sprintf("pmap: corrupt bitmap b=%#b len=%d", n.bitmap, len(n.child)))
}
}
func (n *node) getNode(shift, hash uint32, key Key) interface{} {
n.check()
m := bitmask(hash >> shift)
if n.bitmap&m == 0 {
return nil
}
return n.child[n.index(m)]
}
func (n *collision) getNode(hash uint32, key Key) interface{} {
if hash != n.hash {
return nil
}
for i := range n.leaf {
if key == n.leaf[i].k {
return n.leaf[i]
}
}
return nil
}
func lookup(root interface{}, hash uint32, key Key, zero Value) (Value, bool) {
var shift uint32
cur := root
for {
switch n := cur.(type) {
case nil:
return zero, false
case leaf:
if n.k == key {
return n.v, true
} else {
return zero, false
}
case *node:
cur = n.getNode(shift, hash, key)
shift += nodeShift
case *collision:
cur = n.getNode(hash, key)
default:
panic("pmap: unhandled case in lookup")
}
}
}
func singleton(key Key, val Value, hash, shift uint32) *node {
return newnode(leaf{key, val}, hash, shift)
}
func newnode(child interface{}, hash, shift uint32) *node {
n := &node{}
a := [1]interface{}{child}
n.child = a[:]
n.bitmap = bitmask(hash >> shift)
n.check()
return n
}
func insert(n interface{}, hash uint32, key Key, val Value, hashFn HashFunc) (newNode interface{}, added bool) {
if n == nil {
return leaf{key, val}, true
}
var _insert func(n interface{}, shift uint32) interface{}
_insert = func(n interface{}, shift uint32) interface{} {
//fmt.Printf("insert %v %x %#v\n", shift, hash, n)
switch n := n.(type) {
//case nil:
// added = true
// return leaf{key, val}
case leaf:
if n.k == key {
// replace existing entry
added = false
return leaf{key, val}
} else if h := hashFn(n.k); h == hash {
// collision
added = true
return &collision{hash, []leaf{{key, val}, n}}
} else {
if h>>shift == hash>>shift {
panic("pmap: infinite loop in insert")
}
// not a collision, so we must still have some hash bits left
// split the trie
x := newnode(n, h, shift)
return _insert(x, shift)
}
case *node:
c := n.getNode(shift, hash, key)
if c == nil {
// new node
c = leaf{key, val}
added = true
m := bitmask(hash >> shift)
x := &node{bitmap: n.bitmap | m}
i := x.index(m)
x.child = make([]interface{}, len(n.child)+1)
copy(x.child[:i], n.child[:i])
x.child[i] = c
copy(x.child[i+1:], n.child[i:])
x.check()
return x
} else {
c = _insert(c, shift+nodeShift)
// TODO: short circuit if c unchanged?
m := bitmask(hash >> shift)
x := &node{bitmap: n.bitmap}
i := x.index(m)
x.child = make([]interface{}, len(n.child))
copy(x.child, n.child)
x.child[i] = c
x.check()
return x
}
case *collision:
if n.hash != hash {
// not a collision, so we must still have some hash bits left
// split the trie
x := newnode(n, n.hash, shift)
return _insert(x, shift)
}
for i := range n.leaf {
if key == n.leaf[i].k {
// replace existing entry
l := make([]leaf, len(n.leaf))
l[0] = leaf{key, val}
copy(l[1:], n.leaf[:i])
copy(l[1+i:], n.leaf[i+1:])
added = false
return &collision{hash, l}
}
}
// new collision
added = true
return &collision{hash, append([]leaf{{key, val}}, n.leaf...)}
default:
panic("pmap: unhandled case in insert")
}
}
newNode = _insert(n, 0)
return
}
type stats struct {
count int
maxHeight int
avgHeight float64
avgDegree float64
nodeCount int
leafCount int
collisionCount int
collidedKeys int
emptySlots int
}
func (p pmap) stats() stats {
var s stats
var th float64
var td float64
var visit func(n interface{}, h int)
visit = func(n interface{}, h int) {
switch n := n.(type) {
case leaf:
s.count++
s.leafCount++
th += float64(h)
if s.maxHeight < h {
s.maxHeight = h
}
case *node:
s.count++
s.nodeCount++
for i := range n.child {
if n.child[i] != nil {
td += 1.0
visit(n.child[i], h+1)
} else {
s.emptySlots++
}
}
case *collision:
s.count++
s.collisionCount++
s.collidedKeys += len(n.leaf)
default:
panic("pmap: unhandled case in stats")
}
}
visit(p.root, 1)
if s.leafCount > 0 {
s.avgHeight = th / float64(s.leafCount)
}
if s.nodeCount > 0 {
s.avgDegree = td / float64(s.nodeCount)
}
return s
}
func (s stats) String() string {
return fmt.Sprintf(
"count = %d\n"+
"maxHeight = %d\n"+
"avgHeight = %g\n"+
"avgDegree = %g\n"+
"nodeCount = %d\n"+
"leafCount = %d\n"+
"collisionCount = %d\n"+
"collidedKeys = %d\n"+
"emptySlots = %d\n",
s.count,
s.maxHeight,
s.avgHeight,
s.avgDegree,
s.nodeCount,
s.leafCount,
s.collisionCount,
s.collidedKeys,
s.emptySlots,
)
}