diff --git a/sponge.go b/sponge.go index e1839ac..f1c786d 100644 --- a/sponge.go +++ b/sponge.go @@ -11,7 +11,8 @@ func round(a *[25]uint64) { roundGo(a) } // digest implements hash.Hash type digest struct { a [25]uint64 // a[y][x][z] - buf [200]byte + buf [8]byte // buf[0:ulen] holds a partial uint64 + ulen int8 dsbyte byte len int size int @@ -29,34 +30,46 @@ func (d *digest) BlockSize() int { return 200 - d.size*2 } func (d *digest) Reset() { //fmt.Println("resetting") d.a = [25]uint64{} - d.buf = [200]byte{} + d.buf = [8]byte{} d.len = 0 } func (d *digest) Write(b []byte) (int, error) { written := len(b) - bs := d.BlockSize() - for len(b) > 0 { - n := copy(d.buf[d.len:bs], b) - d.len += n + bs := d.BlockSize() / 8 + // fill buf first, if non-empty + if d.ulen > 0 { + n := copy(d.buf[d.ulen:], b) b = b[n:] + d.ulen += int8(n) + // flush? + if int(d.ulen) == len(d.buf) { + d.a[d.len] ^= le64dec(d.buf[:]) + d.len += 1 + d.ulen = 0 + if d.len == bs { + d.flush() + } + } + } + // xor 8-byte chunks into the state + for len(b) >= 8 { + d.a[d.len] ^= le64dec(b) + b = b[8:] + d.len += 1 if d.len == bs { d.flush() } + } // len(b) < 8 + // store any remaining bytes + if len(b) > 0 { + d.ulen = int8(copy(d.buf[:], b)) } return written, nil } func (d *digest) flush() { - //fmt.Printf("Flushing with %d bytes\n", d.len) - b := d.buf[:d.len] - for i := range d.a { - if len(b) == 0 { - break - } - d.a[i] ^= le64dec(b) - b = b[8:] - } + //fmt.Printf("Flushing with %d bytes\n", d.len*8 + int(d.ulen)) keccakf(&d.a) d.len = 0 } @@ -75,13 +88,19 @@ func (d *digest) clone() *digest { func (d *digest) Sum(b []byte) []byte { d = d.clone() - d.buf[d.len] = d.dsbyte - bs := d.BlockSize() - for i := d.len + 1; i < bs; i++ { - d.buf[i] = 0 + if d.ulen == 0 { + d.a[d.len] ^= uint64(d.dsbyte) + } else { + d.buf[d.ulen] = d.dsbyte + for i := int(d.ulen) + 1; i < len(d.buf); i++ { + d.buf[i] = 0 + } + d.a[d.len] ^= le64dec(d.buf[:]) } - d.buf[bs-1] |= 0x80 - d.len = bs + + bs := d.BlockSize() / 8 + d.a[bs-1] |= 0x80 << 56 + //d.len = bs d.flush() for i := 0; i < d.size/8; i++ {