diff options
Diffstat (limited to 'prog')
| -rw-r--r-- | prog/minimization.go | 86 | ||||
| -rw-r--r-- | prog/minimization_test.go | 119 |
2 files changed, 187 insertions, 18 deletions
diff --git a/prog/minimization.go b/prog/minimization.go index 260a15133..c835ea81a 100644 --- a/prog/minimization.go +++ b/prog/minimization.go @@ -129,6 +129,11 @@ func removeCalls(p0 *Prog, callIndex0 int, pred minimizePred) (*Prog, int) { p0 = p } } + + if callIndex0 != -1 { + p0, callIndex0 = removeUnrelatedCalls(p0, callIndex0, pred) + } + for i := len(p0.Calls) - 1; i >= 0; i-- { if i == callIndex0 { continue @@ -148,6 +153,87 @@ func removeCalls(p0 *Prog, callIndex0 int, pred minimizePred) (*Prog, int) { return p0, callIndex0 } +// removeUnrelatedCalls tries to remove all "unrelated" calls at once. +// Unrelated calls are the calls that don't use any resources/files from +// the transitive closure of the resources/files used by the target call. +// This may significantly reduce large generated programs in a single step. +func removeUnrelatedCalls(p0 *Prog, callIndex0 int, pred minimizePred) (*Prog, int) { + keepCalls := relatedCalls(p0, callIndex0) + if len(p0.Calls)-len(keepCalls) < 3 { + return p0, callIndex0 + } + p, callIndex := p0.Clone(), callIndex0 + for i := len(p0.Calls) - 1; i >= 0; i-- { + if keepCalls[i] { + continue + } + p.RemoveCall(i) + if i < callIndex { + callIndex-- + } + } + if !pred(p, callIndex, statMinRemoveCall, "unrelated calls") { + return p0, callIndex0 + } + return p, callIndex +} + +func relatedCalls(p0 *Prog, callIndex0 int) map[int]bool { + keepCalls := map[int]bool{callIndex0: true} + used := uses(p0.Calls[callIndex0]) + for { + n := len(used) + for i, call := range p0.Calls { + if keepCalls[i] { + continue + } + used1 := uses(call) + if intersects(used, used1) { + keepCalls[i] = true + for what := range used1 { + used[what] = true + } + } + } + if n == len(used) { + return keepCalls + } + } +} + +func uses(call *Call) map[any]bool { + used := make(map[any]bool) + ForeachArg(call, func(arg Arg, _ *ArgCtx) { + switch typ := arg.Type().(type) { + case *ResourceType: + a := arg.(*ResultArg) + used[a] = true + if a.Res != nil { + used[a.Res] = true + } + for use := range a.uses { + used[use] = true + } + case *BufferType: + a := arg.(*DataArg) + if a.Dir() != DirOut && typ.Kind == BufferFilename { + val := string(bytes.TrimRight(a.Data(), "\x00")) + used[val] = true + } + } + }) + return used +} + +func intersects(list, list1 map[any]bool) bool { + for what := range list1 { + if list[what] { + return true + } + } + return false +} + func resetCallProps(p0 *Prog, callIndex0 int, pred minimizePred) *Prog { // Try to reset all call props to their default values. // This should be reasonable for many progs. diff --git a/prog/minimization_test.go b/prog/minimization_test.go index 9fe822577..e8f1d40e5 100644 --- a/prog/minimization_test.go +++ b/prog/minimization_test.go @@ -4,7 +4,9 @@ package prog import ( + "fmt" "math/rand" + "strings" "testing" "github.com/google/syzkaller/pkg/hash" @@ -12,6 +14,7 @@ import ( // nolint:gocyclo func TestMinimize(t *testing.T) { + attempt := 0 // nolint: lll tests := []struct { os string @@ -243,27 +246,107 @@ func TestMinimize(t *testing.T) { "syz_mount_image$ext4(&(0x7f0000000000)='ext4\\x00', &(0x7f0000000100)='./file0\\x00', 0x0, &(0x7f0000010020), 0x1, 0x15, &(0x7f0000000200)=\"$eJwqrqzKTszJSS0CBAAA//8TyQPi\")\n", 0, }, + // Test for removeUnrelatedCalls. + // We test exact candidates we get on each step. + // First candidate should be removal of the trailing calls, which we reject. + // Next candidate is removal of unrelated calls, which we accept. + { + "linux", "amd64", MinimizeCorpus, + ` +getpid() +r0 = open(&(0x7f0000000040)='./file0', 0x0, 0x0) +r1 = open(&(0x7f0000000040)='./file1', 0x0, 0x0) +getuid() +read(r1, &(0x7f0000000040), 0x10) +read(r0, &(0x7f0000000040), 0x10) +pipe(&(0x7f0000000040)={<r2=>0x0, <r3=>0x0}) +creat(&(0x7f0000000040)='./file0', 0x0) +close(r1) +sendfile(r0, r2, &(0x7f0000000040), 0x1) +getgid() +fcntl$getflags(r0, 0x0) +getpid() +close(r3) +getuid() + `, + 11, + func(p *Prog, callIndex int) bool { + pp := strings.TrimSpace(string(p.Serialize())) + if attempt == 0 { + if pp == strings.TrimSpace(` +getpid() +r0 = open(&(0x7f0000000040)='./file0', 0x0, 0x0) +r1 = open(&(0x7f0000000040)='./file1', 0x0, 0x0) +getuid() +read(r1, &(0x7f0000000040), 0x10) +read(r0, &(0x7f0000000040), 0x10) +pipe(&(0x7f0000000040)={<r2=>0x0, 0x0}) +creat(&(0x7f0000000040)='./file0', 0x0) +close(r1) +sendfile(r0, r2, &(0x7f0000000040), 0x1) +getgid() +fcntl$getflags(r0, 0x0) + `) { + return false + } + } else if attempt == 1 { + if pp == strings.TrimSpace(` +r0 = open(&(0x7f0000000040)='./file0', 0x0, 0x0) +read(r0, &(0x7f0000000040), 0x10) +pipe(&(0x7f0000000040)={<r1=>0x0, <r2=>0x0}) +creat(&(0x7f0000000040)='./file0', 0x0) +sendfile(r0, r1, &(0x7f0000000040), 0x1) +fcntl$getflags(r0, 0x0) +close(r2) + `) { + return true + } + } else { + return false + } + panic(fmt.Sprintf("unexpected candidate on attempt %v:\n%v", attempt, pp)) + }, + ` +r0 = open(&(0x7f0000000040)='./file0', 0x0, 0x0) +read(r0, &(0x7f0000000040), 0x10) +pipe(&(0x7f0000000040)={<r1=>0x0, <r2=>0x0}) +creat(&(0x7f0000000040)='./file0', 0x0) +sendfile(r0, r1, &(0x7f0000000040), 0x1) +fcntl$getflags(r0, 0x0) +close(r2) + `, + 5, + }, } t.Parallel() for ti, test := range tests { - target, err := GetTarget(test.os, test.arch) - if err != nil { - t.Fatal(err) - } - p, err := target.Deserialize([]byte(test.orig), Strict) - if err != nil { - t.Fatalf("failed to deserialize original program #%v: %v", ti, err) - } - p1, ci := Minimize(p, test.callIndex, test.mode, test.pred) - res := p1.Serialize() - if string(res) != test.result { - t.Fatalf("minimization produced wrong result #%v\norig:\n%v\nexpect:\n%v\ngot:\n%v", - ti, test.orig, test.result, string(res)) - } - if ci != test.resultCallIndex { - t.Fatalf("minimization broke call index #%v: got %v, want %v", - ti, ci, test.resultCallIndex) - } + t.Run(fmt.Sprint(ti), func(t *testing.T) { + target, err := GetTarget(test.os, test.arch) + if err != nil { + t.Fatal(err) + } + p, err := target.Deserialize([]byte(strings.TrimSpace(test.orig)), Strict) + if err != nil { + t.Fatalf("failed to deserialize original program #%v: %v", ti, err) + } + attempt = 0 + pred := func(p *Prog, callIndex int) bool { + res := test.pred(p, callIndex) + attempt++ + return res + } + p1, ci := Minimize(p, test.callIndex, test.mode, pred) + res := strings.TrimSpace(string(p1.Serialize())) + expect := strings.TrimSpace(test.result) + if res != expect { + t.Fatalf("minimization produced wrong result #%v\norig:\n%v\nexpect:\n%v\ngot:\n%v", + ti, test.orig, expect, res) + } + if ci != test.resultCallIndex { + t.Fatalf("minimization broke call index #%v: got %v, want %v", + ti, ci, test.resultCallIndex) + } + }) } } |
