diff options
| -rw-r--r-- | prog/encoding.go | 23 | ||||
| -rw-r--r-- | prog/validation.go | 5 | ||||
| -rw-r--r-- | sys/linux/init_test.go | 28 |
3 files changed, 49 insertions, 7 deletions
diff --git a/prog/encoding.go b/prog/encoding.go index db32484a8..06099993f 100644 --- a/prog/encoding.go +++ b/prog/encoding.go @@ -235,8 +235,17 @@ func (a *ResultArg) serialize(ctx *serializer) { type DeserializeMode int const ( - Strict DeserializeMode = iota - NonStrict DeserializeMode = iota + // In strict mode deserialization fails if the program is malformed in any way. + // This mode is used for manually written programs to ensure that they are correct. + Strict DeserializeMode = iota + // In non-strict mode malformed programs silently fixed in a best-effort way, + // e.g. missing/wrong arguments are replaced with default values. + // This mode is used for the corpus programs to "repair" them after descriptions changes. + NonStrict + // Unsafe mode is used for VM checking programs. In this mode programs are not fixed + // for safety, e.g. can access global files, issue prohibited ioctl's, disabled syscalls, etc. + StrictUnsafe + NonStrictUnsafe ) func (target *Target) Deserialize(data []byte, mode DeserializeMode) (*Prog, error) { @@ -246,7 +255,8 @@ func (target *Target) Deserialize(data []byte, mode DeserializeMode) (*Prog, err err, target.OS, target.Arch, GitRevision, mode, data)) } }() - p := newParser(target, data, mode == Strict) + strict := mode == Strict || mode == StrictUnsafe + p := newParser(target, data, strict) prog, err := p.parseProg() if err := p.Err(); err != nil { return nil, err @@ -260,6 +270,7 @@ func (target *Target) Deserialize(data []byte, mode DeserializeMode) (*Prog, err if err := prog.validateWithOpts(validationOptions{ // Don't validate auto-set conditional fields. We'll patch them later. ignoreTransient: true, + allowUnsafe: mode == StrictUnsafe || mode == NonStrictUnsafe, }); err != nil { return nil, err } @@ -267,8 +278,10 @@ func (target *Target) Deserialize(data []byte, mode DeserializeMode) (*Prog, err if p.autos != nil { p.fixupAutos(prog) } - if err := prog.sanitize(mode == NonStrict); err != nil { - return nil, err + if mode != StrictUnsafe { + if err := prog.sanitize(!strict); err != nil { + return nil, err + } } return prog, nil } diff --git a/prog/validation.go b/prog/validation.go index f38dcd1e8..b2a358706 100644 --- a/prog/validation.go +++ b/prog/validation.go @@ -34,6 +34,7 @@ type validCtx struct { type validationOptions struct { ignoreTransient bool + allowUnsafe bool // allow global file names, etc } func (p *Prog) validateWithOpts(opts validationOptions) error { @@ -60,7 +61,7 @@ func (p *Prog) validateWithOpts(opts validationOptions) error { } func (ctx *validCtx) validateCall(c *Call) error { - if c.Meta.Attrs.Disabled { + if !ctx.opts.allowUnsafe && c.Meta.Attrs.Disabled { return fmt.Errorf("use of a disabled call") } if c.Props.Rerun > 0 && c.Props.FailNth > 0 { @@ -210,7 +211,7 @@ func (arg *DataArg) validate(ctx *validCtx, dir Dir) error { typ.Name(), arg.Size(), typ.TypeSize) } case BufferFilename: - if escapingFilename(string(arg.data)) { + if !ctx.opts.allowUnsafe && escapingFilename(string(arg.data)) { return fmt.Errorf("escaping filename %q", arg.data) } } diff --git a/sys/linux/init_test.go b/sys/linux/init_test.go index c273a2519..7379f74c0 100644 --- a/sys/linux/init_test.go +++ b/sys/linux/init_test.go @@ -155,3 +155,31 @@ syz_open_dev$tty1(0xc, 0x4, 0x1) }, }) } + +func TestDeserializeStrictUnsafe(t *testing.T) { + t.Parallel() + target, _ := prog.GetTarget("linux", "amd64") + // Raw mode must preserve the global file name, allow to use mmap with non-fixed addr, + // and allow to use disabled syscalls. + had := `openat(0x0, &(0x7f0000000000)='/dev/foo', 0x0, 0x0) +mmap(0x0, 0x0, 0x0, 0x0, 0x0, 0x0) +clone(0x0, &(0x7f0000000000), &(0x7f0000000010), &(0x7f0000000020), &(0x7f0000000030)) +` + p, err := target.Deserialize([]byte(had), prog.StrictUnsafe) + if err != nil { + t.Fatal(err) + } + got := string(p.Serialize()) + if had != got { + t.Fatalf("program was changed:\n%s\ngot:\n%s", had, got) + } +} + +func TestDeserializeNonStrictUnsafe(t *testing.T) { + t.Parallel() + target, _ := prog.GetTarget("linux", "amd64") + _, err := target.Deserialize([]byte("clone()"), prog.NonStrictUnsafe) + if err != nil { + t.Fatal(err) + } +} |
