aboutsummaryrefslogtreecommitdiffstats
path: root/prog
diff options
context:
space:
mode:
Diffstat (limited to 'prog')
-rw-r--r--prog/minimization.go86
-rw-r--r--prog/minimization_test.go119
2 files changed, 187 insertions, 18 deletions
diff --git a/prog/minimization.go b/prog/minimization.go
index 260a15133..c835ea81a 100644
--- a/prog/minimization.go
+++ b/prog/minimization.go
@@ -129,6 +129,11 @@ func removeCalls(p0 *Prog, callIndex0 int, pred minimizePred) (*Prog, int) {
p0 = p
}
}
+
+ if callIndex0 != -1 {
+ p0, callIndex0 = removeUnrelatedCalls(p0, callIndex0, pred)
+ }
+
for i := len(p0.Calls) - 1; i >= 0; i-- {
if i == callIndex0 {
continue
@@ -148,6 +153,87 @@ func removeCalls(p0 *Prog, callIndex0 int, pred minimizePred) (*Prog, int) {
return p0, callIndex0
}
+// removeUnrelatedCalls tries to remove all "unrelated" calls at once.
+// Unrelated calls are the calls that don't use any resources/files from
+// the transitive closure of the resources/files used by the target call.
+// This may significantly reduce large generated programs in a single step.
+func removeUnrelatedCalls(p0 *Prog, callIndex0 int, pred minimizePred) (*Prog, int) {
+ keepCalls := relatedCalls(p0, callIndex0)
+ if len(p0.Calls)-len(keepCalls) < 3 {
+ return p0, callIndex0
+ }
+ p, callIndex := p0.Clone(), callIndex0
+ for i := len(p0.Calls) - 1; i >= 0; i-- {
+ if keepCalls[i] {
+ continue
+ }
+ p.RemoveCall(i)
+ if i < callIndex {
+ callIndex--
+ }
+ }
+ if !pred(p, callIndex, statMinRemoveCall, "unrelated calls") {
+ return p0, callIndex0
+ }
+ return p, callIndex
+}
+
+func relatedCalls(p0 *Prog, callIndex0 int) map[int]bool {
+ keepCalls := map[int]bool{callIndex0: true}
+ used := uses(p0.Calls[callIndex0])
+ for {
+ n := len(used)
+ for i, call := range p0.Calls {
+ if keepCalls[i] {
+ continue
+ }
+ used1 := uses(call)
+ if intersects(used, used1) {
+ keepCalls[i] = true
+ for what := range used1 {
+ used[what] = true
+ }
+ }
+ }
+ if n == len(used) {
+ return keepCalls
+ }
+ }
+}
+
+func uses(call *Call) map[any]bool {
+ used := make(map[any]bool)
+ ForeachArg(call, func(arg Arg, _ *ArgCtx) {
+ switch typ := arg.Type().(type) {
+ case *ResourceType:
+ a := arg.(*ResultArg)
+ used[a] = true
+ if a.Res != nil {
+ used[a.Res] = true
+ }
+ for use := range a.uses {
+ used[use] = true
+ }
+ case *BufferType:
+ a := arg.(*DataArg)
+ if a.Dir() != DirOut && typ.Kind == BufferFilename {
+ val := string(bytes.TrimRight(a.Data(), "\x00"))
+ used[val] = true
+ }
+ }
+ })
+ return used
+}
+
+func intersects(list, list1 map[any]bool) bool {
+ for what := range list1 {
+ if list[what] {
+ return true
+ }
+ }
+ return false
+}
+
func resetCallProps(p0 *Prog, callIndex0 int, pred minimizePred) *Prog {
// Try to reset all call props to their default values.
// This should be reasonable for many progs.
diff --git a/prog/minimization_test.go b/prog/minimization_test.go
index 9fe822577..e8f1d40e5 100644
--- a/prog/minimization_test.go
+++ b/prog/minimization_test.go
@@ -4,7 +4,9 @@
package prog
import (
+ "fmt"
"math/rand"
+ "strings"
"testing"
"github.com/google/syzkaller/pkg/hash"
@@ -12,6 +14,7 @@ import (
// nolint:gocyclo
func TestMinimize(t *testing.T) {
+ attempt := 0
// nolint: lll
tests := []struct {
os string
@@ -243,27 +246,107 @@ func TestMinimize(t *testing.T) {
"syz_mount_image$ext4(&(0x7f0000000000)='ext4\\x00', &(0x7f0000000100)='./file0\\x00', 0x0, &(0x7f0000010020), 0x1, 0x15, &(0x7f0000000200)=\"$eJwqrqzKTszJSS0CBAAA//8TyQPi\")\n",
0,
},
+ // Test for removeUnrelatedCalls.
+ // We test exact candidates we get on each step.
+ // First candidate should be removal of the trailing calls, which we reject.
+ // Next candidate is removal of unrelated calls, which we accept.
+ {
+ "linux", "amd64", MinimizeCorpus,
+ `
+getpid()
+r0 = open(&(0x7f0000000040)='./file0', 0x0, 0x0)
+r1 = open(&(0x7f0000000040)='./file1', 0x0, 0x0)
+getuid()
+read(r1, &(0x7f0000000040), 0x10)
+read(r0, &(0x7f0000000040), 0x10)
+pipe(&(0x7f0000000040)={<r2=>0x0, <r3=>0x0})
+creat(&(0x7f0000000040)='./file0', 0x0)
+close(r1)
+sendfile(r0, r2, &(0x7f0000000040), 0x1)
+getgid()
+fcntl$getflags(r0, 0x0)
+getpid()
+close(r3)
+getuid()
+ `,
+ 11,
+ func(p *Prog, callIndex int) bool {
+ pp := strings.TrimSpace(string(p.Serialize()))
+ if attempt == 0 {
+ if pp == strings.TrimSpace(`
+getpid()
+r0 = open(&(0x7f0000000040)='./file0', 0x0, 0x0)
+r1 = open(&(0x7f0000000040)='./file1', 0x0, 0x0)
+getuid()
+read(r1, &(0x7f0000000040), 0x10)
+read(r0, &(0x7f0000000040), 0x10)
+pipe(&(0x7f0000000040)={<r2=>0x0, 0x0})
+creat(&(0x7f0000000040)='./file0', 0x0)
+close(r1)
+sendfile(r0, r2, &(0x7f0000000040), 0x1)
+getgid()
+fcntl$getflags(r0, 0x0)
+ `) {
+ return false
+ }
+ } else if attempt == 1 {
+ if pp == strings.TrimSpace(`
+r0 = open(&(0x7f0000000040)='./file0', 0x0, 0x0)
+read(r0, &(0x7f0000000040), 0x10)
+pipe(&(0x7f0000000040)={<r1=>0x0, <r2=>0x0})
+creat(&(0x7f0000000040)='./file0', 0x0)
+sendfile(r0, r1, &(0x7f0000000040), 0x1)
+fcntl$getflags(r0, 0x0)
+close(r2)
+ `) {
+ return true
+ }
+ } else {
+ return false
+ }
+ panic(fmt.Sprintf("unexpected candidate on attempt %v:\n%v", attempt, pp))
+ },
+ `
+r0 = open(&(0x7f0000000040)='./file0', 0x0, 0x0)
+read(r0, &(0x7f0000000040), 0x10)
+pipe(&(0x7f0000000040)={<r1=>0x0, <r2=>0x0})
+creat(&(0x7f0000000040)='./file0', 0x0)
+sendfile(r0, r1, &(0x7f0000000040), 0x1)
+fcntl$getflags(r0, 0x0)
+close(r2)
+ `,
+ 5,
+ },
}
t.Parallel()
for ti, test := range tests {
- target, err := GetTarget(test.os, test.arch)
- if err != nil {
- t.Fatal(err)
- }
- p, err := target.Deserialize([]byte(test.orig), Strict)
- if err != nil {
- t.Fatalf("failed to deserialize original program #%v: %v", ti, err)
- }
- p1, ci := Minimize(p, test.callIndex, test.mode, test.pred)
- res := p1.Serialize()
- if string(res) != test.result {
- t.Fatalf("minimization produced wrong result #%v\norig:\n%v\nexpect:\n%v\ngot:\n%v",
- ti, test.orig, test.result, string(res))
- }
- if ci != test.resultCallIndex {
- t.Fatalf("minimization broke call index #%v: got %v, want %v",
- ti, ci, test.resultCallIndex)
- }
+ t.Run(fmt.Sprint(ti), func(t *testing.T) {
+ target, err := GetTarget(test.os, test.arch)
+ if err != nil {
+ t.Fatal(err)
+ }
+ p, err := target.Deserialize([]byte(strings.TrimSpace(test.orig)), Strict)
+ if err != nil {
+ t.Fatalf("failed to deserialize original program #%v: %v", ti, err)
+ }
+ attempt = 0
+ pred := func(p *Prog, callIndex int) bool {
+ res := test.pred(p, callIndex)
+ attempt++
+ return res
+ }
+ p1, ci := Minimize(p, test.callIndex, test.mode, pred)
+ res := strings.TrimSpace(string(p1.Serialize()))
+ expect := strings.TrimSpace(test.result)
+ if res != expect {
+ t.Fatalf("minimization produced wrong result #%v\norig:\n%v\nexpect:\n%v\ngot:\n%v",
+ ti, test.orig, expect, res)
+ }
+ if ci != test.resultCallIndex {
+ t.Fatalf("minimization broke call index #%v: got %v, want %v",
+ ti, ci, test.resultCallIndex)
+ }
+ })
}
}