pmap/pmap.go

280 lines
5.3 KiB
Go

package pmap
import (
"fmt"
"math/bits"
)
const (
nodeDegree = 16 // branch factor of nodes
nodeShift = 4
nodeMask = 0b1111
)
type Key = int
type Value = int
type Map interface {
Get(Key) (Value, bool)
Set(Key, Value) Map
Del(Key) Map
Len() int
}
type pmap struct {
root interface{}
len int
hash func(Key) uint32
}
// A Map implemented as a hashed trie
type node struct {
child [nodeDegree]interface{}
bitmap uint32
}
type collision struct {
hash uint32
leaf []leaf
}
type leaf struct {
k Key
v Value
}
type HashFunc = func(Key) uint32
func New(hash HashFunc) Map {
if hash == nil {
panic("pmap.New: nil hash")
}
return pmap{hash: hash}
}
func (p pmap) Len() int {
return p.len
}
func (p pmap) Get(k Key) (Value, bool) {
var zero Value
h := p.hash(k)
return lookup(p.root, 0, h, k, zero)
}
func (p pmap) Set(k Key, v Value) Map {
h := p.hash(k)
root, added := insert(p.root, h, k, v, hash)
p.root = root
if added {
p.len++
}
//pretty.Println(p)
return p
}
func (p pmap) Del(k Key) Map {
return p
}
func (m *node) getNode(shift, hash uint32, key Key) interface{} {
i := hash >> shift & nodeMask
return m.child[i]
}
func (m collision) getNode(hash uint32, key Key) interface{} {
if hash != m.hash {
return nil
}
for i := range m.leaf {
if key == m.leaf[i].k {
return m.leaf[i]
}
}
return nil
}
func lookup(root interface{}, shift, hash uint32, key Key, zero Key) (Value, bool) {
cur := root
for {
switch n := cur.(type) {
case nil:
return zero, false
case leaf:
if n.k == key {
return n.v, true
} else {
return zero, false
}
case *node:
cur = n.getNode(shift, hash, key)
shift += nodeShift
case collision:
cur = n.getNode(hash, key)
default:
panic("pmap: unhandled case in lookup")
}
}
}
func singleton(key Key, val Value, hash, shift uint32) *node {
return newnode(leaf{key, val}, hash, shift)
}
func newnode(child interface{}, hash, shift uint32) *node {
n := &node{}
idx := hash >> shift & nodeMask
n.child[idx] = child
n.bitmap = 1 << idx
return n
}
func insert(n interface{}, hash uint32, key Key, val Value, hashFn HashFunc) (newNode interface{}, added bool) {
if n == nil {
return leaf{key, val}, true
}
var _insert func(n interface{}, shift uint32) interface{}
_insert = func(n interface{}, shift uint32) interface{} {
//fmt.Printf("insert %v %x %#v\n", shift, hash, n)
switch n := n.(type) {
//case nil:
// added = true
// return leaf{key, val}
case leaf:
if n.k == key {
// replace existing entry
added = false
return leaf{key, val}
} else if h := hashFn(n.k); h == hash {
// collision
added = true
return collision{hash, []leaf{{key, val}, n}}
} else {
if h>>shift == hash>>shift {
panic("pmap: infinite loop in insert")
}
// not a collision, so we must still have some hash bits left
// split the trie
m := newnode(n, h, shift)
return _insert(m, shift)
}
case *node:
c := n.getNode(shift, hash, key)
if c == nil {
// new node
c = leaf{key, val}
added = true
} else {
c = _insert(c, shift+nodeShift)
}
idx := hash >> shift & nodeMask
x := &node{child: n.child, bitmap: n.bitmap}
x.child[idx] = c
x.bitmap |= 1 << idx
return x
case collision:
if n.hash != hash {
// not a collision, so we must still have some hash bits left
// split the trie
m := newnode(n, n.hash, shift)
return _insert(m, shift)
}
for i := range n.leaf {
if key == n.leaf[i].k {
// replace existing entry
l := make([]leaf, 1, len(n.leaf))
l[0] = leaf{key, val}
l = append(l, n.leaf[:i]...)
l = append(l, n.leaf[i+1:]...)
added = false
return collision{hash, l}
}
}
// new collision
added = true
return collision{hash, append([]leaf{{key, val}}, n.leaf...)}
default:
panic("pmap: unhandled case in insert")
}
}
newNode = _insert(n, 0)
return
}
type stats struct {
count int
maxHeight int
avgHeight float64
avgDegree float64
nodeCount int
leafCount int
collisionCount int
collidedKeys int
emptySlots int
}
func (p pmap) stats() stats {
var s stats
var th float64
var td float64
var visit func(n interface{}, h int)
visit = func(n interface{}, h int) {
switch n := n.(type) {
case leaf:
s.count++
s.leafCount++
th += float64(h)
if s.maxHeight < h {
s.maxHeight = h
}
case *node:
s.count++
s.nodeCount++
s.emptySlots += bits.OnesCount32(^n.bitmap)
td += float64(bits.OnesCount32(n.bitmap))
for i := range n.child {
if n.child[i] != nil {
visit(n.child[i], h+1)
}
}
case collision:
s.count++
s.collisionCount++
s.collidedKeys += len(n.leaf)
default:
panic("pmap: unhandled case in stats")
}
}
visit(p.root, 1)
if s.leafCount > 0 {
s.avgHeight = th / float64(s.leafCount)
}
if s.nodeCount > 0 {
s.avgDegree = td / float64(s.nodeCount)
}
return s
}
func (s stats) String() string {
return fmt.Sprintf(
"count = %d\n"+
"maxHeight = %d\n"+
"avgHeight = %g\n"+
"avgDegree = %g\n"+
"nodeCount = %d\n"+
"leafCount = %d\n"+
"collisionCount = %d\n"+
"collidedKeys = %d\n"+
"emptySlots = %d\n",
s.count,
s.maxHeight,
s.avgHeight,
s.avgDegree,
s.nodeCount,
s.leafCount,
s.collisionCount,
s.collidedKeys,
s.emptySlots,
)
}