aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--prog/encoding.go23
-rw-r--r--prog/validation.go5
-rw-r--r--sys/linux/init_test.go28
3 files changed, 49 insertions, 7 deletions
diff --git a/prog/encoding.go b/prog/encoding.go
index db32484a8..06099993f 100644
--- a/prog/encoding.go
+++ b/prog/encoding.go
@@ -235,8 +235,17 @@ func (a *ResultArg) serialize(ctx *serializer) {
type DeserializeMode int
const (
- Strict DeserializeMode = iota
- NonStrict DeserializeMode = iota
+ // In strict mode deserialization fails if the program is malformed in any way.
+ // This mode is used for manually written programs to ensure that they are correct.
+ Strict DeserializeMode = iota
+ // In non-strict mode malformed programs silently fixed in a best-effort way,
+ // e.g. missing/wrong arguments are replaced with default values.
+ // This mode is used for the corpus programs to "repair" them after descriptions changes.
+ NonStrict
+ // Unsafe mode is used for VM checking programs. In this mode programs are not fixed
+ // for safety, e.g. can access global files, issue prohibited ioctl's, disabled syscalls, etc.
+ StrictUnsafe
+ NonStrictUnsafe
)
func (target *Target) Deserialize(data []byte, mode DeserializeMode) (*Prog, error) {
@@ -246,7 +255,8 @@ func (target *Target) Deserialize(data []byte, mode DeserializeMode) (*Prog, err
err, target.OS, target.Arch, GitRevision, mode, data))
}
}()
- p := newParser(target, data, mode == Strict)
+ strict := mode == Strict || mode == StrictUnsafe
+ p := newParser(target, data, strict)
prog, err := p.parseProg()
if err := p.Err(); err != nil {
return nil, err
@@ -260,6 +270,7 @@ func (target *Target) Deserialize(data []byte, mode DeserializeMode) (*Prog, err
if err := prog.validateWithOpts(validationOptions{
// Don't validate auto-set conditional fields. We'll patch them later.
ignoreTransient: true,
+ allowUnsafe: mode == StrictUnsafe || mode == NonStrictUnsafe,
}); err != nil {
return nil, err
}
@@ -267,8 +278,10 @@ func (target *Target) Deserialize(data []byte, mode DeserializeMode) (*Prog, err
if p.autos != nil {
p.fixupAutos(prog)
}
- if err := prog.sanitize(mode == NonStrict); err != nil {
- return nil, err
+ if mode != StrictUnsafe {
+ if err := prog.sanitize(!strict); err != nil {
+ return nil, err
+ }
}
return prog, nil
}
diff --git a/prog/validation.go b/prog/validation.go
index f38dcd1e8..b2a358706 100644
--- a/prog/validation.go
+++ b/prog/validation.go
@@ -34,6 +34,7 @@ type validCtx struct {
type validationOptions struct {
ignoreTransient bool
+ allowUnsafe bool // allow global file names, etc
}
func (p *Prog) validateWithOpts(opts validationOptions) error {
@@ -60,7 +61,7 @@ func (p *Prog) validateWithOpts(opts validationOptions) error {
}
func (ctx *validCtx) validateCall(c *Call) error {
- if c.Meta.Attrs.Disabled {
+ if !ctx.opts.allowUnsafe && c.Meta.Attrs.Disabled {
return fmt.Errorf("use of a disabled call")
}
if c.Props.Rerun > 0 && c.Props.FailNth > 0 {
@@ -210,7 +211,7 @@ func (arg *DataArg) validate(ctx *validCtx, dir Dir) error {
typ.Name(), arg.Size(), typ.TypeSize)
}
case BufferFilename:
- if escapingFilename(string(arg.data)) {
+ if !ctx.opts.allowUnsafe && escapingFilename(string(arg.data)) {
return fmt.Errorf("escaping filename %q", arg.data)
}
}
diff --git a/sys/linux/init_test.go b/sys/linux/init_test.go
index c273a2519..7379f74c0 100644
--- a/sys/linux/init_test.go
+++ b/sys/linux/init_test.go
@@ -155,3 +155,31 @@ syz_open_dev$tty1(0xc, 0x4, 0x1)
},
})
}
+
+func TestDeserializeStrictUnsafe(t *testing.T) {
+ t.Parallel()
+ target, _ := prog.GetTarget("linux", "amd64")
+ // Raw mode must preserve the global file name, allow to use mmap with non-fixed addr,
+ // and allow to use disabled syscalls.
+ had := `openat(0x0, &(0x7f0000000000)='/dev/foo', 0x0, 0x0)
+mmap(0x0, 0x0, 0x0, 0x0, 0x0, 0x0)
+clone(0x0, &(0x7f0000000000), &(0x7f0000000010), &(0x7f0000000020), &(0x7f0000000030))
+`
+ p, err := target.Deserialize([]byte(had), prog.StrictUnsafe)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := string(p.Serialize())
+ if had != got {
+ t.Fatalf("program was changed:\n%s\ngot:\n%s", had, got)
+ }
+}
+
+func TestDeserializeNonStrictUnsafe(t *testing.T) {
+ t.Parallel()
+ target, _ := prog.GetTarget("linux", "amd64")
+ _, err := target.Deserialize([]byte("clone()"), prog.NonStrictUnsafe)
+ if err != nil {
+ t.Fatal(err)
+ }
+}