aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2022-12-17 12:12:05 +0100
committerDmitry Vyukov <dvyukov@google.com>2022-12-22 10:11:08 +0100
commita0df376348d2ad1d3e557ea221e75c78a5d9fd96 (patch)
treeb15ead346eb5f9f01e71dca1f4f4d1966da3e0cf
parent09ff16760eac2d6f03e93bd7d50892a6d536ed1b (diff)
pkg/image: make Decompress easier to use
Change DecompressWriter to DecompressCheck: checking validity of the image is the only useful use of DecompressWriter. Change Decompress to MustDecompress which does not return an error. We check validity during program deserialization, so all other uses already panic on errors. Also add dtor return value in preparation for subsequent changes.
-rw-r--r--pkg/image/compression.go17
-rw-r--r--pkg/image/compression_test.go31
-rw-r--r--prog/analysis.go6
-rw-r--r--prog/encoding.go3
-rw-r--r--prog/mutation.go6
5 files changed, 33 insertions, 30 deletions
diff --git a/pkg/image/compression.go b/pkg/image/compression.go
index 9878b460d..a2a51b146 100644
--- a/pkg/image/compression.go
+++ b/pkg/image/compression.go
@@ -9,6 +9,7 @@ import (
"encoding/base64"
"fmt"
"io"
+ "io/ioutil"
)
func Compress(rawData []byte) []byte {
@@ -28,14 +29,20 @@ func Compress(rawData []byte) []byte {
return buffer.Bytes()
}
-func Decompress(compressedData []byte) ([]byte, error) {
+func MustDecompress(compressed []byte) (data []byte, dtor func()) {
buf := new(bytes.Buffer)
- err := DecompressWriter(buf, compressedData)
- return buf.Bytes(), err
+ if err := decompressWriter(buf, compressed); err != nil {
+ panic(err)
+ }
+ return buf.Bytes(), func() {}
+}
+
+func DecompressCheck(compressed []byte) error {
+ return decompressWriter(ioutil.Discard, compressed)
}
-func DecompressWriter(w io.Writer, compressedData []byte) error {
- zlibReader, err := zlib.NewReader(bytes.NewReader(compressedData))
+func decompressWriter(w io.Writer, compressed []byte) error {
+ zlibReader, err := zlib.NewReader(bytes.NewReader(compressed))
if err != nil {
return fmt.Errorf("could not initialise zlib: %v", err)
}
diff --git a/pkg/image/compression_test.go b/pkg/image/compression_test.go
index cf18ed340..0107f82aa 100644
--- a/pkg/image/compression_test.go
+++ b/pkg/image/compression_test.go
@@ -13,32 +13,33 @@ import (
)
func TestCompress(t *testing.T) {
+ t.Parallel()
r := rand.New(testutil.RandSource(t))
- err := testRoundTrip(r, Compress, Decompress)
- if err != nil {
- t.Fatalf("compress/decompress %v", err)
+ for i := 0; i < testutil.IterCount(); i++ {
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ randBytes := testutil.RandMountImage(r)
+ resultBytes := Compress(randBytes)
+ resultBytes, dtor := MustDecompress(resultBytes)
+ defer dtor()
+ if !bytes.Equal(randBytes, resultBytes) {
+ t.Fatalf("roundtrip changes data (length %v->%v)", len(randBytes), len(resultBytes))
+ }
+ })
}
}
func TestEncode(t *testing.T) {
+ t.Parallel()
r := rand.New(testutil.RandSource(t))
- err := testRoundTrip(r, EncodeB64, DecodeB64)
- if err != nil {
- t.Fatalf("encode/decode Base64 %v", err)
- }
-}
-
-func testRoundTrip(r *rand.Rand, transform func([]byte) []byte, inverse func([]byte) ([]byte, error)) error {
for i := 0; i < testutil.IterCount(); i++ {
randBytes := testutil.RandMountImage(r)
- resultBytes := transform(randBytes)
- resultBytes, err := inverse(resultBytes)
+ resultBytes := EncodeB64(randBytes)
+ resultBytes, err := DecodeB64(resultBytes)
if err != nil {
- return err
+ t.Fatalf("decoding failed: %v", err)
}
if !bytes.Equal(randBytes, resultBytes) {
- return fmt.Errorf("roundtrip changes data (original length %d)", len(randBytes))
+ t.Fatalf("roundtrip changes data (original length %d)", len(randBytes))
}
}
- return nil
}
diff --git a/prog/analysis.go b/prog/analysis.go
index fec43e1bd..850fdb1dd 100644
--- a/prog/analysis.go
+++ b/prog/analysis.go
@@ -356,10 +356,8 @@ func (p *Prog) ForEachAsset(cb func(name string, typ AssetType, r io.Reader)) {
if !ok || a.Type().(*BufferType).Kind != BufferCompressed {
return
}
- data, err := image.Decompress(a.Data())
- if err != nil {
- panic(err)
- }
+ data, dtor := image.MustDecompress(a.Data())
+ defer dtor()
if len(data) == 0 {
return
}
diff --git a/prog/encoding.go b/prog/encoding.go
index bfa80b983..bd2efb836 100644
--- a/prog/encoding.go
+++ b/prog/encoding.go
@@ -7,7 +7,6 @@ import (
"bytes"
"encoding/hex"
"fmt"
- "io/ioutil"
"reflect"
"strconv"
"strings"
@@ -606,7 +605,7 @@ func (p *parser) parseArgString(t Type, dir Dir) (Arg, error) {
}
// Check compressed data for validity.
if typ.IsCompressed() {
- if err := image.DecompressWriter(ioutil.Discard, data); err != nil {
+ if err := image.DecompressCheck(data); err != nil {
p.strictFailf("invalid compressed data in arg: %v", err)
// In non-strict mode, empty the data slice.
data = image.Compress(nil)
diff --git a/prog/mutation.go b/prog/mutation.go
index 100c47a91..5a38cfa88 100644
--- a/prog/mutation.go
+++ b/prog/mutation.go
@@ -396,10 +396,8 @@ func (r *randGen) mutateImage(compressed []byte) (data []byte, retry bool) {
// Reconsider when/if we move mutation to the host process.
imageMu.Lock()
defer imageMu.Unlock()
- data, err := image.Decompress(compressed)
- if err != nil {
- panic(fmt.Sprintf("could not decompress data: %v", err))
- }
+ data, dtor := image.MustDecompress(compressed)
+ defer dtor()
if len(data) == 0 {
return compressed, true // Do not mutate empty data.
}