diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2020-04-26 14:14:14 +0200 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2020-05-01 13:31:17 +0200 |
| commit | e54e9781a4e043b3140b0c908ba4f4e469fd317e (patch) | |
| tree | 16e6387d78a8577c5f3d9fb8d05a51752da6338e /prog | |
| parent | 3f4dbb2f6fff9479d6c250e224bc3cb7f5cd66ed (diff) | |
prog: remove Dir from Type
Having Dir is Type is handy, but forces us to duplicate lots of types.
E.g. if a struct is referenced as both in and out, then we need to
have 2 copies and 2 copies of structs/types it includes.
If also prevents us from having the struct type as struct identity
(because we can have up to 3 of them).
Revert to the old way we used to do it: propagate Dir as we walk
syscall arguments. This moves lots of dir passing from pkg/compiler
to prog package.
Now Arg contains the dir, so once we build the tree, we can use dirs
as before.
Reduces size of sys/linux/gen/amd64.go from 6058336 to 5661150 (-6.6%).
Update #1580
Diffstat (limited to 'prog')
| -rw-r--r-- | prog/analysis.go | 4 | ||||
| -rw-r--r-- | prog/any.go | 34 | ||||
| -rw-r--r-- | prog/encoding.go | 118 | ||||
| -rw-r--r-- | prog/encodingexec.go | 2 | ||||
| -rw-r--r-- | prog/hints.go | 2 | ||||
| -rw-r--r-- | prog/hints_test.go | 8 | ||||
| -rw-r--r-- | prog/minimization.go | 6 | ||||
| -rw-r--r-- | prog/mutation.go | 28 | ||||
| -rw-r--r-- | prog/prio.go | 16 | ||||
| -rw-r--r-- | prog/prog.go | 60 | ||||
| -rw-r--r-- | prog/prog_test.go | 17 | ||||
| -rw-r--r-- | prog/rand.go | 145 | ||||
| -rw-r--r-- | prog/rand_test.go | 4 | ||||
| -rw-r--r-- | prog/resources.go | 26 | ||||
| -rw-r--r-- | prog/rotation.go | 2 | ||||
| -rw-r--r-- | prog/target.go | 20 | ||||
| -rw-r--r-- | prog/types.go | 148 | ||||
| -rw-r--r-- | prog/validation.go | 34 |
18 files changed, 348 insertions, 326 deletions
diff --git a/prog/analysis.go b/prog/analysis.go index fe022b670..e1fbaa557 100644 --- a/prog/analysis.go +++ b/prog/analysis.go @@ -69,13 +69,13 @@ func (s *state) analyzeImpl(c *Call, resources bool) { switch typ := arg.Type().(type) { case *ResourceType: a := arg.(*ResultArg) - if resources && typ.Dir() != DirIn { + if resources && a.Dir() != DirIn { s.resources[typ.Desc.Name] = append(s.resources[typ.Desc.Name], a) // TODO: negative PIDs and add them as well (that's process groups). } case *BufferType: a := arg.(*DataArg) - if typ.Dir() != DirOut && len(a.Data()) != 0 { + if a.Dir() != DirOut && len(a.Data()) != 0 { val := string(a.Data()) // Remove trailing zero padding. for len(val) >= 2 && val[len(val)-1] == 0 && val[len(val)-2] == 0 { diff --git a/prog/any.go b/prog/any.go index 15ce6ec53..d1433b18d 100644 --- a/prog/any.go +++ b/prog/any.go @@ -54,7 +54,8 @@ func initAnyTypes(target *Target) { TypeSize: target.PtrSize, IsOptional: true, }, - Type: target.any.array, + Type: target.any.array, + ElemDir: DirIn, } target.any.ptr64 = &PtrType{ TypeCommon: TypeCommon{ @@ -63,7 +64,8 @@ func initAnyTypes(target *Target) { TypeSize: 8, IsOptional: true, }, - Type: target.any.array, + Type: target.any.array, + ElemDir: DirIn, } target.any.blob = &BufferType{ TypeCommon: TypeCommon{ @@ -77,7 +79,6 @@ func initAnyTypes(target *Target) { TypeCommon: TypeCommon{ TypeName: name, FldName: name, - ArgDir: DirIn, TypeSize: size, IsOptional: true, }, @@ -100,7 +101,6 @@ func initAnyTypes(target *Target) { TypeName: "ANYUNION", FldName: "ANYUNION", IsVarlen: true, - ArgDir: DirIn, }, Fields: []Type{ target.any.blob, @@ -150,7 +150,7 @@ func (p *Prog) complexPtrs() (res []*PointerArg) { } func (target *Target) isComplexPtr(arg *PointerArg) bool { - if arg.Res == nil || arg.Type().Dir() != DirIn { + if arg.Res == nil || arg.Dir() != DirIn { return false } if target.isAnyPtr(arg.Type()) { @@ -175,6 +175,15 @@ func (target *Target) isComplexPtr(arg *PointerArg) bool { return complex && !hasPtr } +func (target *Target) isAnyRes(name string) bool { + return name == target.any.res16.TypeName || + name == target.any.res32.TypeName || + name == target.any.res64.TypeName || + name == target.any.resdec.TypeName || + name == target.any.reshex.TypeName || + name == target.any.resoct.TypeName +} + func (target *Target) CallContainsAny(c *Call) (res bool) { ForeachArg(c, func(arg Arg, ctx *ArgCtx) { if target.isAnyPtr(arg.Type()) { @@ -208,7 +217,7 @@ func (target *Target) squashPtr(arg *PointerArg, preserveField bool) { field = arg.Type().FieldName() } arg.typ = target.makeAnyPtrType(arg.Type().Size(), field) - arg.Res = MakeGroupArg(arg.typ.(*PtrType).Type, elems) + arg.Res = MakeGroupArg(arg.typ.(*PtrType).Type, DirIn, elems) if size := arg.Res.Size(); size != size0 { panic(fmt.Sprintf("squash changed size %v->%v for %v", size0, size, res0.Type())) } @@ -230,7 +239,7 @@ func (target *Target) squashPtrImpl(a Arg, elems *[]Arg) { } target.squashPtrImpl(arg.Option, elems) case *DataArg: - if arg.Type().Dir() == DirOut { + if arg.Dir() == DirOut { pad = arg.Size() } else { elem := target.ensureDataElem(elems) @@ -299,7 +308,8 @@ func (target *Target) squashResult(arg *ResultArg, elems *[]Arg) { default: panic("bad") } - *elems = append(*elems, MakeUnionArg(target.any.union, arg)) + arg.dir = DirIn + *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, arg)) } func (target *Target) squashGroup(arg *GroupArg, elems *[]Arg) { @@ -375,14 +385,14 @@ func (target *Target) squashedValue(arg *ConstArg) (uint64, BinaryFormat) { func (target *Target) ensureDataElem(elems *[]Arg) *DataArg { if len(*elems) == 0 { - res := MakeDataArg(target.any.blob, nil) - *elems = append(*elems, MakeUnionArg(target.any.union, res)) + res := MakeDataArg(target.any.blob, DirIn, nil) + *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, res)) return res } res, ok := (*elems)[len(*elems)-1].(*UnionArg).Option.(*DataArg) if !ok { - res = MakeDataArg(target.any.blob, nil) - *elems = append(*elems, MakeUnionArg(target.any.union, res)) + res = MakeDataArg(target.any.blob, DirIn, nil) + *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, res)) } return res } diff --git a/prog/encoding.go b/prog/encoding.go index f99ff9d84..c7f8ba56a 100644 --- a/prog/encoding.go +++ b/prog/encoding.go @@ -113,7 +113,7 @@ func (a *PointerArg) serialize(ctx *serializer) { func (a *DataArg) serialize(ctx *serializer) { typ := a.Type().(*BufferType) - if typ.Dir() == DirOut { + if a.Dir() == DirOut { ctx.printf("\"\"/%v", a.Size()) return } @@ -280,7 +280,7 @@ func (p *parser) parseProg() (*Prog, error) { if IsPad(typ) { return nil, fmt.Errorf("padding in syscall %v arguments", name) } - arg, err := p.parseArg(typ) + arg, err := p.parseArg(typ, DirIn) if err != nil { return nil, err } @@ -302,7 +302,7 @@ func (p *parser) parseProg() (*Prog, error) { } for i := len(c.Args); i < len(meta.Args); i++ { p.strictFailf("missing syscall args") - c.Args = append(c.Args, meta.Args[i].DefaultArg()) + c.Args = append(c.Args, meta.Args[i].DefaultArg(DirIn)) } if len(c.Args) != len(meta.Args) { return nil, fmt.Errorf("wrong call arg count: %v, want %v", len(c.Args), len(meta.Args)) @@ -318,7 +318,7 @@ func (p *parser) parseProg() (*Prog, error) { return prog, nil } -func (p *parser) parseArg(typ Type) (Arg, error) { +func (p *parser) parseArg(typ Type, dir Dir) (Arg, error) { r := "" if p.Char() == '<' { p.Parse('<') @@ -326,13 +326,13 @@ func (p *parser) parseArg(typ Type) (Arg, error) { p.Parse('=') p.Parse('>') } - arg, err := p.parseArgImpl(typ) + arg, err := p.parseArgImpl(typ, dir) if err != nil { return nil, err } if arg == nil { if typ != nil { - arg = typ.DefaultArg() + arg = typ.DefaultArg(dir) } else if r != "" { return nil, fmt.Errorf("named nil argument") } @@ -345,26 +345,26 @@ func (p *parser) parseArg(typ Type) (Arg, error) { return arg, nil } -func (p *parser) parseArgImpl(typ Type) (Arg, error) { +func (p *parser) parseArgImpl(typ Type, dir Dir) (Arg, error) { if typ == nil && p.Char() != 'n' { p.eatExcessive(true, "non-nil argument for nil type") return nil, nil } switch p.Char() { case '0': - return p.parseArgInt(typ) + return p.parseArgInt(typ, dir) case 'r': - return p.parseArgRes(typ) + return p.parseArgRes(typ, dir) case '&': - return p.parseArgAddr(typ) + return p.parseArgAddr(typ, dir) case '"', '\'': - return p.parseArgString(typ) + return p.parseArgString(typ, dir) case '{': - return p.parseArgStruct(typ) + return p.parseArgStruct(typ, dir) case '[': - return p.parseArgArray(typ) + return p.parseArgArray(typ, dir) case '@': - return p.parseArgUnion(typ) + return p.parseArgUnion(typ, dir) case 'n': p.Parse('n') p.Parse('i') @@ -375,14 +375,14 @@ func (p *parser) parseArgImpl(typ Type) (Arg, error) { p.Parse('U') p.Parse('T') p.Parse('O') - return p.parseAuto(typ) + return p.parseAuto(typ, dir) default: return nil, fmt.Errorf("failed to parse argument at '%c' (line #%v/%v: %v)", p.Char(), p.l, p.i, p.s) } } -func (p *parser) parseArgInt(typ Type) (Arg, error) { +func (p *parser) parseArgInt(typ Type, dir Dir) (Arg, error) { val := p.Ident() v, err := strconv.ParseUint(val, 0, 64) if err != nil { @@ -390,35 +390,35 @@ func (p *parser) parseArgInt(typ Type) (Arg, error) { } switch typ.(type) { case *ConstType, *IntType, *FlagsType, *ProcType, *CsumType: - arg := Arg(MakeConstArg(typ, v)) - if typ.Dir() == DirOut && !typ.isDefaultArg(arg) { + arg := Arg(MakeConstArg(typ, dir, v)) + if dir == DirOut && !typ.isDefaultArg(arg) { p.strictFailf("out arg %v has non-default value: %v", typ, v) - arg = typ.DefaultArg() + arg = typ.DefaultArg(dir) } return arg, nil case *LenType: - return MakeConstArg(typ, v), nil + return MakeConstArg(typ, dir, v), nil case *ResourceType: - return MakeResultArg(typ, nil, v), nil + return MakeResultArg(typ, dir, nil, v), nil case *PtrType, *VmaType: index := -v % uint64(len(p.target.SpecialPointers)) - return MakeSpecialPointerArg(typ, index), nil + return MakeSpecialPointerArg(typ, dir, index), nil default: p.eatExcessive(true, "wrong int arg %T", typ) - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } } -func (p *parser) parseAuto(typ Type) (Arg, error) { +func (p *parser) parseAuto(typ Type, dir Dir) (Arg, error) { switch typ.(type) { case *ConstType, *LenType, *CsumType: - return p.auto(MakeConstArg(typ, 0)), nil + return p.auto(MakeConstArg(typ, dir, 0)), nil default: return nil, fmt.Errorf("wrong type %T for AUTO", typ) } } -func (p *parser) parseArgRes(typ Type) (Arg, error) { +func (p *parser) parseArgRes(typ Type, dir Dir) (Arg, error) { id := p.Ident() var div, add uint64 if p.Char() == '/' { @@ -442,23 +442,25 @@ func (p *parser) parseArgRes(typ Type) (Arg, error) { v := p.vars[id] if v == nil { p.strictFailf("undeclared variable %v", id) - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } - arg := MakeResultArg(typ, v, 0) + arg := MakeResultArg(typ, dir, v, 0) arg.OpDiv = div arg.OpAdd = add return arg, nil } -func (p *parser) parseArgAddr(typ Type) (Arg, error) { +func (p *parser) parseArgAddr(typ Type, dir Dir) (Arg, error) { var typ1 Type + elemDir := DirInOut switch t1 := typ.(type) { case *PtrType: typ1 = t1.Type + elemDir = t1.ElemDir case *VmaType: default: p.eatExcessive(true, "wrong addr arg") - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } p.Parse('&') auto := false @@ -487,11 +489,11 @@ func (p *parser) parseArgAddr(typ Type) (Arg, error) { p.Parse('N') p.Parse('Y') p.Parse('=') - typ = p.target.makeAnyPtrType(typ.Size(), typ.FieldName()) - typ1 = p.target.any.array + anyPtr := p.target.makeAnyPtrType(typ.Size(), typ.FieldName()) + typ, typ1, elemDir = anyPtr, anyPtr.Type, anyPtr.ElemDir } var err error - inner, err = p.parseArg(typ1) + inner, err = p.parseArg(typ1, elemDir) if err != nil { return nil, err } @@ -501,23 +503,23 @@ func (p *parser) parseArgAddr(typ Type) (Arg, error) { p.strictFailf("unaligned vma address 0x%x", addr) addr &= ^(p.target.PageSize - 1) } - return MakeVmaPointerArg(typ, addr, vmaSize), nil + return MakeVmaPointerArg(typ, dir, addr, vmaSize), nil } if inner == nil { - inner = typ1.DefaultArg() + inner = typ1.DefaultArg(elemDir) } - arg := MakePointerArg(typ, addr, inner) + arg := MakePointerArg(typ, dir, addr, inner) if auto { p.auto(arg) } return arg, nil } -func (p *parser) parseArgString(t Type) (Arg, error) { +func (p *parser) parseArgString(t Type, dir Dir) (Arg, error) { typ, ok := t.(*BufferType) if !ok { p.eatExcessive(true, "wrong string arg") - return t.DefaultArg(), nil + return t.DefaultArg(dir), nil } data, err := p.deserializeData() if err != nil { @@ -542,8 +544,8 @@ func (p *parser) parseArgString(t Type) (Arg, error) { } else if size == ^uint64(0) { size = uint64(len(data)) } - if typ.Dir() == DirOut { - return MakeOutDataArg(typ, size), nil + if dir == DirOut { + return MakeOutDataArg(typ, dir, size), nil } if diff := int(size) - len(data); diff > 0 { data = append(data, make([]byte, diff)...) @@ -564,16 +566,16 @@ func (p *parser) parseArgString(t Type) (Arg, error) { data = []byte(typ.Values[0]) } } - return MakeDataArg(typ, data), nil + return MakeDataArg(typ, dir, data), nil } -func (p *parser) parseArgStruct(typ Type) (Arg, error) { +func (p *parser) parseArgStruct(typ Type, dir Dir) (Arg, error) { p.Parse('{') t1, ok := typ.(*StructType) if !ok { p.eatExcessive(false, "wrong struct arg") p.Parse('}') - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } var inner []Arg for i := 0; p.Char() != '}'; i++ { @@ -583,9 +585,9 @@ func (p *parser) parseArgStruct(typ Type) (Arg, error) { } fld := t1.Fields[i] if IsPad(fld) { - inner = append(inner, MakeConstArg(fld, 0)) + inner = append(inner, MakeConstArg(fld, dir, 0)) } else { - arg, err := p.parseArg(fld) + arg, err := p.parseArg(fld, dir) if err != nil { return nil, err } @@ -601,22 +603,22 @@ func (p *parser) parseArgStruct(typ Type) (Arg, error) { if !IsPad(fld) { p.strictFailf("missing struct %v fields %v/%v", typ.Name(), len(inner), len(t1.Fields)) } - inner = append(inner, fld.DefaultArg()) + inner = append(inner, fld.DefaultArg(dir)) } - return MakeGroupArg(typ, inner), nil + return MakeGroupArg(typ, dir, inner), nil } -func (p *parser) parseArgArray(typ Type) (Arg, error) { +func (p *parser) parseArgArray(typ Type, dir Dir) (Arg, error) { p.Parse('[') t1, ok := typ.(*ArrayType) if !ok { p.eatExcessive(false, "wrong array arg %T", typ) p.Parse(']') - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } var inner []Arg for i := 0; p.Char() != ']'; i++ { - arg, err := p.parseArg(t1.Type) + arg, err := p.parseArg(t1.Type, dir) if err != nil { return nil, err } @@ -629,18 +631,18 @@ func (p *parser) parseArgArray(typ Type) (Arg, error) { if t1.Kind == ArrayRangeLen && t1.RangeBegin == t1.RangeEnd { for uint64(len(inner)) < t1.RangeBegin { p.strictFailf("missing array elements") - inner = append(inner, t1.Type.DefaultArg()) + inner = append(inner, t1.Type.DefaultArg(dir)) } inner = inner[:t1.RangeBegin] } - return MakeGroupArg(typ, inner), nil + return MakeGroupArg(typ, dir, inner), nil } -func (p *parser) parseArgUnion(typ Type) (Arg, error) { +func (p *parser) parseArgUnion(typ Type, dir Dir) (Arg, error) { t1, ok := typ.(*UnionType) if !ok { p.eatExcessive(true, "wrong union arg") - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } p.Parse('@') name := p.Ident() @@ -653,20 +655,20 @@ func (p *parser) parseArgUnion(typ Type) (Arg, error) { } if optType == nil { p.eatExcessive(true, "wrong union option") - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } var opt Arg if p.Char() == '=' { p.Parse('=') var err error - opt, err = p.parseArg(optType) + opt, err = p.parseArg(optType, dir) if err != nil { return nil, err } } else { - opt = optType.DefaultArg() + opt = optType.DefaultArg(dir) } - return MakeUnionArg(typ, opt), nil + return MakeUnionArg(typ, dir, opt), nil } // Eats excessive call arguments and struct fields to recover after description changes. diff --git a/prog/encodingexec.go b/prog/encodingexec.go index b5c410287..99357dfd2 100644 --- a/prog/encodingexec.go +++ b/prog/encodingexec.go @@ -134,7 +134,7 @@ func (w *execContext) writeCopyin(c *Call) { return } typ := arg.Type() - if typ.Dir() == DirOut || IsPad(typ) || (arg.Size() == 0 && !typ.IsBitfield()) { + if arg.Dir() == DirOut || IsPad(typ) || (arg.Size() == 0 && !typ.IsBitfield()) { return } w.write(execInstrCopyin) diff --git a/prog/hints.go b/prog/hints.go index 9a5675b1b..be2b371f2 100644 --- a/prog/hints.go +++ b/prog/hints.go @@ -82,7 +82,7 @@ func (p *Prog) MutateWithHints(callIndex int, comps CompMap, exec func(p *Prog)) func generateHints(compMap CompMap, arg Arg, exec func()) { typ := arg.Type() - if typ == nil || typ.Dir() == DirOut { + if typ == nil || arg.Dir() == DirOut { return } switch t := typ.(type) { diff --git a/prog/hints_test.go b/prog/hints_test.go index 0fc9afa02..caf84e715 100644 --- a/prog/hints_test.go +++ b/prog/hints_test.go @@ -157,7 +157,7 @@ func TestHintsCheckConstArg(t *testing.T) { typ := &IntType{IntTypeCommon: IntTypeCommon{TypeCommon: TypeCommon{ TypeSize: test.size}, BitfieldLen: test.bitsize}} - constArg := MakeConstArg(typ, test.in) + constArg := MakeConstArg(typ, DirIn, test.in) checkConstArg(constArg, test.comps, func() { res = append(res, constArg.Val) }) @@ -295,8 +295,8 @@ func TestHintsCheckDataArg(t *testing.T) { res := make(map[string]bool) // Whatever type here. It's just needed to pass the // dataArg.Type().Dir() == DirIn check. - typ := &ArrayType{TypeCommon{"", "", 0, DirIn, false, true}, nil, 0, 0, 0} - dataArg := MakeDataArg(typ, []byte(test.in)) + typ := &ArrayType{TypeCommon{"", "", 0, false, true}, nil, 0, 0, 0} + dataArg := MakeDataArg(typ, DirIn, []byte(test.in)) checkDataArg(dataArg, test.comps, func() { res[string(dataArg.Data())] = true }) @@ -499,7 +499,7 @@ func TestHintsRandom(t *testing.T) { func extractValues(c *Call) map[uint64]bool { vals := make(map[uint64]bool) ForeachArg(c, func(arg Arg, _ *ArgCtx) { - if typ := arg.Type(); typ == nil || typ.Dir() == DirOut { + if arg.Dir() == DirOut { return } switch a := arg.(type) { diff --git a/prog/minimization.go b/prog/minimization.go index 93a986556..0b5bb05b2 100644 --- a/prog/minimization.go +++ b/prog/minimization.go @@ -137,7 +137,7 @@ func (typ *PtrType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool { } if !ctx.triedPaths[path+"->"] { removeArg(a.Res) - replaceArg(a, MakeSpecialPointerArg(a.Type(), 0)) + replaceArg(a, MakeSpecialPointerArg(a.Type(), a.Dir(), 0)) ctx.target.assignSizesCall(ctx.call) if ctx.pred(ctx.p, ctx.callIndex0) { *ctx.p0 = ctx.p @@ -201,7 +201,7 @@ func minimizeInt(ctx *minimizeArgsCtx, arg Arg, path string) bool { return false } a := arg.(*ConstArg) - def := arg.Type().DefaultArg().(*ConstArg) + def := arg.Type().DefaultArg(arg.Dir()).(*ConstArg) if a.Val == def.Val { return false } @@ -239,7 +239,7 @@ func (typ *ResourceType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bo func (typ *BufferType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool { // TODO: try to set individual bytes to 0 - if typ.Kind != BufferBlobRand && typ.Kind != BufferBlobRange || typ.Dir() == DirOut { + if typ.Kind != BufferBlobRand && typ.Kind != BufferBlobRange || arg.Dir() == DirOut { return false } a := arg.(*DataArg) diff --git a/prog/mutation.go b/prog/mutation.go index 7203a86ec..7087b4d86 100644 --- a/prog/mutation.go +++ b/prog/mutation.go @@ -99,7 +99,7 @@ func (ctx *mutator) squashAny() bool { var blobs []*DataArg var bases []*PointerArg ForeachSubArg(ptr, func(arg Arg, ctx *ArgCtx) { - if data, ok := arg.(*DataArg); ok && arg.Type().Dir() != DirOut { + if data, ok := arg.(*DataArg); ok && arg.Dir() != DirOut { blobs = append(blobs, data) bases = append(bases, ctx.Base) } @@ -119,7 +119,7 @@ func (ctx *mutator) squashAny() bool { // Update base pointer if size has increased. if baseSize < base.Res.Size() { s := analyze(ctx.ct, ctx.corpus, p, p.Calls[0]) - newArg := r.allocAddr(s, base.Type(), base.Res.Size(), base.Res) + newArg := r.allocAddr(s, base.Type(), base.Dir(), base.Res.Size(), base.Res) *base = *newArg } return true @@ -252,7 +252,7 @@ func (target *Target) mutateArg(r *randGen, s *state, arg Arg, ctx ArgCtx, updat } // Update base pointer if size has increased. if base := ctx.Base; base != nil && baseSize < base.Res.Size() { - newArg := r.allocAddr(s, base.Type(), base.Res.Size(), base.Res) + newArg := r.allocAddr(s, base.Type(), base.Dir(), base.Res.Size(), base.Res) replaceArg(base, newArg) } return calls, true @@ -260,7 +260,7 @@ func (target *Target) mutateArg(r *randGen, s *state, arg Arg, ctx ArgCtx, updat func regenerate(r *randGen, s *state, arg Arg) (calls []*Call, retry, preserve bool) { var newArg Arg - newArg, calls = r.generateArg(s, arg.Type()) + newArg, calls = r.generateArg(s, arg.Type(), arg.Dir()) replaceArg(arg, newArg) return } @@ -346,7 +346,7 @@ func (t *BufferType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls [] minLen, maxLen = t.RangeBegin, t.RangeEnd } a := arg.(*DataArg) - if t.Dir() == DirOut { + if a.Dir() == DirOut { mutateBufferSize(r, a, minLen, maxLen) return } @@ -412,7 +412,7 @@ func (t *ArrayType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []* } if count > uint64(len(a.Inner)) { for count > uint64(len(a.Inner)) { - newArg, newCalls := r.generateArg(s, t.Type) + newArg, newCalls := r.generateArg(s, t.Type, a.Dir()) a.Inner = append(a.Inner, newArg) calls = append(calls, newCalls...) for _, c := range newCalls { @@ -433,11 +433,11 @@ func (t *PtrType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*Ca if r.oneOf(1000) { removeArg(a.Res) index := r.rand(len(r.target.SpecialPointers)) - newArg := MakeSpecialPointerArg(t, index) + newArg := MakeSpecialPointerArg(t, a.Dir(), index) replaceArg(arg, newArg) return } - newArg := r.allocAddr(s, t, a.Res.Size(), a.Res) + newArg := r.allocAddr(s, t, a.Dir(), a.Res.Size(), a.Res) replaceArg(arg, newArg) return } @@ -448,7 +448,7 @@ func (t *StructType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls [] panic("bad arg returned by mutationArgs: StructType") } var newArg Arg - newArg, calls = gen(&Gen{r, s}, t, arg) + newArg, calls = gen(&Gen{r, s}, t, arg.Dir(), arg) a := arg.(*GroupArg) for i, f := range newArg.(*GroupArg).Inner { replaceArg(a.Inner[i], f) @@ -459,7 +459,7 @@ func (t *StructType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls [] func (t *UnionType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*Call, retry, preserve bool) { if gen := r.target.SpecialTypes[t.Name()]; gen != nil { var newArg Arg - newArg, calls = gen(&Gen{r, s}, t, arg) + newArg, calls = gen(&Gen{r, s}, t, arg.Dir(), arg) replaceArg(arg, newArg) return } @@ -481,8 +481,8 @@ func (t *UnionType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []* optType := t.Fields[newIdx] removeArg(a.Option) var newOpt Arg - newOpt, calls = r.generateArg(s, optType) - replaceArg(arg, MakeUnionArg(t, newOpt)) + newOpt, calls = r.generateArg(s, optType, a.Dir()) + replaceArg(arg, MakeUnionArg(t, a.Dir(), newOpt)) return } @@ -522,7 +522,7 @@ func (ma *mutationArgs) collectArg(arg Arg, ctx *ArgCtx) { _, isArrayTyp := typ.(*ArrayType) _, isBufferTyp := typ.(*BufferType) - if !isBufferTyp && !isArrayTyp && typ.Dir() == DirOut || !typ.Varlen() && typ.Size() == 0 { + if !isBufferTyp && !isArrayTyp && arg.Dir() == DirOut || !typ.Varlen() && typ.Size() == 0 { return } @@ -645,7 +645,7 @@ func (t *LenType) getMutationPrio(target *Target, arg Arg, ignoreSpecial bool) ( } func (t *BufferType) getMutationPrio(target *Target, arg Arg, ignoreSpecial bool) (prio float64, stopRecursion bool) { - if t.Dir() == DirOut && !t.Varlen() { + if arg.Dir() == DirOut && !t.Varlen() { return dontMutate, false } if t.Kind == BufferString && len(t.Values) == 1 { diff --git a/prog/prio.go b/prog/prio.go index b67bbaea0..2a9486570 100644 --- a/prog/prio.go +++ b/prog/prio.go @@ -65,11 +65,11 @@ func (target *Target) calcStaticPriorities() [][]float32 { func (target *Target) calcResourceUsage() map[string]map[int]weights { uses := make(map[string]map[int]weights) for _, c := range target.Syscalls { - ForeachType(c, func(t Type) { + foreachType(c, func(t Type, ctx typeCtx) { switch a := t.(type) { case *ResourceType: if target.AuxResources[a.Desc.Name] { - noteUsage(uses, c, 0.1, a.Dir(), "res%v", a.Desc.Name) + noteUsage(uses, c, 0.1, ctx.Dir, "res%v", a.Desc.Name) } else { str := "res" for i, k := range a.Desc.Kind { @@ -78,25 +78,25 @@ func (target *Target) calcResourceUsage() map[string]map[int]weights { if i < len(a.Desc.Kind)-1 { w = 0.2 } - noteUsage(uses, c, float32(w), a.Dir(), str) + noteUsage(uses, c, float32(w), ctx.Dir, str) } } case *PtrType: if _, ok := a.Type.(*StructType); ok { - noteUsage(uses, c, 1.0, a.Dir(), "ptrto-%v", a.Type.Name()) + noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Type.Name()) } if _, ok := a.Type.(*UnionType); ok { - noteUsage(uses, c, 1.0, a.Dir(), "ptrto-%v", a.Type.Name()) + noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Type.Name()) } if arr, ok := a.Type.(*ArrayType); ok { - noteUsage(uses, c, 1.0, a.Dir(), "ptrto-%v", arr.Type.Name()) + noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", arr.Type.Name()) } case *BufferType: switch a.Kind { case BufferBlobRand, BufferBlobRange, BufferText: case BufferString: if a.SubKind != "" { - noteUsage(uses, c, 0.2, a.Dir(), fmt.Sprintf("str-%v", a.SubKind)) + noteUsage(uses, c, 0.2, ctx.Dir, fmt.Sprintf("str-%v", a.SubKind)) } case BufferFilename: noteUsage(uses, c, 1.0, DirIn, "filename") @@ -104,7 +104,7 @@ func (target *Target) calcResourceUsage() map[string]map[int]weights { panic("unknown buffer kind") } case *VmaType: - noteUsage(uses, c, 0.5, a.Dir(), "vma") + noteUsage(uses, c, 0.5, ctx.Dir, "vma") case *IntType: switch a.Kind { case IntPlain, IntRange: diff --git a/prog/prog.go b/prog/prog.go index 1600c0a28..017a0dbbb 100644 --- a/prog/prog.go +++ b/prog/prog.go @@ -22,6 +22,7 @@ type Call struct { type Arg interface { Type() Type + Dir() Dir Size() uint64 validate(ctx *validCtx) error @@ -30,20 +31,25 @@ type Arg interface { type ArgCommon struct { typ Type + dir Dir } func (arg *ArgCommon) Type() Type { return arg.typ } +func (arg *ArgCommon) Dir() Dir { + return arg.dir +} + // Used for ConstType, IntType, FlagsType, LenType, ProcType and CsumType. type ConstArg struct { ArgCommon Val uint64 } -func MakeConstArg(t Type, v uint64) *ConstArg { - return &ConstArg{ArgCommon: ArgCommon{typ: t}, Val: v} +func MakeConstArg(t Type, dir Dir, v uint64) *ConstArg { + return &ConstArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Val: v} } func (arg *ConstArg) Size() uint64 { @@ -84,34 +90,37 @@ type PointerArg struct { Res Arg // pointee (nil for vma) } -func MakePointerArg(t Type, addr uint64, data Arg) *PointerArg { +func MakePointerArg(t Type, dir Dir, addr uint64, data Arg) *PointerArg { if data == nil { panic("nil pointer data arg") } return &PointerArg{ - ArgCommon: ArgCommon{typ: t}, + ArgCommon: ArgCommon{typ: t, dir: DirIn}, // pointers are always in Address: addr, Res: data, } } -func MakeVmaPointerArg(t Type, addr, size uint64) *PointerArg { +func MakeVmaPointerArg(t Type, dir Dir, addr, size uint64) *PointerArg { if addr%1024 != 0 { panic("unaligned vma address") } return &PointerArg{ - ArgCommon: ArgCommon{typ: t}, + ArgCommon: ArgCommon{typ: t, dir: dir}, Address: addr, VmaSize: size, } } -func MakeSpecialPointerArg(t Type, index uint64) *PointerArg { +func MakeSpecialPointerArg(t Type, dir Dir, index uint64) *PointerArg { if index >= maxSpecialPointers { panic("bad special pointer index") } + if _, ok := t.(*PtrType); ok { + dir = DirIn // pointers are always in + } return &PointerArg{ - ArgCommon: ArgCommon{typ: t}, + ArgCommon: ArgCommon{typ: t, dir: dir}, Address: -index, } } @@ -138,18 +147,18 @@ type DataArg struct { size uint64 // for out Args } -func MakeDataArg(t Type, data []byte) *DataArg { - if t.Dir() == DirOut { +func MakeDataArg(t Type, dir Dir, data []byte) *DataArg { + if dir == DirOut { panic("non-empty output data arg") } - return &DataArg{ArgCommon: ArgCommon{typ: t}, data: append([]byte{}, data...)} + return &DataArg{ArgCommon: ArgCommon{typ: t, dir: dir}, data: append([]byte{}, data...)} } -func MakeOutDataArg(t Type, size uint64) *DataArg { - if t.Dir() != DirOut { +func MakeOutDataArg(t Type, dir Dir, size uint64) *DataArg { + if dir != DirOut { panic("empty input data arg") } - return &DataArg{ArgCommon: ArgCommon{typ: t}, size: size} + return &DataArg{ArgCommon: ArgCommon{typ: t, dir: dir}, size: size} } func (arg *DataArg) Size() uint64 { @@ -160,14 +169,14 @@ func (arg *DataArg) Size() uint64 { } func (arg *DataArg) Data() []byte { - if arg.Type().Dir() == DirOut { + if arg.Dir() == DirOut { panic("getting data of output data arg") } return arg.data } func (arg *DataArg) SetData(data []byte) { - if arg.Type().Dir() == DirOut { + if arg.Dir() == DirOut { panic("setting data of output data arg") } arg.data = append([]byte{}, data...) @@ -180,8 +189,8 @@ type GroupArg struct { Inner []Arg } -func MakeGroupArg(t Type, inner []Arg) *GroupArg { - return &GroupArg{ArgCommon: ArgCommon{typ: t}, Inner: inner} +func MakeGroupArg(t Type, dir Dir, inner []Arg) *GroupArg { + return &GroupArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Inner: inner} } func (arg *GroupArg) Size() uint64 { @@ -227,8 +236,8 @@ type UnionArg struct { Option Arg } -func MakeUnionArg(t Type, opt Arg) *UnionArg { - return &UnionArg{ArgCommon: ArgCommon{typ: t}, Option: opt} +func MakeUnionArg(t Type, dir Dir, opt Arg) *UnionArg { + return &UnionArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Option: opt} } func (arg *UnionArg) Size() uint64 { @@ -250,8 +259,8 @@ type ResultArg struct { uses map[*ResultArg]bool // ArgResult args that use this arg } -func MakeResultArg(t Type, r *ResultArg, v uint64) *ResultArg { - arg := &ResultArg{ArgCommon: ArgCommon{typ: t}, Res: r, Val: v} +func MakeResultArg(t Type, dir Dir, r *ResultArg, v uint64) *ResultArg { + arg := &ResultArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Res: r, Val: v} if r == nil { return arg } @@ -266,10 +275,7 @@ func MakeReturnArg(t Type) *ResultArg { if t == nil { return nil } - if t.Dir() != DirOut { - panic("return arg is not out") - } - return &ResultArg{ArgCommon: ArgCommon{typ: t}} + return &ResultArg{ArgCommon: ArgCommon{typ: t, dir: DirOut}} } func (arg *ResultArg) Size() uint64 { @@ -369,7 +375,7 @@ func removeArg(arg0 Arg) { delete(uses, a) } for arg1 := range a.uses { - arg2 := arg1.Type().DefaultArg().(*ResultArg) + arg2 := arg1.Type().DefaultArg(arg1.Dir()).(*ResultArg) replaceResultArg(arg1, arg2) } }) diff --git a/prog/prog_test.go b/prog/prog_test.go index a42e4437f..9b8f442e3 100644 --- a/prog/prog_test.go +++ b/prog/prog_test.go @@ -21,8 +21,8 @@ func TestGeneration(t *testing.T) { func TestDefault(t *testing.T) { target, _, _ := initTest(t) for _, meta := range target.Syscalls { - ForeachType(meta, func(typ Type) { - arg := typ.DefaultArg() + foreachType(meta, func(typ Type, ctx typeCtx) { + arg := typ.DefaultArg(ctx.Dir) if !isDefault(arg) { t.Errorf("default arg is not default: %s\ntype: %#v\narg: %#v", typ, typ, arg) @@ -203,8 +203,8 @@ func TestSpecialStructs(t *testing.T) { t.Run(special, func(t *testing.T) { var typ Type for i := 0; i < len(target.Syscalls) && typ == nil; i++ { - ForeachType(target.Syscalls[i], func(t Type) { - if t.Dir() == DirOut { + foreachType(target.Syscalls[i], func(t Type, ctx typeCtx) { + if ctx.Dir == DirOut { return } if s, ok := t.(*StructType); ok && s.Name() == special { @@ -220,8 +220,13 @@ func TestSpecialStructs(t *testing.T) { } g := &Gen{newRand(target, rs), newState(target, nil, nil)} for i := 0; i < iters/len(target.SpecialTypes); i++ { - arg, _ := gen(g, typ, nil) - gen(g, typ, arg) + var arg Arg + for i := 0; i < 2; i++ { + arg, _ = gen(g, typ, DirInOut, arg) + if arg.Dir() != DirInOut { + t.Fatalf("got wrong arg dir %v", arg.Dir()) + } + } } }) } diff --git a/prog/rand.go b/prog/rand.go index b350e31c0..5277a9814 100644 --- a/prog/rand.go +++ b/prog/rand.go @@ -341,16 +341,16 @@ func (r *randGen) randString(s *state, t *BufferType) []byte { return buf.Bytes() } -func (r *randGen) allocAddr(s *state, typ Type, size uint64, data Arg) *PointerArg { - return MakePointerArg(typ, s.ma.alloc(r, size), data) +func (r *randGen) allocAddr(s *state, typ Type, dir Dir, size uint64, data Arg) *PointerArg { + return MakePointerArg(typ, dir, s.ma.alloc(r, size), data) } -func (r *randGen) allocVMA(s *state, typ Type, numPages uint64) *PointerArg { +func (r *randGen) allocVMA(s *state, typ Type, dir Dir, numPages uint64) *PointerArg { page := s.va.alloc(r, numPages) - return MakeVmaPointerArg(typ, page*r.target.PageSize, numPages*r.target.PageSize) + return MakeVmaPointerArg(typ, dir, page*r.target.PageSize, numPages*r.target.PageSize) } -func (r *randGen) createResource(s *state, res *ResourceType) (arg Arg, calls []*Call) { +func (r *randGen) createResource(s *state, res *ResourceType, dir Dir) (arg Arg, calls []*Call) { if r.inCreateResource { return nil, nil } @@ -385,7 +385,7 @@ func (r *randGen) createResource(s *state, res *ResourceType) (arg Arg, calls [] metas = append(metas, meta) } if len(metas) == 0 { - return res.DefaultArg(), nil + return res.DefaultArg(dir), nil } // Now we have a set of candidate calls that can create the necessary resource. @@ -404,7 +404,7 @@ func (r *randGen) createResource(s *state, res *ResourceType) (arg Arg, calls [] } if len(allres) != 0 { // Bingo! - arg := MakeResultArg(res, allres[r.Intn(len(allres))], 0) + arg := MakeResultArg(res, dir, allres[r.Intn(len(allres))], 0) return arg, calls } // Discard unsuccessful calls. @@ -562,7 +562,7 @@ func (r *randGen) generateParticularCall(s *state, meta *Syscall) (calls []*Call Meta: meta, Ret: MakeReturnArg(meta.Ret), } - c.Args, calls = r.generateArgs(s, meta.Args) + c.Args, calls = r.generateArgs(s, meta.Args, DirIn) r.target.assignSizesCall(c) return append(calls, c) } @@ -601,13 +601,13 @@ func (target *Target) DataMmapProg() *Prog { } } -func (r *randGen) generateArgs(s *state, types []Type) ([]Arg, []*Call) { +func (r *randGen) generateArgs(s *state, types []Type, dir Dir) ([]Arg, []*Call) { var calls []*Call args := make([]Arg, len(types)) // Generate all args. Size args have the default value 0 for now. for i, typ := range types { - arg, calls1 := r.generateArg(s, typ) + arg, calls1 := r.generateArg(s, typ, dir) if arg == nil { panic(fmt.Sprintf("generated arg is nil for type '%v', types: %+v", typ.Name(), types)) } @@ -618,29 +618,28 @@ func (r *randGen) generateArgs(s *state, types []Type) ([]Arg, []*Call) { return args, calls } -func (r *randGen) generateArg(s *state, typ Type) (arg Arg, calls []*Call) { - return r.generateArgImpl(s, typ, false) +func (r *randGen) generateArg(s *state, typ Type, dir Dir) (arg Arg, calls []*Call) { + return r.generateArgImpl(s, typ, dir, false) } -func (r *randGen) generateArgImpl(s *state, typ Type, ignoreSpecial bool) (arg Arg, calls []*Call) { - if typ.Dir() == DirOut { +func (r *randGen) generateArgImpl(s *state, typ Type, dir Dir, ignoreSpecial bool) (arg Arg, calls []*Call) { + if dir == DirOut { // No need to generate something interesting for output scalar arguments. // But we still need to generate the argument itself so that it can be referenced // in subsequent calls. For the same reason we do generate pointer/array/struct // output arguments (their elements can be referenced in subsequent calls). switch typ.(type) { - case *IntType, *FlagsType, *ConstType, *ProcType, - *VmaType, *ResourceType: - return typ.DefaultArg(), nil + case *IntType, *FlagsType, *ConstType, *ProcType, *VmaType, *ResourceType: + return typ.DefaultArg(dir), nil } } if typ.Optional() && r.oneOf(5) { if res, ok := typ.(*ResourceType); ok { v := res.Desc.Values[r.Intn(len(res.Desc.Values))] - return MakeResultArg(typ, nil, v), nil + return MakeResultArg(typ, dir, nil, v), nil } - return typ.DefaultArg(), nil + return typ.DefaultArg(dir), nil } // Allow infinite recursion for optional pointers. @@ -656,70 +655,70 @@ func (r *randGen) generateArgImpl(s *state, typ Type, ignoreSpecial bool) (arg A } }() if r.recDepth[name] >= 3 { - return MakeSpecialPointerArg(typ, 0), nil + return MakeSpecialPointerArg(typ, dir, 0), nil } } } - if !ignoreSpecial && typ.Dir() != DirOut { + if !ignoreSpecial && dir != DirOut { switch typ.(type) { case *StructType, *UnionType: if gen := r.target.SpecialTypes[typ.Name()]; gen != nil { - return gen(&Gen{r, s}, typ, nil) + return gen(&Gen{r, s}, typ, dir, nil) } } } - return typ.generate(r, s) + return typ.generate(r, s, dir) } -func (a *ResourceType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *ResourceType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { if r.oneOf(3) { - arg = r.existingResource(s, a) + arg = r.existingResource(s, a, dir) if arg != nil { return } } if r.nOutOf(2, 3) { - arg, calls = r.resourceCentric(s, a) + arg, calls = r.resourceCentric(s, a, dir) if arg != nil { return } } if r.nOutOf(4, 5) { - arg, calls = r.createResource(s, a) + arg, calls = r.createResource(s, a, dir) if arg != nil { return } } special := a.SpecialValues() - arg = MakeResultArg(a, nil, special[r.Intn(len(special))]) + arg = MakeResultArg(a, dir, nil, special[r.Intn(len(special))]) return } -func (a *BufferType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *BufferType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { switch a.Kind { case BufferBlobRand, BufferBlobRange: sz := r.randBufLen() if a.Kind == BufferBlobRange { sz = r.randRange(a.RangeBegin, a.RangeEnd) } - if a.Dir() == DirOut { - return MakeOutDataArg(a, sz), nil + if dir == DirOut { + return MakeOutDataArg(a, dir, sz), nil } data := make([]byte, sz) for i := range data { data[i] = byte(r.Intn(256)) } - return MakeDataArg(a, data), nil + return MakeDataArg(a, dir, data), nil case BufferString: data := r.randString(s, a) - if a.Dir() == DirOut { - return MakeOutDataArg(a, uint64(len(data))), nil + if dir == DirOut { + return MakeOutDataArg(a, dir, uint64(len(data))), nil } - return MakeDataArg(a, data), nil + return MakeDataArg(a, dir, data), nil case BufferFilename: - if a.Dir() == DirOut { + if dir == DirOut { var sz uint64 switch { case !a.Varlen(): @@ -731,50 +730,50 @@ func (a *BufferType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { default: sz = 4096 // PATH_MAX } - return MakeOutDataArg(a, sz), nil + return MakeOutDataArg(a, dir, sz), nil } - return MakeDataArg(a, []byte(r.filename(s, a))), nil + return MakeDataArg(a, dir, []byte(r.filename(s, a))), nil case BufferText: - if a.Dir() == DirOut { - return MakeOutDataArg(a, uint64(r.Intn(100))), nil + if dir == DirOut { + return MakeOutDataArg(a, dir, uint64(r.Intn(100))), nil } - return MakeDataArg(a, r.generateText(a.Text)), nil + return MakeDataArg(a, dir, r.generateText(a.Text)), nil default: panic("unknown buffer kind") } } -func (a *VmaType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *VmaType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { npages := r.randPageCount() if a.RangeBegin != 0 || a.RangeEnd != 0 { npages = a.RangeBegin + uint64(r.Intn(int(a.RangeEnd-a.RangeBegin+1))) } - return r.allocVMA(s, a, npages), nil + return r.allocVMA(s, a, dir, npages), nil } -func (a *FlagsType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { - return MakeConstArg(a, r.flags(a.Vals, a.BitMask, 0)), nil +func (a *FlagsType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { + return MakeConstArg(a, dir, r.flags(a.Vals, a.BitMask, 0)), nil } -func (a *ConstType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { - return MakeConstArg(a, a.Val), nil +func (a *ConstType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { + return MakeConstArg(a, dir, a.Val), nil } -func (a *IntType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *IntType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { bits := a.TypeBitSize() v := r.randInt(bits) switch a.Kind { case IntRange: v = r.randRangeInt(a.RangeBegin, a.RangeEnd, bits, a.Align) } - return MakeConstArg(a, v), nil + return MakeConstArg(a, dir, v), nil } -func (a *ProcType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { - return MakeConstArg(a, r.rand(int(a.ValuesPerProc))), nil +func (a *ProcType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { + return MakeConstArg(a, dir, r.rand(int(a.ValuesPerProc))), nil } -func (a *ArrayType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *ArrayType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { var count uint64 switch a.Kind { case ArrayRandLen: @@ -784,46 +783,46 @@ func (a *ArrayType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { } var inner []Arg for i := uint64(0); i < count; i++ { - arg1, calls1 := r.generateArg(s, a.Type) + arg1, calls1 := r.generateArg(s, a.Type, dir) inner = append(inner, arg1) calls = append(calls, calls1...) } - return MakeGroupArg(a, inner), calls + return MakeGroupArg(a, dir, inner), calls } -func (a *StructType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { - args, calls := r.generateArgs(s, a.Fields) - group := MakeGroupArg(a, args) +func (a *StructType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { + args, calls := r.generateArgs(s, a.Fields, dir) + group := MakeGroupArg(a, dir, args) return group, calls } -func (a *UnionType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *UnionType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { optType := a.Fields[r.Intn(len(a.Fields))] - opt, calls := r.generateArg(s, optType) - return MakeUnionArg(a, opt), calls + opt, calls := r.generateArg(s, optType, dir) + return MakeUnionArg(a, dir, opt), calls } -func (a *PtrType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *PtrType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { if r.oneOf(1000) { index := r.rand(len(r.target.SpecialPointers)) - return MakeSpecialPointerArg(a, index), nil + return MakeSpecialPointerArg(a, dir, index), nil } - inner, calls := r.generateArg(s, a.Type) - arg = r.allocAddr(s, a, inner.Size(), inner) + inner, calls := r.generateArg(s, a.Type, a.ElemDir) + arg = r.allocAddr(s, a, dir, inner.Size(), inner) return arg, calls } -func (a *LenType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *LenType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { // Updated later in assignSizesCall. - return MakeConstArg(a, 0), nil + return MakeConstArg(a, dir, 0), nil } -func (a *CsumType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { +func (a *CsumType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { // Filled at runtime by executor. - return MakeConstArg(a, 0), nil + return MakeConstArg(a, dir, 0), nil } -func (r *randGen) existingResource(s *state, res *ResourceType) Arg { +func (r *randGen) existingResource(s *state, res *ResourceType, dir Dir) Arg { alltypes := make([][]*ResultArg, 0, len(s.resources)) for _, res1 := range s.resources { alltypes = append(alltypes, res1) @@ -842,11 +841,11 @@ func (r *randGen) existingResource(s *state, res *ResourceType) Arg { if len(allres) == 0 { return nil } - return MakeResultArg(res, allres[r.Intn(len(allres))], 0) + return MakeResultArg(res, dir, allres[r.Intn(len(allres))], 0) } // Finds a compatible resource with the type `t` and the calls that initialize that resource. -func (r *randGen) resourceCentric(s *state, t *ResourceType) (arg Arg, calls []*Call) { +func (r *randGen) resourceCentric(s *state, t *ResourceType, dir Dir) (arg Arg, calls []*Call) { var p *Prog var resource *ResultArg for idx := range r.Perm(len(s.corpus)) { @@ -898,7 +897,7 @@ func (r *randGen) resourceCentric(s *state, t *ResourceType) (arg Arg, calls []* p.removeCall(i) } - return MakeResultArg(t, resource, 0), p.Calls + return MakeResultArg(t, dir, resource, 0), p.Calls } func getCompatibleResources(p *Prog, resourceType string, r *randGen) (resources []*ResultArg) { @@ -906,7 +905,7 @@ func getCompatibleResources(p *Prog, resourceType string, r *randGen) (resources ForeachArg(c, func(arg Arg, _ *ArgCtx) { // Collect only initialized resources (the ones that are already used in other calls). a, ok := arg.(*ResultArg) - if !ok || len(a.uses) == 0 || a.typ.Dir() != DirOut { + if !ok || len(a.uses) == 0 || a.Dir() != DirOut { return } if !r.target.isCompatibleResource(resourceType, a.typ.Name()) { diff --git a/prog/rand_test.go b/prog/rand_test.go index 09f3b359c..6f251cb7c 100644 --- a/prog/rand_test.go +++ b/prog/rand_test.go @@ -102,14 +102,14 @@ func TestSizeGenerateConstArg(t *testing.T) { target, rs, iters := initRandomTargetTest(t, "test", "64") r := newRand(target, rs) for _, c := range target.Syscalls { - ForeachType(c, func(typ Type) { + foreachType(c, func(typ Type, ctx typeCtx) { if _, ok := typ.(*IntType); !ok { return } bits := typ.TypeBitSize() limit := uint64(1<<bits - 1) for i := 0; i < iters; i++ { - newArg, _ := typ.generate(r, nil) + newArg, _ := typ.generate(r, nil, ctx.Dir) newVal := newArg.(*ConstArg).Val if newVal > limit { t.Fatalf("invalid generated value: %d. (arg bitsize: %d; max value: %d)", newVal, bits, limit) diff --git a/prog/resources.go b/prog/resources.go index fa03e6cd6..b7bcecf95 100644 --- a/prog/resources.go +++ b/prog/resources.go @@ -46,10 +46,10 @@ func (target *Target) populateResourceCtors() { // Find resources that are created by each call. callsResources := make([][]*ResourceDesc, len(target.Syscalls)) for call, meta := range target.Syscalls { - ForeachType(meta, func(typ Type) { + foreachType(meta, func(typ Type, ctx typeCtx) { switch typ1 := typ.(type) { case *ResourceType: - if typ1.Dir() != DirIn { + if ctx.Dir != DirIn { callsResources[call] = append(callsResources[call], typ1.Desc) } } @@ -84,21 +84,19 @@ func (target *Target) populateResourceCtors() { // isCompatibleResource returns true if resource of kind src can be passed as an argument of kind dst. func (target *Target) isCompatibleResource(dst, src string) bool { - if dst == target.any.res16.TypeName || - dst == target.any.res32.TypeName || - dst == target.any.res64.TypeName || - dst == target.any.resdec.TypeName || - dst == target.any.reshex.TypeName || - dst == target.any.resoct.TypeName { + if target.isAnyRes(dst) { return true } + if target.isAnyRes(src) { + return false + } dstRes := target.resourceMap[dst] if dstRes == nil { - panic(fmt.Sprintf("unknown resource '%v'", dst)) + panic(fmt.Sprintf("unknown resource %q", dst)) } srcRes := target.resourceMap[src] if srcRes == nil { - panic(fmt.Sprintf("unknown resource '%v'", src)) + panic(fmt.Sprintf("unknown resource %q", src)) } return isCompatibleResourceImpl(dstRes.Kind, srcRes.Kind, false) } @@ -128,8 +126,8 @@ func isCompatibleResourceImpl(dst, src []string, precise bool) bool { func (target *Target) getInputResources(c *Syscall) []*ResourceDesc { var resources []*ResourceDesc - ForeachType(c, func(typ Type) { - if typ.Dir() == DirOut { + foreachType(c, func(typ Type, ctx typeCtx) { + if ctx.Dir == DirOut { return } switch typ1 := typ.(type) { @@ -148,10 +146,10 @@ func (target *Target) getInputResources(c *Syscall) []*ResourceDesc { func (target *Target) getOutputResources(c *Syscall) []*ResourceDesc { var resources []*ResourceDesc - ForeachType(c, func(typ Type) { + foreachType(c, func(typ Type, ctx typeCtx) { switch typ1 := typ.(type) { case *ResourceType: - if typ1.Dir() != DirIn { + if ctx.Dir != DirIn { resources = append(resources, typ1.Desc) } } diff --git a/prog/rotation.go b/prog/rotation.go index f95ffa03d..47ee2ca81 100644 --- a/prog/rotation.go +++ b/prog/rotation.go @@ -50,7 +50,7 @@ func MakeRotator(target *Target, calls map[*Syscall]bool, rnd *rand.Rand) *Rotat } // VMAs and filenames are effectively resources for our purposes // (but they don't have ctors). - ForeachType(call, func(t Type) { + foreachType(call, func(t Type, _ typeCtx) { switch a := t.(type) { case *BufferType: switch a.Kind { diff --git a/prog/target.go b/prog/target.go index f9e10b6f4..692d0b877 100644 --- a/prog/target.go +++ b/prog/target.go @@ -45,7 +45,7 @@ type Target struct { // allocate memory, etc. typ is the struct/union type. old is the old value of the struct/union // for mutation, or nil for generation. The function returns a new value of the struct/union, // and optionally any calls that need to be inserted before the arg reference. - SpecialTypes map[string]func(g *Gen, typ Type, old Arg) (Arg, []*Call) + SpecialTypes map[string]func(g *Gen, typ Type, dir Dir, old Arg) (Arg, []*Call) // Special strings that can matter for the target. // Used as fallback when string type does not have own dictionary. @@ -197,7 +197,7 @@ func restoreLinks(syscalls []*Syscall, resources []*ResourceDesc, structs []*Key if c.Ret != nil { unref(&c.Ret, types) } - ForeachType(c, func(t0 Type) { + foreachType(c, func(t0 Type, _ typeCtx) { switch t := t0.(type) { case *PtrType: unref(&t.Type, types) @@ -247,20 +247,20 @@ func (g *Gen) NOutOf(n, outOf int) bool { return g.r.nOutOf(n, outOf) } -func (g *Gen) Alloc(ptrType Type, data Arg) (Arg, []*Call) { - return g.r.allocAddr(g.s, ptrType, data.Size(), data), nil +func (g *Gen) Alloc(ptrType Type, dir Dir, data Arg) (Arg, []*Call) { + return g.r.allocAddr(g.s, ptrType, dir, data.Size(), data), nil } -func (g *Gen) GenerateArg(typ Type, pcalls *[]*Call) Arg { - return g.generateArg(typ, pcalls, false) +func (g *Gen) GenerateArg(typ Type, dir Dir, pcalls *[]*Call) Arg { + return g.generateArg(typ, dir, pcalls, false) } -func (g *Gen) GenerateSpecialArg(typ Type, pcalls *[]*Call) Arg { - return g.generateArg(typ, pcalls, true) +func (g *Gen) GenerateSpecialArg(typ Type, dir Dir, pcalls *[]*Call) Arg { + return g.generateArg(typ, dir, pcalls, true) } -func (g *Gen) generateArg(typ Type, pcalls *[]*Call, ignoreSpecial bool) Arg { - arg, calls := g.r.generateArgImpl(g.s, typ, ignoreSpecial) +func (g *Gen) generateArg(typ Type, dir Dir, pcalls *[]*Call, ignoreSpecial bool) Arg { + arg, calls := g.r.generateArgImpl(g.s, typ, dir, ignoreSpecial) *pcalls = append(*pcalls, calls...) g.r.target.assignSizesArray([]Arg{arg}, nil) return arg diff --git a/prog/types.go b/prog/types.go index ac272c3c1..a3b3c9709 100644 --- a/prog/types.go +++ b/prog/types.go @@ -44,7 +44,7 @@ type SyscallAttrs struct { // Executor also knows about this value. const MaxArgs = 9 -type Dir int +type Dir uint8 const ( DirIn Dir = iota @@ -80,7 +80,6 @@ type Type interface { Name() string FieldName() string TemplateName() string // for template structs name without arguments - Dir() Dir Optional() bool Varlen() bool Size() uint64 @@ -95,9 +94,9 @@ type Type interface { UnitSize() uint64 UnitOffset() uint64 - DefaultArg() Arg + DefaultArg(dir Dir) Arg isDefaultArg(arg Arg) bool - generate(r *randGen, s *state) (arg Arg, calls []*Call) + generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*Call, retry, preserve bool) getMutationPrio(target *Target, arg Arg, ignoreSpecial bool) (prio float64, stopRecursion bool) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool @@ -105,25 +104,25 @@ type Type interface { type Ref uint32 -func (ti Ref) String() string { panic("prog.Ref method called") } -func (ti Ref) Name() string { panic("prog.Ref method called") } -func (ti Ref) FieldName() string { panic("prog.Ref method called") } -func (ti Ref) TemplateName() string { panic("prog.Ref method called") } -func (ti Ref) Dir() Dir { panic("prog.Ref method called") } -func (ti Ref) Optional() bool { panic("prog.Ref method called") } -func (ti Ref) Varlen() bool { panic("prog.Ref method called") } -func (ti Ref) Size() uint64 { panic("prog.Ref method called") } -func (ti Ref) TypeBitSize() uint64 { panic("prog.Ref method called") } -func (ti Ref) Format() BinaryFormat { panic("prog.Ref method called") } -func (ti Ref) BitfieldOffset() uint64 { panic("prog.Ref method called") } -func (ti Ref) BitfieldLength() uint64 { panic("prog.Ref method called") } -func (ti Ref) IsBitfield() bool { panic("prog.Ref method called") } -func (ti Ref) UnitSize() uint64 { panic("prog.Ref method called") } -func (ti Ref) UnitOffset() uint64 { panic("prog.Ref method called") } -func (ti Ref) DefaultArg() Arg { panic("prog.Ref method called") } -func (ti Ref) Clone() Type { panic("prog.Ref method called") } -func (ti Ref) isDefaultArg(arg Arg) bool { panic("prog.Ref method called") } -func (ti Ref) generate(r *randGen, s *state) (Arg, []*Call) { panic("prog.Ref method called") } +func (ti Ref) String() string { panic("prog.Ref method called") } +func (ti Ref) Name() string { panic("prog.Ref method called") } +func (ti Ref) FieldName() string { panic("prog.Ref method called") } +func (ti Ref) TemplateName() string { panic("prog.Ref method called") } + +func (ti Ref) Optional() bool { panic("prog.Ref method called") } +func (ti Ref) Varlen() bool { panic("prog.Ref method called") } +func (ti Ref) Size() uint64 { panic("prog.Ref method called") } +func (ti Ref) TypeBitSize() uint64 { panic("prog.Ref method called") } +func (ti Ref) Format() BinaryFormat { panic("prog.Ref method called") } +func (ti Ref) BitfieldOffset() uint64 { panic("prog.Ref method called") } +func (ti Ref) BitfieldLength() uint64 { panic("prog.Ref method called") } +func (ti Ref) IsBitfield() bool { panic("prog.Ref method called") } +func (ti Ref) UnitSize() uint64 { panic("prog.Ref method called") } +func (ti Ref) UnitOffset() uint64 { panic("prog.Ref method called") } +func (ti Ref) DefaultArg(dir Dir) Arg { panic("prog.Ref method called") } +func (ti Ref) Clone() Type { panic("prog.Ref method called") } +func (ti Ref) isDefaultArg(arg Arg) bool { panic("prog.Ref method called") } +func (ti Ref) generate(r *randGen, s *state, dir Dir) (Arg, []*Call) { panic("prog.Ref method called") } func (ti Ref) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) ([]*Call, bool, bool) { panic("prog.Ref method called") } @@ -146,7 +145,6 @@ type TypeCommon struct { FldName string // for struct fields and named args // Static size of the type, or 0 for variable size types and all but last bitfields in the group. TypeSize uint64 - ArgDir Dir IsOptional bool IsVarlen bool } @@ -210,10 +208,6 @@ func (t *TypeCommon) IsBitfield() bool { return false } -func (t TypeCommon) Dir() Dir { - return t.ArgDir -} - type ResourceDesc struct { Name string Kind []string @@ -236,8 +230,8 @@ func (t *ResourceType) String() string { return t.Name() } -func (t *ResourceType) DefaultArg() Arg { - return MakeResultArg(t, nil, t.Default()) +func (t *ResourceType) DefaultArg(dir Dir) Arg { + return MakeResultArg(t, dir, nil, t.Default()) } func (t *ResourceType) isDefaultArg(arg Arg) bool { @@ -317,8 +311,8 @@ type ConstType struct { IsPad bool } -func (t *ConstType) DefaultArg() Arg { - return MakeConstArg(t, t.Val) +func (t *ConstType) DefaultArg(dir Dir) Arg { + return MakeConstArg(t, dir, t.Val) } func (t *ConstType) isDefaultArg(arg Arg) bool { @@ -347,8 +341,8 @@ type IntType struct { Align uint64 } -func (t *IntType) DefaultArg() Arg { - return MakeConstArg(t, 0) +func (t *IntType) DefaultArg(dir Dir) Arg { + return MakeConstArg(t, dir, 0) } func (t *IntType) isDefaultArg(arg Arg) bool { @@ -361,8 +355,8 @@ type FlagsType struct { BitMask bool } -func (t *FlagsType) DefaultArg() Arg { - return MakeConstArg(t, 0) +func (t *FlagsType) DefaultArg(dir Dir) Arg { + return MakeConstArg(t, dir, 0) } func (t *FlagsType) isDefaultArg(arg Arg) bool { @@ -376,8 +370,8 @@ type LenType struct { Path []string } -func (t *LenType) DefaultArg() Arg { - return MakeConstArg(t, 0) +func (t *LenType) DefaultArg(dir Dir) Arg { + return MakeConstArg(t, dir, 0) } func (t *LenType) isDefaultArg(arg Arg) bool { @@ -395,8 +389,8 @@ const ( procDefaultValue = 0xffffffffffffffff // special value denoting 0 for all procs ) -func (t *ProcType) DefaultArg() Arg { - return MakeConstArg(t, procDefaultValue) +func (t *ProcType) DefaultArg(dir Dir) Arg { + return MakeConstArg(t, dir, procDefaultValue) } func (t *ProcType) isDefaultArg(arg Arg) bool { @@ -421,8 +415,8 @@ func (t *CsumType) String() string { return "csum" } -func (t *CsumType) DefaultArg() Arg { - return MakeConstArg(t, 0) +func (t *CsumType) DefaultArg(dir Dir) Arg { + return MakeConstArg(t, dir, 0) } func (t *CsumType) isDefaultArg(arg Arg) bool { @@ -439,8 +433,8 @@ func (t *VmaType) String() string { return "vma" } -func (t *VmaType) DefaultArg() Arg { - return MakeSpecialPointerArg(t, 0) +func (t *VmaType) DefaultArg(dir Dir) Arg { + return MakeSpecialPointerArg(t, dir, 0) } func (t *VmaType) isDefaultArg(arg Arg) bool { @@ -484,19 +478,19 @@ func (t *BufferType) String() string { return "buffer" } -func (t *BufferType) DefaultArg() Arg { - if t.Dir() == DirOut { +func (t *BufferType) DefaultArg(dir Dir) Arg { + if dir == DirOut { var sz uint64 if !t.Varlen() { sz = t.Size() } - return MakeOutDataArg(t, sz) + return MakeOutDataArg(t, dir, sz) } var data []byte if !t.Varlen() { data = make([]byte, t.Size()) } - return MakeDataArg(t, data) + return MakeDataArg(t, dir, data) } func (t *BufferType) isDefaultArg(arg Arg) bool { @@ -507,7 +501,7 @@ func (t *BufferType) isDefaultArg(arg Arg) bool { if a.Type().Varlen() { return false } - if a.Type().Dir() == DirOut { + if a.Dir() == DirOut { return true } for _, v := range a.Data() { @@ -537,14 +531,14 @@ func (t *ArrayType) String() string { return fmt.Sprintf("array[%v]", t.Type.String()) } -func (t *ArrayType) DefaultArg() Arg { +func (t *ArrayType) DefaultArg(dir Dir) Arg { var elems []Arg if t.Kind == ArrayRangeLen && t.RangeBegin == t.RangeEnd { for i := uint64(0); i < t.RangeBegin; i++ { - elems = append(elems, t.Type.DefaultArg()) + elems = append(elems, t.Type.DefaultArg(dir)) } } - return MakeGroupArg(t, elems) + return MakeGroupArg(t, dir, elems) } func (t *ArrayType) isDefaultArg(arg Arg) bool { @@ -562,18 +556,19 @@ func (t *ArrayType) isDefaultArg(arg Arg) bool { type PtrType struct { TypeCommon - Type Type + Type Type + ElemDir Dir } func (t *PtrType) String() string { - return fmt.Sprintf("ptr[%v, %v]", t.Dir(), t.Type.String()) + return fmt.Sprintf("ptr[%v, %v]", t.ElemDir, t.Type.String()) } -func (t *PtrType) DefaultArg() Arg { +func (t *PtrType) DefaultArg(dir Dir) Arg { if t.Optional() { - return MakeSpecialPointerArg(t, 0) + return MakeSpecialPointerArg(t, dir, 0) } - return MakePointerArg(t, 0, t.Type.DefaultArg()) + return MakePointerArg(t, dir, 0, t.Type.DefaultArg(t.ElemDir)) } func (t *PtrType) isDefaultArg(arg Arg) bool { @@ -598,12 +593,12 @@ func (t *StructType) FieldName() string { return t.FldName } -func (t *StructType) DefaultArg() Arg { +func (t *StructType) DefaultArg(dir Dir) Arg { inner := make([]Arg, len(t.Fields)) for i, field := range t.Fields { - inner[i] = field.DefaultArg() + inner[i] = field.DefaultArg(dir) } - return MakeGroupArg(t, inner) + return MakeGroupArg(t, dir, inner) } func (t *StructType) isDefaultArg(arg Arg) bool { @@ -630,8 +625,8 @@ func (t *UnionType) FieldName() string { return t.FldName } -func (t *UnionType) DefaultArg() Arg { - return MakeUnionArg(t, t.Fields[0].DefaultArg()) +func (t *UnionType) DefaultArg(dir Dir) Arg { + return MakeUnionArg(t, dir, t.Fields[0].DefaultArg(dir)) } func (t *UnionType) isDefaultArg(arg Arg) bool { @@ -651,7 +646,6 @@ func (t *StructDesc) FieldName() string { type StructKey struct { Name string - Dir Dir } type KeyedStruct struct { @@ -664,29 +658,33 @@ type ConstValue struct { Value uint64 } -func ForeachType(meta *Syscall, f func(Type)) { - var rec func(t Type) +type typeCtx struct { + Dir Dir +} + +func foreachType(meta *Syscall, f func(t Type, ctx typeCtx)) { + var rec func(t Type, dir Dir) seen := make(map[*StructDesc]bool) - recStruct := func(desc *StructDesc) { + recStruct := func(desc *StructDesc, dir Dir) { if seen[desc] { return // prune recursion via pointers to structs/unions } seen[desc] = true for _, f := range desc.Fields { - rec(f) + rec(f, dir) } } - rec = func(t Type) { - f(t) + rec = func(t Type, dir Dir) { + f(t, typeCtx{Dir: dir}) switch a := t.(type) { case *PtrType: - rec(a.Type) + rec(a.Type, a.ElemDir) case *ArrayType: - rec(a.Type) + rec(a.Type, dir) case *StructType: - recStruct(a.StructDesc) + recStruct(a.StructDesc, dir) case *UnionType: - recStruct(a.StructDesc) + recStruct(a.StructDesc, dir) case *ResourceType, *BufferType, *VmaType, *LenType, *FlagsType, *ConstType, *IntType, *ProcType, *CsumType: default: @@ -694,10 +692,10 @@ func ForeachType(meta *Syscall, f func(Type)) { } } for _, t := range meta.Args { - rec(t) + rec(t, DirIn) } if meta.Ret != nil { - rec(meta.Ret) + rec(meta.Ret, DirOut) } } diff --git a/prog/validation.go b/prog/validation.go index c8c030aa7..8ec0615da 100644 --- a/prog/validation.go +++ b/prog/validation.go @@ -55,7 +55,7 @@ func (ctx *validCtx) validateCall(c *Call) error { len(c.Meta.Args), len(c.Args)) } for i, arg := range c.Args { - if err := ctx.validateArg(arg, c.Meta.Args[i]); err != nil { + if err := ctx.validateArg(arg, c.Meta.Args[i], DirIn); err != nil { return err } } @@ -72,16 +72,13 @@ func (ctx *validCtx) validateRet(c *Call) error { if c.Ret == nil { return fmt.Errorf("return value is absent") } - if c.Ret.Type().Dir() != DirOut { - return fmt.Errorf("return value %v is not output", c.Ret) - } if c.Ret.Res != nil || c.Ret.Val != 0 || c.Ret.OpDiv != 0 || c.Ret.OpAdd != 0 { return fmt.Errorf("return value %v is not empty", c.Ret) } - return ctx.validateArg(c.Ret, c.Meta.Ret) + return ctx.validateArg(c.Ret, c.Meta.Ret, DirOut) } -func (ctx *validCtx) validateArg(arg Arg, typ Type) error { +func (ctx *validCtx) validateArg(arg Arg, typ Type, dir Dir) error { if arg == nil { return fmt.Errorf("nil arg") } @@ -91,6 +88,12 @@ func (ctx *validCtx) validateArg(arg Arg, typ Type) error { if arg.Type() == nil { return fmt.Errorf("no arg type") } + if _, ok := typ.(*PtrType); ok { + dir = DirIn // pointers are always in + } + if arg.Dir() != dir { + return fmt.Errorf("arg %#v type %v has wrong dir %v, expect %v", arg, arg.Type(), arg.Dir(), dir) + } if !ctx.target.isAnyPtr(arg.Type()) && arg.Type() != typ { return fmt.Errorf("bad arg type %#v, expect %#v", arg.Type(), typ) } @@ -101,7 +104,7 @@ func (ctx *validCtx) validateArg(arg Arg, typ Type) error { func (arg *ConstArg) validate(ctx *validCtx) error { switch typ := arg.Type().(type) { case *IntType: - if typ.Dir() == DirOut && !isDefault(arg) { + if arg.Dir() == DirOut && !isDefault(arg) { return fmt.Errorf("out int arg '%v' has bad const value %v", typ.Name(), arg.Val) } case *ProcType: @@ -116,9 +119,10 @@ func (arg *ConstArg) validate(ctx *validCtx) error { default: return fmt.Errorf("const arg %v has bad type %v", arg, typ.Name()) } - if typ := arg.Type(); typ.Dir() == DirOut { + if arg.Dir() == DirOut { // We generate output len arguments, which makes sense since it can be // a length of a variable-length array which is not known otherwise. + typ := arg.Type() if _, isLen := typ.(*LenType); !isLen { if !typ.isDefaultArg(arg) { return fmt.Errorf("output arg '%v'/'%v' has non default value '%+v'", @@ -143,7 +147,7 @@ func (arg *ResultArg) validate(ctx *validCtx) error { } ctx.uses[u] = arg } - if typ.Dir() == DirOut && arg.Val != 0 && arg.Val != typ.Default() { + if arg.Dir() == DirOut && arg.Val != 0 && arg.Val != typ.Default() { return fmt.Errorf("out resource arg '%v' has bad const value %v", typ.Name(), arg.Val) } if arg.Res != nil { @@ -163,7 +167,7 @@ func (arg *DataArg) validate(ctx *validCtx) error { if !ok { return fmt.Errorf("data arg %v has bad type %v", arg, arg.Type().Name()) } - if typ.Dir() == DirOut && len(arg.data) != 0 { + if arg.Dir() == DirOut && len(arg.data) != 0 { return fmt.Errorf("output arg '%v' has data", typ.Name()) } if !typ.Varlen() && typ.Size() != arg.Size() { @@ -188,7 +192,7 @@ func (arg *GroupArg) validate(ctx *validCtx) error { typ.Name(), len(typ.Fields), len(arg.Inner)) } for i, field := range arg.Inner { - if err := ctx.validateArg(field, typ.Fields[i]); err != nil { + if err := ctx.validateArg(field, typ.Fields[i], arg.Dir()); err != nil { return err } } @@ -199,7 +203,7 @@ func (arg *GroupArg) validate(ctx *validCtx) error { typ.Name(), len(arg.Inner), typ.RangeBegin) } for _, elem := range arg.Inner { - if err := ctx.validateArg(elem, typ.Type); err != nil { + if err := ctx.validateArg(elem, typ.Type, arg.Dir()); err != nil { return err } } @@ -224,7 +228,7 @@ func (arg *UnionArg) validate(ctx *validCtx) error { if optType == nil { return fmt.Errorf("union arg '%v' has bad option", typ.Name()) } - return ctx.validateArg(arg.Option, optType) + return ctx.validateArg(arg.Option, optType, arg.Dir()) } func (arg *PointerArg) validate(ctx *validCtx) error { @@ -235,14 +239,14 @@ func (arg *PointerArg) validate(ctx *validCtx) error { } case *PtrType: if arg.Res != nil { - if err := ctx.validateArg(arg.Res, typ.Type); err != nil { + if err := ctx.validateArg(arg.Res, typ.Type, typ.ElemDir); err != nil { return err } } if arg.VmaSize != 0 { return fmt.Errorf("pointer arg '%v' has nonzero size", typ.Name()) } - if typ.Dir() == DirOut { + if arg.Dir() == DirOut { return fmt.Errorf("pointer arg '%v' has output direction", typ.Name()) } default: |
