]> Cypherpunks.ru repositories - gostls13.git/commitdiff
compress/flate: fix deflate Reset consistency
authorKlaus Post <klauspost@gmail.com>
Thu, 7 May 2020 12:50:00 +0000 (12:50 +0000)
committerRuss Cox <rsc@golang.org>
Thu, 16 Jul 2020 20:54:37 +0000 (20:54 +0000)
Modify the overflow detection logic to shuffle the contents
of the table to a lower offset to avoid leaking the effects
of a previous use of compress.Writer past Reset calls.

Fixes #34121

Change-Id: I9963eadfa5482881e7b7adbad4c2cae146b669ab
GitHub-Last-Rev: 8b35798cdd4d5a901d6422647b12984d7e500ba3
GitHub-Pull-Request: golang/go#34128
Reviewed-on: https://go-review.googlesource.com/c/go/+/193605
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
Run-TryBot: Joe Tsai <thebrokentoaster@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/compress/flate/deflatefast.go
src/compress/flate/writer_test.go

index 08298b76bbb2da5a7882d439d61d8428b28f632a..24f8be9d5db86c79764e6c234ac2a7ac4a91ad54 100644 (file)
@@ -4,6 +4,8 @@
 
 package flate
 
+import "math"
+
 // This encoding algorithm, which prioritizes speed over output size, is
 // based on Snappy's LZ77-style encoder: github.com/golang/snappy
 
@@ -12,6 +14,13 @@ const (
        tableSize  = 1 << tableBits // Size of the table.
        tableMask  = tableSize - 1  // Mask for table indices. Redundant, but can eliminate bounds checks.
        tableShift = 32 - tableBits // Right-shift to get the tableBits most significant bits of a uint32.
+
+       // Reset the buffer offset when reaching this.
+       // Offsets are stored between blocks as int32 values.
+       // Since the offset we are checking against is at the beginning
+       // of the buffer, we need to subtract the current and input
+       // buffer to not risk overflowing the int32.
+       bufferReset = math.MaxInt32 - maxStoreBlockSize*2
 )
 
 func load32(b []byte, i int32) uint32 {
@@ -59,8 +68,8 @@ func newDeflateFast() *deflateFast {
 // to dst and returns the result.
 func (e *deflateFast) encode(dst []token, src []byte) []token {
        // Ensure that e.cur doesn't wrap.
-       if e.cur > 1<<30 {
-               e.resetAll()
+       if e.cur >= bufferReset {
+               e.shiftOffsets()
        }
 
        // This check isn't in the Snappy implementation, but there, the caller
@@ -264,22 +273,32 @@ func (e *deflateFast) reset() {
        e.cur += maxMatchOffset
 
        // Protect against e.cur wraparound.
-       if e.cur > 1<<30 {
-               e.resetAll()
+       if e.cur >= bufferReset {
+               e.shiftOffsets()
        }
 }
 
-// resetAll resets the deflateFast struct and is only called in rare
-// situations to prevent integer overflow. It manually resets each field
-// to avoid causing large stack growth.
+// shiftOffsets will shift down all match offset.
+// This is only called in rare situations to prevent integer overflow.
 //
-// See https://golang.org/issue/18636.
-func (e *deflateFast) resetAll() {
-       // This is equivalent to:
-       //      *e = deflateFast{cur: maxStoreBlockSize, prev: e.prev[:0]}
-       e.cur = maxStoreBlockSize
-       e.prev = e.prev[:0]
-       for i := range e.table {
-               e.table[i] = tableEntry{}
+// See https://golang.org/issue/18636 and https://github.com/golang/go/issues/34121.
+func (e *deflateFast) shiftOffsets() {
+       if len(e.prev) == 0 {
+               // We have no history; just clear the table.
+               for i := range e.table[:] {
+                       e.table[i] = tableEntry{}
+               }
+               e.cur = maxMatchOffset
+               return
+       }
+
+       // Shift down everything in the table that isn't already too far away.
+       for i := range e.table[:] {
+               v := e.table[i].offset - e.cur + maxMatchOffset
+               if v < 0 {
+                       v = 0
+               }
+               e.table[i].offset = v
        }
+       e.cur = maxMatchOffset
 }
index c4d36aa37e9b6fcb743ddfb840134c7e14ed8363..881cb71cc3a407c8458ca13b0e80c1c98a21e1ab 100644 (file)
@@ -173,3 +173,66 @@ func testDeterministic(i int, t *testing.T) {
                t.Errorf("level %d did not produce deterministic result, result mismatch, len(a) = %d, len(b) = %d", i, len(b1b), len(b2b))
        }
 }
+
+// TestDeflateFast_Reset will test that encoding is consistent
+// across a warparound of the table offset.
+// See https://github.com/golang/go/issues/34121
+func TestDeflateFast_Reset(t *testing.T) {
+       buf := new(bytes.Buffer)
+       n := 65536
+
+       for i := 0; i < n; i++ {
+               fmt.Fprintf(buf, "asdfasdfasdfasdf%d%dfghfgujyut%dyutyu\n", i, i, i)
+       }
+       // This is specific to level 1.
+       const level = 1
+       in := buf.Bytes()
+       offset := 1
+       if testing.Short() {
+               offset = 256
+       }
+
+       // We do an encode with a clean buffer to compare.
+       var want bytes.Buffer
+       w, err := NewWriter(&want, level)
+       if err != nil {
+               t.Fatalf("NewWriter: level %d: %v", level, err)
+       }
+
+       // Output written 3 times.
+       w.Write(in)
+       w.Write(in)
+       w.Write(in)
+       w.Close()
+
+       for ; offset <= 256; offset *= 2 {
+               w, err := NewWriter(ioutil.Discard, level)
+               if err != nil {
+                       t.Fatalf("NewWriter: level %d: %v", level, err)
+               }
+
+               // Reset until we are right before the wraparound.
+               // Each reset adds maxMatchOffset to the offset.
+               for i := 0; i < (bufferReset-len(in)-offset-maxMatchOffset)/maxMatchOffset; i++ {
+                       // skip ahead to where we are close to wrap around...
+                       w.d.reset(nil)
+               }
+               var got bytes.Buffer
+               w.Reset(&got)
+
+               // Write 3 times, close.
+               for i := 0; i < 3; i++ {
+                       _, err = w.Write(in)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+               }
+               err = w.Close()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if !bytes.Equal(got.Bytes(), want.Bytes()) {
+                       t.Fatalf("output did not match at wraparound, len(want)  = %d, len(got) = %d", want.Len(), got.Len())
+               }
+       }
+}