aboutsummaryrefslogtreecommitdiffstats
path: root/prog
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2024-08-08 14:50:12 +0200
committerDmitry Vyukov <dvyukov@google.com>2024-08-08 14:06:39 +0000
commit61405512146275a395ed4174f448ddc175f8c189 (patch)
treea612ce81a0a58d370ea27b9947a648f970ea91d0 /prog
parenta85b371c5584664083ed7e1a394607c3100534c2 (diff)
prog: try to remove all unrelated calls during minimization
We have too many corpus minimization executions and the main source of these is call removal. Try to remove all "unrelated" calls at once. Unrelated calls are the calls that don't use any resources/files from the transitive closure of the resources/files used by the target call. This may significantly reduce large generated programs in a single step.
Diffstat (limited to 'prog')
-rw-r--r--prog/minimization.go86
-rw-r--r--prog/minimization_test.go119
2 files changed, 187 insertions, 18 deletions
diff --git a/prog/minimization.go b/prog/minimization.go
index 260a15133..c835ea81a 100644
--- a/prog/minimization.go
+++ b/prog/minimization.go
@@ -129,6 +129,11 @@ func removeCalls(p0 *Prog, callIndex0 int, pred minimizePred) (*Prog, int) {
p0 = p
}
}
+
+ if callIndex0 != -1 {
+ p0, callIndex0 = removeUnrelatedCalls(p0, callIndex0, pred)
+ }
+
for i := len(p0.Calls) - 1; i >= 0; i-- {
if i == callIndex0 {
continue
@@ -148,6 +153,87 @@ func removeCalls(p0 *Prog, callIndex0 int, pred minimizePred) (*Prog, int) {
return p0, callIndex0
}
+// removeUnrelatedCalls tries to remove all "unrelated" calls at once.
+// Unrelated calls are the calls that don't use any resources/files from
+// the transitive closure of the resources/files used by the target call.
+// This may significantly reduce large generated programs in a single step.
+func removeUnrelatedCalls(p0 *Prog, callIndex0 int, pred minimizePred) (*Prog, int) {
+ keepCalls := relatedCalls(p0, callIndex0)
+ if len(p0.Calls)-len(keepCalls) < 3 {
+ return p0, callIndex0
+ }
+ p, callIndex := p0.Clone(), callIndex0
+ for i := len(p0.Calls) - 1; i >= 0; i-- {
+ if keepCalls[i] {
+ continue
+ }
+ p.RemoveCall(i)
+ if i < callIndex {
+ callIndex--
+ }
+ }
+ if !pred(p, callIndex, statMinRemoveCall, "unrelated calls") {
+ return p0, callIndex0
+ }
+ return p, callIndex
+}
+
+func relatedCalls(p0 *Prog, callIndex0 int) map[int]bool {
+ keepCalls := map[int]bool{callIndex0: true}
+ used := uses(p0.Calls[callIndex0])
+ for {
+ n := len(used)
+ for i, call := range p0.Calls {
+ if keepCalls[i] {
+ continue
+ }
+ used1 := uses(call)
+ if intersects(used, used1) {
+ keepCalls[i] = true
+ for what := range used1 {
+ used[what] = true
+ }
+ }
+ }
+ if n == len(used) {
+ return keepCalls
+ }
+ }
+}
+
+func uses(call *Call) map[any]bool {
+ used := make(map[any]bool)
+ ForeachArg(call, func(arg Arg, _ *ArgCtx) {
+ switch typ := arg.Type().(type) {
+ case *ResourceType:
+ a := arg.(*ResultArg)
+ used[a] = true
+ if a.Res != nil {
+ used[a.Res] = true
+ }
+ for use := range a.uses {
+ used[use] = true
+ }
+ case *BufferType:
+ a := arg.(*DataArg)
+ if a.Dir() != DirOut && typ.Kind == BufferFilename {
+ val := string(bytes.TrimRight(a.Data(), "\x00"))
+ used[val] = true
+ }
+ }
+ })
+ return used
+}
+
+func intersects(list, list1 map[any]bool) bool {
+ for what := range list1 {
+ if list[what] {
+ return true
+ }
+ }
+ return false
+}
+
func resetCallProps(p0 *Prog, callIndex0 int, pred minimizePred) *Prog {
// Try to reset all call props to their default values.
// This should be reasonable for many progs.
diff --git a/prog/minimization_test.go b/prog/minimization_test.go
index 9fe822577..e8f1d40e5 100644
--- a/prog/minimization_test.go
+++ b/prog/minimization_test.go
@@ -4,7 +4,9 @@
package prog
import (
+ "fmt"
"math/rand"
+ "strings"
"testing"
"github.com/google/syzkaller/pkg/hash"
@@ -12,6 +14,7 @@ import (
// nolint:gocyclo
func TestMinimize(t *testing.T) {
+ attempt := 0
// nolint: lll
tests := []struct {
os string
@@ -243,27 +246,107 @@ func TestMinimize(t *testing.T) {
"syz_mount_image$ext4(&(0x7f0000000000)='ext4\\x00', &(0x7f0000000100)='./file0\\x00', 0x0, &(0x7f0000010020), 0x1, 0x15, &(0x7f0000000200)=\"$eJwqrqzKTszJSS0CBAAA//8TyQPi\")\n",
0,
},
+ // Test for removeUnrelatedCalls.
+ // We test exact candidates we get on each step.
+ // First candidate should be removal of the trailing calls, which we reject.
+ // Next candidate is removal of unrelated calls, which we accept.
+ {
+ "linux", "amd64", MinimizeCorpus,
+ `
+getpid()
+r0 = open(&(0x7f0000000040)='./file0', 0x0, 0x0)
+r1 = open(&(0x7f0000000040)='./file1', 0x0, 0x0)
+getuid()
+read(r1, &(0x7f0000000040), 0x10)
+read(r0, &(0x7f0000000040), 0x10)
+pipe(&(0x7f0000000040)={<r2=>0x0, <r3=>0x0})
+creat(&(0x7f0000000040)='./file0', 0x0)
+close(r1)
+sendfile(r0, r2, &(0x7f0000000040), 0x1)
+getgid()
+fcntl$getflags(r0, 0x0)
+getpid()
+close(r3)
+getuid()
+ `,
+ 11,
+ func(p *Prog, callIndex int) bool {
+ pp := strings.TrimSpace(string(p.Serialize()))
+ if attempt == 0 {
+ if pp == strings.TrimSpace(`
+getpid()
+r0 = open(&(0x7f0000000040)='./file0', 0x0, 0x0)
+r1 = open(&(0x7f0000000040)='./file1', 0x0, 0x0)
+getuid()
+read(r1, &(0x7f0000000040), 0x10)
+read(r0, &(0x7f0000000040), 0x10)
+pipe(&(0x7f0000000040)={<r2=>0x0, 0x0})
+creat(&(0x7f0000000040)='./file0', 0x0)
+close(r1)
+sendfile(r0, r2, &(0x7f0000000040), 0x1)
+getgid()
+fcntl$getflags(r0, 0x0)
+ `) {
+ return false
+ }
+ } else if attempt == 1 {
+ if pp == strings.TrimSpace(`
+r0 = open(&(0x7f0000000040)='./file0', 0x0, 0x0)
+read(r0, &(0x7f0000000040), 0x10)
+pipe(&(0x7f0000000040)={<r1=>0x0, <r2=>0x0})
+creat(&(0x7f0000000040)='./file0', 0x0)
+sendfile(r0, r1, &(0x7f0000000040), 0x1)
+fcntl$getflags(r0, 0x0)
+close(r2)
+ `) {
+ return true
+ }
+ } else {
+ return false
+ }
+ panic(fmt.Sprintf("unexpected candidate on attempt %v:\n%v", attempt, pp))
+ },
+ `
+r0 = open(&(0x7f0000000040)='./file0', 0x0, 0x0)
+read(r0, &(0x7f0000000040), 0x10)
+pipe(&(0x7f0000000040)={<r1=>0x0, <r2=>0x0})
+creat(&(0x7f0000000040)='./file0', 0x0)
+sendfile(r0, r1, &(0x7f0000000040), 0x1)
+fcntl$getflags(r0, 0x0)
+close(r2)
+ `,
+ 5,
+ },
}
t.Parallel()
for ti, test := range tests {
- target, err := GetTarget(test.os, test.arch)
- if err != nil {
- t.Fatal(err)
- }
- p, err := target.Deserialize([]byte(test.orig), Strict)
- if err != nil {
- t.Fatalf("failed to deserialize original program #%v: %v", ti, err)
- }
- p1, ci := Minimize(p, test.callIndex, test.mode, test.pred)
- res := p1.Serialize()
- if string(res) != test.result {
- t.Fatalf("minimization produced wrong result #%v\norig:\n%v\nexpect:\n%v\ngot:\n%v",
- ti, test.orig, test.result, string(res))
- }
- if ci != test.resultCallIndex {
- t.Fatalf("minimization broke call index #%v: got %v, want %v",
- ti, ci, test.resultCallIndex)
- }
+ t.Run(fmt.Sprint(ti), func(t *testing.T) {
+ target, err := GetTarget(test.os, test.arch)
+ if err != nil {
+ t.Fatal(err)
+ }
+ p, err := target.Deserialize([]byte(strings.TrimSpace(test.orig)), Strict)
+ if err != nil {
+ t.Fatalf("failed to deserialize original program #%v: %v", ti, err)
+ }
+ attempt = 0
+ pred := func(p *Prog, callIndex int) bool {
+ res := test.pred(p, callIndex)
+ attempt++
+ return res
+ }
+ p1, ci := Minimize(p, test.callIndex, test.mode, pred)
+ res := strings.TrimSpace(string(p1.Serialize()))
+ expect := strings.TrimSpace(test.result)
+ if res != expect {
+ t.Fatalf("minimization produced wrong result #%v\norig:\n%v\nexpect:\n%v\ngot:\n%v",
+ ti, test.orig, expect, res)
+ }
+ if ci != test.resultCallIndex {
+ t.Fatalf("minimization broke call index #%v: got %v, want %v",
+ ti, ci, test.resultCallIndex)
+ }
+ })
}
}