Merge branch 'shrink-buf' into shake

good example of how a merge can have no conflicts yet neverthless
require a bunch of adjustments to got something that actually works.

these two changes are conceptually independent but since shrink-buf
changes the internal representation of a digest, the shake code needs to
be adjusted to match. i wanted to capture the code both before and
after the shrink-buf change, thus this merge.

there's a bunch of code duplication that i will clean up later.
shake
magical 2024-10-06 22:21:14 -07:00
commit 4c06f23ff1
3 changed files with 134 additions and 62 deletions

View File

@ -40,6 +40,31 @@ var tests = []struct {
text: "The quick brown fox jumps over the lazy dog",
hash: "69070dda01975c8c120c3aada1b282394e7f032fa9cf32f4cb2259a0897dfc04",
},
{
f: New256,
name: "SHA3-256",
text: "a",
hash: "80084bf2fba02475726feb2cab2d8215eab14bc6bdd8bfb2c8151257032ecd8b",
},
{
f: New256,
name: "SHA3-256",
text: "abcdefg",
hash: "7d55114476dfc6a2fbeaa10e221a8d0f32fc8f2efb69a6e878f4633366917a62",
},
{
f: New256,
name: "SHA3-256",
text: "abcdefgh",
hash: "3e2020725a38a48eb3bbf75767f03a22c6b3f41f459c831309b06433ec649779",
},
{
f: New256,
name: "SHA3-256",
text: "abcdefghi",
hash: "f74eb337992307c22bc59eb43e59583a683f3b93077e7f2472508e8c464d2657",
},
}
func TestHash(t *testing.T) {
@ -58,6 +83,24 @@ func TestHash(t *testing.T) {
}
}
func TestHashSmallWrites(t *testing.T) {
for _, tt := range tests {
want, err := hex.DecodeString(tt.hash)
if err != nil {
t.Errorf("%s(%q): %s", tt.name, tt.text, err)
continue
}
h := tt.f()
for i := range []byte(tt.text) {
io.WriteString(h, tt.text[i:i+1])
}
got := h.Sum(nil)
if !bytes.Equal(got, want) {
t.Errorf("%s(%q) = %x, want %x", tt.name, tt.text, got, want)
}
}
}
func benchmark(b *testing.B, f func() hash.Hash, size int64) {
var tmp [Size * 2]byte
var msg [8192]byte

View File

@ -26,8 +26,8 @@ func newShake(N, S []byte, sizeBytes int) *Shake {
s.digest.Write(N)
s.digest.Write(leftEncode(uint64(len(S)) * 8)) // length of S in bits
s.digest.Write(S)
if s.len > 0 {
s.pad(rate)
if s.len > 0 || s.ulen > 0 {
s.pad8()
s.flush()
}
}
@ -36,18 +36,14 @@ func newShake(N, S []byte, sizeBytes int) *Shake {
}
func (s *Shake) pad8() {
n := -s.len & 7
for i := 0; i < n; i++ {
s.buf[s.len+i] = 0
if s.ulen > 0 {
for i := int(s.ulen); i < len(s.buf); i++ {
s.buf[i] = 0
}
s.a[s.len] ^= le64dec(s.buf[:])
s.len += 1
s.ulen = 0
}
s.len += n
}
func (s *Shake) pad(rate int) {
for i := s.len; i < rate && i < len(s.buf); i++ {
s.buf[i] = 0
}
s.len = rate
}
// Shake is only resettable if Reset is called before the first Write or Read.
@ -64,7 +60,8 @@ func (s *Shake) Reset() {
panic("keccak: Reset called after Read or Write")
}
s.a = *s.initialState
s.buf = [200]byte{}
s.buf = [8]byte{}
s.ulen = 0
s.len = 0
s.running = 0
}
@ -79,50 +76,62 @@ func (s *Shake) Read(p []byte) (int, error) {
if s.running < 2 && len(p) > 0 {
s.running = 2
s.buf[s.len] = s.dsbyte
bs := s.BlockSize()
for i := s.len + 1; i < bs; i++ {
s.buf[i] = 0
}
s.buf[bs-1] |= 0x80
for i := range s.a {
if i*8 > bs {
break
var dsword uint64
if s.ulen == 0 {
dsword = uint64(s.dsbyte)
} else {
s.buf[s.ulen] = s.dsbyte
for i := int(s.ulen) + 1; i < len(s.buf); i++ {
s.buf[i] = 0
}
s.a[i] ^= le64dec(s.buf[i*8:])
dsword = le64dec(s.buf[:])
}
s.a[s.len] ^= dsword
bs := s.BlockSize() / 8
s.a[bs-1] ^= 0x80 << 56
s.len = bs
s.ulen = 0
}
return s.digest.read(p)
}
func (d *digest) read(p []byte) (int, error) {
bs := d.BlockSize()
bs := d.BlockSize() / 8
size := len(p)
for len(p) > 0 {
if d.len == bs {
d.squeeze(bs)
}
n := copy(p, d.buf[:bs])
d.len += n
if d.ulen > 0 {
n := copy(p, d.buf[d.ulen:])
p = p[n:]
d.ulen += int8(n)
if int(d.ulen) == len(d.buf) {
d.ulen = 0
d.len += 1
}
}
for len(p) >= 8 {
if d.len == bs {
d.squeeze()
}
le64enc(p[:0], d.a[d.len])
p = p[8:]
d.len += 1
}
if len(p) > 0 {
le64enc(d.buf[:0], d.a[d.len])
d.ulen = int8(copy(p, d.buf[:]))
}
return size, nil
}
func (d *digest) squeeze(bs int) {
func (d *digest) squeeze() {
//fmt.Printf("Squeezing\n", d.len)
keccakf(&d.a)
b := d.buf[:bs]
for i := range d.a {
if len(b) == 0 {
break
}
le64enc(b[:0], d.a[i]) // append
b = b[8:]
}
d.len = 0
}

View File

@ -11,7 +11,8 @@ func round(a *[25]uint64) { roundGo(a) }
// digest implements hash.Hash
type digest struct {
a [25]uint64 // a[y][x][z]
buf [200]byte
buf [8]byte // buf[0:ulen] holds a partial uint64
ulen int8
dsbyte byte
len int
size int
@ -29,34 +30,46 @@ func (d *digest) BlockSize() int { return 200 - d.size*2 }
func (d *digest) Reset() {
//fmt.Println("resetting")
d.a = [25]uint64{}
d.buf = [200]byte{}
d.buf = [8]byte{}
d.len = 0
}
func (d *digest) Write(b []byte) (int, error) {
written := len(b)
bs := d.BlockSize()
for len(b) > 0 {
n := copy(d.buf[d.len:bs], b)
d.len += n
bs := d.BlockSize() / 8
// fill buf first, if non-empty
if d.ulen > 0 {
n := copy(d.buf[d.ulen:], b)
b = b[n:]
d.ulen += int8(n)
// flush?
if int(d.ulen) == len(d.buf) {
d.a[d.len] ^= le64dec(d.buf[:])
d.len += 1
d.ulen = 0
if d.len == bs {
d.flush()
}
}
}
// xor 8-byte chunks into the state
for len(b) >= 8 {
d.a[d.len] ^= le64dec(b)
b = b[8:]
d.len += 1
if d.len == bs {
d.flush()
}
} // len(b) < 8
// store any remaining bytes
if len(b) > 0 {
d.ulen = int8(copy(d.buf[:], b))
}
return written, nil
}
func (d *digest) flush() {
//fmt.Printf("Flushing with %d bytes\n", d.len)
b := d.buf[:d.len]
for i := range d.a {
if len(b) == 0 {
break
}
d.a[i] ^= le64dec(b)
b = b[8:]
}
//fmt.Printf("Flushing with %d bytes\n", d.len*8 + int(d.ulen))
keccakf(&d.a)
d.len = 0
}
@ -75,13 +88,19 @@ func (d *digest) clone() *digest {
func (d *digest) Sum(b []byte) []byte {
d = d.clone()
d.buf[d.len] = d.dsbyte
bs := d.BlockSize()
for i := d.len + 1; i < bs; i++ {
d.buf[i] = 0
if d.ulen == 0 {
d.a[d.len] ^= uint64(d.dsbyte)
} else {
d.buf[d.ulen] = d.dsbyte
for i := int(d.ulen) + 1; i < len(d.buf); i++ {
d.buf[i] = 0
}
d.a[d.len] ^= le64dec(d.buf[:])
}
d.buf[bs-1] |= 0x80
d.len = bs
bs := d.BlockSize() / 8
d.a[bs-1] |= 0x80 << 56
//d.len = bs
d.flush()
for i := 0; i < d.size/8; i++ {
@ -91,6 +110,7 @@ func (d *digest) Sum(b []byte) []byte {
}
func le64dec(b []byte) uint64 {
_ = b[7]
return uint64(b[0])<<0 | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}