diff --git a/keccak_test.go b/keccak_test.go index 3e8d0ae..128a064 100644 --- a/keccak_test.go +++ b/keccak_test.go @@ -40,6 +40,31 @@ var tests = []struct { text: "The quick brown fox jumps over the lazy dog", hash: "69070dda01975c8c120c3aada1b282394e7f032fa9cf32f4cb2259a0897dfc04", }, + + { + f: New256, + name: "SHA3-256", + text: "a", + hash: "80084bf2fba02475726feb2cab2d8215eab14bc6bdd8bfb2c8151257032ecd8b", + }, + { + f: New256, + name: "SHA3-256", + text: "abcdefg", + hash: "7d55114476dfc6a2fbeaa10e221a8d0f32fc8f2efb69a6e878f4633366917a62", + }, + { + f: New256, + name: "SHA3-256", + text: "abcdefgh", + hash: "3e2020725a38a48eb3bbf75767f03a22c6b3f41f459c831309b06433ec649779", + }, + { + f: New256, + name: "SHA3-256", + text: "abcdefghi", + hash: "f74eb337992307c22bc59eb43e59583a683f3b93077e7f2472508e8c464d2657", + }, } func TestHash(t *testing.T) { @@ -58,6 +83,24 @@ func TestHash(t *testing.T) { } } +func TestHashSmallWrites(t *testing.T) { + for _, tt := range tests { + want, err := hex.DecodeString(tt.hash) + if err != nil { + t.Errorf("%s(%q): %s", tt.name, tt.text, err) + continue + } + h := tt.f() + for i := range []byte(tt.text) { + io.WriteString(h, tt.text[i:i+1]) + } + got := h.Sum(nil) + if !bytes.Equal(got, want) { + t.Errorf("%s(%q) = %x, want %x", tt.name, tt.text, got, want) + } + } +} + func benchmark(b *testing.B, f func() hash.Hash, size int64) { var tmp [Size * 2]byte var msg [8192]byte diff --git a/shake.go b/shake.go index 95f9601..f9f31b2 100644 --- a/shake.go +++ b/shake.go @@ -26,8 +26,8 @@ func newShake(N, S []byte, sizeBytes int) *Shake { s.digest.Write(N) s.digest.Write(leftEncode(uint64(len(S)) * 8)) // length of S in bits s.digest.Write(S) - if s.len > 0 { - s.pad(rate) + if s.len > 0 || s.ulen > 0 { + s.pad8() s.flush() } } @@ -36,18 +36,14 @@ func newShake(N, S []byte, sizeBytes int) *Shake { } func (s *Shake) pad8() { - n := -s.len & 7 - for i := 0; i < n; i++ { - s.buf[s.len+i] = 0 + if s.ulen > 0 { + for i := int(s.ulen); i < len(s.buf); i++ { + s.buf[i] = 0 + } + s.a[s.len] ^= le64dec(s.buf[:]) + s.len += 1 + s.ulen = 0 } - s.len += n -} - -func (s *Shake) pad(rate int) { - for i := s.len; i < rate && i < len(s.buf); i++ { - s.buf[i] = 0 - } - s.len = rate } // Shake is only resettable if Reset is called before the first Write or Read. @@ -64,7 +60,8 @@ func (s *Shake) Reset() { panic("keccak: Reset called after Read or Write") } s.a = *s.initialState - s.buf = [200]byte{} + s.buf = [8]byte{} + s.ulen = 0 s.len = 0 s.running = 0 } @@ -79,50 +76,62 @@ func (s *Shake) Read(p []byte) (int, error) { if s.running < 2 && len(p) > 0 { s.running = 2 - s.buf[s.len] = s.dsbyte - bs := s.BlockSize() - for i := s.len + 1; i < bs; i++ { - s.buf[i] = 0 - } - s.buf[bs-1] |= 0x80 - - for i := range s.a { - if i*8 > bs { - break + var dsword uint64 + if s.ulen == 0 { + dsword = uint64(s.dsbyte) + } else { + s.buf[s.ulen] = s.dsbyte + for i := int(s.ulen) + 1; i < len(s.buf); i++ { + s.buf[i] = 0 } - s.a[i] ^= le64dec(s.buf[i*8:]) + dsword = le64dec(s.buf[:]) } + s.a[s.len] ^= dsword + + bs := s.BlockSize() / 8 + s.a[bs-1] ^= 0x80 << 56 s.len = bs + s.ulen = 0 + } return s.digest.read(p) } func (d *digest) read(p []byte) (int, error) { - bs := d.BlockSize() + bs := d.BlockSize() / 8 size := len(p) - for len(p) > 0 { - if d.len == bs { - d.squeeze(bs) - } - n := copy(p, d.buf[:bs]) - d.len += n + + if d.ulen > 0 { + n := copy(p, d.buf[d.ulen:]) p = p[n:] + d.ulen += int8(n) + if int(d.ulen) == len(d.buf) { + d.ulen = 0 + d.len += 1 + } } + + for len(p) >= 8 { + if d.len == bs { + d.squeeze() + } + le64enc(p[:0], d.a[d.len]) + p = p[8:] + d.len += 1 + } + + if len(p) > 0 { + le64enc(d.buf[:0], d.a[d.len]) + d.ulen = int8(copy(p, d.buf[:])) + } + return size, nil } -func (d *digest) squeeze(bs int) { +func (d *digest) squeeze() { //fmt.Printf("Squeezing\n", d.len) keccakf(&d.a) - b := d.buf[:bs] - for i := range d.a { - if len(b) == 0 { - break - } - le64enc(b[:0], d.a[i]) // append - b = b[8:] - } d.len = 0 } diff --git a/sponge.go b/sponge.go index aab872d..f1c786d 100644 --- a/sponge.go +++ b/sponge.go @@ -11,7 +11,8 @@ func round(a *[25]uint64) { roundGo(a) } // digest implements hash.Hash type digest struct { a [25]uint64 // a[y][x][z] - buf [200]byte + buf [8]byte // buf[0:ulen] holds a partial uint64 + ulen int8 dsbyte byte len int size int @@ -29,34 +30,46 @@ func (d *digest) BlockSize() int { return 200 - d.size*2 } func (d *digest) Reset() { //fmt.Println("resetting") d.a = [25]uint64{} - d.buf = [200]byte{} + d.buf = [8]byte{} d.len = 0 } func (d *digest) Write(b []byte) (int, error) { written := len(b) - bs := d.BlockSize() - for len(b) > 0 { - n := copy(d.buf[d.len:bs], b) - d.len += n + bs := d.BlockSize() / 8 + // fill buf first, if non-empty + if d.ulen > 0 { + n := copy(d.buf[d.ulen:], b) b = b[n:] + d.ulen += int8(n) + // flush? + if int(d.ulen) == len(d.buf) { + d.a[d.len] ^= le64dec(d.buf[:]) + d.len += 1 + d.ulen = 0 + if d.len == bs { + d.flush() + } + } + } + // xor 8-byte chunks into the state + for len(b) >= 8 { + d.a[d.len] ^= le64dec(b) + b = b[8:] + d.len += 1 if d.len == bs { d.flush() } + } // len(b) < 8 + // store any remaining bytes + if len(b) > 0 { + d.ulen = int8(copy(d.buf[:], b)) } return written, nil } func (d *digest) flush() { - //fmt.Printf("Flushing with %d bytes\n", d.len) - b := d.buf[:d.len] - for i := range d.a { - if len(b) == 0 { - break - } - d.a[i] ^= le64dec(b) - b = b[8:] - } + //fmt.Printf("Flushing with %d bytes\n", d.len*8 + int(d.ulen)) keccakf(&d.a) d.len = 0 } @@ -75,13 +88,19 @@ func (d *digest) clone() *digest { func (d *digest) Sum(b []byte) []byte { d = d.clone() - d.buf[d.len] = d.dsbyte - bs := d.BlockSize() - for i := d.len + 1; i < bs; i++ { - d.buf[i] = 0 + if d.ulen == 0 { + d.a[d.len] ^= uint64(d.dsbyte) + } else { + d.buf[d.ulen] = d.dsbyte + for i := int(d.ulen) + 1; i < len(d.buf); i++ { + d.buf[i] = 0 + } + d.a[d.len] ^= le64dec(d.buf[:]) } - d.buf[bs-1] |= 0x80 - d.len = bs + + bs := d.BlockSize() / 8 + d.a[bs-1] |= 0x80 << 56 + //d.len = bs d.flush() for i := 0; i < d.size/8; i++ { @@ -91,6 +110,7 @@ func (d *digest) Sum(b []byte) []byte { } func le64dec(b []byte) uint64 { + _ = b[7] return uint64(b[0])<<0 | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56 }