diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2020-04-26 14:14:14 +0200 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2020-05-01 13:31:17 +0200 |
| commit | e54e9781a4e043b3140b0c908ba4f4e469fd317e (patch) | |
| tree | 16e6387d78a8577c5f3d9fb8d05a51752da6338e | |
| parent | 3f4dbb2f6fff9479d6c250e224bc3cb7f5cd66ed (diff) | |
prog: remove Dir from Type
Having Dir is Type is handy, but forces us to duplicate lots of types.
E.g. if a struct is referenced as both in and out, then we need to
have 2 copies and 2 copies of structs/types it includes.
If also prevents us from having the struct type as struct identity
(because we can have up to 3 of them).
Revert to the old way we used to do it: propagate Dir as we walk
syscall arguments. This moves lots of dir passing from pkg/compiler
to prog package.
Now Arg contains the dir, so once we build the tree, we can use dirs
as before.
Reduces size of sys/linux/gen/amd64.go from 6058336 to 5661150 (-6.6%).
Update #1580
| -rw-r--r-- | pkg/compiler/check.go | 14 | ||||
| -rw-r--r-- | pkg/compiler/compiler.go | 7 | ||||
| -rw-r--r-- | pkg/compiler/gen.go | 36 | ||||
| -rw-r--r-- | pkg/compiler/types.go | 15 | ||||
| -rw-r--r-- | prog/analysis.go | 4 | ||||
| -rw-r--r-- | prog/any.go | 34 | ||||
| -rw-r--r-- | prog/encoding.go | 118 | ||||
| -rw-r--r-- | prog/encodingexec.go | 2 | ||||
| -rw-r--r-- | prog/hints.go | 2 | ||||
| -rw-r--r-- | prog/hints_test.go | 8 | ||||
| -rw-r--r-- | prog/minimization.go | 6 | ||||
| -rw-r--r-- | prog/mutation.go | 28 | ||||
| -rw-r--r-- | prog/prio.go | 16 | ||||
| -rw-r--r-- | prog/prog.go | 60 | ||||
| -rw-r--r-- | prog/prog_test.go | 17 | ||||
| -rw-r--r-- | prog/rand.go | 145 | ||||
| -rw-r--r-- | prog/rand_test.go | 4 | ||||
| -rw-r--r-- | prog/resources.go | 26 | ||||
| -rw-r--r-- | prog/rotation.go | 2 | ||||
| -rw-r--r-- | prog/target.go | 20 | ||||
| -rw-r--r-- | prog/types.go | 148 | ||||
| -rw-r--r-- | prog/validation.go | 34 | ||||
| -rw-r--r-- | sys/linux/init.go | 39 | ||||
| -rw-r--r-- | sys/linux/init_alg.go | 41 | ||||
| -rw-r--r-- | sys/linux/init_iptables.go | 16 | ||||
| -rw-r--r-- | sys/linux/init_vusb.go | 8 | ||||
| -rw-r--r-- | sys/targets/common.go | 18 | ||||
| -rw-r--r-- | sys/windows/init.go | 8 | ||||
| -rw-r--r-- | tools/syz-trace2syz/proggen/generate_unions.go | 12 | ||||
| -rw-r--r-- | tools/syz-trace2syz/proggen/proggen.go | 136 |
30 files changed, 521 insertions, 503 deletions
diff --git a/pkg/compiler/check.go b/pkg/compiler/check.go index e22f217fd..2c0438086 100644 --- a/pkg/compiler/check.go +++ b/pkg/compiler/check.go @@ -323,7 +323,7 @@ func (comp *compiler) checkLenType(t0, t *ast.Type, parents []parentDesc, warned[parentName] = true return } - _, args, _ := comp.getArgsBase(t, "", prog.DirIn, isArg) + _, args, _ := comp.getArgsBase(t, "", isArg) for i, arg := range args { argDesc := desc.Args[i] if argDesc.Type == typeArgLenTarget { @@ -522,7 +522,7 @@ func (comp *compiler) collectUsedType(structs, flags, strflags map[string]bool, } return } - _, args, _ := comp.getArgsBase(t, "", prog.DirIn, isArg) + _, args, _ := comp.getArgsBase(t, "", isArg) for i, arg := range args { if desc.Args[i].Type == typeArgType { comp.collectUsedType(structs, flags, strflags, arg, desc.Args[i].IsArg) @@ -603,7 +603,7 @@ func (comp *compiler) checkTypeCtors(t *ast.Type, dir prog.Dir, isArg bool, if desc == typePtr { dir = genDir(t.Args[0]) } - _, args, _ := comp.getArgsBase(t, "", dir, isArg) + _, args, _ := comp.getArgsBase(t, "", isArg) for i, arg := range args { if desc.Args[i].Type == typeArgType { comp.checkTypeCtors(arg, dir, desc.Args[i].IsArg, ctors, checked) @@ -684,7 +684,7 @@ func (comp *compiler) recurseField(checked map[string]bool, t *ast.Type, path [] comp.checkStructRecursion(checked, comp.structs[t.Ident], path) return } - _, args, base := comp.getArgsBase(t, "", prog.DirIn, false) + _, args, base := comp.getArgsBase(t, "", false) if desc == typePtr && base.IsOptional { return // optional pointers prune recursion } @@ -774,7 +774,7 @@ func (comp *compiler) checkType(ctx checkCtx, t *ast.Type, flags checkFlags) { return } if desc.Check != nil { - _, args, base := comp.getArgsBase(t, "", prog.DirIn, flags&checkIsArg != 0) + _, args, base := comp.getArgsBase(t, "", flags&checkIsArg != 0) desc.Check(comp, t, args, base) } } @@ -1098,12 +1098,12 @@ func (comp *compiler) checkVarlens() { } func (comp *compiler) isVarlen(t *ast.Type) bool { - desc, args, _ := comp.getArgsBase(t, "", prog.DirIn, false) + desc, args, _ := comp.getArgsBase(t, "", false) return desc.Varlen != nil && desc.Varlen(comp, t, args) } func (comp *compiler) isZeroSize(t *ast.Type) bool { - desc, args, _ := comp.getArgsBase(t, "", prog.DirIn, false) + desc, args, _ := comp.getArgsBase(t, "", false) return desc.ZeroSize != nil && desc.ZeroSize(comp, t, args) } diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 7316b4d4c..8e5afddde 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -249,14 +249,13 @@ func (comp *compiler) getTypeDesc(t *ast.Type) *typeDesc { return nil } -func (comp *compiler) getArgsBase(t *ast.Type, field string, dir prog.Dir, isArg bool) ( - *typeDesc, []*ast.Type, prog.IntTypeCommon) { +func (comp *compiler) getArgsBase(t *ast.Type, field string, isArg bool) (*typeDesc, []*ast.Type, prog.IntTypeCommon) { desc := comp.getTypeDesc(t) if desc == nil { panic(fmt.Sprintf("no type desc for %#v", *t)) } args, opt := removeOpt(t) - com := genCommon(t.Ident, field, sizeUnassigned, dir, opt != nil) + com := genCommon(t.Ident, field, sizeUnassigned, opt != nil) base := genIntCommon(com, 0, false) if desc.NeedBase { base.TypeSize = comp.ptrSize @@ -306,7 +305,7 @@ func (comp *compiler) foreachType(n0 ast.Node, func (comp *compiler) foreachSubType(t *ast.Type, isArg bool, cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) { - desc, args, base := comp.getArgsBase(t, "", prog.DirIn, isArg) + desc, args, base := comp.getArgsBase(t, "", isArg) cb(t, desc, args, base) for i, arg := range args { if desc.Args[i].Type == typeArgType { diff --git a/pkg/compiler/gen.go b/pkg/compiler/gen.go index c8255c554..61dd85eae 100644 --- a/pkg/compiler/gen.go +++ b/pkg/compiler/gen.go @@ -65,8 +65,8 @@ func (comp *compiler) collectCallArgSizes() map[string][]uint64 { if len(argSizes) <= i { argSizes = append(argSizes, comp.ptrSize) } - desc, _, _ := comp.getArgsBase(arg.Type, arg.Name.Name, prog.DirIn, true) - typ := comp.genField(arg, prog.DirIn, comp.ptrSize) + desc, _, _ := comp.getArgsBase(arg.Type, arg.Name.Name, true) + typ := comp.genField(arg, comp.ptrSize) // Ignore all types with base (const, flags). We don't have base in syscall args. // Also ignore resources and pointers because fd can be 32-bits and pointer 64-bits, // and then there is no way to fix this. @@ -112,7 +112,7 @@ func (comp *compiler) genSyscalls() []*prog.Syscall { func (comp *compiler) genSyscall(n *ast.Call, argSizes []uint64) *prog.Syscall { var ret prog.Type if n.Ret != nil { - ret = comp.genType(n.Ret, "ret", prog.DirOut, comp.ptrSize) + ret = comp.genType(n.Ret, "ret", comp.ptrSize) } var attrs prog.SyscallAttrs descAttrs := comp.parseAttrs(callAttrs, n, n.Attrs) @@ -129,7 +129,7 @@ func (comp *compiler) genSyscall(n *ast.Call, argSizes []uint64) *prog.Syscall { CallName: n.CallName, NR: n.NR, MissingArgs: len(argSizes) - len(n.Args), - Args: comp.genFieldArray(n.Args, prog.DirIn, argSizes), + Args: comp.genFieldArray(n.Args, argSizes), Ret: ret, Attrs: attrs, } @@ -237,10 +237,7 @@ func (comp *compiler) genStructDescs(syscalls []*prog.Syscall) []*prog.KeyedStru sort.Slice(ctx.structs, func(i, j int) bool { si, sj := ctx.structs[i].Key, ctx.structs[j].Key - if si.Name != sj.Name { - return si.Name < sj.Name - } - return si.Dir < sj.Dir + return si.Name < sj.Name }) return ctx.structs } @@ -371,14 +368,14 @@ func (ctx *structGen) walkUnion(t *prog.UnionType) { } } -func (comp *compiler) genStructDesc(res *prog.StructDesc, n *ast.Struct, dir prog.Dir, varlen bool) { +func (comp *compiler) genStructDesc(res *prog.StructDesc, n *ast.Struct, varlen bool) { // Leave node for genStructDescs to calculate size/padding. comp.structNodes[res] = n - common := genCommon(n.Name.Name, "", sizeUnassigned, dir, false) + common := genCommon(n.Name.Name, "", sizeUnassigned, false) common.IsVarlen = varlen *res = prog.StructDesc{ TypeCommon: common, - Fields: comp.genFieldArray(n.Fields, dir, make([]uint64, len(n.Fields))), + Fields: comp.genFieldArray(n.Fields, make([]uint64, len(n.Fields))), } } @@ -570,25 +567,25 @@ func (comp *compiler) typeAlign(t0 prog.Type) uint64 { func genPad(size uint64) prog.Type { return &prog.ConstType{ - IntTypeCommon: genIntCommon(genCommon("pad", "", size, prog.DirIn, false), 0, false), + IntTypeCommon: genIntCommon(genCommon("pad", "", size, false), 0, false), IsPad: true, } } -func (comp *compiler) genFieldArray(fields []*ast.Field, dir prog.Dir, argSizes []uint64) []prog.Type { +func (comp *compiler) genFieldArray(fields []*ast.Field, argSizes []uint64) []prog.Type { var res []prog.Type for i, f := range fields { - res = append(res, comp.genField(f, dir, argSizes[i])) + res = append(res, comp.genField(f, argSizes[i])) } return res } -func (comp *compiler) genField(f *ast.Field, dir prog.Dir, argSize uint64) prog.Type { - return comp.genType(f.Type, f.Name.Name, dir, argSize) +func (comp *compiler) genField(f *ast.Field, argSize uint64) prog.Type { + return comp.genType(f.Type, f.Name.Name, argSize) } -func (comp *compiler) genType(t *ast.Type, field string, dir prog.Dir, argSize uint64) prog.Type { - desc, args, base := comp.getArgsBase(t, field, dir, argSize != 0) +func (comp *compiler) genType(t *ast.Type, field string, argSize uint64) prog.Type { + desc, args, base := comp.getArgsBase(t, field, argSize != 0) if desc.Gen == nil { panic(fmt.Sprintf("no gen for %v %#v", field, t)) } @@ -605,12 +602,11 @@ func (comp *compiler) genType(t *ast.Type, field string, dir prog.Dir, argSize u return desc.Gen(comp, t, args, base) } -func genCommon(name, field string, size uint64, dir prog.Dir, opt bool) prog.TypeCommon { +func genCommon(name, field string, size uint64, opt bool) prog.TypeCommon { return prog.TypeCommon{ TypeName: name, TypeSize: size, FldName: field, - ArgDir: dir, IsOptional: opt, } } diff --git a/pkg/compiler/types.go b/pkg/compiler/types.go index fce1153de..6f2c26a6e 100644 --- a/pkg/compiler/types.go +++ b/pkg/compiler/types.go @@ -156,14 +156,14 @@ var typePtr = &typeDesc{ CanBeTypedef: true, Args: []namedArg{{Name: "direction", Type: typeArgDir}, {Name: "type", Type: typeArgType}}, Gen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type { - base.ArgDir = prog.DirIn // pointers are always in base.TypeSize = comp.ptrSize if t.Ident == "ptr64" { base.TypeSize = 8 } return &prog.PtrType{ TypeCommon: base.TypeCommon, - Type: comp.genType(args[1], "", genDir(args[0]), 0), + Type: comp.genType(args[1], "", 0), + ElemDir: genDir(args[0]), } }, } @@ -212,7 +212,7 @@ var typeArray = &typeDesc{ return comp.isZeroSize(args[0]) }, Gen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type { - elemType := comp.genType(args[0], "", base.ArgDir, 0) + elemType := comp.genType(args[0], "", 0) kind, begin, end := prog.ArrayRandLen, uint64(0), uint64(0) if len(args) > 1 { kind, begin, end = prog.ArrayRangeLen, args[1].Value, args[1].Value @@ -696,7 +696,7 @@ var typeFmt = &typeDesc{ {Name: "value", Type: typeArgType, IsArg: true}, }, Check: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) { - desc, _, _ := comp.getArgsBase(args[1], "", base.TypeCommon.ArgDir, true) + desc, _, _ := comp.getArgsBase(args[1], "", true) switch desc { case typeResource, typeInt, typeLen, typeFlags, typeProc: default: @@ -718,7 +718,7 @@ var typeFmt = &typeDesc{ format = prog.FormatStrOct size = 23 } - typ := comp.genType(args[1], "", base.TypeCommon.ArgDir, comp.ptrSize) + typ := comp.genType(args[1], "", comp.ptrSize) switch t := typ.(type) { case *prog.ResourceType: t.ArgFormat = format @@ -767,7 +767,7 @@ func init() { baseType = r.Base r = comp.resources[r.Base.Ident] } - baseProgType := comp.genType(baseType, "", prog.DirIn, 0) + baseProgType := comp.genType(baseType, "", 0) base.TypeSize = baseProgType.Size() return &prog.ResourceType{ TypeCommon: base.TypeCommon, @@ -818,14 +818,13 @@ func init() { s := comp.structs[t.Ident] key := prog.StructKey{ Name: t.Ident, - Dir: base.ArgDir, } desc := comp.structDescs[key] if desc == nil { // Need to assign to structDescs before calling genStructDesc to break recursion. desc = new(prog.StructDesc) comp.structDescs[key] = desc - comp.genStructDesc(desc, s, base.ArgDir, typeStruct.Varlen(comp, t, args)) + comp.genStructDesc(desc, s, typeStruct.Varlen(comp, t, args)) } if s.IsUnion { return &prog.UnionType{ diff --git a/prog/analysis.go b/prog/analysis.go index fe022b670..e1fbaa557 100644 --- a/prog/analysis.go +++ b/prog/analysis.go @@ -69,13 +69,13 @@ func (s *state) analyzeImpl(c *Call, resources bool) { switch typ := arg.Type().(type) { case *ResourceType: a := arg.(*ResultArg) - if resources && typ.Dir() != DirIn { + if resources && a.Dir() != DirIn { s.resources[typ.Desc.Name] = append(s.resources[typ.Desc.Name], a) // TODO: negative PIDs and add them as well (that's process groups). } case *BufferType: a := arg.(*DataArg) - if typ.Dir() != DirOut && len(a.Data()) != 0 { + if a.Dir() != DirOut && len(a.Data()) != 0 { val := string(a.Data()) // Remove trailing zero padding. for len(val) >= 2 && val[len(val)-1] == 0 && val[len(val)-2] == 0 { diff --git a/prog/any.go b/prog/any.go index 15ce6ec53..d1433b18d 100644 --- a/prog/any.go +++ b/prog/any.go @@ -54,7 +54,8 @@ func initAnyTypes(target *Target) { TypeSize: target.PtrSize, IsOptional: true, }, - Type: target.any.array, + Type: target.any.array, + ElemDir: DirIn, } target.any.ptr64 = &PtrType{ TypeCommon: TypeCommon{ @@ -63,7 +64,8 @@ func initAnyTypes(target *Target) { TypeSize: 8, IsOptional: true, }, - Type: target.any.array, + Type: target.any.array, + ElemDir: DirIn, } target.any.blob = &BufferType{ TypeCommon: TypeCommon{ @@ -77,7 +79,6 @@ func initAnyTypes(target *Target) { TypeCommon: TypeCommon{ TypeName: name, FldName: name, - ArgDir: DirIn, TypeSize: size, IsOptional: true, }, @@ -100,7 +101,6 @@ func initAnyTypes(target *Target) { TypeName: "ANYUNION", FldName: "ANYUNION", IsVarlen: true, - ArgDir: DirIn, }, Fields: []Type{ target.any.blob, @@ -150,7 +150,7 @@ func (p *Prog) complexPtrs() (res []*PointerArg) { } func (target *Target) isComplexPtr(arg *PointerArg) bool { - if arg.Res == nil || arg.Type().Dir() != DirIn { + if arg.Res == nil || arg.Dir() != DirIn { return false } if target.isAnyPtr(arg.Type()) { @@ -175,6 +175,15 @@ func (target *Target) isComplexPtr(arg *PointerArg) bool { return complex && !hasPtr } +func (target *Target) isAnyRes(name string) bool { + return name == target.any.res16.TypeName || + name == target.any.res32.TypeName || + name == target.any.res64.TypeName || + name == target.any.resdec.TypeName || + name == target.any.reshex.TypeName || + name == target.any.resoct.TypeName +} + func (target *Target) CallContainsAny(c *Call) (res bool) { ForeachArg(c, func(arg Arg, ctx *ArgCtx) { if target.isAnyPtr(arg.Type()) { @@ -208,7 +217,7 @@ func (target *Target) squashPtr(arg *PointerArg, preserveField bool) { field = arg.Type().FieldName() } arg.typ = target.makeAnyPtrType(arg.Type().Size(), field) - arg.Res = MakeGroupArg(arg.typ.(*PtrType).Type, elems) + arg.Res = MakeGroupArg(arg.typ.(*PtrType).Type, DirIn, elems) if size := arg.Res.Size(); size != size0 { panic(fmt.Sprintf("squash changed size %v->%v for %v", size0, size, res0.Type())) } @@ -230,7 +239,7 @@ func (target *Target) squashPtrImpl(a Arg, elems *[]Arg) { } target.squashPtrImpl(arg.Option, elems) case *DataArg: - if arg.Type().Dir() == DirOut { + if arg.Dir() == DirOut { pad = arg.Size() } else { elem := target.ensureDataElem(elems) @@ -299,7 +308,8 @@ func (target *Target) squashResult(arg *ResultArg, elems *[]Arg) { default: panic("bad") } - *elems = append(*elems, MakeUnionArg(target.any.union, arg)) + arg.dir = DirIn + *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, arg)) } func (target *Target) squashGroup(arg *GroupArg, elems *[]Arg) { @@ -375,14 +385,14 @@ func (target *Target) squashedValue(arg *ConstArg) (uint64, BinaryFormat) { func (target *Target) ensureDataElem(elems *[]Arg) *DataArg { if len(*elems) == 0 { - res := MakeDataArg(target.any.blob, nil) - *elems = append(*elems, MakeUnionArg(target.any.union, res)) + res := MakeDataArg(target.any.blob, DirIn, nil) + *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, res)) return res } res, ok := (*elems)[len(*elems)-1].(*UnionArg).Option.(*DataArg) if !ok { - res = MakeDataArg(target.any.blob, nil) - *elems = append(*elems, MakeUnionArg(target.any.union, res)) + res = MakeDataArg(target.any.blob, DirIn, nil) + *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, res)) } return res } diff --git a/prog/encoding.go b/prog/encoding.go index f99ff9d84..c7f8ba56a 100644 --- a/prog/encoding.go +++ b/prog/encoding.go @@ -113,7 +113,7 @@ func (a *PointerArg) serialize(ctx *serializer) { func (a *DataArg) serialize(ctx *serializer) { typ := a.Type().(*BufferType) - if typ.Dir() == DirOut { + if a.Dir() == DirOut { ctx.printf("\"\"/%v", a.Size()) return } @@ -280,7 +280,7 @@ func (p *parser) parseProg() (*Prog, error) { if IsPad(typ) { return nil, fmt.Errorf("padding in syscall %v arguments", name) } - arg, err := p.parseArg(typ) + arg, err := p.parseArg(typ, DirIn) if err != nil { return nil, err } @@ -302,7 +302,7 @@ func (p *parser) parseProg() (*Prog, error) { } for i := len(c.Args); i < len(meta.Args); i++ { p.strictFailf("missing syscall args") - c.Args = append(c.Args, meta.Args[i].DefaultArg()) + c.Args = append(c.Args, meta.Args[i].DefaultArg(DirIn)) } if len(c.Args) != len(meta.Args) { return nil, fmt.Errorf("wrong call arg count: %v, want %v", len(c.Args), len(meta.Args)) @@ -318,7 +318,7 @@ func (p *parser) parseProg() (*Prog, error) { return prog, nil } -func (p *parser) parseArg(typ Type) (Arg, error) { +func (p *parser) parseArg(typ Type, dir Dir) (Arg, error) { r := "" if p.Char() == '<' { p.Parse('<') @@ -326,13 +326,13 @@ func (p *parser) parseArg(typ Type) (Arg, error) { p.Parse('=') p.Parse('>') } - arg, err := p.parseArgImpl(typ) + arg, err := p.parseArgImpl(typ, dir) if err != nil { return nil, err } if arg == nil { if typ != nil { - arg = typ.DefaultArg() + arg = typ.DefaultArg(dir) } else if r != "" { return nil, fmt.Errorf("named nil argument") } @@ -345,26 +345,26 @@ func (p *parser) parseArg(typ Type) (Arg, error) { return arg, nil } -func (p *parser) parseArgImpl(typ Type) (Arg, error) { +func (p *parser) parseArgImpl(typ Type, dir Dir) (Arg, error) { if typ == nil && p.Char() != 'n' { p.eatExcessive(true, "non-nil argument for nil type") return nil, nil } switch p.Char() { case '0': - return p.parseArgInt(typ) + return p.parseArgInt(typ, dir) case 'r': - return p.parseArgRes(typ) + return p.parseArgRes(typ, dir) case '&': - return p.parseArgAddr(typ) + return p.parseArgAddr(typ, dir) case '"', '\'': - return p.parseArgString(typ) + return p.parseArgString(typ, dir) case '{': - return p.parseArgStruct(typ) + return p.parseArgStruct(typ, dir) case '[': - return p.parseArgArray(typ) + return p.parseArgArray(typ, dir) case '@': - return p.parseArgUnion(typ) + return p.parseArgUnion(typ, dir) case 'n': p.Parse('n') p.Parse('i') @@ -375,14 +375,14 @@ func (p *parser) parseArgImpl(typ Type) (Arg, error) { p.Parse('U') p.Parse('T') p.Parse('O') - return p.parseAuto(typ) + return p.parseAuto(typ, dir) default: return nil, fmt.Errorf("failed to parse argument at '%c' (line #%v/%v: %v)", p.Char(), p.l, p.i, p.s) } } -func (p *parser) parseArgInt(typ Type) (Arg, error) { +func (p *parser) parseArgInt(typ Type, dir Dir) (Arg, error) { val := p.Ident() v, err := strconv.ParseUint(val, 0, 64) if err != nil { @@ -390,35 +390,35 @@ func (p *parser) parseArgInt(typ Type) (Arg, error) { } switch typ.(type) { case *ConstType, *IntType, *FlagsType, *ProcType, *CsumType: - arg := Arg(MakeConstArg(typ, v)) - if typ.Dir() == DirOut && !typ.isDefaultArg(arg) { + arg := Arg(MakeConstArg(typ, dir, v)) + if dir == DirOut && !typ.isDefaultArg(arg) { p.strictFailf("out arg %v has non-default value: %v", typ, v) - arg = typ.DefaultArg() + arg = typ.DefaultArg(dir) } return arg, nil case *LenType: - return MakeConstArg(typ, v), nil + return MakeConstArg(typ, dir, v), nil case *ResourceType: - return MakeResultArg(typ, nil, v), nil + return MakeResultArg(typ, dir, nil, v), nil case *PtrType, *VmaType: index := -v % uint64(len(p.target.SpecialPointers)) - return MakeSpecialPointerArg(typ, index), nil + return MakeSpecialPointerArg(typ, dir, index), nil default: p.eatExcessive(true, "wrong int arg %T", typ) - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } } -func (p *parser) parseAuto(typ Type) (Arg, error) { +func (p *parser) parseAuto(typ Type, dir Dir) (Arg, error) { switch typ.(type) { case *ConstType, *LenType, *CsumType: - return p.auto(MakeConstArg(typ, 0)), nil + return p.auto(MakeConstArg(typ, dir, 0)), nil default: return nil, fmt.Errorf("wrong type %T for AUTO", typ) } } -func (p *parser) parseArgRes(typ Type) (Arg, error) { +func (p *parser) parseArgRes(typ Type, dir Dir) (Arg, error) { id := p.Ident() var div, add uint64 if p.Char() == '/' { @@ -442,23 +442,25 @@ func (p *parser) parseArgRes(typ Type) (Arg, error) { v := p.vars[id] if v == nil { p.strictFailf("undeclared variable %v", id) - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } - arg := MakeResultArg(typ, v, 0) + arg := MakeResultArg(typ, dir, v, 0) arg.OpDiv = div arg.OpAdd = add return arg, nil } -func (p *parser) parseArgAddr(typ Type) (Arg, error) { +func (p *parser) parseArgAddr(typ Type, dir Dir) (Arg, error) { var typ1 Type + elemDir := DirInOut switch t1 := typ.(type) { case *PtrType: typ1 = t1.Type + elemDir = t1.ElemDir case *VmaType: default: p.eatExcessive(true, "wrong addr arg") - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } p.Parse('&') auto := false @@ -487,11 +489,11 @@ func (p *parser) parseArgAddr(typ Type) (Arg, error) { p.Parse('N') p.Parse('Y') p.Parse('=') - typ = p.target.makeAnyPtrType(typ.Size(), typ.FieldName()) - typ1 = p.target.any.array + anyPtr := p.target.makeAnyPtrType(typ.Size(), typ.FieldName()) + typ, typ1, elemDir = anyPtr, anyPtr.Type, anyPtr.ElemDir } var err error - inner, err = p.parseArg(typ1) + inner, err = p.parseArg(typ1, elemDir) if err != nil { return nil, err } @@ -501,23 +503,23 @@ func (p *parser) parseArgAddr(typ Type) (Arg, error) { p.strictFailf("unaligned vma address 0x%x", addr) addr &= ^(p.target.PageSize - 1) } - return MakeVmaPointerArg(typ, addr, vmaSize), nil + return MakeVmaPointerArg(typ, dir, addr, vmaSize), nil } if inner == nil { - inner = typ1.DefaultArg() + inner = typ1.DefaultArg(elemDir) } - arg := MakePointerArg(typ, addr, inner) + arg := MakePointerArg(typ, dir, addr, inner) if auto { p.auto(arg) } return arg, nil } -func (p *parser) parseArgString(t Type) (Arg, error) { +func (p *parser) parseArgString(t Type, dir Dir) (Arg, error) { typ, ok := t.(*BufferType) if !ok { p.eatExcessive(true, "wrong string arg") - return t.DefaultArg(), nil + return t.DefaultArg(dir), nil } data, err := p.deserializeData() if err != nil { @@ -542,8 +544,8 @@ func (p *parser) parseArgString(t Type) (Arg, error) { } else if size == ^uint64(0) { size = uint64(len(data)) } - if typ.Dir() == DirOut { - return MakeOutDataArg(typ, size), nil + if dir == DirOut { + return MakeOutDataArg(typ, dir, size), nil } if diff := int(size) - len(data); diff > 0 { data = append(data, make([]byte, diff)...) @@ -564,16 +566,16 @@ func (p *parser) parseArgString(t Type) (Arg, error) { data = []byte(typ.Values[0]) } } - return MakeDataArg(typ, data), nil + return MakeDataArg(typ, dir, data), nil } -func (p *parser) parseArgStruct(typ Type) (Arg, error) { +func (p *parser) parseArgStruct(typ Type, dir Dir) (Arg, error) { p.Parse('{') t1, ok := typ.(*StructType) if !ok { p.eatExcessive(false, "wrong struct arg") p.Parse('}') - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } var inner []Arg for i := 0; p.Char() != '}'; i++ { @@ -583,9 +585,9 @@ func (p *parser) parseArgStruct(typ Type) (Arg, error) { } fld := t1.Fields[i] if IsPad(fld) { - inner = append(inner, MakeConstArg(fld, 0)) + inner = append(inner, MakeConstArg(fld, dir, 0)) } else { - arg, err := p.parseArg(fld) + arg, err := p.parseArg(fld, dir) if err != nil { return nil, err } @@ -601,22 +603,22 @@ func (p *parser) parseArgStruct(typ Type) (Arg, error) { if !IsPad(fld) { p.strictFailf("missing struct %v fields %v/%v", typ.Name(), len(inner), len(t1.Fields)) } - inner = append(inner, fld.DefaultArg()) + inner = append(inner, fld.DefaultArg(dir)) } - return MakeGroupArg(typ, inner), nil + return MakeGroupArg(typ, dir, inner), nil } -func (p *parser) parseArgArray(typ Type) (Arg, error) { +func (p *parser) parseArgArray(typ Type, dir Dir) (Arg, error) { p.Parse('[') t1, ok := typ.(*ArrayType) if !ok { p.eatExcessive(false, "wrong array arg %T", typ) p.Parse(']') - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } var inner []Arg for i := 0; p.Char() != ']'; i++ { - arg, err := p.parseArg(t1.Type) + arg, err := p.parseArg(t1.Type, dir) if err != nil { return nil, err } @@ -629,18 +631,18 @@ func (p *parser) parseArgArray(typ Type) (Arg, error) { if t1.Kind == ArrayRangeLen && t1.RangeBegin == t1.RangeEnd { for uint64(len(inner)) < t1.RangeBegin { p.strictFailf("missing array elements") - inner = append(inner, t1.Type.DefaultArg()) + inner = append(inner, t1.Type.DefaultArg(dir)) } inner = inner[:t1.RangeBegin] } - return MakeGroupArg(typ, inner), nil + return MakeGroupArg(typ, dir, inner), nil } -func (p *parser) parseArgUnion(typ Type) (Arg, error) { +func (p *parser) parseArgUnion(typ Type, dir Dir) (Arg, error) { t1, ok := typ.(*UnionType) if !ok { p.eatExcessive(true, "wrong union arg") - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } p.Parse('@') name := p.Ident() @@ -653,20 +655,20 @@ func (p *parser) parseArgUnion(typ Type) (Arg, error) { } if optType == nil { p.eatExcessive(true, "wrong union option") - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } var opt Arg if p.Char() == '=' { p.Parse('=') var err error - opt, err = p.parseArg(optType) + opt, err = p.parseArg(optType, dir) if err != nil { return nil, err } } else { - opt = optType.DefaultArg() + opt = optType.DefaultArg(dir) } - return MakeUnionArg(typ, opt), nil + return MakeUnionArg(typ, dir, opt), nil } // Eats excessive call arguments and struct fields to recover after description changes. diff --git a/prog/encodingexec.go b/prog/encodingexec.go index b5c410287..99357dfd2 100644 --- a/prog/encodingexec.go +++ b/prog/encodingexec.go @@ -134,7 +134,7 @@ func (w *execContext) writeCopyin(c *Call) { return } typ := arg.Type() - if typ.Dir() == DirOut || IsPad(typ) || (arg.Size() == 0 && !typ.IsBitfield()) { + if arg.Dir() == DirOut || IsPad(typ) || (arg.Size() == 0 && !typ.IsBitfield()) { return } w.write(execInstrCopyin) diff --git a/prog/hints.go b/prog/hints.go index 9a5675b1b..be2b371f2 100644 --- a/prog/hints.go +++ b/prog/hints.go @@ -82,7 +82,7 @@ func (p *Prog) MutateWithHints(callIndex int, comps CompMap, exec func(p *Prog)) func generateHints(compMap CompMap, arg Arg, exec func()) { typ := arg.Type() - if typ == nil || typ.Dir() == DirOut { + if typ == nil || arg.Dir() == DirOut { return } switch t := typ.(type) { diff --git a/prog/hints_test.go b/prog/hints_test.go index 0fc9afa02..caf84e715 100644 --- a/prog/hints_test.go +++ b/prog/hints_test.go @@ -157,7 +157,7 @@ func TestHintsCheckConstArg(t *testing.T) { typ := &IntType{IntTypeCommon: IntTypeCommon{TypeCommon: TypeCommon{ TypeSize: test.size}, BitfieldLen: test.bitsize}} - constArg := MakeConstArg(typ, test.in) + constArg := MakeConstArg(typ, DirIn, test.in) checkConstArg(constArg, test.comps, func() { res = append(res, constArg.Val) }) @@ -295,8 +295,8 @@ func TestHintsCheckDataArg(t *testing.T) { res := make(map[string]bool) // Whatever type here. It's just needed to pass the // dataArg.Type().Dir() == DirIn check. - typ := &ArrayType{TypeCommon{"", "", 0, DirIn, false, true}, nil, 0, 0, 0} - dataArg := MakeDataArg(typ, []byte(test.in)) + typ := &ArrayType{TypeCommon{"", "", 0, false, true}, nil, 0, 0, 0} + dataArg := MakeDataArg(typ, DirIn, []byte(test.in)) checkDataArg(dataArg, test.comps, func() { res[string(dataArg.Data())] = true }) @@ -499,7 +499,7 @@ func TestHintsRandom(t *testing.T) { func extractValues(c *Call) map[uint64]bool { vals := make(map[uint64]bool) ForeachArg(c, func(arg Arg, _ *ArgCtx) { - if typ := arg.Type(); typ == nil || typ.Dir() == DirOut { + if arg.Dir() == DirOut { return } switch a := arg.(type) { diff --git a/prog/minimization.go b/prog/minimization.go index 93a986556..0b5bb05b2 100644 --- a/prog/minimization.go +++ b/prog/minimization.go @@ -137,7 +137,7 @@ func (typ *PtrType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool { } if !ctx.triedPaths[path+"->"] { removeArg(a.Res) - replaceArg(a, MakeSpecialPointerArg(a.Type(), 0)) + replaceArg(a, MakeSpecialPointerArg(a.Type(), a.Dir(), 0)) ctx.target.assignSizesCall(ctx.call) if ctx.pred(ctx.p, ctx.callIndex0) { *ctx.p0 = ctx.p @@ -201,7 +201,7 @@ func minimizeInt(ctx *minimizeArgsCtx, arg Arg, path string) bool { return false } a := arg.(*ConstArg) - def := arg.Type().DefaultArg().(*ConstArg) + def := arg.Type().DefaultArg(arg.Dir()).(*ConstArg) if a.Val == def.Val { return false } @@ -239,7 +239,7 @@ func (typ *ResourceType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bo func (typ *BufferType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool { // TODO: try to set individual bytes to 0 - if typ.Kind != BufferBlobRand && typ.Kind != BufferBlobRange || typ.Dir() == DirOut { + if typ.Kind != BufferBlobRand && typ.Kind != BufferBlobRange || arg.Dir() == DirOut { return false } a := arg.(*DataArg) diff --git a/prog/mutation.go b/prog/mutation.go index 7203a86ec..7087b4d86 100644 --- a/prog/mutation.go +++ b/prog/mutation.go @@ -99,7 +99,7 @@ func (ctx *mutator) squashAny() bool { var blobs []*DataArg var bases []*PointerArg ForeachSubArg(ptr, func(arg Arg, ctx *ArgCtx) { - if data, ok := arg.(*DataArg); ok && arg.Type().Dir() != DirOut { + if data, ok := arg.(*DataArg); ok && arg.Dir() != DirOut { blobs = append(blobs, data) bases = append(bases, ctx.Base) } @@ -119,7 +119,7 @@ func (ctx *mutator) squashAny() bool { // Update base pointer if size has increased. if baseSize < base.Res.Size() { s := analyze(ctx.ct, ctx.corpus, p, p.Calls[0]) - newArg := r.allocAddr(s, base.Type(), base.Res.Size(), base.Res) + newArg := r.allocAddr(s, base.Type(), base.Dir(), base.Res.Size(), base.Res) *base = *newArg } return true @@ -252,7 +252,7 @@ func (target *Target) mutateArg(r *randGen, s *state, arg Arg, ctx ArgCtx, updat } // Update base pointer if size has increased. if base := ctx.Base; base != nil && baseSize < base.Res.Size() { - newArg := r.allocAddr(s, base.Type(), base.Res.Size(), base.Res) + newArg := r.allocAddr(s, base.Type(), base.Dir(), base.Res.Size(), base.Res) replaceArg(base, newArg) } return calls, true @@ -260,7 +260,7 @@ func (target *Target) mutateArg(r *randGen, s *state, arg Arg, ctx ArgCtx, updat func regenerate(r *randGen, s *state, arg Arg) (calls []*Call, retry, preserve bool) { var newArg Arg - newArg, calls = r.generateArg(s, arg.Type()) + newArg, calls = r.generateArg(s, arg.Type(), arg.Dir()) replaceArg(arg, newArg) return } @@ -346,7 +346,7 @@ func (t *BufferType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls [] minLen, maxLen = t.RangeBegin, t.RangeEnd } a := arg.(*DataArg) - if t.Dir() == DirOut { + if a.Dir() == DirOut { mutateBufferSize(r, a, minLen, maxLen) return } @@ -412,7 +412,7 @@ func (t *ArrayType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []* } if count > uint64(len(a.Inner)) { for count > uint64(len(a.Inner)) { - newArg, newCalls := r.generateArg(s, t.Type) + newArg, newCalls := r.generateArg(s, t.Type, a.Dir()) a.Inner = append(a.Inner, newArg) calls = append(calls, newCalls...) for _, c := range newCalls { @@ -433,11 +433,11 @@ func (t *PtrType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*Ca if r.oneOf(1000) { removeArg(a.Res) index := r.rand(len(r.target.SpecialPointers)) - newArg := MakeSpecialPointerArg(t, index) + newArg := MakeSpecialPointerArg(t, a.Dir(), index) replaceArg(arg, newArg) return } - newArg := r.allocAddr(s, t, a.Res.Size(), a.Res) + newArg := r.allocAddr(s, t, a.Dir(), a.Res.Size(), a.Res) replaceArg(arg, newArg) return } @@ -448,7 +448,7 @@ func (t *StructType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls [] panic("bad arg returned by mutationArgs: StructType") } var newArg Arg - newArg, calls = gen(&Gen{r, s}, t, arg) + newArg, calls = gen(&Gen{r, s}, t, arg.Dir(), arg) a := arg.(*GroupArg) for i, f := range newArg.(*GroupArg).Inner { replaceArg(a.Inner[i], f) @@ -459,7 +459,7 @@ func (t *StructType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls [] func (t *UnionType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*Call, retry, preserve bool) { if gen := r.target.SpecialTypes[t.Name()]; gen != nil { var newArg Arg - newArg, calls = gen(&Gen{r, s}, t, arg) + newArg, calls = gen(&Gen{r, s}, t, arg.Dir(), arg) replaceArg(arg, newArg) return } @@ -481,8 +481,8 @@ func (t *UnionType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []* optType := t.Fields[newIdx] removeArg(a.Option) var newOpt Arg - newOpt, calls = r.generateArg(s, optType) - replaceArg(arg, MakeUnionArg(t, newOpt)) + newOpt, calls = r.generateArg(s, optType, a.Dir()) + replaceArg(arg, MakeUnionArg(t, a.Dir(), newOpt)) return } @@ -522,7 +522,7 @@ func (ma *mutationArgs) collectArg(arg Arg, ctx *ArgCtx) { _, isArrayTyp := typ.(*ArrayType) _, isBufferTyp := typ.(*BufferType) - if !isBufferTyp && !isArrayTyp && typ.Dir() == DirOut || !typ.Varlen() && typ.Size() == 0 { + if !isBufferTyp && !isArrayTyp && arg.Dir() == DirOut || !typ.Varlen() && typ.Size() == 0 { return } @@ -645,7 +645,7 @@ func (t *LenType) getMutationPrio(target *Target, arg Arg, ignoreSpecial bool) ( } func (t *BufferType) getMutationPrio(target *Target, arg Arg, ignoreSpecial bool) (prio float64, stopRecursion bool) { - if t.Dir() == DirOut && !t.Varlen() { + if arg.Dir() == DirOut && !t.Varlen() { return dontMutate, false } if t.Kind == BufferString && len(t.Values) == 1 { diff --git a/prog/prio.go b/prog/prio.go index b67bbaea0..2a9486570 100644 --- a/prog/prio.go +++ b/prog/prio.go @@ -65,11 +65,11 @@ func (target *Target) calcStaticPriorities() [][]float32 { func (target *Target) calcResourceUsage() map[string]map[int]weights { uses := make(map[string]map[int]weights) for _, c := range target.Syscalls { - ForeachType(c, func(t Type) { + foreachType(c, func(t Type, ctx typeCtx) { switch a := t.(type) { case *ResourceType: if target.AuxResources[a.Desc.Name] { - noteUsage(uses, c, 0.1, a.Dir(), "res%v", a.Desc.Name) + noteUsage(uses, c, 0.1, ctx.Dir, "res%v", a.Desc.Name) } else { str := "res" for i, k := range a.Desc.Kind { @@ -78,25 +78,25 @@ func (target *Target) calcResourceUsage() map[string]map[int]weights { if i < len(a.Desc.Kind)-1 { w = 0.2 } - noteUsage(uses, c, float32(w), a.Dir(), str) + noteUsage(uses, c, float32(w), ctx.Dir, str) } } case *PtrType: if _, ok := a.Type.(*StructType); ok { - noteUsage(uses, c, 1.0, a.Dir(), "ptrto-%v", a.Type.Name()) + noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Type.Name()) } if _, ok := a.Type.(*UnionType); ok { - noteUsage(uses, c, 1.0, a.Dir(), "ptrto-%v", a.Type.Name()) + noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Type.Name()) } if arr, ok := a.Type.(*ArrayType); ok { - noteUsage(uses, c, 1.0, a.Dir(), "ptrto-%v", arr.Type.Name()) + noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", arr.Type.Name()) } case *BufferType: switch a.Kind { case BufferBlobRand, BufferBlobRange, BufferText: case BufferString: if a.SubKind != "" { - noteUsage(uses, c, 0.2, a.Dir(), fmt.Sprintf("str-%v", a.SubKind)) + noteUsage(uses, c, 0.2, ctx.Dir, fmt.Sprintf("str-%v", a.SubKind)) } case BufferFilename: noteUsage(uses, c, 1.0, DirIn, "filename") @@ -104,7 +104,7 @@ func (target *Target) calcResourceUsage() map[string]map[int]weights { panic("unknown buffer kind") } case *VmaType: - noteUsage(uses, c, 0.5, a.Dir(), "vma") + noteUsage(uses, c, 0.5, ctx.Dir, "vma") case *IntType: switch a.Kind { case IntPlain, IntRange: diff --git a/prog/prog.go b/prog/prog.go index 1600c0a28..017a0dbbb 100644 --- a/prog/prog.go +++ b/prog/prog.go @@ -22,6 +22,7 @@ type Call struct { type Arg interface { Type() Type + Dir() Dir Size() uint64 validate(ctx *validCtx) error @@ -30,20 +31,25 @@ type Arg interface { type ArgCommon struct { typ Type + dir Dir } func (arg *ArgCommon) Type() Type { return arg.typ } +func (arg *ArgCommon) Dir() Dir { + return arg.dir +} + // Used for ConstType, IntType, FlagsType, LenType, ProcType and CsumType. type ConstArg struct { ArgCommon Val uint64 } -func MakeConstArg(t Type, v uint64) *ConstArg { - return &ConstArg{ArgCommon: ArgCommon{typ: t}, Val: v} +func MakeConstArg(t Type, dir Dir, v uint64) *ConstArg { + return &ConstArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Val: v} } func (arg *ConstArg) Size() uint64 { @@ -84,34 +90,37 @@ type PointerArg struct { Res Arg // pointee (nil for vma) } -func MakePointerArg(t Type, addr uint64, data Arg) *PointerArg { +func MakePointerArg(t Type, dir Dir, addr uint64, data Arg) *PointerArg { if data == nil { panic("nil pointer data arg") } return &PointerArg{ - ArgCommon: ArgCommon{typ: t}, + ArgCommon: ArgCommon{typ: t, dir: DirIn}, // pointers are always in Address: addr, Res: data, } } -func MakeVmaPointerArg(t Type, addr, size uint64) *PointerArg { +func MakeVmaPointerArg(t Type, dir Dir, addr, size uint64) *PointerArg { if addr%1024 != 0 { panic("unaligned vma address") } return &PointerArg{ - ArgCommon: ArgCommon{typ: t}, + ArgCommon: ArgCommon{typ: t, dir: dir}, Address: addr, VmaSize: size, } } -func MakeSpecialPointerArg(t Type, index uint64) *PointerArg { +func MakeSpecialPointerArg(t Type, dir Dir, index uint64) *PointerArg { if index >= maxSpecialPointers { panic("bad special pointer index") } + if _, ok := t.(*PtrType); ok { + dir = DirIn // pointers are always in + } return &PointerArg{ - ArgCommon: ArgCommon{typ: t}, + ArgCommon: ArgCommon{typ: t, dir: dir}, Address: -index, } } @@ -138,18 +147,18 @@ type DataArg struct { size uint64 // for out Args } -func MakeDataArg(t Type, data []byte) *DataArg { - if t.Dir() == DirOut { +func MakeDataArg(t Type, dir Dir, data []byte) *DataArg { + if dir == DirOut { panic("non-empty output data arg") } - return &DataArg{ArgCommon: ArgCommon{typ: t}, data: append([]byte{}, data...)} + return &DataArg{ArgCommon: ArgCommon{typ: t, dir: dir}, data: append([]byte{}, data...)} } -func MakeOutDataArg(t Type, size uint64) *DataArg { - if t.Dir() != DirOut { +func MakeOutDataArg(t Type, dir Dir, size uint64) *DataArg { + if dir != DirOut { panic("empty input data arg") } - return &DataArg{ArgCommon: ArgCommon{typ: t}, size: size} + return &DataArg{ArgCommon: ArgCommon{typ: t, dir: dir}, size: size} } func (arg *DataArg) Size() uint64 { @@ -160,14 +169,14 @@ func (arg *DataArg) Size() uint64 { } func (arg *DataArg) Data() []byte { - if arg.Type().Dir() == DirOut { + if arg.Dir() == DirOut { panic("getting data of output data arg") } return arg.data } func (arg *DataArg) SetData(data []byte) { - if arg.Type().Dir() == DirOut { + if arg.Dir() == DirOut { panic("setting data of output data arg") } arg.data = append([]byte{}, data...) @@ -180,8 +189,8 @@ type GroupArg struct { Inner []Arg } -func MakeGroupArg(t Type, inner []Arg) *GroupArg { - return &GroupArg{ArgCommon: ArgCommon{typ: t}, Inner: inner} +func MakeGroupArg(t Type, dir Dir, inner []Arg) *GroupArg { + return &GroupArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Inner: inner} } func (arg *GroupArg) Size() uint64 { @@ -227,8 +236,8 @@ type UnionArg struct { Option Arg } -func MakeUnionArg(t Type, opt Arg) *UnionArg { - return &UnionArg{ArgCommon: ArgCommon{typ: t}, Option: opt} +func MakeUnionArg(t Type, dir Dir, opt Arg) *UnionArg { + return &UnionArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Option: opt} } func (arg *UnionArg) Size() uint64 { @@ -250,8 +259,8 @@ type ResultArg struct { uses map[*ResultArg]bool // ArgResult args that use this arg } -func MakeResultArg(t Type, r *ResultArg, v uint64) *ResultArg { - arg := &ResultArg{ArgCommon: ArgCommon{typ: t}, Res: r, Val: v} +func MakeResultArg(t Type, dir Dir, r *ResultArg, v uint64) *ResultArg { + arg := &ResultArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Res: r, Val: v} if r == nil { return arg } @@ -266,10 +275,7 @@ func MakeReturnArg(t Type) *ResultArg { if t == nil { return nil } - if t.Dir() != DirOut { - panic("return arg is not out") - } - return &ResultArg{ArgCommon: ArgCommon{typ: t}} + return &ResultArg{ArgCommon: ArgCommon{typ: t, dir: DirOut}} } func (arg *ResultArg) Size() uint64 { @@ -369,7 +375,7 @@ func removeArg(arg0 Arg) { delete(uses, a) } for arg1 := range a.uses { - arg2 := arg1.Type().DefaultArg().(*ResultArg) + arg2 := arg1.Type().DefaultArg(arg1.Dir()).(*ResultArg) replaceResultArg(arg1, arg2) } }) diff --git a/prog/prog_test.go b/prog/prog_test.go index a42e4437f..9b8f442e3 100644 --- a/prog/prog_test.go +++ b/prog/prog_test.go @@ -21,8 +21,8 @@ func TestGeneration(t *testing.T) { func TestDefault(t *testing.T) { target, _, _ := initTest(t) for _, meta := range target.Syscalls { - ForeachType(meta, func(typ Type) { - arg := typ.DefaultArg() + foreachType(meta, func(typ Type, ctx typeCtx) { + arg := typ.DefaultArg(ctx.Dir) if !isDefault(arg) { t.Errorf("default arg is not default: %s\ntype: %#v\narg: %#v", typ, typ, arg) @@ -203,8 +203,8 @@ func TestSpecialStructs(t *testing.T) { t.Run(special, func(t *testing.T) { var typ Type for i := 0; i < len(target.Syscalls) && typ == nil; i++ { - ForeachType(target.Syscalls[i], func(t Type) { - if t.Dir() == DirOut { + foreachType(target.Syscalls[i], func(t Type, ctx typeCtx) { + if ctx.Dir == DirOut { return } if s, ok := t.(*StructType); ok && s.Name() == special { @@ -220,8 +220,13 @@ func TestSpecialStructs(t *testing.T) { } g := &Gen{newRand(target, rs), newState(target, nil, nil)} for i := 0; i < iters/len(target.SpecialTypes); i++ { - arg, _ := gen(g, typ, nil) - gen(g, typ, arg) + var arg Arg + for i := 0; i < 2; i++ { + arg, _ = gen(g, typ, DirInOut, arg) + if arg.Dir() != DirInOut { + t.Fatalf("got wrong arg dir %v", arg.Dir()) + } + } } }) } diff --git a/prog/rand.go b/prog/rand.go index b350e31c0..5277a9814 100644 --- a/prog/rand.go +++ b/prog/rand.go @@ -341,16 +341,16 @@ func (r *randGen) randString(s *state, t *BufferType) []byte { return buf.Bytes() } -func (r *randGen) allocAddr(s *state, typ Type, size uint64, data Arg) *PointerArg { - return MakePointerArg(typ, s.ma.alloc(r, size), data) +func (r *randGen) allocAddr(s *state, typ Type, dir Dir, size uint64, data Arg) *PointerArg { + return MakePointerArg(typ, dir, s.ma.alloc(r, size), data) } -func (r *randGen) allocVMA(s *state, typ Type, numPages uint64) *PointerArg { +func (r *randGen) allocVMA(s *state, typ Type, dir Dir, numPages uint64) *PointerArg { page := s.va.alloc(r, numPages) - return MakeVmaPointerArg(typ, page*r.target.PageSize, numPages*r.target.PageSize) + return MakeVmaPointerArg(typ, dir, page*r.target.PageSize, numPages*r.target.PageSize) } -func (r *randGen) createResource(s *state, res *ResourceType) (arg Arg, calls []*Call) { +func (r *randGen) createResource(s *state, res *ResourceType, dir Dir) (arg Arg, calls []*Call) { if r.inCreateResource { return nil, nil } @@ -385,7 +385,7 @@ func (r *randGen) createResource(s *state, res *ResourceType) (arg Arg, calls [] metas = append(metas, meta) } if len(metas) == 0 { - return res.DefaultArg(), nil + return res.DefaultArg(dir), nil } // Now we have a set of candidate calls that can create the necessary resource. @@ -404,7 +404,7 @@ func (r *randGen) createResource(s *state, res *ResourceType) (arg Arg, calls [] } if len(allres) != 0 { // Bingo! - arg := MakeResultArg(res, allres[r.Intn(len(allres))], 0) + arg := MakeResultArg(res, dir, allres[r.Intn(len(allres))], 0) return arg, calls } // Discard unsuccessful calls. @@ -562,7 +562,7 @@ func (r *randGen) generateParticularCall(s *state, meta *Syscall) (calls []*Call Meta: meta, Ret: MakeReturnArg(meta.Ret), } - c.Args, calls = r.generateArgs(s, meta.Args) + c.Args, calls = r.generateArgs(s, meta.Args, DirIn) r.target.assignSizesCall(c) return append(calls, c) } @@ -601,13 +601,13 @@ func (target *Target) DataMmapProg() *Prog { } } -func (r *randGen) generateArgs(s *state, types []Type) ([]Arg, []*Call) { +func (r *randGen) generateArgs(s *state, types []Type, dir Dir) ([]Arg, []*Call) { var calls []*Call args := make([]Arg, len(types)) // Generate all args. Size args have the default value 0 for now. for i, typ := range types { - arg, calls1 := r.generateArg(s, typ) + arg, calls1 := r.generateArg(s, typ, dir) if arg == nil { panic(fmt.Sprintf("generated arg is nil for type '%v', types: %+v", typ.Name(), types)) } @@ -618,29 +618,28 @@ func (r *randGen) generateArgs(s *state, types []Type) ([]Arg, []*Call) { return args, calls } -func (r *randGen) generateArg(s *state, typ Type) (arg Arg, calls []*Call) { - return r.generateArgImpl(s, typ, false) +func (r *randGen) generateArg(s *state, typ Type, dir Dir) (arg Arg, calls []*Call) { + return r.generateArgImpl(s, typ, dir, false) } -func (r *randGen) generateArgImpl(s *state, typ Type, ignoreSpecial bool) (arg Arg, calls []*Call) { - if typ.Dir() == DirOut { +func (r *randGen) generateArgImpl(s *state, typ Type, dir Dir, ignoreSpecial bool) (arg Arg, calls []*Call) { + if dir == DirOut { // No need to generate something interesting for output scalar arguments. // But we still need to generate the argument itself so that it can be referenced // in subsequent calls. For the same reason we do generate pointer/array/struct // output arguments (their elements can be referenced in subsequent calls). switch typ.(type) { - case *IntType, *FlagsType, *ConstType, *ProcType, - *VmaType, *ResourceType: - return typ.DefaultArg(), nil + case *IntType, *FlagsType, *ConstType, *ProcType, *VmaType, *ResourceType: + return typ.DefaultArg(dir), nil } } if typ.Optional() && r.oneOf(5) { if res, ok := typ.(*ResourceType); ok { v := res.Desc.Values[r.Intn(len(res.Desc.Values))] - return MakeResultArg(typ, nil, v), nil + return MakeResultArg(typ, dir, nil, v), nil } - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } // Allow infinite recursion for optional pointers. @@ -656,70 +655,70 @@ func (r *randGen) generateArgImpl(s *state, typ Type, ignoreSpecial bool) (arg A } }() if r.recDepth[name] >= 3 { - return MakeSpecialPointerArg(typ, 0), nil + return MakeSpecialPointerArg(typ, dir, 0), nil } } } - if !ignoreSpecial && typ.Dir() != DirOut { + if !ignoreSpecial && dir != DirOut { switch typ.(type) { case *StructType, *UnionType: if gen := r.target.SpecialTypes[typ.Name()]; gen != nil { - return gen(&Gen{r, s}, typ, nil) + return gen(&Gen{r, s}, typ, dir, nil) } } } - return typ.generate(r, s) + return typ.generate(r, s, dir) } -func (a *ResourceType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *ResourceType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { if r.oneOf(3) { - arg = r.existingResource(s, a) + arg = r.existingResource(s, a, dir) if arg != nil { return } } if r.nOutOf(2, 3) { - arg, calls = r.resourceCentric(s, a) + arg, calls = r.resourceCentric(s, a, dir) if arg != nil { return } } if r.nOutOf(4, 5) { - arg, calls = r.createResource(s, a) + arg, calls = r.createResource(s, a, dir) if arg != nil { return } } special := a.SpecialValues() - arg = MakeResultArg(a, nil, special[r.Intn(len(special))]) + arg = MakeResultArg(a, dir, nil, special[r.Intn(len(special))]) return } -func (a *BufferType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *BufferType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { switch a.Kind { case BufferBlobRand, BufferBlobRange: sz := r.randBufLen() if a.Kind == BufferBlobRange { sz = r.randRange(a.RangeBegin, a.RangeEnd) } - if a.Dir() == DirOut { - return MakeOutDataArg(a, sz), nil + if dir == DirOut { + return MakeOutDataArg(a, dir, sz), nil } data := make([]byte, sz) for i := range data { data[i] = byte(r.Intn(256)) } - return MakeDataArg(a, data), nil + return MakeDataArg(a, dir, data), nil case BufferString: data := r.randString(s, a) - if a.Dir() == DirOut { - return MakeOutDataArg(a, uint64(len(data))), nil + if dir == DirOut { + return MakeOutDataArg(a, dir, uint64(len(data))), nil } - return MakeDataArg(a, data), nil + return MakeDataArg(a, dir, data), nil case BufferFilename: - if a.Dir() == DirOut { + if dir == DirOut { var sz uint64 switch { case !a.Varlen(): @@ -731,50 +730,50 @@ func (a *BufferType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { default: sz = 4096 // PATH_MAX } - return MakeOutDataArg(a, sz), nil + return MakeOutDataArg(a, dir, sz), nil } - return MakeDataArg(a, []byte(r.filename(s, a))), nil + return MakeDataArg(a, dir, []byte(r.filename(s, a))), nil case BufferText: - if a.Dir() == DirOut { - return MakeOutDataArg(a, uint64(r.Intn(100))), nil + if dir == DirOut { + return MakeOutDataArg(a, dir, uint64(r.Intn(100))), nil } - return MakeDataArg(a, r.generateText(a.Text)), nil + return MakeDataArg(a, dir, r.generateText(a.Text)), nil default: panic("unknown buffer kind") } } -func (a *VmaType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *VmaType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { npages := r.randPageCount() if a.RangeBegin != 0 || a.RangeEnd != 0 { npages = a.RangeBegin + uint64(r.Intn(int(a.RangeEnd-a.RangeBegin+1))) } - return r.allocVMA(s, a, npages), nil + return r.allocVMA(s, a, dir, npages), nil } -func (a *FlagsType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { - return MakeConstArg(a, r.flags(a.Vals, a.BitMask, 0)), nil +func (a *FlagsType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { + return MakeConstArg(a, dir, r.flags(a.Vals, a.BitMask, 0)), nil } -func (a *ConstType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { - return MakeConstArg(a, a.Val), nil +func (a *ConstType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { + return MakeConstArg(a, dir, a.Val), nil } -func (a *IntType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *IntType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { bits := a.TypeBitSize() v := r.randInt(bits) switch a.Kind { case IntRange: v = r.randRangeInt(a.RangeBegin, a.RangeEnd, bits, a.Align) } - return MakeConstArg(a, v), nil + return MakeConstArg(a, dir, v), nil } -func (a *ProcType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { - return MakeConstArg(a, r.rand(int(a.ValuesPerProc))), nil +func (a *ProcType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { + return MakeConstArg(a, dir, r.rand(int(a.ValuesPerProc))), nil } -func (a *ArrayType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *ArrayType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { var count uint64 switch a.Kind { case ArrayRandLen: @@ -784,46 +783,46 @@ func (a *ArrayType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { } var inner []Arg for i := uint64(0); i < count; i++ { - arg1, calls1 := r.generateArg(s, a.Type) + arg1, calls1 := r.generateArg(s, a.Type, dir) inner = append(inner, arg1) calls = append(calls, calls1...) } - return MakeGroupArg(a, inner), calls + return MakeGroupArg(a, dir, inner), calls } -func (a *StructType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { - args, calls := r.generateArgs(s, a.Fields) - group := MakeGroupArg(a, args) +func (a *StructType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { + args, calls := r.generateArgs(s, a.Fields, dir) + group := MakeGroupArg(a, dir, args) return group, calls } -func (a *UnionType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *UnionType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { optType := a.Fields[r.Intn(len(a.Fields))] - opt, calls := r.generateArg(s, optType) - return MakeUnionArg(a, opt), calls + opt, calls := r.generateArg(s, optType, dir) + return MakeUnionArg(a, dir, opt), calls } -func (a *PtrType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *PtrType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { if r.oneOf(1000) { index := r.rand(len(r.target.SpecialPointers)) - return MakeSpecialPointerArg(a, index), nil + return MakeSpecialPointerArg(a, dir, index), nil } - inner, calls := r.generateArg(s, a.Type) - arg = r.allocAddr(s, a, inner.Size(), inner) + inner, calls := r.generateArg(s, a.Type, a.ElemDir) + arg = r.allocAddr(s, a, dir, inner.Size(), inner) return arg, calls } -func (a *LenType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *LenType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { // Updated later in assignSizesCall. - return MakeConstArg(a, 0), nil + return MakeConstArg(a, dir, 0), nil } -func (a *CsumType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *CsumType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { // Filled at runtime by executor. - return MakeConstArg(a, 0), nil + return MakeConstArg(a, dir, 0), nil } -func (r *randGen) existingResource(s *state, res *ResourceType) Arg { +func (r *randGen) existingResource(s *state, res *ResourceType, dir Dir) Arg { alltypes := make([][]*ResultArg, 0, len(s.resources)) for _, res1 := range s.resources { alltypes = append(alltypes, res1) @@ -842,11 +841,11 @@ func (r *randGen) existingResource(s *state, res *ResourceType) Arg { if len(allres) == 0 { return nil } - return MakeResultArg(res, allres[r.Intn(len(allres))], 0) + return MakeResultArg(res, dir, allres[r.Intn(len(allres))], 0) } // Finds a compatible resource with the type `t` and the calls that initialize that resource. -func (r *randGen) resourceCentric(s *state, t *ResourceType) (arg Arg, calls []*Call) { +func (r *randGen) resourceCentric(s *state, t *ResourceType, dir Dir) (arg Arg, calls []*Call) { var p *Prog var resource *ResultArg for idx := range r.Perm(len(s.corpus)) { @@ -898,7 +897,7 @@ func (r *randGen) resourceCentric(s *state, t *ResourceType) (arg Arg, calls []* p.removeCall(i) } - return MakeResultArg(t, resource, 0), p.Calls + return MakeResultArg(t, dir, resource, 0), p.Calls } func getCompatibleResources(p *Prog, resourceType string, r *randGen) (resources []*ResultArg) { @@ -906,7 +905,7 @@ func getCompatibleResources(p *Prog, resourceType string, r *randGen) (resources ForeachArg(c, func(arg Arg, _ *ArgCtx) { // Collect only initialized resources (the ones that are already used in other calls). a, ok := arg.(*ResultArg) - if !ok || len(a.uses) == 0 || a.typ.Dir() != DirOut { + if !ok || len(a.uses) == 0 || a.Dir() != DirOut { return } if !r.target.isCompatibleResource(resourceType, a.typ.Name()) { diff --git a/prog/rand_test.go b/prog/rand_test.go index 09f3b359c..6f251cb7c 100644 --- a/prog/rand_test.go +++ b/prog/rand_test.go @@ -102,14 +102,14 @@ func TestSizeGenerateConstArg(t *testing.T) { target, rs, iters := initRandomTargetTest(t, "test", "64") r := newRand(target, rs) for _, c := range target.Syscalls { - ForeachType(c, func(typ Type) { + foreachType(c, func(typ Type, ctx typeCtx) { if _, ok := typ.(*IntType); !ok { return } bits := typ.TypeBitSize() limit := uint64(1<<bits - 1) for i := 0; i < iters; i++ { - newArg, _ := typ.generate(r, nil) + newArg, _ := typ.generate(r, nil, ctx.Dir) newVal := newArg.(*ConstArg).Val if newVal > limit { t.Fatalf("invalid generated value: %d. (arg bitsize: %d; max value: %d)", newVal, bits, limit) diff --git a/prog/resources.go b/prog/resources.go index fa03e6cd6..b7bcecf95 100644 --- a/prog/resources.go +++ b/prog/resources.go @@ -46,10 +46,10 @@ func (target *Target) populateResourceCtors() { // Find resources that are created by each call. callsResources := make([][]*ResourceDesc, len(target.Syscalls)) for call, meta := range target.Syscalls { - ForeachType(meta, func(typ Type) { + foreachType(meta, func(typ Type, ctx typeCtx) { switch typ1 := typ.(type) { case *ResourceType: - if typ1.Dir() != DirIn { + if ctx.Dir != DirIn { callsResources[call] = append(callsResources[call], typ1.Desc) } } @@ -84,21 +84,19 @@ func (target *Target) populateResourceCtors() { // isCompatibleResource returns true if resource of kind src can be passed as an argument of kind dst. func (target *Target) isCompatibleResource(dst, src string) bool { - if dst == target.any.res16.TypeName || - dst == target.any.res32.TypeName || - dst == target.any.res64.TypeName || - dst == target.any.resdec.TypeName || - dst == target.any.reshex.TypeName || - dst == target.any.resoct.TypeName { + if target.isAnyRes(dst) { return true } + if target.isAnyRes(src) { + return false + } dstRes := target.resourceMap[dst] if dstRes == nil { - panic(fmt.Sprintf("unknown resource '%v'", dst)) + panic(fmt.Sprintf("unknown resource %q", dst)) } srcRes := target.resourceMap[src] if srcRes == nil { - panic(fmt.Sprintf("unknown resource '%v'", src)) + panic(fmt.Sprintf("unknown resource %q", src)) } return isCompatibleResourceImpl(dstRes.Kind, srcRes.Kind, false) } @@ -128,8 +126,8 @@ func isCompatibleResourceImpl(dst, src []string, precise bool) bool { func (target *Target) getInputResources(c *Syscall) []*ResourceDesc { var resources []*ResourceDesc - ForeachType(c, func(typ Type) { - if typ.Dir() == DirOut { + foreachType(c, func(typ Type, ctx typeCtx) { + if ctx.Dir == DirOut { return } switch typ1 := typ.(type) { @@ -148,10 +146,10 @@ func (target *Target) getInputResources(c *Syscall) []*ResourceDesc { func (target *Target) getOutputResources(c *Syscall) []*ResourceDesc { var resources []*ResourceDesc - ForeachType(c, func(typ Type) { + foreachType(c, func(typ Type, ctx typeCtx) { switch typ1 := typ.(type) { case *ResourceType: - if typ1.Dir() != DirIn { + if ctx.Dir != DirIn { resources = append(resources, typ1.Desc) } } diff --git a/prog/rotation.go b/prog/rotation.go index f95ffa03d..47ee2ca81 100644 --- a/prog/rotation.go +++ b/prog/rotation.go @@ -50,7 +50,7 @@ func MakeRotator(target *Target, calls map[*Syscall]bool, rnd *rand.Rand) *Rotat } // VMAs and filenames are effectively resources for our purposes // (but they don't have ctors). - ForeachType(call, func(t Type) { + foreachType(call, func(t Type, _ typeCtx) { switch a := t.(type) { case *BufferType: switch a.Kind { diff --git a/prog/target.go b/prog/target.go index f9e10b6f4..692d0b877 100644 --- a/prog/target.go +++ b/prog/target.go @@ -45,7 +45,7 @@ type Target struct { // allocate memory, etc. typ is the struct/union type. old is the old value of the struct/union // for mutation, or nil for generation. The function returns a new value of the struct/union, // and optionally any calls that need to be inserted before the arg reference. - SpecialTypes map[string]func(g *Gen, typ Type, old Arg) (Arg, []*Call) + SpecialTypes map[string]func(g *Gen, typ Type, dir Dir, old Arg) (Arg, []*Call) // Special strings that can matter for the target. // Used as fallback when string type does not have own dictionary. @@ -197,7 +197,7 @@ func restoreLinks(syscalls []*Syscall, resources []*ResourceDesc, structs []*Key if c.Ret != nil { unref(&c.Ret, types) } - ForeachType(c, func(t0 Type) { + foreachType(c, func(t0 Type, _ typeCtx) { switch t := t0.(type) { case *PtrType: unref(&t.Type, types) @@ -247,20 +247,20 @@ func (g *Gen) NOutOf(n, outOf int) bool { return g.r.nOutOf(n, outOf) } -func (g *Gen) Alloc(ptrType Type, data Arg) (Arg, []*Call) { - return g.r.allocAddr(g.s, ptrType, data.Size(), data), nil +func (g *Gen) Alloc(ptrType Type, dir Dir, data Arg) (Arg, []*Call) { + return g.r.allocAddr(g.s, ptrType, dir, data.Size(), data), nil } -func (g *Gen) GenerateArg(typ Type, pcalls *[]*Call) Arg { - return g.generateArg(typ, pcalls, false) +func (g *Gen) GenerateArg(typ Type, dir Dir, pcalls *[]*Call) Arg { + return g.generateArg(typ, dir, pcalls, false) } -func (g *Gen) GenerateSpecialArg(typ Type, pcalls *[]*Call) Arg { - return g.generateArg(typ, pcalls, true) +func (g *Gen) GenerateSpecialArg(typ Type, dir Dir, pcalls *[]*Call) Arg { + return g.generateArg(typ, dir, pcalls, true) } -func (g *Gen) generateArg(typ Type, pcalls *[]*Call, ignoreSpecial bool) Arg { - arg, calls := g.r.generateArgImpl(g.s, typ, ignoreSpecial) +func (g *Gen) generateArg(typ Type, dir Dir, pcalls *[]*Call, ignoreSpecial bool) Arg { + arg, calls := g.r.generateArgImpl(g.s, typ, dir, ignoreSpecial) *pcalls = append(*pcalls, calls...) g.r.target.assignSizesArray([]Arg{arg}, nil) return arg diff --git a/prog/types.go b/prog/types.go index ac272c3c1..a3b3c9709 100644 --- a/prog/types.go +++ b/prog/types.go @@ -44,7 +44,7 @@ type SyscallAttrs struct { // Executor also knows about this value. const MaxArgs = 9 -type Dir int +type Dir uint8 const ( DirIn Dir = iota @@ -80,7 +80,6 @@ type Type interface { Name() string FieldName() string TemplateName() string // for template structs name without arguments - Dir() Dir Optional() bool Varlen() bool Size() uint64 @@ -95,9 +94,9 @@ type Type interface { UnitSize() uint64 UnitOffset() uint64 - DefaultArg() Arg + DefaultArg(dir Dir) Arg isDefaultArg(arg Arg) bool - generate(r *randGen, s *state) (arg Arg, calls []*Call) + generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*Call, retry, preserve bool) getMutationPrio(target *Target, arg Arg, ignoreSpecial bool) (prio float64, stopRecursion bool) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool @@ -105,25 +104,25 @@ type Type interface { type Ref uint32 -func (ti Ref) String() string { panic("prog.Ref method called") } -func (ti Ref) Name() string { panic("prog.Ref method called") } -func (ti Ref) FieldName() string { panic("prog.Ref method called") } -func (ti Ref) TemplateName() string { panic("prog.Ref method called") } -func (ti Ref) Dir() Dir { panic("prog.Ref method called") } -func (ti Ref) Optional() bool { panic("prog.Ref method called") } -func (ti Ref) Varlen() bool { panic("prog.Ref method called") } -func (ti Ref) Size() uint64 { panic("prog.Ref method called") } -func (ti Ref) TypeBitSize() uint64 { panic("prog.Ref method called") } -func (ti Ref) Format() BinaryFormat { panic("prog.Ref method called") } -func (ti Ref) BitfieldOffset() uint64 { panic("prog.Ref method called") } -func (ti Ref) BitfieldLength() uint64 { panic("prog.Ref method called") } -func (ti Ref) IsBitfield() bool { panic("prog.Ref method called") } -func (ti Ref) UnitSize() uint64 { panic("prog.Ref method called") } -func (ti Ref) UnitOffset() uint64 { panic("prog.Ref method called") } -func (ti Ref) DefaultArg() Arg { panic("prog.Ref method called") } -func (ti Ref) Clone() Type { panic("prog.Ref method called") } -func (ti Ref) isDefaultArg(arg Arg) bool { panic("prog.Ref method called") } -func (ti Ref) generate(r *randGen, s *state) (Arg, []*Call) { panic("prog.Ref method called") } +func (ti Ref) String() string { panic("prog.Ref method called") } +func (ti Ref) Name() string { panic("prog.Ref method called") } +func (ti Ref) FieldName() string { panic("prog.Ref method called") } +func (ti Ref) TemplateName() string { panic("prog.Ref method called") } + +func (ti Ref) Optional() bool { panic("prog.Ref method called") } +func (ti Ref) Varlen() bool { panic("prog.Ref method called") } +func (ti Ref) Size() uint64 { panic("prog.Ref method called") } +func (ti Ref) TypeBitSize() uint64 { panic("prog.Ref method called") } +func (ti Ref) Format() BinaryFormat { panic("prog.Ref method called") } +func (ti Ref) BitfieldOffset() uint64 { panic("prog.Ref method called") } +func (ti Ref) BitfieldLength() uint64 { panic("prog.Ref method called") } +func (ti Ref) IsBitfield() bool { panic("prog.Ref method called") } +func (ti Ref) UnitSize() uint64 { panic("prog.Ref method called") } +func (ti Ref) UnitOffset() uint64 { panic("prog.Ref method called") } +func (ti Ref) DefaultArg(dir Dir) Arg { panic("prog.Ref method called") } +func (ti Ref) Clone() Type { panic("prog.Ref method called") } +func (ti Ref) isDefaultArg(arg Arg) bool { panic("prog.Ref method called") } +func (ti Ref) generate(r *randGen, s *state, dir Dir) (Arg, []*Call) { panic("prog.Ref method called") } func (ti Ref) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) ([]*Call, bool, bool) { panic("prog.Ref method called") } @@ -146,7 +145,6 @@ type TypeCommon struct { FldName string // for struct fields and named args // Static size of the type, or 0 for variable size types and all but last bitfields in the group. TypeSize uint64 - ArgDir Dir IsOptional bool IsVarlen bool } @@ -210,10 +208,6 @@ func (t *TypeCommon) IsBitfield() bool { return false } -func (t TypeCommon) Dir() Dir { - return t.ArgDir -} - type ResourceDesc struct { Name string Kind []string @@ -236,8 +230,8 @@ func (t *ResourceType) String() string { return t.Name() } -func (t *ResourceType) DefaultArg() Arg { - return MakeResultArg(t, nil, t.Default()) +func (t *ResourceType) DefaultArg(dir Dir) Arg { + return MakeResultArg(t, dir, nil, t.Default()) } func (t *ResourceType) isDefaultArg(arg Arg) bool { @@ -317,8 +311,8 @@ type ConstType struct { IsPad bool } -func (t *ConstType) DefaultArg() Arg { - return MakeConstArg(t, t.Val) +func (t *ConstType) DefaultArg(dir Dir) Arg { + return MakeConstArg(t, dir, t.Val) } func (t *ConstType) isDefaultArg(arg Arg) bool { @@ -347,8 +341,8 @@ type IntType struct { Align uint64 } -func (t *IntType) DefaultArg() Arg { - return MakeConstArg(t, 0) +func (t *IntType) DefaultArg(dir Dir) Arg { + return MakeConstArg(t, dir, 0) } func (t *IntType) isDefaultArg(arg Arg) bool { @@ -361,8 +355,8 @@ type FlagsType struct { BitMask bool } -func (t *FlagsType) DefaultArg() Arg { - return MakeConstArg(t, 0) +func (t *FlagsType) DefaultArg(dir Dir) Arg { + return MakeConstArg(t, dir, 0) } func (t *FlagsType) isDefaultArg(arg Arg) bool { @@ -376,8 +370,8 @@ type LenType struct { Path []string } -func (t *LenType) DefaultArg() Arg { - return MakeConstArg(t, 0) +func (t *LenType) DefaultArg(dir Dir) Arg { + return MakeConstArg(t, dir, 0) } func (t *LenType) isDefaultArg(arg Arg) bool { @@ -395,8 +389,8 @@ const ( procDefaultValue = 0xffffffffffffffff // special value denoting 0 for all procs ) -func (t *ProcType) DefaultArg() Arg { - return MakeConstArg(t, procDefaultValue) +func (t *ProcType) DefaultArg(dir Dir) Arg { + return MakeConstArg(t, dir, procDefaultValue) } func (t *ProcType) isDefaultArg(arg Arg) bool { @@ -421,8 +415,8 @@ func (t *CsumType) String() string { return "csum" } -func (t *CsumType) DefaultArg() Arg { - return MakeConstArg(t, 0) +func (t *CsumType) DefaultArg(dir Dir) Arg { + return MakeConstArg(t, dir, 0) } func (t *CsumType) isDefaultArg(arg Arg) bool { @@ -439,8 +433,8 @@ func (t *VmaType) String() string { return "vma" } -func (t *VmaType) DefaultArg() Arg { - return MakeSpecialPointerArg(t, 0) +func (t *VmaType) DefaultArg(dir Dir) Arg { + return MakeSpecialPointerArg(t, dir, 0) } func (t *VmaType) isDefaultArg(arg Arg) bool { @@ -484,19 +478,19 @@ func (t *BufferType) String() string { return "buffer" } -func (t *BufferType) DefaultArg() Arg { - if t.Dir() == DirOut { +func (t *BufferType) DefaultArg(dir Dir) Arg { + if dir == DirOut { var sz uint64 if !t.Varlen() { sz = t.Size() } - return MakeOutDataArg(t, sz) + return MakeOutDataArg(t, dir, sz) } var data []byte if !t.Varlen() { data = make([]byte, t.Size()) } - return MakeDataArg(t, data) + return MakeDataArg(t, dir, data) } func (t *BufferType) isDefaultArg(arg Arg) bool { @@ -507,7 +501,7 @@ func (t *BufferType) isDefaultArg(arg Arg) bool { if a.Type().Varlen() { return false } - if a.Type().Dir() == DirOut { + if a.Dir() == DirOut { return true } for _, v := range a.Data() { @@ -537,14 +531,14 @@ func (t *ArrayType) String() string { return fmt.Sprintf("array[%v]", t.Type.String()) } -func (t *ArrayType) DefaultArg() Arg { +func (t *ArrayType) DefaultArg(dir Dir) Arg { var elems []Arg if t.Kind == ArrayRangeLen && t.RangeBegin == t.RangeEnd { for i := uint64(0); i < t.RangeBegin; i++ { - elems = append(elems, t.Type.DefaultArg()) + elems = append(elems, t.Type.DefaultArg(dir)) } } - return MakeGroupArg(t, elems) + return MakeGroupArg(t, dir, elems) } func (t *ArrayType) isDefaultArg(arg Arg) bool { @@ -562,18 +556,19 @@ func (t *ArrayType) isDefaultArg(arg Arg) bool { type PtrType struct { TypeCommon - Type Type + Type Type + ElemDir Dir } func (t *PtrType) String() string { - return fmt.Sprintf("ptr[%v, %v]", t.Dir(), t.Type.String()) + return fmt.Sprintf("ptr[%v, %v]", t.ElemDir, t.Type.String()) } -func (t *PtrType) DefaultArg() Arg { +func (t *PtrType) DefaultArg(dir Dir) Arg { if t.Optional() { - return MakeSpecialPointerArg(t, 0) + return MakeSpecialPointerArg(t, dir, 0) } - return MakePointerArg(t, 0, t.Type.DefaultArg()) + return MakePointerArg(t, dir, 0, t.Type.DefaultArg(t.ElemDir)) } func (t *PtrType) isDefaultArg(arg Arg) bool { @@ -598,12 +593,12 @@ func (t *StructType) FieldName() string { return t.FldName } -func (t *StructType) DefaultArg() Arg { +func (t *StructType) DefaultArg(dir Dir) Arg { inner := make([]Arg, len(t.Fields)) for i, field := range t.Fields { - inner[i] = field.DefaultArg() + inner[i] = field.DefaultArg(dir) } - return MakeGroupArg(t, inner) + return MakeGroupArg(t, dir, inner) } func (t *StructType) isDefaultArg(arg Arg) bool { @@ -630,8 +625,8 @@ func (t *UnionType) FieldName() string { return t.FldName } -func (t *UnionType) DefaultArg() Arg { - return MakeUnionArg(t, t.Fields[0].DefaultArg()) +func (t *UnionType) DefaultArg(dir Dir) Arg { + return MakeUnionArg(t, dir, t.Fields[0].DefaultArg(dir)) } func (t *UnionType) isDefaultArg(arg Arg) bool { @@ -651,7 +646,6 @@ func (t *StructDesc) FieldName() string { type StructKey struct { Name string - Dir Dir } type KeyedStruct struct { @@ -664,29 +658,33 @@ type ConstValue struct { Value uint64 } -func ForeachType(meta *Syscall, f func(Type)) { - var rec func(t Type) +type typeCtx struct { + Dir Dir +} + +func foreachType(meta *Syscall, f func(t Type, ctx typeCtx)) { + var rec func(t Type, dir Dir) seen := make(map[*StructDesc]bool) - recStruct := func(desc *StructDesc) { + recStruct := func(desc *StructDesc, dir Dir) { if seen[desc] { return // prune recursion via pointers to structs/unions } seen[desc] = true for _, f := range desc.Fields { - rec(f) + rec(f, dir) } } - rec = func(t Type) { - f(t) + rec = func(t Type, dir Dir) { + f(t, typeCtx{Dir: dir}) switch a := t.(type) { case *PtrType: - rec(a.Type) + rec(a.Type, a.ElemDir) case *ArrayType: - rec(a.Type) + rec(a.Type, dir) case *StructType: - recStruct(a.StructDesc) + recStruct(a.StructDesc, dir) case *UnionType: - recStruct(a.StructDesc) + recStruct(a.StructDesc, dir) case *ResourceType, *BufferType, *VmaType, *LenType, *FlagsType, *ConstType, *IntType, *ProcType, *CsumType: default: @@ -694,10 +692,10 @@ func ForeachType(meta *Syscall, f func(Type)) { } } for _, t := range meta.Args { - rec(t) + rec(t, DirIn) } if meta.Ret != nil { - rec(meta.Ret) + rec(meta.Ret, DirOut) } } diff --git a/prog/validation.go b/prog/validation.go index c8c030aa7..8ec0615da 100644 --- a/prog/validation.go +++ b/prog/validation.go @@ -55,7 +55,7 @@ func (ctx *validCtx) validateCall(c *Call) error { len(c.Meta.Args), len(c.Args)) } for i, arg := range c.Args { - if err := ctx.validateArg(arg, c.Meta.Args[i]); err != nil { + if err := ctx.validateArg(arg, c.Meta.Args[i], DirIn); err != nil { return err } } @@ -72,16 +72,13 @@ func (ctx *validCtx) validateRet(c *Call) error { if c.Ret == nil { return fmt.Errorf("return value is absent") } - if c.Ret.Type().Dir() != DirOut { - return fmt.Errorf("return value %v is not output", c.Ret) - } if c.Ret.Res != nil || c.Ret.Val != 0 || c.Ret.OpDiv != 0 || c.Ret.OpAdd != 0 { return fmt.Errorf("return value %v is not empty", c.Ret) } - return ctx.validateArg(c.Ret, c.Meta.Ret) + return ctx.validateArg(c.Ret, c.Meta.Ret, DirOut) } -func (ctx *validCtx) validateArg(arg Arg, typ Type) error { +func (ctx *validCtx) validateArg(arg Arg, typ Type, dir Dir) error { if arg == nil { return fmt.Errorf("nil arg") } @@ -91,6 +88,12 @@ func (ctx *validCtx) validateArg(arg Arg, typ Type) error { if arg.Type() == nil { return fmt.Errorf("no arg type") } + if _, ok := typ.(*PtrType); ok { + dir = DirIn // pointers are always in + } + if arg.Dir() != dir { + return fmt.Errorf("arg %#v type %v has wrong dir %v, expect %v", arg, arg.Type(), arg.Dir(), dir) + } if !ctx.target.isAnyPtr(arg.Type()) && arg.Type() != typ { return fmt.Errorf("bad arg type %#v, expect %#v", arg.Type(), typ) } @@ -101,7 +104,7 @@ func (ctx *validCtx) validateArg(arg Arg, typ Type) error { func (arg *ConstArg) validate(ctx *validCtx) error { switch typ := arg.Type().(type) { case *IntType: - if typ.Dir() == DirOut && !isDefault(arg) { + if arg.Dir() == DirOut && !isDefault(arg) { return fmt.Errorf("out int arg '%v' has bad const value %v", typ.Name(), arg.Val) } case *ProcType: @@ -116,9 +119,10 @@ func (arg *ConstArg) validate(ctx *validCtx) error { default: return fmt.Errorf("const arg %v has bad type %v", arg, typ.Name()) } - if typ := arg.Type(); typ.Dir() == DirOut { + if arg.Dir() == DirOut { // We generate output len arguments, which makes sense since it can be // a length of a variable-length array which is not known otherwise. + typ := arg.Type() if _, isLen := typ.(*LenType); !isLen { if !typ.isDefaultArg(arg) { return fmt.Errorf("output arg '%v'/'%v' has non default value '%+v'", @@ -143,7 +147,7 @@ func (arg *ResultArg) validate(ctx *validCtx) error { } ctx.uses[u] = arg } - if typ.Dir() == DirOut && arg.Val != 0 && arg.Val != typ.Default() { + if arg.Dir() == DirOut && arg.Val != 0 && arg.Val != typ.Default() { return fmt.Errorf("out resource arg '%v' has bad const value %v", typ.Name(), arg.Val) } if arg.Res != nil { @@ -163,7 +167,7 @@ func (arg *DataArg) validate(ctx *validCtx) error { if !ok { return fmt.Errorf("data arg %v has bad type %v", arg, arg.Type().Name()) } - if typ.Dir() == DirOut && len(arg.data) != 0 { + if arg.Dir() == DirOut && len(arg.data) != 0 { return fmt.Errorf("output arg '%v' has data", typ.Name()) } if !typ.Varlen() && typ.Size() != arg.Size() { @@ -188,7 +192,7 @@ func (arg *GroupArg) validate(ctx *validCtx) error { typ.Name(), len(typ.Fields), len(arg.Inner)) } for i, field := range arg.Inner { - if err := ctx.validateArg(field, typ.Fields[i]); err != nil { + if err := ctx.validateArg(field, typ.Fields[i], arg.Dir()); err != nil { return err } } @@ -199,7 +203,7 @@ func (arg *GroupArg) validate(ctx *validCtx) error { typ.Name(), len(arg.Inner), typ.RangeBegin) } for _, elem := range arg.Inner { - if err := ctx.validateArg(elem, typ.Type); err != nil { + if err := ctx.validateArg(elem, typ.Type, arg.Dir()); err != nil { return err } } @@ -224,7 +228,7 @@ func (arg *UnionArg) validate(ctx *validCtx) error { if optType == nil { return fmt.Errorf("union arg '%v' has bad option", typ.Name()) } - return ctx.validateArg(arg.Option, optType) + return ctx.validateArg(arg.Option, optType, arg.Dir()) } func (arg *PointerArg) validate(ctx *validCtx) error { @@ -235,14 +239,14 @@ func (arg *PointerArg) validate(ctx *validCtx) error { } case *PtrType: if arg.Res != nil { - if err := ctx.validateArg(arg.Res, typ.Type); err != nil { + if err := ctx.validateArg(arg.Res, typ.Type, typ.ElemDir); err != nil { return err } } if arg.VmaSize != 0 { return fmt.Errorf("pointer arg '%v' has nonzero size", typ.Name()) } - if typ.Dir() == DirOut { + if arg.Dir() == DirOut { return fmt.Errorf("pointer arg '%v' has output direction", typ.Name()) } default: diff --git a/sys/linux/init.go b/sys/linux/init.go index c816bbd28..9a6eaf62e 100644 --- a/sys/linux/init.go +++ b/sys/linux/init.go @@ -50,7 +50,7 @@ func InitTarget(target *prog.Target) { target.MakeDataMmap = targets.MakePosixMmap(target, true, true) target.Neutralize = arch.neutralize - target.SpecialTypes = map[string]func(g *prog.Gen, typ prog.Type, old prog.Arg) ( + target.SpecialTypes = map[string]func(g *prog.Gen, typ prog.Type, dir prog.Dir, old prog.Arg) ( prog.Arg, []*prog.Call){ "timespec": arch.generateTimespec, "timeval": arch.generateTimespec, @@ -276,7 +276,8 @@ func (arch *arch) neutralizeIoctl(c *prog.Call) { } } -func (arch *arch) generateTimespec(g *prog.Gen, typ0 prog.Type, old prog.Arg) (arg prog.Arg, calls []*prog.Call) { +func (arch *arch) generateTimespec(g *prog.Gen, typ0 prog.Type, dir prog.Dir, old prog.Arg) ( + arg prog.Arg, calls []*prog.Call) { typ := typ0.(*prog.StructType) // We need to generate timespec/timeval that are either // (1) definitely in the past, or @@ -293,9 +294,9 @@ func (arch *arch) generateTimespec(g *prog.Gen, typ0 prog.Type, old prog.Arg) (a switch { case g.NOutOf(1, 4): // Now for relative, past for absolute. - arg = prog.MakeGroupArg(typ, []prog.Arg{ - prog.MakeResultArg(typ.Fields[0], nil, 0), - prog.MakeResultArg(typ.Fields[1], nil, 0), + arg = prog.MakeGroupArg(typ, dir, []prog.Arg{ + prog.MakeResultArg(typ.Fields[0], dir, nil, 0), + prog.MakeResultArg(typ.Fields[1], dir, nil, 0), }) case g.NOutOf(1, 3): // Few ms ahead for relative, past for absolute @@ -306,38 +307,38 @@ func (arch *arch) generateTimespec(g *prog.Gen, typ0 prog.Type, old prog.Arg) (a if usec { nsec /= 1e3 } - arg = prog.MakeGroupArg(typ, []prog.Arg{ - prog.MakeResultArg(typ.Fields[0], nil, 0), - prog.MakeResultArg(typ.Fields[1], nil, nsec), + arg = prog.MakeGroupArg(typ, dir, []prog.Arg{ + prog.MakeResultArg(typ.Fields[0], dir, nil, 0), + prog.MakeResultArg(typ.Fields[1], dir, nil, nsec), }) case g.NOutOf(1, 2): // Unreachable fututre for both relative and absolute - arg = prog.MakeGroupArg(typ, []prog.Arg{ - prog.MakeResultArg(typ.Fields[0], nil, 2e9), - prog.MakeResultArg(typ.Fields[1], nil, 0), + arg = prog.MakeGroupArg(typ, dir, []prog.Arg{ + prog.MakeResultArg(typ.Fields[0], dir, nil, 2e9), + prog.MakeResultArg(typ.Fields[1], dir, nil, 0), }) default: // Few ms ahead for absolute. meta := arch.clockGettimeSyscall ptrArgType := meta.Args[1].(*prog.PtrType) argType := ptrArgType.Type.(*prog.StructType) - tp := prog.MakeGroupArg(argType, []prog.Arg{ - prog.MakeResultArg(argType.Fields[0], nil, 0), - prog.MakeResultArg(argType.Fields[1], nil, 0), + tp := prog.MakeGroupArg(argType, prog.DirOut, []prog.Arg{ + prog.MakeResultArg(argType.Fields[0], prog.DirOut, nil, 0), + prog.MakeResultArg(argType.Fields[1], prog.DirOut, nil, 0), }) var tpaddr prog.Arg - tpaddr, calls = g.Alloc(ptrArgType, tp) + tpaddr, calls = g.Alloc(ptrArgType, prog.DirIn, tp) gettime := &prog.Call{ Meta: meta, Args: []prog.Arg{ - prog.MakeConstArg(meta.Args[0], arch.CLOCK_REALTIME), + prog.MakeConstArg(meta.Args[0], prog.DirIn, arch.CLOCK_REALTIME), tpaddr, }, Ret: prog.MakeReturnArg(meta.Ret), } calls = append(calls, gettime) - sec := prog.MakeResultArg(typ.Fields[0], tp.Inner[0].(*prog.ResultArg), 0) - nsec := prog.MakeResultArg(typ.Fields[1], tp.Inner[1].(*prog.ResultArg), 0) + sec := prog.MakeResultArg(typ.Fields[0], dir, tp.Inner[0].(*prog.ResultArg), 0) + nsec := prog.MakeResultArg(typ.Fields[1], dir, tp.Inner[1].(*prog.ResultArg), 0) msec := timeout1 if g.NOutOf(1, 2) { msec = timeout2 @@ -348,7 +349,7 @@ func (arch *arch) generateTimespec(g *prog.Gen, typ0 prog.Type, old prog.Arg) (a } else { nsec.OpAdd = msec * 1e6 } - arg = prog.MakeGroupArg(typ, []prog.Arg{sec, nsec}) + arg = prog.MakeGroupArg(typ, dir, []prog.Arg{sec, nsec}) } return } diff --git a/sys/linux/init_alg.go b/sys/linux/init_alg.go index 74734907a..0e6e13fac 100644 --- a/sys/linux/init_alg.go +++ b/sys/linux/init_alg.go @@ -9,58 +9,59 @@ import ( "github.com/google/syzkaller/prog" ) -func (arch *arch) generateSockaddrAlg(g *prog.Gen, typ0 prog.Type, old prog.Arg) ( +func (arch *arch) generateSockaddrAlg(g *prog.Gen, typ0 prog.Type, dir prog.Dir, old prog.Arg) ( arg prog.Arg, calls []*prog.Call) { typ := typ0.(*prog.StructType) - family := g.GenerateArg(typ.Fields[0], &calls) + family := g.GenerateArg(typ.Fields[0], dir, &calls) // There is very little point in generating feat/mask, // because that can only fail otherwise correct bind. - feat := prog.MakeConstArg(typ.Fields[2], 0) - mask := prog.MakeConstArg(typ.Fields[3], 0) + feat := prog.MakeConstArg(typ.Fields[2], dir, 0) + mask := prog.MakeConstArg(typ.Fields[3], dir, 0) if g.NOutOf(1, 1000) { - feat = g.GenerateArg(typ.Fields[2], &calls).(*prog.ConstArg) - mask = g.GenerateArg(typ.Fields[3], &calls).(*prog.ConstArg) + feat = g.GenerateArg(typ.Fields[2], dir, &calls).(*prog.ConstArg) + mask = g.GenerateArg(typ.Fields[3], dir, &calls).(*prog.ConstArg) } algType, algName := generateAlgName(g.Rand()) // Extend/truncate type/name to their fixed sizes. algTypeData := fixedSizeData(algType, typ.Fields[1].Size()) algNameData := fixedSizeData(algName, typ.Fields[4].Size()) - arg = prog.MakeGroupArg(typ, []prog.Arg{ + arg = prog.MakeGroupArg(typ, dir, []prog.Arg{ family, - prog.MakeDataArg(typ.Fields[1], algTypeData), + prog.MakeDataArg(typ.Fields[1], dir, algTypeData), feat, mask, - prog.MakeDataArg(typ.Fields[4], algNameData), + prog.MakeDataArg(typ.Fields[4], dir, algNameData), }) return } -func (arch *arch) generateAlgName(g *prog.Gen, typ prog.Type, old prog.Arg) ( +func (arch *arch) generateAlgName(g *prog.Gen, typ prog.Type, dir prog.Dir, old prog.Arg) ( arg prog.Arg, calls []*prog.Call) { - return generateAlgNameStruct(g, typ, allTypes[g.Rand().Intn(len(allTypes))].typ) + return generateAlgNameStruct(g, typ, dir, allTypes[g.Rand().Intn(len(allTypes))].typ) } -func (arch *arch) generateAlgAeadName(g *prog.Gen, typ prog.Type, old prog.Arg) ( +func (arch *arch) generateAlgAeadName(g *prog.Gen, typ prog.Type, dir prog.Dir, old prog.Arg) ( arg prog.Arg, calls []*prog.Call) { - return generateAlgNameStruct(g, typ, ALG_AEAD) + return generateAlgNameStruct(g, typ, dir, ALG_AEAD) } -func (arch *arch) generateAlgHashName(g *prog.Gen, typ prog.Type, old prog.Arg) ( +func (arch *arch) generateAlgHashName(g *prog.Gen, typ prog.Type, dir prog.Dir, old prog.Arg) ( arg prog.Arg, calls []*prog.Call) { - return generateAlgNameStruct(g, typ, ALG_HASH) + return generateAlgNameStruct(g, typ, dir, ALG_HASH) } -func (arch *arch) generateAlgSkcipherhName(g *prog.Gen, typ prog.Type, old prog.Arg) ( +func (arch *arch) generateAlgSkcipherhName(g *prog.Gen, typ prog.Type, dir prog.Dir, old prog.Arg) ( arg prog.Arg, calls []*prog.Call) { - return generateAlgNameStruct(g, typ, ALG_SKCIPHER) + return generateAlgNameStruct(g, typ, dir, ALG_SKCIPHER) } -func generateAlgNameStruct(g *prog.Gen, typ0 prog.Type, algTyp int) (arg prog.Arg, calls []*prog.Call) { +func generateAlgNameStruct(g *prog.Gen, typ0 prog.Type, dir prog.Dir, algTyp int) ( + arg prog.Arg, calls []*prog.Call) { typ := typ0.(*prog.StructType) algName := generateAlg(g.Rand(), algTyp) algNameData := fixedSizeData(algName, typ.Fields[0].Size()) - arg = prog.MakeGroupArg(typ, []prog.Arg{ - prog.MakeDataArg(typ.Fields[0], algNameData), + arg = prog.MakeGroupArg(typ, dir, []prog.Arg{ + prog.MakeDataArg(typ.Fields[0], dir, algNameData), }) return } diff --git a/sys/linux/init_iptables.go b/sys/linux/init_iptables.go index a1adf3fb0..7e96662fb 100644 --- a/sys/linux/init_iptables.go +++ b/sys/linux/init_iptables.go @@ -9,17 +9,17 @@ import ( "github.com/google/syzkaller/prog" ) -func (arch *arch) generateIptables(g *prog.Gen, typ prog.Type, old prog.Arg) ( +func (arch *arch) generateIptables(g *prog.Gen, typ prog.Type, dir prog.Dir, old prog.Arg) ( arg prog.Arg, calls []*prog.Call) { - return arch.generateNetfilterTable(g, typ, old, true, 5) + return arch.generateNetfilterTable(g, typ, dir, old, true, 5) } -func (arch *arch) generateArptables(g *prog.Gen, typ prog.Type, old prog.Arg) ( +func (arch *arch) generateArptables(g *prog.Gen, typ prog.Type, dir prog.Dir, old prog.Arg) ( arg prog.Arg, calls []*prog.Call) { - return arch.generateNetfilterTable(g, typ, old, false, 3) + return arch.generateNetfilterTable(g, typ, dir, old, false, 3) } -func (arch *arch) generateNetfilterTable(g *prog.Gen, typ prog.Type, old prog.Arg, +func (arch *arch) generateNetfilterTable(g *prog.Gen, typ prog.Type, dir prog.Dir, old prog.Arg, hasUnion bool, hookCount int) (arg prog.Arg, calls []*prog.Call) { const ( hookStart = 4 @@ -27,7 +27,7 @@ func (arch *arch) generateNetfilterTable(g *prog.Gen, typ prog.Type, old prog.Ar unused = uint64(^uint32(0)) ) if old == nil { - arg = g.GenerateSpecialArg(typ, &calls) + arg = g.GenerateSpecialArg(typ, dir, &calls) } else { // TODO(dvyukov): try to restore original hook order after mutation // instead of assigning brand new offsets. @@ -106,10 +106,10 @@ func (arch *arch) generateNetfilterTable(g *prog.Gen, typ prog.Type, old prog.Ar return } -func (arch *arch) generateEbtables(g *prog.Gen, typ prog.Type, old prog.Arg) ( +func (arch *arch) generateEbtables(g *prog.Gen, typ prog.Type, dir prog.Dir, old prog.Arg) ( arg prog.Arg, calls []*prog.Call) { if old == nil { - arg = g.GenerateSpecialArg(typ, &calls) + arg = g.GenerateSpecialArg(typ, dir, &calls) } else { // TODO(dvyukov): try to restore original hook order after mutation // instead of assigning brand new offsets. diff --git a/sys/linux/init_vusb.go b/sys/linux/init_vusb.go index 7a61b0942..fc847be41 100644 --- a/sys/linux/init_vusb.go +++ b/sys/linux/init_vusb.go @@ -50,11 +50,11 @@ type HidDeviceID struct { Product uint32 } -func (arch *arch) generateUsbDeviceDescriptor(g *prog.Gen, typ0 prog.Type, old prog.Arg) ( +func (arch *arch) generateUsbDeviceDescriptor(g *prog.Gen, typ0 prog.Type, dir prog.Dir, old prog.Arg) ( arg prog.Arg, calls []*prog.Call) { if old == nil { - arg = g.GenerateSpecialArg(typ0, &calls) + arg = g.GenerateSpecialArg(typ0, dir, &calls) } else { arg = old calls = g.MutateArg(arg) @@ -140,11 +140,11 @@ func randUsbDeviceID(g *prog.Gen) UsbDeviceID { return id } -func (arch *arch) generateUsbHidDeviceDescriptor(g *prog.Gen, typ0 prog.Type, old prog.Arg) ( +func (arch *arch) generateUsbHidDeviceDescriptor(g *prog.Gen, typ0 prog.Type, dir prog.Dir, old prog.Arg) ( arg prog.Arg, calls []*prog.Call) { if old == nil { - arg = g.GenerateSpecialArg(typ0, &calls) + arg = g.GenerateSpecialArg(typ0, dir, &calls) } else { arg = old calls = g.MutateArg(arg) diff --git a/sys/targets/common.go b/sys/targets/common.go index beac7004d..908be3ff7 100644 --- a/sys/targets/common.go +++ b/sys/targets/common.go @@ -22,19 +22,19 @@ func MakePosixMmap(target *prog.Target, exec, contain bool) func() []*prog.Call const invalidFD = ^uint64(0) makeMmap := func(addr, size, prot uint64) *prog.Call { args := []prog.Arg{ - prog.MakeVmaPointerArg(meta.Args[0], addr, size), - prog.MakeConstArg(meta.Args[1], size), - prog.MakeConstArg(meta.Args[2], prot), - prog.MakeConstArg(meta.Args[3], flags), - prog.MakeResultArg(meta.Args[4], nil, invalidFD), + prog.MakeVmaPointerArg(meta.Args[0], prog.DirIn, addr, size), + prog.MakeConstArg(meta.Args[1], prog.DirIn, size), + prog.MakeConstArg(meta.Args[2], prog.DirIn, prot), + prog.MakeConstArg(meta.Args[3], prog.DirIn, flags), + prog.MakeResultArg(meta.Args[4], prog.DirIn, nil, invalidFD), } i := len(args) // Some targets have a padding argument between fd and offset. if len(meta.Args) > 6 { - args = append(args, prog.MakeConstArg(meta.Args[i], 0)) + args = append(args, prog.MakeConstArg(meta.Args[i], prog.DirIn, 0)) i++ } - args = append(args, prog.MakeConstArg(meta.Args[i], 0)) + args = append(args, prog.MakeConstArg(meta.Args[i], prog.DirIn, 0)) return &prog.Call{ Meta: meta, Args: args, @@ -61,8 +61,8 @@ func MakeSyzMmap(target *prog.Target) func() []*prog.Call { { Meta: meta, Args: []prog.Arg{ - prog.MakeVmaPointerArg(meta.Args[0], 0, size), - prog.MakeConstArg(meta.Args[1], size), + prog.MakeVmaPointerArg(meta.Args[0], prog.DirIn, 0, size), + prog.MakeConstArg(meta.Args[1], prog.DirIn, size), }, Ret: prog.MakeReturnArg(meta.Ret), }, diff --git a/sys/windows/init.go b/sys/windows/init.go index 54d93777f..574123258 100644 --- a/sys/windows/init.go +++ b/sys/windows/init.go @@ -35,10 +35,10 @@ func (arch *arch) makeMmap() []*prog.Call { { Meta: meta, Args: []prog.Arg{ - prog.MakeVmaPointerArg(meta.Args[0], 0, size), - prog.MakeConstArg(meta.Args[1], size), - prog.MakeConstArg(meta.Args[2], arch.MEM_COMMIT|arch.MEM_RESERVE), - prog.MakeConstArg(meta.Args[3], arch.PAGE_EXECUTE_READWRITE), + prog.MakeVmaPointerArg(meta.Args[0], prog.DirIn, 0, size), + prog.MakeConstArg(meta.Args[1], prog.DirIn, size), + prog.MakeConstArg(meta.Args[2], prog.DirIn, arch.MEM_COMMIT|arch.MEM_RESERVE), + prog.MakeConstArg(meta.Args[3], prog.DirIn, arch.PAGE_EXECUTE_READWRITE), }, Ret: prog.MakeReturnArg(meta.Ret), }, diff --git a/tools/syz-trace2syz/proggen/generate_unions.go b/tools/syz-trace2syz/proggen/generate_unions.go index c41a18811..3a70033d8 100644 --- a/tools/syz-trace2syz/proggen/generate_unions.go +++ b/tools/syz-trace2syz/proggen/generate_unions.go @@ -11,7 +11,7 @@ import ( "github.com/google/syzkaller/tools/syz-trace2syz/parser" ) -func (ctx *context) genSockaddrStorage(syzType *prog.UnionType, straceType parser.IrType) prog.Arg { +func (ctx *context) genSockaddrStorage(syzType *prog.UnionType, dir prog.Dir, straceType parser.IrType) prog.Arg { field2Opt := make(map[string]int) for i, field := range syzType.Fields { field2Opt[field.FieldName()] = i @@ -44,10 +44,10 @@ func (ctx *context) genSockaddrStorage(syzType *prog.UnionType, straceType parse default: log.Fatalf("unable to parse sockaddr_storage. Unsupported type: %#v", strType) } - return prog.MakeUnionArg(syzType, ctx.genArgs(syzType.Fields[idx], straceType)) + return prog.MakeUnionArg(syzType, dir, ctx.genArg(syzType.Fields[idx], dir, straceType)) } -func (ctx *context) genSockaddrNetlink(syzType *prog.UnionType, straceType parser.IrType) prog.Arg { +func (ctx *context) genSockaddrNetlink(syzType *prog.UnionType, dir prog.Dir, straceType parser.IrType) prog.Arg { var idx = 2 field2Opt := make(map[string]int) for i, field := range syzType.Fields { @@ -74,14 +74,14 @@ func (ctx *context) genSockaddrNetlink(syzType *prog.UnionType, straceType parse } } } - return prog.MakeUnionArg(syzType, ctx.genArgs(syzType.Fields[idx], straceType)) + return prog.MakeUnionArg(syzType, dir, ctx.genArg(syzType.Fields[idx], dir, straceType)) } -func (ctx *context) genIfrIfru(syzType *prog.UnionType, straceType parser.IrType) prog.Arg { +func (ctx *context) genIfrIfru(syzType *prog.UnionType, dir prog.Dir, straceType parser.IrType) prog.Arg { idx := 0 switch straceType.(type) { case parser.Constant: idx = 2 } - return prog.MakeUnionArg(syzType, ctx.genArgs(syzType.Fields[idx], straceType)) + return prog.MakeUnionArg(syzType, dir, ctx.genArg(syzType.Fields[idx], dir, straceType)) } diff --git a/tools/syz-trace2syz/proggen/proggen.go b/tools/syz-trace2syz/proggen/proggen.go index d027360de..1189fe2ac 100644 --- a/tools/syz-trace2syz/proggen/proggen.go +++ b/tools/syz-trace2syz/proggen/proggen.go @@ -115,7 +115,7 @@ func (ctx *context) genCall() *prog.Call { if i < len(straceCall.Args) { strArg = straceCall.Args[i] } - res := ctx.genArgs(syzCall.Meta.Args[i], strArg) + res := ctx.genArg(syzCall.Meta.Args[i], prog.DirIn, strArg) syzCall.Args = append(syzCall.Args, res) } ctx.genResult(syzCall.Meta.Ret, straceCall.Ret) @@ -144,91 +144,91 @@ func (ctx *context) genResult(syzType prog.Type, straceRet int64) { } } -func (ctx *context) genArgs(syzType prog.Type, traceArg parser.IrType) prog.Arg { +func (ctx *context) genArg(syzType prog.Type, dir prog.Dir, traceArg parser.IrType) prog.Arg { if traceArg == nil { log.Logf(3, "parsing syzType: %s, traceArg is nil. generating default arg...", syzType.Name()) - return syzType.DefaultArg() + return syzType.DefaultArg(dir) } log.Logf(3, "parsing arg of syz type: %s, ir type: %#v", syzType.Name(), traceArg) - if syzType.Dir() == prog.DirOut { + if dir == prog.DirOut { switch syzType.(type) { case *prog.PtrType, *prog.StructType, *prog.ResourceType, *prog.BufferType: // Resource Types need special care. Pointers, Structs can have resource fields e.g. pipe, socketpair // Buffer may need special care in out direction default: - return syzType.DefaultArg() + return syzType.DefaultArg(dir) } } switch a := syzType.(type) { case *prog.IntType, *prog.ConstType, *prog.FlagsType, *prog.CsumType: - return ctx.genConst(a, traceArg) + return ctx.genConst(a, dir, traceArg) case *prog.LenType: - return syzType.DefaultArg() + return syzType.DefaultArg(dir) case *prog.ProcType: - return ctx.parseProc(a, traceArg) + return ctx.parseProc(a, dir, traceArg) case *prog.ResourceType: - return ctx.genResource(a, traceArg) + return ctx.genResource(a, dir, traceArg) case *prog.PtrType: - return ctx.genPtr(a, traceArg) + return ctx.genPtr(a, dir, traceArg) case *prog.BufferType: - return ctx.genBuffer(a, traceArg) + return ctx.genBuffer(a, dir, traceArg) case *prog.StructType: - return ctx.genStruct(a, traceArg) + return ctx.genStruct(a, dir, traceArg) case *prog.ArrayType: - return ctx.genArray(a, traceArg) + return ctx.genArray(a, dir, traceArg) case *prog.UnionType: - return ctx.genUnionArg(a, traceArg) + return ctx.genUnionArg(a, dir, traceArg) case *prog.VmaType: - return ctx.genVma(a, traceArg) + return ctx.genVma(a, dir, traceArg) default: log.Fatalf("unsupported type: %#v", syzType) } return nil } -func (ctx *context) genVma(syzType *prog.VmaType, _ parser.IrType) prog.Arg { +func (ctx *context) genVma(syzType *prog.VmaType, dir prog.Dir, _ parser.IrType) prog.Arg { npages := uint64(1) if syzType.RangeBegin != 0 || syzType.RangeEnd != 0 { npages = syzType.RangeEnd } - return prog.MakeVmaPointerArg(syzType, ctx.builder.AllocateVMA(npages), npages) + return prog.MakeVmaPointerArg(syzType, dir, ctx.builder.AllocateVMA(npages), npages) } -func (ctx *context) genArray(syzType *prog.ArrayType, traceType parser.IrType) prog.Arg { +func (ctx *context) genArray(syzType *prog.ArrayType, dir prog.Dir, traceType parser.IrType) prog.Arg { var args []prog.Arg switch a := traceType.(type) { case *parser.GroupType: for i := 0; i < len(a.Elems); i++ { - args = append(args, ctx.genArgs(syzType.Type, a.Elems[i])) + args = append(args, ctx.genArg(syzType.Type, dir, a.Elems[i])) } default: log.Fatalf("unsupported type for array: %#v", traceType) } - return prog.MakeGroupArg(syzType, args) + return prog.MakeGroupArg(syzType, dir, args) } -func (ctx *context) genStruct(syzType *prog.StructType, traceType parser.IrType) prog.Arg { +func (ctx *context) genStruct(syzType *prog.StructType, dir prog.Dir, traceType parser.IrType) prog.Arg { var args []prog.Arg switch a := traceType.(type) { case *parser.GroupType: j := 0 - if ret, recursed := ctx.recurseStructs(syzType, a); recursed { + if ret, recursed := ctx.recurseStructs(syzType, dir, a); recursed { return ret } for i := range syzType.Fields { if prog.IsPad(syzType.Fields[i]) { - args = append(args, syzType.Fields[i].DefaultArg()) + args = append(args, syzType.Fields[i].DefaultArg(dir)) continue } // If the last n fields of a struct are zero or NULL, strace will occasionally omit those values // this creates a mismatch in the number of elements in the ir type and in // our descriptions. We generate default values for omitted fields if j >= len(a.Elems) { - args = append(args, syzType.Fields[i].DefaultArg()) + args = append(args, syzType.Fields[i].DefaultArg(dir)) } else { - args = append(args, ctx.genArgs(syzType.Fields[i], a.Elems[j])) + args = append(args, ctx.genArg(syzType.Fields[i], dir, a.Elems[j])) } j++ } @@ -236,11 +236,11 @@ func (ctx *context) genStruct(syzType *prog.StructType, traceType parser.IrType) // We could have a case like the following: // ioctl(3, 35111, {ifr_name="\x6c\x6f", ifr_hwaddr=00:00:00:00:00:00}) = 0 // if_hwaddr gets parsed as a BufferType but our syscall descriptions have it as a struct type - return syzType.DefaultArg() + return syzType.DefaultArg(dir) default: log.Fatalf("unsupported type for struct: %#v", a) } - return prog.MakeGroupArg(syzType, args) + return prog.MakeGroupArg(syzType, dir, args) } // recurseStructs handles cases where syzType corresponds to struct descriptions like @@ -249,7 +249,7 @@ func (ctx *context) genStruct(syzType *prog.StructType, traceType parser.IrType) // } [size[SOCKADDR_STORAGE_SIZE], align_ptr] // which need to be recursively generated. It returns true if we needed to recurse // along with the generated argument and false otherwise. -func (ctx *context) recurseStructs(syzType *prog.StructType, traceType *parser.GroupType) (prog.Arg, bool) { +func (ctx *context) recurseStructs(syzType *prog.StructType, dir prog.Dir, traceType *parser.GroupType) (prog.Arg, bool) { // only consider structs with one non-padded field numFields := 0 for _, field := range syzType.Fields { @@ -273,19 +273,19 @@ func (ctx *context) recurseStructs(syzType *prog.StructType, traceType *parser.G if len(t.Fields) != len(traceType.Elems) { return nil, false } - args = append(args, ctx.genStruct(t, traceType)) + args = append(args, ctx.genStruct(t, dir, traceType)) for _, field := range syzType.Fields[1:] { - args = append(args, field.DefaultArg()) + args = append(args, field.DefaultArg(dir)) } - return prog.MakeGroupArg(syzType, args), true + return prog.MakeGroupArg(syzType, dir, args), true } return nil, false } -func (ctx *context) genUnionArg(syzType *prog.UnionType, straceType parser.IrType) prog.Arg { +func (ctx *context) genUnionArg(syzType *prog.UnionType, dir prog.Dir, straceType parser.IrType) prog.Arg { if straceType == nil { log.Logf(1, "generating union arg. straceType is nil") - return syzType.DefaultArg() + return syzType.DefaultArg(dir) } log.Logf(4, "generating union arg: %s %#v", syzType.TypeName, straceType) @@ -295,33 +295,33 @@ func (ctx *context) genUnionArg(syzType *prog.UnionType, straceType parser.IrTyp switch syzType.TypeName { case "sockaddr_storage": - return ctx.genSockaddrStorage(syzType, straceType) + return ctx.genSockaddrStorage(syzType, dir, straceType) case "sockaddr_nl": - return ctx.genSockaddrNetlink(syzType, straceType) + return ctx.genSockaddrNetlink(syzType, dir, straceType) case "ifr_ifru": - return ctx.genIfrIfru(syzType, straceType) + return ctx.genIfrIfru(syzType, dir, straceType) } - return prog.MakeUnionArg(syzType, ctx.genArgs(syzType.Fields[0], straceType)) + return prog.MakeUnionArg(syzType, dir, ctx.genArg(syzType.Fields[0], dir, straceType)) } -func (ctx *context) genBuffer(syzType *prog.BufferType, traceType parser.IrType) prog.Arg { - if syzType.Dir() == prog.DirOut { +func (ctx *context) genBuffer(syzType *prog.BufferType, dir prog.Dir, traceType parser.IrType) prog.Arg { + if dir == prog.DirOut { if !syzType.Varlen() { - return prog.MakeOutDataArg(syzType, syzType.Size()) + return prog.MakeOutDataArg(syzType, dir, syzType.Size()) } switch a := traceType.(type) { case *parser.BufferType: - return prog.MakeOutDataArg(syzType, uint64(len(a.Val))) + return prog.MakeOutDataArg(syzType, dir, uint64(len(a.Val))) default: switch syzType.Kind { case prog.BufferBlobRand: size := rand.Intn(256) - return prog.MakeOutDataArg(syzType, uint64(size)) + return prog.MakeOutDataArg(syzType, dir, uint64(size)) case prog.BufferBlobRange: max := rand.Intn(int(syzType.RangeEnd) - int(syzType.RangeBegin) + 1) size := max + int(syzType.RangeBegin) - return prog.MakeOutDataArg(syzType, uint64(size)) + return prog.MakeOutDataArg(syzType, dir, uint64(size)) default: log.Fatalf("unexpected buffer type kind: %v. call %v arg %#v", syzType.Kind, ctx.currentSyzCall, traceType) } @@ -351,28 +351,28 @@ func (ctx *context) genBuffer(syzType *prog.BufferType, traceType parser.IrType) } bufVal = bufVal[:size] } - return prog.MakeDataArg(syzType, bufVal) + return prog.MakeDataArg(syzType, dir, bufVal) } -func (ctx *context) genPtr(syzType *prog.PtrType, traceType parser.IrType) prog.Arg { +func (ctx *context) genPtr(syzType *prog.PtrType, dir prog.Dir, traceType parser.IrType) prog.Arg { switch a := traceType.(type) { case parser.Constant: if a.Val() == 0 { - return prog.MakeSpecialPointerArg(syzType, 0) + return prog.MakeSpecialPointerArg(syzType, dir, 0) } // Likely have a type of the form bind(3, 0xfffffffff, [3]); - res := syzType.Type.DefaultArg() - return ctx.addr(syzType, res.Size(), res) + res := syzType.Type.DefaultArg(syzType.ElemDir) + return ctx.addr(syzType, dir, res.Size(), res) default: - res := ctx.genArgs(syzType.Type, a) - return ctx.addr(syzType, res.Size(), res) + res := ctx.genArg(syzType.Type, syzType.ElemDir, a) + return ctx.addr(syzType, dir, res.Size(), res) } } -func (ctx *context) genConst(syzType prog.Type, traceType parser.IrType) prog.Arg { +func (ctx *context) genConst(syzType prog.Type, dir prog.Dir, traceType parser.IrType) prog.Arg { switch a := traceType.(type) { case parser.Constant: - return prog.MakeConstArg(syzType, a.Val()) + return prog.MakeConstArg(syzType, dir, a.Val()) case *parser.GroupType: // Sometimes strace represents a pointer to int as [0] which gets parsed // as Array([0], len=1). A good example is ioctl(3, FIONBIO, [1]). We may also have an union int type that @@ -381,9 +381,9 @@ func (ctx *context) genConst(syzType prog.Type, traceType parser.IrType) prog.Ar // For now we choose the first option if len(a.Elems) == 0 { log.Logf(2, "parsing const type, got array type with len 0") - return syzType.DefaultArg() + return syzType.DefaultArg(dir) } - return ctx.genConst(syzType, a.Elems[0]) + return ctx.genConst(syzType, dir, a.Elems[0]) case *parser.BufferType: // strace decodes some arguments as hex strings because those values are network ordered // e.g. sin_port or sin_addr fields of sockaddr_in. @@ -409,19 +409,19 @@ func (ctx *context) genConst(syzType prog.Type, traceType parser.IrType) prog.Ar case 1: val = uint64(a.Val[0]) default: - return syzType.DefaultArg() + return syzType.DefaultArg(dir) } - return prog.MakeConstArg(syzType, val) + return prog.MakeConstArg(syzType, dir, val) default: log.Fatalf("unsupported type for const: %#v", traceType) } return nil } -func (ctx *context) genResource(syzType *prog.ResourceType, traceType parser.IrType) prog.Arg { - if syzType.Dir() == prog.DirOut { +func (ctx *context) genResource(syzType *prog.ResourceType, dir prog.Dir, traceType parser.IrType) prog.Arg { + if dir == prog.DirOut { log.Logf(2, "resource returned by call argument: %s", traceType.String()) - res := prog.MakeResultArg(syzType, nil, syzType.Default()) + res := prog.MakeResultArg(syzType, dir, nil, syzType.Default()) ctx.returnCache.cache(syzType, traceType, res) return res } @@ -429,17 +429,17 @@ func (ctx *context) genResource(syzType *prog.ResourceType, traceType parser.IrT case parser.Constant: val := a.Val() if arg := ctx.returnCache.get(syzType, traceType); arg != nil { - res := prog.MakeResultArg(syzType, arg.(*prog.ResultArg), syzType.Default()) + res := prog.MakeResultArg(syzType, dir, arg.(*prog.ResultArg), syzType.Default()) return res } - res := prog.MakeResultArg(syzType, nil, val) + res := prog.MakeResultArg(syzType, dir, nil, val) return res case *parser.GroupType: if len(a.Elems) == 1 { // For example: 5028 ioctl(3, SIOCSPGRP, [0]) = 0 // last argument is a pointer to a resource. Strace will output a pointer to // a number x as [x]. - res := prog.MakeResultArg(syzType, nil, syzType.Default()) + res := prog.MakeResultArg(syzType, dir, nil, syzType.Default()) ctx.returnCache.cache(syzType, a.Elems[0], res) return res } @@ -450,27 +450,27 @@ func (ctx *context) genResource(syzType *prog.ResourceType, traceType parser.IrT return nil } -func (ctx *context) parseProc(syzType *prog.ProcType, traceType parser.IrType) prog.Arg { +func (ctx *context) parseProc(syzType *prog.ProcType, dir prog.Dir, traceType parser.IrType) prog.Arg { switch a := traceType.(type) { case parser.Constant: val := a.Val() if val >= syzType.ValuesPerProc { - return prog.MakeConstArg(syzType, syzType.ValuesPerProc-1) + return prog.MakeConstArg(syzType, dir, syzType.ValuesPerProc-1) } - return prog.MakeConstArg(syzType, val) + return prog.MakeConstArg(syzType, dir, val) case *parser.BufferType: // Again probably an error case // Something like the following will trigger this // bind(3, {sa_family=AF_INET, sa_data="\xac"}, 3) = -1 EINVAL(Invalid argument) - return syzType.DefaultArg() + return syzType.DefaultArg(dir) default: log.Fatalf("unsupported type for proc: %#v", traceType) } return nil } -func (ctx *context) addr(syzType prog.Type, size uint64, data prog.Arg) prog.Arg { - return prog.MakePointerArg(syzType, ctx.builder.Allocate(size), data) +func (ctx *context) addr(syzType prog.Type, dir prog.Dir, size uint64, data prog.Arg) prog.Arg { + return prog.MakePointerArg(syzType, dir, ctx.builder.Allocate(size), data) } func shouldSkip(c *parser.Syscall) bool { |
