aboutsummaryrefslogtreecommitdiffstats
path: root/prog/test_util.go
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2020-03-17 10:44:38 +0100
committerDmitry Vyukov <dvyukov@google.com>2020-03-17 21:19:13 +0100
commita2f9a446496d23c4bf6db95e0d4337583595c78c (patch)
treeb745c1e04b7b2f4997ca823a7d7a294bf62671f6 /prog/test_util.go
parent0a4d69469bf7e77f26f3036fbb183ecf73368a5d (diff)
prog: export deserialization test helper for sys/{linux,openbsd}
sys/{linux,openbsd} duplicate deserialization test logic as well. Export and reuse the existing helper function.
Diffstat (limited to 'prog/test_util.go')
-rw-r--r--prog/test_util.go72
1 files changed, 72 insertions, 0 deletions
diff --git a/prog/test_util.go b/prog/test_util.go
new file mode 100644
index 000000000..bb96a64e8
--- /dev/null
+++ b/prog/test_util.go
@@ -0,0 +1,72 @@
+// Copyright 2020 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package prog
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+)
+
+func InitTargetTest(t *testing.T, os, arch string) *Target {
+ t.Parallel()
+ target, err := GetTarget(os, arch)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return target
+}
+
+type DeserializeTest struct {
+ In string
+ Out string
+ Err string
+ StrictErr string
+}
+
+func TestDeserializeHelper(t *testing.T, OS, arch string, transform func(*Target, *Prog), tests []DeserializeTest) {
+ target := InitTargetTest(t, OS, arch)
+ buf := make([]byte, ExecBufferSize)
+ for testidx, test := range tests {
+ t.Run(fmt.Sprint(testidx), func(t *testing.T) {
+ if test.StrictErr == "" {
+ test.StrictErr = test.Err
+ }
+ if test.Err != "" && test.Out != "" {
+ t.Fatalf("both Err and Out are set")
+ }
+ for _, mode := range []DeserializeMode{NonStrict, Strict} {
+ p, err := target.Deserialize([]byte(test.In), mode)
+ wantErr := test.Err
+ if mode == Strict {
+ wantErr = test.StrictErr
+ }
+ if err != nil {
+ if wantErr == "" {
+ t.Fatalf("deserialization failed with\n%s\ndata:\n%s\n",
+ err, test.In)
+ }
+ if !strings.Contains(err.Error(), wantErr) {
+ t.Fatalf("deserialization failed with\n%s\nwhich doesn't match\n%s\ndata:\n%s",
+ err, wantErr, test.In)
+ }
+ } else {
+ if wantErr != "" {
+ t.Fatalf("deserialization should have failed with:\n%s\ndata:\n%s\n",
+ wantErr, test.In)
+ }
+ if transform != nil {
+ transform(target, p)
+ }
+ output := strings.TrimSpace(string(p.Serialize()))
+ want := strings.TrimSpace(test.Out)
+ if want != "" && want != output {
+ t.Fatalf("wrong serialized data:\n%s\nexpect:\n%s\n", output, want)
+ }
+ p.SerializeForExec(buf)
+ }
+ }
+ })
+ }
+}