]> Cypherpunks.ru repositories - gostls13.git/blobdiff - src/internal/zstd/zstd_test.go
internal/zstd: use dynamic path resolution for zstd in tests
[gostls13.git] / src / internal / zstd / zstd_test.go
index 22af814acfce57d8344436394c128a13dc748431..4ae6f2b3982c1cc3ed88d393255b9b5b1c762393 100644 (file)
@@ -6,12 +6,14 @@ package zstd
 
 import (
        "bytes"
+       "crypto/sha256"
        "fmt"
        "internal/race"
        "internal/testenv"
        "io"
        "os"
        "os/exec"
+       "path/filepath"
        "strings"
        "sync"
        "testing"
@@ -90,6 +92,22 @@ var tests = []struct {
                "0\x00\x00\x00\x00\x000\x00\x00\x00\x00\x001\x00\x00\x00\x00\x000000",
                "(\xb5/\xfd\x04X\x8d\x00\x00P0\x000\x001\x000000\x03T\x02\x00\x01\x01m\xf9\xb7G",
        },
+       {
+               "empty block",
+               "",
+               "\x28\xb5\x2f\xfd\x00\x00\x15\x00\x00\x00\x00",
+       },
+       {
+               "single skippable frame",
+               "",
+               "\x50\x2a\x4d\x18\x00\x00\x00\x00",
+       },
+       {
+               "two skippable frames",
+               "",
+               "\x50\x2a\x4d\x18\x00\x00\x00\x00" +
+                       "\x50\x2a\x4d\x18\x00\x00\x00\x00",
+       },
 }
 
 func TestSamples(t *testing.T) {
@@ -109,6 +127,26 @@ func TestSamples(t *testing.T) {
        }
 }
 
+func TestReset(t *testing.T) {
+       input := strings.NewReader("")
+       r := NewReader(input)
+       for _, test := range tests {
+               test := test
+               t.Run(test.name, func(t *testing.T) {
+                       input.Reset(test.compressed)
+                       r.Reset(input)
+                       got, err := io.ReadAll(r)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       gotstr := string(got)
+                       if gotstr != test.uncompressed {
+                               t.Errorf("got %q want %q", gotstr, test.uncompressed)
+                       }
+               })
+       }
+}
+
 var (
        bigDataOnce  sync.Once
        bigDataBytes []byte
@@ -129,10 +167,17 @@ func bigData(t testing.TB) []byte {
        return bigDataBytes
 }
 
+func findZstd(t testing.TB) string {
+       zstd, err := exec.LookPath("zstd")
+       if err != nil {
+               t.Skip("skipping because zstd not found")
+       }
+       return zstd
+}
+
 var (
        zstdBigOnce  sync.Once
        zstdBigBytes []byte
-       zstdBigSkip  bool
        zstdBigErr   error
 )
 
@@ -142,13 +187,10 @@ var (
 func zstdBigData(t testing.TB) []byte {
        input := bigData(t)
 
-       zstdBigOnce.Do(func() {
-               if _, err := os.Stat("/usr/bin/zstd"); err != nil {
-                       zstdBigSkip = true
-                       return
-               }
+       zstd := findZstd(t)
 
-               cmd := exec.Command("/usr/bin/zstd", "-z")
+       zstdBigOnce.Do(func() {
+               cmd := exec.Command(zstd, "-z")
                cmd.Stdin = bytes.NewReader(input)
                var compressed bytes.Buffer
                cmd.Stdout = &compressed
@@ -160,9 +202,6 @@ func zstdBigData(t testing.TB) []byte {
 
                zstdBigBytes = compressed.Bytes()
        })
-       if zstdBigSkip {
-               t.Skip("skipping because /usr/bin/zstd does not exist")
-       }
        if zstdBigErr != nil {
                t.Fatal(zstdBigErr)
        }
@@ -179,7 +218,7 @@ func TestLarge(t *testing.T) {
        data := bigData(t)
        compressed := zstdBigData(t)
 
-       t.Logf("/usr/bin/zstd compressed %d bytes to %d", len(data), len(compressed))
+       t.Logf("zstd compressed %d bytes to %d", len(data), len(compressed))
 
        r := NewReader(bytes.NewReader(compressed))
        got, err := io.ReadAll(r)
@@ -232,6 +271,39 @@ func TestAlloc(t *testing.T) {
        }
 }
 
+func TestFileSamples(t *testing.T) {
+       samples, err := os.ReadDir("testdata")
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       for _, sample := range samples {
+               name := sample.Name()
+               if !strings.HasSuffix(name, ".zst") {
+                       continue
+               }
+
+               t.Run(name, func(t *testing.T) {
+                       f, err := os.Open(filepath.Join("testdata", name))
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+
+                       r := NewReader(f)
+                       h := sha256.New()
+                       if _, err := io.Copy(h, r); err != nil {
+                               t.Fatal(err)
+                       }
+                       got := fmt.Sprintf("%x", h.Sum(nil))[:8]
+
+                       want, _, _ := strings.Cut(name, ".")
+                       if got != want {
+                               t.Errorf("Wrong uncompressed content hash: got %s, want %s", got, want)
+                       }
+               })
+       }
+}
+
 func BenchmarkLarge(b *testing.B) {
        b.StopTimer()
        b.ReportAllocs()