aboutsummaryrefslogtreecommitdiffstats
path: root/prog
diff options
context:
space:
mode:
Diffstat (limited to 'prog')
-rw-r--r--prog/analysis.go4
-rw-r--r--prog/any.go34
-rw-r--r--prog/encoding.go118
-rw-r--r--prog/encodingexec.go2
-rw-r--r--prog/hints.go2
-rw-r--r--prog/hints_test.go8
-rw-r--r--prog/minimization.go6
-rw-r--r--prog/mutation.go28
-rw-r--r--prog/prio.go16
-rw-r--r--prog/prog.go60
-rw-r--r--prog/prog_test.go17
-rw-r--r--prog/rand.go145
-rw-r--r--prog/rand_test.go4
-rw-r--r--prog/resources.go26
-rw-r--r--prog/rotation.go2
-rw-r--r--prog/target.go20
-rw-r--r--prog/types.go148
-rw-r--r--prog/validation.go34
18 files changed, 348 insertions, 326 deletions
diff --git a/prog/analysis.go b/prog/analysis.go
index fe022b670..e1fbaa557 100644
--- a/prog/analysis.go
+++ b/prog/analysis.go
@@ -69,13 +69,13 @@ func (s *state) analyzeImpl(c *Call, resources bool) {
switch typ := arg.Type().(type) {
case *ResourceType:
a := arg.(*ResultArg)
- if resources && typ.Dir() != DirIn {
+ if resources && a.Dir() != DirIn {
s.resources[typ.Desc.Name] = append(s.resources[typ.Desc.Name], a)
// TODO: negative PIDs and add them as well (that's process groups).
}
case *BufferType:
a := arg.(*DataArg)
- if typ.Dir() != DirOut && len(a.Data()) != 0 {
+ if a.Dir() != DirOut && len(a.Data()) != 0 {
val := string(a.Data())
// Remove trailing zero padding.
for len(val) >= 2 && val[len(val)-1] == 0 && val[len(val)-2] == 0 {
diff --git a/prog/any.go b/prog/any.go
index 15ce6ec53..d1433b18d 100644
--- a/prog/any.go
+++ b/prog/any.go
@@ -54,7 +54,8 @@ func initAnyTypes(target *Target) {
TypeSize: target.PtrSize,
IsOptional: true,
},
- Type: target.any.array,
+ Type: target.any.array,
+ ElemDir: DirIn,
}
target.any.ptr64 = &PtrType{
TypeCommon: TypeCommon{
@@ -63,7 +64,8 @@ func initAnyTypes(target *Target) {
TypeSize: 8,
IsOptional: true,
},
- Type: target.any.array,
+ Type: target.any.array,
+ ElemDir: DirIn,
}
target.any.blob = &BufferType{
TypeCommon: TypeCommon{
@@ -77,7 +79,6 @@ func initAnyTypes(target *Target) {
TypeCommon: TypeCommon{
TypeName: name,
FldName: name,
- ArgDir: DirIn,
TypeSize: size,
IsOptional: true,
},
@@ -100,7 +101,6 @@ func initAnyTypes(target *Target) {
TypeName: "ANYUNION",
FldName: "ANYUNION",
IsVarlen: true,
- ArgDir: DirIn,
},
Fields: []Type{
target.any.blob,
@@ -150,7 +150,7 @@ func (p *Prog) complexPtrs() (res []*PointerArg) {
}
func (target *Target) isComplexPtr(arg *PointerArg) bool {
- if arg.Res == nil || arg.Type().Dir() != DirIn {
+ if arg.Res == nil || arg.Dir() != DirIn {
return false
}
if target.isAnyPtr(arg.Type()) {
@@ -175,6 +175,15 @@ func (target *Target) isComplexPtr(arg *PointerArg) bool {
return complex && !hasPtr
}
+func (target *Target) isAnyRes(name string) bool {
+ return name == target.any.res16.TypeName ||
+ name == target.any.res32.TypeName ||
+ name == target.any.res64.TypeName ||
+ name == target.any.resdec.TypeName ||
+ name == target.any.reshex.TypeName ||
+ name == target.any.resoct.TypeName
+}
+
func (target *Target) CallContainsAny(c *Call) (res bool) {
ForeachArg(c, func(arg Arg, ctx *ArgCtx) {
if target.isAnyPtr(arg.Type()) {
@@ -208,7 +217,7 @@ func (target *Target) squashPtr(arg *PointerArg, preserveField bool) {
field = arg.Type().FieldName()
}
arg.typ = target.makeAnyPtrType(arg.Type().Size(), field)
- arg.Res = MakeGroupArg(arg.typ.(*PtrType).Type, elems)
+ arg.Res = MakeGroupArg(arg.typ.(*PtrType).Type, DirIn, elems)
if size := arg.Res.Size(); size != size0 {
panic(fmt.Sprintf("squash changed size %v->%v for %v", size0, size, res0.Type()))
}
@@ -230,7 +239,7 @@ func (target *Target) squashPtrImpl(a Arg, elems *[]Arg) {
}
target.squashPtrImpl(arg.Option, elems)
case *DataArg:
- if arg.Type().Dir() == DirOut {
+ if arg.Dir() == DirOut {
pad = arg.Size()
} else {
elem := target.ensureDataElem(elems)
@@ -299,7 +308,8 @@ func (target *Target) squashResult(arg *ResultArg, elems *[]Arg) {
default:
panic("bad")
}
- *elems = append(*elems, MakeUnionArg(target.any.union, arg))
+ arg.dir = DirIn
+ *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, arg))
}
func (target *Target) squashGroup(arg *GroupArg, elems *[]Arg) {
@@ -375,14 +385,14 @@ func (target *Target) squashedValue(arg *ConstArg) (uint64, BinaryFormat) {
func (target *Target) ensureDataElem(elems *[]Arg) *DataArg {
if len(*elems) == 0 {
- res := MakeDataArg(target.any.blob, nil)
- *elems = append(*elems, MakeUnionArg(target.any.union, res))
+ res := MakeDataArg(target.any.blob, DirIn, nil)
+ *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, res))
return res
}
res, ok := (*elems)[len(*elems)-1].(*UnionArg).Option.(*DataArg)
if !ok {
- res = MakeDataArg(target.any.blob, nil)
- *elems = append(*elems, MakeUnionArg(target.any.union, res))
+ res = MakeDataArg(target.any.blob, DirIn, nil)
+ *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, res))
}
return res
}
diff --git a/prog/encoding.go b/prog/encoding.go
index f99ff9d84..c7f8ba56a 100644
--- a/prog/encoding.go
+++ b/prog/encoding.go
@@ -113,7 +113,7 @@ func (a *PointerArg) serialize(ctx *serializer) {
func (a *DataArg) serialize(ctx *serializer) {
typ := a.Type().(*BufferType)
- if typ.Dir() == DirOut {
+ if a.Dir() == DirOut {
ctx.printf("\"\"/%v", a.Size())
return
}
@@ -280,7 +280,7 @@ func (p *parser) parseProg() (*Prog, error) {
if IsPad(typ) {
return nil, fmt.Errorf("padding in syscall %v arguments", name)
}
- arg, err := p.parseArg(typ)
+ arg, err := p.parseArg(typ, DirIn)
if err != nil {
return nil, err
}
@@ -302,7 +302,7 @@ func (p *parser) parseProg() (*Prog, error) {
}
for i := len(c.Args); i < len(meta.Args); i++ {
p.strictFailf("missing syscall args")
- c.Args = append(c.Args, meta.Args[i].DefaultArg())
+ c.Args = append(c.Args, meta.Args[i].DefaultArg(DirIn))
}
if len(c.Args) != len(meta.Args) {
return nil, fmt.Errorf("wrong call arg count: %v, want %v", len(c.Args), len(meta.Args))
@@ -318,7 +318,7 @@ func (p *parser) parseProg() (*Prog, error) {
return prog, nil
}
-func (p *parser) parseArg(typ Type) (Arg, error) {
+func (p *parser) parseArg(typ Type, dir Dir) (Arg, error) {
r := ""
if p.Char() == '<' {
p.Parse('<')
@@ -326,13 +326,13 @@ func (p *parser) parseArg(typ Type) (Arg, error) {
p.Parse('=')
p.Parse('>')
}
- arg, err := p.parseArgImpl(typ)
+ arg, err := p.parseArgImpl(typ, dir)
if err != nil {
return nil, err
}
if arg == nil {
if typ != nil {
- arg = typ.DefaultArg()
+ arg = typ.DefaultArg(dir)
} else if r != "" {
return nil, fmt.Errorf("named nil argument")
}
@@ -345,26 +345,26 @@ func (p *parser) parseArg(typ Type) (Arg, error) {
return arg, nil
}
-func (p *parser) parseArgImpl(typ Type) (Arg, error) {
+func (p *parser) parseArgImpl(typ Type, dir Dir) (Arg, error) {
if typ == nil && p.Char() != 'n' {
p.eatExcessive(true, "non-nil argument for nil type")
return nil, nil
}
switch p.Char() {
case '0':
- return p.parseArgInt(typ)
+ return p.parseArgInt(typ, dir)
case 'r':
- return p.parseArgRes(typ)
+ return p.parseArgRes(typ, dir)
case '&':
- return p.parseArgAddr(typ)
+ return p.parseArgAddr(typ, dir)
case '"', '\'':
- return p.parseArgString(typ)
+ return p.parseArgString(typ, dir)
case '{':
- return p.parseArgStruct(typ)
+ return p.parseArgStruct(typ, dir)
case '[':
- return p.parseArgArray(typ)
+ return p.parseArgArray(typ, dir)
case '@':
- return p.parseArgUnion(typ)
+ return p.parseArgUnion(typ, dir)
case 'n':
p.Parse('n')
p.Parse('i')
@@ -375,14 +375,14 @@ func (p *parser) parseArgImpl(typ Type) (Arg, error) {
p.Parse('U')
p.Parse('T')
p.Parse('O')
- return p.parseAuto(typ)
+ return p.parseAuto(typ, dir)
default:
return nil, fmt.Errorf("failed to parse argument at '%c' (line #%v/%v: %v)",
p.Char(), p.l, p.i, p.s)
}
}
-func (p *parser) parseArgInt(typ Type) (Arg, error) {
+func (p *parser) parseArgInt(typ Type, dir Dir) (Arg, error) {
val := p.Ident()
v, err := strconv.ParseUint(val, 0, 64)
if err != nil {
@@ -390,35 +390,35 @@ func (p *parser) parseArgInt(typ Type) (Arg, error) {
}
switch typ.(type) {
case *ConstType, *IntType, *FlagsType, *ProcType, *CsumType:
- arg := Arg(MakeConstArg(typ, v))
- if typ.Dir() == DirOut && !typ.isDefaultArg(arg) {
+ arg := Arg(MakeConstArg(typ, dir, v))
+ if dir == DirOut && !typ.isDefaultArg(arg) {
p.strictFailf("out arg %v has non-default value: %v", typ, v)
- arg = typ.DefaultArg()
+ arg = typ.DefaultArg(dir)
}
return arg, nil
case *LenType:
- return MakeConstArg(typ, v), nil
+ return MakeConstArg(typ, dir, v), nil
case *ResourceType:
- return MakeResultArg(typ, nil, v), nil
+ return MakeResultArg(typ, dir, nil, v), nil
case *PtrType, *VmaType:
index := -v % uint64(len(p.target.SpecialPointers))
- return MakeSpecialPointerArg(typ, index), nil
+ return MakeSpecialPointerArg(typ, dir, index), nil
default:
p.eatExcessive(true, "wrong int arg %T", typ)
- return typ.DefaultArg(), nil
+ return typ.DefaultArg(dir), nil
}
}
-func (p *parser) parseAuto(typ Type) (Arg, error) {
+func (p *parser) parseAuto(typ Type, dir Dir) (Arg, error) {
switch typ.(type) {
case *ConstType, *LenType, *CsumType:
- return p.auto(MakeConstArg(typ, 0)), nil
+ return p.auto(MakeConstArg(typ, dir, 0)), nil
default:
return nil, fmt.Errorf("wrong type %T for AUTO", typ)
}
}
-func (p *parser) parseArgRes(typ Type) (Arg, error) {
+func (p *parser) parseArgRes(typ Type, dir Dir) (Arg, error) {
id := p.Ident()
var div, add uint64
if p.Char() == '/' {
@@ -442,23 +442,25 @@ func (p *parser) parseArgRes(typ Type) (Arg, error) {
v := p.vars[id]
if v == nil {
p.strictFailf("undeclared variable %v", id)
- return typ.DefaultArg(), nil
+ return typ.DefaultArg(dir), nil
}
- arg := MakeResultArg(typ, v, 0)
+ arg := MakeResultArg(typ, dir, v, 0)
arg.OpDiv = div
arg.OpAdd = add
return arg, nil
}
-func (p *parser) parseArgAddr(typ Type) (Arg, error) {
+func (p *parser) parseArgAddr(typ Type, dir Dir) (Arg, error) {
var typ1 Type
+ elemDir := DirInOut
switch t1 := typ.(type) {
case *PtrType:
typ1 = t1.Type
+ elemDir = t1.ElemDir
case *VmaType:
default:
p.eatExcessive(true, "wrong addr arg")
- return typ.DefaultArg(), nil
+ return typ.DefaultArg(dir), nil
}
p.Parse('&')
auto := false
@@ -487,11 +489,11 @@ func (p *parser) parseArgAddr(typ Type) (Arg, error) {
p.Parse('N')
p.Parse('Y')
p.Parse('=')
- typ = p.target.makeAnyPtrType(typ.Size(), typ.FieldName())
- typ1 = p.target.any.array
+ anyPtr := p.target.makeAnyPtrType(typ.Size(), typ.FieldName())
+ typ, typ1, elemDir = anyPtr, anyPtr.Type, anyPtr.ElemDir
}
var err error
- inner, err = p.parseArg(typ1)
+ inner, err = p.parseArg(typ1, elemDir)
if err != nil {
return nil, err
}
@@ -501,23 +503,23 @@ func (p *parser) parseArgAddr(typ Type) (Arg, error) {
p.strictFailf("unaligned vma address 0x%x", addr)
addr &= ^(p.target.PageSize - 1)
}
- return MakeVmaPointerArg(typ, addr, vmaSize), nil
+ return MakeVmaPointerArg(typ, dir, addr, vmaSize), nil
}
if inner == nil {
- inner = typ1.DefaultArg()
+ inner = typ1.DefaultArg(elemDir)
}
- arg := MakePointerArg(typ, addr, inner)
+ arg := MakePointerArg(typ, dir, addr, inner)
if auto {
p.auto(arg)
}
return arg, nil
}
-func (p *parser) parseArgString(t Type) (Arg, error) {
+func (p *parser) parseArgString(t Type, dir Dir) (Arg, error) {
typ, ok := t.(*BufferType)
if !ok {
p.eatExcessive(true, "wrong string arg")
- return t.DefaultArg(), nil
+ return t.DefaultArg(dir), nil
}
data, err := p.deserializeData()
if err != nil {
@@ -542,8 +544,8 @@ func (p *parser) parseArgString(t Type) (Arg, error) {
} else if size == ^uint64(0) {
size = uint64(len(data))
}
- if typ.Dir() == DirOut {
- return MakeOutDataArg(typ, size), nil
+ if dir == DirOut {
+ return MakeOutDataArg(typ, dir, size), nil
}
if diff := int(size) - len(data); diff > 0 {
data = append(data, make([]byte, diff)...)
@@ -564,16 +566,16 @@ func (p *parser) parseArgString(t Type) (Arg, error) {
data = []byte(typ.Values[0])
}
}
- return MakeDataArg(typ, data), nil
+ return MakeDataArg(typ, dir, data), nil
}
-func (p *parser) parseArgStruct(typ Type) (Arg, error) {
+func (p *parser) parseArgStruct(typ Type, dir Dir) (Arg, error) {
p.Parse('{')
t1, ok := typ.(*StructType)
if !ok {
p.eatExcessive(false, "wrong struct arg")
p.Parse('}')
- return typ.DefaultArg(), nil
+ return typ.DefaultArg(dir), nil
}
var inner []Arg
for i := 0; p.Char() != '}'; i++ {
@@ -583,9 +585,9 @@ func (p *parser) parseArgStruct(typ Type) (Arg, error) {
}
fld := t1.Fields[i]
if IsPad(fld) {
- inner = append(inner, MakeConstArg(fld, 0))
+ inner = append(inner, MakeConstArg(fld, dir, 0))
} else {
- arg, err := p.parseArg(fld)
+ arg, err := p.parseArg(fld, dir)
if err != nil {
return nil, err
}
@@ -601,22 +603,22 @@ func (p *parser) parseArgStruct(typ Type) (Arg, error) {
if !IsPad(fld) {
p.strictFailf("missing struct %v fields %v/%v", typ.Name(), len(inner), len(t1.Fields))
}
- inner = append(inner, fld.DefaultArg())
+ inner = append(inner, fld.DefaultArg(dir))
}
- return MakeGroupArg(typ, inner), nil
+ return MakeGroupArg(typ, dir, inner), nil
}
-func (p *parser) parseArgArray(typ Type) (Arg, error) {
+func (p *parser) parseArgArray(typ Type, dir Dir) (Arg, error) {
p.Parse('[')
t1, ok := typ.(*ArrayType)
if !ok {
p.eatExcessive(false, "wrong array arg %T", typ)
p.Parse(']')
- return typ.DefaultArg(), nil
+ return typ.DefaultArg(dir), nil
}
var inner []Arg
for i := 0; p.Char() != ']'; i++ {
- arg, err := p.parseArg(t1.Type)
+ arg, err := p.parseArg(t1.Type, dir)
if err != nil {
return nil, err
}
@@ -629,18 +631,18 @@ func (p *parser) parseArgArray(typ Type) (Arg, error) {
if t1.Kind == ArrayRangeLen && t1.RangeBegin == t1.RangeEnd {
for uint64(len(inner)) < t1.RangeBegin {
p.strictFailf("missing array elements")
- inner = append(inner, t1.Type.DefaultArg())
+ inner = append(inner, t1.Type.DefaultArg(dir))
}
inner = inner[:t1.RangeBegin]
}
- return MakeGroupArg(typ, inner), nil
+ return MakeGroupArg(typ, dir, inner), nil
}
-func (p *parser) parseArgUnion(typ Type) (Arg, error) {
+func (p *parser) parseArgUnion(typ Type, dir Dir) (Arg, error) {
t1, ok := typ.(*UnionType)
if !ok {
p.eatExcessive(true, "wrong union arg")
- return typ.DefaultArg(), nil
+ return typ.DefaultArg(dir), nil
}
p.Parse('@')
name := p.Ident()
@@ -653,20 +655,20 @@ func (p *parser) parseArgUnion(typ Type) (Arg, error) {
}
if optType == nil {
p.eatExcessive(true, "wrong union option")
- return typ.DefaultArg(), nil
+ return typ.DefaultArg(dir), nil
}
var opt Arg
if p.Char() == '=' {
p.Parse('=')
var err error
- opt, err = p.parseArg(optType)
+ opt, err = p.parseArg(optType, dir)
if err != nil {
return nil, err
}
} else {
- opt = optType.DefaultArg()
+ opt = optType.DefaultArg(dir)
}
- return MakeUnionArg(typ, opt), nil
+ return MakeUnionArg(typ, dir, opt), nil
}
// Eats excessive call arguments and struct fields to recover after description changes.
diff --git a/prog/encodingexec.go b/prog/encodingexec.go
index b5c410287..99357dfd2 100644
--- a/prog/encodingexec.go
+++ b/prog/encodingexec.go
@@ -134,7 +134,7 @@ func (w *execContext) writeCopyin(c *Call) {
return
}
typ := arg.Type()
- if typ.Dir() == DirOut || IsPad(typ) || (arg.Size() == 0 && !typ.IsBitfield()) {
+ if arg.Dir() == DirOut || IsPad(typ) || (arg.Size() == 0 && !typ.IsBitfield()) {
return
}
w.write(execInstrCopyin)
diff --git a/prog/hints.go b/prog/hints.go
index 9a5675b1b..be2b371f2 100644
--- a/prog/hints.go
+++ b/prog/hints.go
@@ -82,7 +82,7 @@ func (p *Prog) MutateWithHints(callIndex int, comps CompMap, exec func(p *Prog))
func generateHints(compMap CompMap, arg Arg, exec func()) {
typ := arg.Type()
- if typ == nil || typ.Dir() == DirOut {
+ if typ == nil || arg.Dir() == DirOut {
return
}
switch t := typ.(type) {
diff --git a/prog/hints_test.go b/prog/hints_test.go
index 0fc9afa02..caf84e715 100644
--- a/prog/hints_test.go
+++ b/prog/hints_test.go
@@ -157,7 +157,7 @@ func TestHintsCheckConstArg(t *testing.T) {
typ := &IntType{IntTypeCommon: IntTypeCommon{TypeCommon: TypeCommon{
TypeSize: test.size},
BitfieldLen: test.bitsize}}
- constArg := MakeConstArg(typ, test.in)
+ constArg := MakeConstArg(typ, DirIn, test.in)
checkConstArg(constArg, test.comps, func() {
res = append(res, constArg.Val)
})
@@ -295,8 +295,8 @@ func TestHintsCheckDataArg(t *testing.T) {
res := make(map[string]bool)
// Whatever type here. It's just needed to pass the
// dataArg.Type().Dir() == DirIn check.
- typ := &ArrayType{TypeCommon{"", "", 0, DirIn, false, true}, nil, 0, 0, 0}
- dataArg := MakeDataArg(typ, []byte(test.in))
+ typ := &ArrayType{TypeCommon{"", "", 0, false, true}, nil, 0, 0, 0}
+ dataArg := MakeDataArg(typ, DirIn, []byte(test.in))
checkDataArg(dataArg, test.comps, func() {
res[string(dataArg.Data())] = true
})
@@ -499,7 +499,7 @@ func TestHintsRandom(t *testing.T) {
func extractValues(c *Call) map[uint64]bool {
vals := make(map[uint64]bool)
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
- if typ := arg.Type(); typ == nil || typ.Dir() == DirOut {
+ if arg.Dir() == DirOut {
return
}
switch a := arg.(type) {
diff --git a/prog/minimization.go b/prog/minimization.go
index 93a986556..0b5bb05b2 100644
--- a/prog/minimization.go
+++ b/prog/minimization.go
@@ -137,7 +137,7 @@ func (typ *PtrType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool {
}
if !ctx.triedPaths[path+"->"] {
removeArg(a.Res)
- replaceArg(a, MakeSpecialPointerArg(a.Type(), 0))
+ replaceArg(a, MakeSpecialPointerArg(a.Type(), a.Dir(), 0))
ctx.target.assignSizesCall(ctx.call)
if ctx.pred(ctx.p, ctx.callIndex0) {
*ctx.p0 = ctx.p
@@ -201,7 +201,7 @@ func minimizeInt(ctx *minimizeArgsCtx, arg Arg, path string) bool {
return false
}
a := arg.(*ConstArg)
- def := arg.Type().DefaultArg().(*ConstArg)
+ def := arg.Type().DefaultArg(arg.Dir()).(*ConstArg)
if a.Val == def.Val {
return false
}
@@ -239,7 +239,7 @@ func (typ *ResourceType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bo
func (typ *BufferType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool {
// TODO: try to set individual bytes to 0
- if typ.Kind != BufferBlobRand && typ.Kind != BufferBlobRange || typ.Dir() == DirOut {
+ if typ.Kind != BufferBlobRand && typ.Kind != BufferBlobRange || arg.Dir() == DirOut {
return false
}
a := arg.(*DataArg)
diff --git a/prog/mutation.go b/prog/mutation.go
index 7203a86ec..7087b4d86 100644
--- a/prog/mutation.go
+++ b/prog/mutation.go
@@ -99,7 +99,7 @@ func (ctx *mutator) squashAny() bool {
var blobs []*DataArg
var bases []*PointerArg
ForeachSubArg(ptr, func(arg Arg, ctx *ArgCtx) {
- if data, ok := arg.(*DataArg); ok && arg.Type().Dir() != DirOut {
+ if data, ok := arg.(*DataArg); ok && arg.Dir() != DirOut {
blobs = append(blobs, data)
bases = append(bases, ctx.Base)
}
@@ -119,7 +119,7 @@ func (ctx *mutator) squashAny() bool {
// Update base pointer if size has increased.
if baseSize < base.Res.Size() {
s := analyze(ctx.ct, ctx.corpus, p, p.Calls[0])
- newArg := r.allocAddr(s, base.Type(), base.Res.Size(), base.Res)
+ newArg := r.allocAddr(s, base.Type(), base.Dir(), base.Res.Size(), base.Res)
*base = *newArg
}
return true
@@ -252,7 +252,7 @@ func (target *Target) mutateArg(r *randGen, s *state, arg Arg, ctx ArgCtx, updat
}
// Update base pointer if size has increased.
if base := ctx.Base; base != nil && baseSize < base.Res.Size() {
- newArg := r.allocAddr(s, base.Type(), base.Res.Size(), base.Res)
+ newArg := r.allocAddr(s, base.Type(), base.Dir(), base.Res.Size(), base.Res)
replaceArg(base, newArg)
}
return calls, true
@@ -260,7 +260,7 @@ func (target *Target) mutateArg(r *randGen, s *state, arg Arg, ctx ArgCtx, updat
func regenerate(r *randGen, s *state, arg Arg) (calls []*Call, retry, preserve bool) {
var newArg Arg
- newArg, calls = r.generateArg(s, arg.Type())
+ newArg, calls = r.generateArg(s, arg.Type(), arg.Dir())
replaceArg(arg, newArg)
return
}
@@ -346,7 +346,7 @@ func (t *BufferType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []
minLen, maxLen = t.RangeBegin, t.RangeEnd
}
a := arg.(*DataArg)
- if t.Dir() == DirOut {
+ if a.Dir() == DirOut {
mutateBufferSize(r, a, minLen, maxLen)
return
}
@@ -412,7 +412,7 @@ func (t *ArrayType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*
}
if count > uint64(len(a.Inner)) {
for count > uint64(len(a.Inner)) {
- newArg, newCalls := r.generateArg(s, t.Type)
+ newArg, newCalls := r.generateArg(s, t.Type, a.Dir())
a.Inner = append(a.Inner, newArg)
calls = append(calls, newCalls...)
for _, c := range newCalls {
@@ -433,11 +433,11 @@ func (t *PtrType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*Ca
if r.oneOf(1000) {
removeArg(a.Res)
index := r.rand(len(r.target.SpecialPointers))
- newArg := MakeSpecialPointerArg(t, index)
+ newArg := MakeSpecialPointerArg(t, a.Dir(), index)
replaceArg(arg, newArg)
return
}
- newArg := r.allocAddr(s, t, a.Res.Size(), a.Res)
+ newArg := r.allocAddr(s, t, a.Dir(), a.Res.Size(), a.Res)
replaceArg(arg, newArg)
return
}
@@ -448,7 +448,7 @@ func (t *StructType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []
panic("bad arg returned by mutationArgs: StructType")
}
var newArg Arg
- newArg, calls = gen(&Gen{r, s}, t, arg)
+ newArg, calls = gen(&Gen{r, s}, t, arg.Dir(), arg)
a := arg.(*GroupArg)
for i, f := range newArg.(*GroupArg).Inner {
replaceArg(a.Inner[i], f)
@@ -459,7 +459,7 @@ func (t *StructType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []
func (t *UnionType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*Call, retry, preserve bool) {
if gen := r.target.SpecialTypes[t.Name()]; gen != nil {
var newArg Arg
- newArg, calls = gen(&Gen{r, s}, t, arg)
+ newArg, calls = gen(&Gen{r, s}, t, arg.Dir(), arg)
replaceArg(arg, newArg)
return
}
@@ -481,8 +481,8 @@ func (t *UnionType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*
optType := t.Fields[newIdx]
removeArg(a.Option)
var newOpt Arg
- newOpt, calls = r.generateArg(s, optType)
- replaceArg(arg, MakeUnionArg(t, newOpt))
+ newOpt, calls = r.generateArg(s, optType, a.Dir())
+ replaceArg(arg, MakeUnionArg(t, a.Dir(), newOpt))
return
}
@@ -522,7 +522,7 @@ func (ma *mutationArgs) collectArg(arg Arg, ctx *ArgCtx) {
_, isArrayTyp := typ.(*ArrayType)
_, isBufferTyp := typ.(*BufferType)
- if !isBufferTyp && !isArrayTyp && typ.Dir() == DirOut || !typ.Varlen() && typ.Size() == 0 {
+ if !isBufferTyp && !isArrayTyp && arg.Dir() == DirOut || !typ.Varlen() && typ.Size() == 0 {
return
}
@@ -645,7 +645,7 @@ func (t *LenType) getMutationPrio(target *Target, arg Arg, ignoreSpecial bool) (
}
func (t *BufferType) getMutationPrio(target *Target, arg Arg, ignoreSpecial bool) (prio float64, stopRecursion bool) {
- if t.Dir() == DirOut && !t.Varlen() {
+ if arg.Dir() == DirOut && !t.Varlen() {
return dontMutate, false
}
if t.Kind == BufferString && len(t.Values) == 1 {
diff --git a/prog/prio.go b/prog/prio.go
index b67bbaea0..2a9486570 100644
--- a/prog/prio.go
+++ b/prog/prio.go
@@ -65,11 +65,11 @@ func (target *Target) calcStaticPriorities() [][]float32 {
func (target *Target) calcResourceUsage() map[string]map[int]weights {
uses := make(map[string]map[int]weights)
for _, c := range target.Syscalls {
- ForeachType(c, func(t Type) {
+ foreachType(c, func(t Type, ctx typeCtx) {
switch a := t.(type) {
case *ResourceType:
if target.AuxResources[a.Desc.Name] {
- noteUsage(uses, c, 0.1, a.Dir(), "res%v", a.Desc.Name)
+ noteUsage(uses, c, 0.1, ctx.Dir, "res%v", a.Desc.Name)
} else {
str := "res"
for i, k := range a.Desc.Kind {
@@ -78,25 +78,25 @@ func (target *Target) calcResourceUsage() map[string]map[int]weights {
if i < len(a.Desc.Kind)-1 {
w = 0.2
}
- noteUsage(uses, c, float32(w), a.Dir(), str)
+ noteUsage(uses, c, float32(w), ctx.Dir, str)
}
}
case *PtrType:
if _, ok := a.Type.(*StructType); ok {
- noteUsage(uses, c, 1.0, a.Dir(), "ptrto-%v", a.Type.Name())
+ noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Type.Name())
}
if _, ok := a.Type.(*UnionType); ok {
- noteUsage(uses, c, 1.0, a.Dir(), "ptrto-%v", a.Type.Name())
+ noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Type.Name())
}
if arr, ok := a.Type.(*ArrayType); ok {
- noteUsage(uses, c, 1.0, a.Dir(), "ptrto-%v", arr.Type.Name())
+ noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", arr.Type.Name())
}
case *BufferType:
switch a.Kind {
case BufferBlobRand, BufferBlobRange, BufferText:
case BufferString:
if a.SubKind != "" {
- noteUsage(uses, c, 0.2, a.Dir(), fmt.Sprintf("str-%v", a.SubKind))
+ noteUsage(uses, c, 0.2, ctx.Dir, fmt.Sprintf("str-%v", a.SubKind))
}
case BufferFilename:
noteUsage(uses, c, 1.0, DirIn, "filename")
@@ -104,7 +104,7 @@ func (target *Target) calcResourceUsage() map[string]map[int]weights {
panic("unknown buffer kind")
}
case *VmaType:
- noteUsage(uses, c, 0.5, a.Dir(), "vma")
+ noteUsage(uses, c, 0.5, ctx.Dir, "vma")
case *IntType:
switch a.Kind {
case IntPlain, IntRange:
diff --git a/prog/prog.go b/prog/prog.go
index 1600c0a28..017a0dbbb 100644
--- a/prog/prog.go
+++ b/prog/prog.go
@@ -22,6 +22,7 @@ type Call struct {
type Arg interface {
Type() Type
+ Dir() Dir
Size() uint64
validate(ctx *validCtx) error
@@ -30,20 +31,25 @@ type Arg interface {
type ArgCommon struct {
typ Type
+ dir Dir
}
func (arg *ArgCommon) Type() Type {
return arg.typ
}
+func (arg *ArgCommon) Dir() Dir {
+ return arg.dir
+}
+
// Used for ConstType, IntType, FlagsType, LenType, ProcType and CsumType.
type ConstArg struct {
ArgCommon
Val uint64
}
-func MakeConstArg(t Type, v uint64) *ConstArg {
- return &ConstArg{ArgCommon: ArgCommon{typ: t}, Val: v}
+func MakeConstArg(t Type, dir Dir, v uint64) *ConstArg {
+ return &ConstArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Val: v}
}
func (arg *ConstArg) Size() uint64 {
@@ -84,34 +90,37 @@ type PointerArg struct {
Res Arg // pointee (nil for vma)
}
-func MakePointerArg(t Type, addr uint64, data Arg) *PointerArg {
+func MakePointerArg(t Type, dir Dir, addr uint64, data Arg) *PointerArg {
if data == nil {
panic("nil pointer data arg")
}
return &PointerArg{
- ArgCommon: ArgCommon{typ: t},
+ ArgCommon: ArgCommon{typ: t, dir: DirIn}, // pointers are always in
Address: addr,
Res: data,
}
}
-func MakeVmaPointerArg(t Type, addr, size uint64) *PointerArg {
+func MakeVmaPointerArg(t Type, dir Dir, addr, size uint64) *PointerArg {
if addr%1024 != 0 {
panic("unaligned vma address")
}
return &PointerArg{
- ArgCommon: ArgCommon{typ: t},
+ ArgCommon: ArgCommon{typ: t, dir: dir},
Address: addr,
VmaSize: size,
}
}
-func MakeSpecialPointerArg(t Type, index uint64) *PointerArg {
+func MakeSpecialPointerArg(t Type, dir Dir, index uint64) *PointerArg {
if index >= maxSpecialPointers {
panic("bad special pointer index")
}
+ if _, ok := t.(*PtrType); ok {
+ dir = DirIn // pointers are always in
+ }
return &PointerArg{
- ArgCommon: ArgCommon{typ: t},
+ ArgCommon: ArgCommon{typ: t, dir: dir},
Address: -index,
}
}
@@ -138,18 +147,18 @@ type DataArg struct {
size uint64 // for out Args
}
-func MakeDataArg(t Type, data []byte) *DataArg {
- if t.Dir() == DirOut {
+func MakeDataArg(t Type, dir Dir, data []byte) *DataArg {
+ if dir == DirOut {
panic("non-empty output data arg")
}
- return &DataArg{ArgCommon: ArgCommon{typ: t}, data: append([]byte{}, data...)}
+ return &DataArg{ArgCommon: ArgCommon{typ: t, dir: dir}, data: append([]byte{}, data...)}
}
-func MakeOutDataArg(t Type, size uint64) *DataArg {
- if t.Dir() != DirOut {
+func MakeOutDataArg(t Type, dir Dir, size uint64) *DataArg {
+ if dir != DirOut {
panic("empty input data arg")
}
- return &DataArg{ArgCommon: ArgCommon{typ: t}, size: size}
+ return &DataArg{ArgCommon: ArgCommon{typ: t, dir: dir}, size: size}
}
func (arg *DataArg) Size() uint64 {
@@ -160,14 +169,14 @@ func (arg *DataArg) Size() uint64 {
}
func (arg *DataArg) Data() []byte {
- if arg.Type().Dir() == DirOut {
+ if arg.Dir() == DirOut {
panic("getting data of output data arg")
}
return arg.data
}
func (arg *DataArg) SetData(data []byte) {
- if arg.Type().Dir() == DirOut {
+ if arg.Dir() == DirOut {
panic("setting data of output data arg")
}
arg.data = append([]byte{}, data...)
@@ -180,8 +189,8 @@ type GroupArg struct {
Inner []Arg
}
-func MakeGroupArg(t Type, inner []Arg) *GroupArg {
- return &GroupArg{ArgCommon: ArgCommon{typ: t}, Inner: inner}
+func MakeGroupArg(t Type, dir Dir, inner []Arg) *GroupArg {
+ return &GroupArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Inner: inner}
}
func (arg *GroupArg) Size() uint64 {
@@ -227,8 +236,8 @@ type UnionArg struct {
Option Arg
}
-func MakeUnionArg(t Type, opt Arg) *UnionArg {
- return &UnionArg{ArgCommon: ArgCommon{typ: t}, Option: opt}
+func MakeUnionArg(t Type, dir Dir, opt Arg) *UnionArg {
+ return &UnionArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Option: opt}
}
func (arg *UnionArg) Size() uint64 {
@@ -250,8 +259,8 @@ type ResultArg struct {
uses map[*ResultArg]bool // ArgResult args that use this arg
}
-func MakeResultArg(t Type, r *ResultArg, v uint64) *ResultArg {
- arg := &ResultArg{ArgCommon: ArgCommon{typ: t}, Res: r, Val: v}
+func MakeResultArg(t Type, dir Dir, r *ResultArg, v uint64) *ResultArg {
+ arg := &ResultArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Res: r, Val: v}
if r == nil {
return arg
}
@@ -266,10 +275,7 @@ func MakeReturnArg(t Type) *ResultArg {
if t == nil {
return nil
}
- if t.Dir() != DirOut {
- panic("return arg is not out")
- }
- return &ResultArg{ArgCommon: ArgCommon{typ: t}}
+ return &ResultArg{ArgCommon: ArgCommon{typ: t, dir: DirOut}}
}
func (arg *ResultArg) Size() uint64 {
@@ -369,7 +375,7 @@ func removeArg(arg0 Arg) {
delete(uses, a)
}
for arg1 := range a.uses {
- arg2 := arg1.Type().DefaultArg().(*ResultArg)
+ arg2 := arg1.Type().DefaultArg(arg1.Dir()).(*ResultArg)
replaceResultArg(arg1, arg2)
}
})
diff --git a/prog/prog_test.go b/prog/prog_test.go
index a42e4437f..9b8f442e3 100644
--- a/prog/prog_test.go
+++ b/prog/prog_test.go
@@ -21,8 +21,8 @@ func TestGeneration(t *testing.T) {
func TestDefault(t *testing.T) {
target, _, _ := initTest(t)
for _, meta := range target.Syscalls {
- ForeachType(meta, func(typ Type) {
- arg := typ.DefaultArg()
+ foreachType(meta, func(typ Type, ctx typeCtx) {
+ arg := typ.DefaultArg(ctx.Dir)
if !isDefault(arg) {
t.Errorf("default arg is not default: %s\ntype: %#v\narg: %#v",
typ, typ, arg)
@@ -203,8 +203,8 @@ func TestSpecialStructs(t *testing.T) {
t.Run(special, func(t *testing.T) {
var typ Type
for i := 0; i < len(target.Syscalls) && typ == nil; i++ {
- ForeachType(target.Syscalls[i], func(t Type) {
- if t.Dir() == DirOut {
+ foreachType(target.Syscalls[i], func(t Type, ctx typeCtx) {
+ if ctx.Dir == DirOut {
return
}
if s, ok := t.(*StructType); ok && s.Name() == special {
@@ -220,8 +220,13 @@ func TestSpecialStructs(t *testing.T) {
}
g := &Gen{newRand(target, rs), newState(target, nil, nil)}
for i := 0; i < iters/len(target.SpecialTypes); i++ {
- arg, _ := gen(g, typ, nil)
- gen(g, typ, arg)
+ var arg Arg
+ for i := 0; i < 2; i++ {
+ arg, _ = gen(g, typ, DirInOut, arg)
+ if arg.Dir() != DirInOut {
+ t.Fatalf("got wrong arg dir %v", arg.Dir())
+ }
+ }
}
})
}
diff --git a/prog/rand.go b/prog/rand.go
index b350e31c0..5277a9814 100644
--- a/prog/rand.go
+++ b/prog/rand.go
@@ -341,16 +341,16 @@ func (r *randGen) randString(s *state, t *BufferType) []byte {
return buf.Bytes()
}
-func (r *randGen) allocAddr(s *state, typ Type, size uint64, data Arg) *PointerArg {
- return MakePointerArg(typ, s.ma.alloc(r, size), data)
+func (r *randGen) allocAddr(s *state, typ Type, dir Dir, size uint64, data Arg) *PointerArg {
+ return MakePointerArg(typ, dir, s.ma.alloc(r, size), data)
}
-func (r *randGen) allocVMA(s *state, typ Type, numPages uint64) *PointerArg {
+func (r *randGen) allocVMA(s *state, typ Type, dir Dir, numPages uint64) *PointerArg {
page := s.va.alloc(r, numPages)
- return MakeVmaPointerArg(typ, page*r.target.PageSize, numPages*r.target.PageSize)
+ return MakeVmaPointerArg(typ, dir, page*r.target.PageSize, numPages*r.target.PageSize)
}
-func (r *randGen) createResource(s *state, res *ResourceType) (arg Arg, calls []*Call) {
+func (r *randGen) createResource(s *state, res *ResourceType, dir Dir) (arg Arg, calls []*Call) {
if r.inCreateResource {
return nil, nil
}
@@ -385,7 +385,7 @@ func (r *randGen) createResource(s *state, res *ResourceType) (arg Arg, calls []
metas = append(metas, meta)
}
if len(metas) == 0 {
- return res.DefaultArg(), nil
+ return res.DefaultArg(dir), nil
}
// Now we have a set of candidate calls that can create the necessary resource.
@@ -404,7 +404,7 @@ func (r *randGen) createResource(s *state, res *ResourceType) (arg Arg, calls []
}
if len(allres) != 0 {
// Bingo!
- arg := MakeResultArg(res, allres[r.Intn(len(allres))], 0)
+ arg := MakeResultArg(res, dir, allres[r.Intn(len(allres))], 0)
return arg, calls
}
// Discard unsuccessful calls.
@@ -562,7 +562,7 @@ func (r *randGen) generateParticularCall(s *state, meta *Syscall) (calls []*Call
Meta: meta,
Ret: MakeReturnArg(meta.Ret),
}
- c.Args, calls = r.generateArgs(s, meta.Args)
+ c.Args, calls = r.generateArgs(s, meta.Args, DirIn)
r.target.assignSizesCall(c)
return append(calls, c)
}
@@ -601,13 +601,13 @@ func (target *Target) DataMmapProg() *Prog {
}
}
-func (r *randGen) generateArgs(s *state, types []Type) ([]Arg, []*Call) {
+func (r *randGen) generateArgs(s *state, types []Type, dir Dir) ([]Arg, []*Call) {
var calls []*Call
args := make([]Arg, len(types))
// Generate all args. Size args have the default value 0 for now.
for i, typ := range types {
- arg, calls1 := r.generateArg(s, typ)
+ arg, calls1 := r.generateArg(s, typ, dir)
if arg == nil {
panic(fmt.Sprintf("generated arg is nil for type '%v', types: %+v", typ.Name(), types))
}
@@ -618,29 +618,28 @@ func (r *randGen) generateArgs(s *state, types []Type) ([]Arg, []*Call) {
return args, calls
}
-func (r *randGen) generateArg(s *state, typ Type) (arg Arg, calls []*Call) {
- return r.generateArgImpl(s, typ, false)
+func (r *randGen) generateArg(s *state, typ Type, dir Dir) (arg Arg, calls []*Call) {
+ return r.generateArgImpl(s, typ, dir, false)
}
-func (r *randGen) generateArgImpl(s *state, typ Type, ignoreSpecial bool) (arg Arg, calls []*Call) {
- if typ.Dir() == DirOut {
+func (r *randGen) generateArgImpl(s *state, typ Type, dir Dir, ignoreSpecial bool) (arg Arg, calls []*Call) {
+ if dir == DirOut {
// No need to generate something interesting for output scalar arguments.
// But we still need to generate the argument itself so that it can be referenced
// in subsequent calls. For the same reason we do generate pointer/array/struct
// output arguments (their elements can be referenced in subsequent calls).
switch typ.(type) {
- case *IntType, *FlagsType, *ConstType, *ProcType,
- *VmaType, *ResourceType:
- return typ.DefaultArg(), nil
+ case *IntType, *FlagsType, *ConstType, *ProcType, *VmaType, *ResourceType:
+ return typ.DefaultArg(dir), nil
}
}
if typ.Optional() && r.oneOf(5) {
if res, ok := typ.(*ResourceType); ok {
v := res.Desc.Values[r.Intn(len(res.Desc.Values))]
- return MakeResultArg(typ, nil, v), nil
+ return MakeResultArg(typ, dir, nil, v), nil
}
- return typ.DefaultArg(), nil
+ return typ.DefaultArg(dir), nil
}
// Allow infinite recursion for optional pointers.
@@ -656,70 +655,70 @@ func (r *randGen) generateArgImpl(s *state, typ Type, ignoreSpecial bool) (arg A
}
}()
if r.recDepth[name] >= 3 {
- return MakeSpecialPointerArg(typ, 0), nil
+ return MakeSpecialPointerArg(typ, dir, 0), nil
}
}
}
- if !ignoreSpecial && typ.Dir() != DirOut {
+ if !ignoreSpecial && dir != DirOut {
switch typ.(type) {
case *StructType, *UnionType:
if gen := r.target.SpecialTypes[typ.Name()]; gen != nil {
- return gen(&Gen{r, s}, typ, nil)
+ return gen(&Gen{r, s}, typ, dir, nil)
}
}
}
- return typ.generate(r, s)
+ return typ.generate(r, s, dir)
}
-func (a *ResourceType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
+func (a *ResourceType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
if r.oneOf(3) {
- arg = r.existingResource(s, a)
+ arg = r.existingResource(s, a, dir)
if arg != nil {
return
}
}
if r.nOutOf(2, 3) {
- arg, calls = r.resourceCentric(s, a)
+ arg, calls = r.resourceCentric(s, a, dir)
if arg != nil {
return
}
}
if r.nOutOf(4, 5) {
- arg, calls = r.createResource(s, a)
+ arg, calls = r.createResource(s, a, dir)
if arg != nil {
return
}
}
special := a.SpecialValues()
- arg = MakeResultArg(a, nil, special[r.Intn(len(special))])
+ arg = MakeResultArg(a, dir, nil, special[r.Intn(len(special))])
return
}
-func (a *BufferType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
+func (a *BufferType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
switch a.Kind {
case BufferBlobRand, BufferBlobRange:
sz := r.randBufLen()
if a.Kind == BufferBlobRange {
sz = r.randRange(a.RangeBegin, a.RangeEnd)
}
- if a.Dir() == DirOut {
- return MakeOutDataArg(a, sz), nil
+ if dir == DirOut {
+ return MakeOutDataArg(a, dir, sz), nil
}
data := make([]byte, sz)
for i := range data {
data[i] = byte(r.Intn(256))
}
- return MakeDataArg(a, data), nil
+ return MakeDataArg(a, dir, data), nil
case BufferString:
data := r.randString(s, a)
- if a.Dir() == DirOut {
- return MakeOutDataArg(a, uint64(len(data))), nil
+ if dir == DirOut {
+ return MakeOutDataArg(a, dir, uint64(len(data))), nil
}
- return MakeDataArg(a, data), nil
+ return MakeDataArg(a, dir, data), nil
case BufferFilename:
- if a.Dir() == DirOut {
+ if dir == DirOut {
var sz uint64
switch {
case !a.Varlen():
@@ -731,50 +730,50 @@ func (a *BufferType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
default:
sz = 4096 // PATH_MAX
}
- return MakeOutDataArg(a, sz), nil
+ return MakeOutDataArg(a, dir, sz), nil
}
- return MakeDataArg(a, []byte(r.filename(s, a))), nil
+ return MakeDataArg(a, dir, []byte(r.filename(s, a))), nil
case BufferText:
- if a.Dir() == DirOut {
- return MakeOutDataArg(a, uint64(r.Intn(100))), nil
+ if dir == DirOut {
+ return MakeOutDataArg(a, dir, uint64(r.Intn(100))), nil
}
- return MakeDataArg(a, r.generateText(a.Text)), nil
+ return MakeDataArg(a, dir, r.generateText(a.Text)), nil
default:
panic("unknown buffer kind")
}
}
-func (a *VmaType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
+func (a *VmaType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
npages := r.randPageCount()
if a.RangeBegin != 0 || a.RangeEnd != 0 {
npages = a.RangeBegin + uint64(r.Intn(int(a.RangeEnd-a.RangeBegin+1)))
}
- return r.allocVMA(s, a, npages), nil
+ return r.allocVMA(s, a, dir, npages), nil
}
-func (a *FlagsType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
- return MakeConstArg(a, r.flags(a.Vals, a.BitMask, 0)), nil
+func (a *FlagsType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
+ return MakeConstArg(a, dir, r.flags(a.Vals, a.BitMask, 0)), nil
}
-func (a *ConstType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
- return MakeConstArg(a, a.Val), nil
+func (a *ConstType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
+ return MakeConstArg(a, dir, a.Val), nil
}
-func (a *IntType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
+func (a *IntType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
bits := a.TypeBitSize()
v := r.randInt(bits)
switch a.Kind {
case IntRange:
v = r.randRangeInt(a.RangeBegin, a.RangeEnd, bits, a.Align)
}
- return MakeConstArg(a, v), nil
+ return MakeConstArg(a, dir, v), nil
}
-func (a *ProcType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
- return MakeConstArg(a, r.rand(int(a.ValuesPerProc))), nil
+func (a *ProcType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
+ return MakeConstArg(a, dir, r.rand(int(a.ValuesPerProc))), nil
}
-func (a *ArrayType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
+func (a *ArrayType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
var count uint64
switch a.Kind {
case ArrayRandLen:
@@ -784,46 +783,46 @@ func (a *ArrayType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
}
var inner []Arg
for i := uint64(0); i < count; i++ {
- arg1, calls1 := r.generateArg(s, a.Type)
+ arg1, calls1 := r.generateArg(s, a.Type, dir)
inner = append(inner, arg1)
calls = append(calls, calls1...)
}
- return MakeGroupArg(a, inner), calls
+ return MakeGroupArg(a, dir, inner), calls
}
-func (a *StructType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
- args, calls := r.generateArgs(s, a.Fields)
- group := MakeGroupArg(a, args)
+func (a *StructType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
+ args, calls := r.generateArgs(s, a.Fields, dir)
+ group := MakeGroupArg(a, dir, args)
return group, calls
}
-func (a *UnionType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
+func (a *UnionType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
optType := a.Fields[r.Intn(len(a.Fields))]
- opt, calls := r.generateArg(s, optType)
- return MakeUnionArg(a, opt), calls
+ opt, calls := r.generateArg(s, optType, dir)
+ return MakeUnionArg(a, dir, opt), calls
}
-func (a *PtrType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
+func (a *PtrType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
if r.oneOf(1000) {
index := r.rand(len(r.target.SpecialPointers))
- return MakeSpecialPointerArg(a, index), nil
+ return MakeSpecialPointerArg(a, dir, index), nil
}
- inner, calls := r.generateArg(s, a.Type)
- arg = r.allocAddr(s, a, inner.Size(), inner)
+ inner, calls := r.generateArg(s, a.Type, a.ElemDir)
+ arg = r.allocAddr(s, a, dir, inner.Size(), inner)
return arg, calls
}
-func (a *LenType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
+func (a *LenType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
// Updated later in assignSizesCall.
- return MakeConstArg(a, 0), nil
+ return MakeConstArg(a, dir, 0), nil
}
-func (a *CsumType) generate(r *randGen, s *state) (arg Arg, calls []*Call) {
+func (a *CsumType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) {
// Filled at runtime by executor.
- return MakeConstArg(a, 0), nil
+ return MakeConstArg(a, dir, 0), nil
}
-func (r *randGen) existingResource(s *state, res *ResourceType) Arg {
+func (r *randGen) existingResource(s *state, res *ResourceType, dir Dir) Arg {
alltypes := make([][]*ResultArg, 0, len(s.resources))
for _, res1 := range s.resources {
alltypes = append(alltypes, res1)
@@ -842,11 +841,11 @@ func (r *randGen) existingResource(s *state, res *ResourceType) Arg {
if len(allres) == 0 {
return nil
}
- return MakeResultArg(res, allres[r.Intn(len(allres))], 0)
+ return MakeResultArg(res, dir, allres[r.Intn(len(allres))], 0)
}
// Finds a compatible resource with the type `t` and the calls that initialize that resource.
-func (r *randGen) resourceCentric(s *state, t *ResourceType) (arg Arg, calls []*Call) {
+func (r *randGen) resourceCentric(s *state, t *ResourceType, dir Dir) (arg Arg, calls []*Call) {
var p *Prog
var resource *ResultArg
for idx := range r.Perm(len(s.corpus)) {
@@ -898,7 +897,7 @@ func (r *randGen) resourceCentric(s *state, t *ResourceType) (arg Arg, calls []*
p.removeCall(i)
}
- return MakeResultArg(t, resource, 0), p.Calls
+ return MakeResultArg(t, dir, resource, 0), p.Calls
}
func getCompatibleResources(p *Prog, resourceType string, r *randGen) (resources []*ResultArg) {
@@ -906,7 +905,7 @@ func getCompatibleResources(p *Prog, resourceType string, r *randGen) (resources
ForeachArg(c, func(arg Arg, _ *ArgCtx) {
// Collect only initialized resources (the ones that are already used in other calls).
a, ok := arg.(*ResultArg)
- if !ok || len(a.uses) == 0 || a.typ.Dir() != DirOut {
+ if !ok || len(a.uses) == 0 || a.Dir() != DirOut {
return
}
if !r.target.isCompatibleResource(resourceType, a.typ.Name()) {
diff --git a/prog/rand_test.go b/prog/rand_test.go
index 09f3b359c..6f251cb7c 100644
--- a/prog/rand_test.go
+++ b/prog/rand_test.go
@@ -102,14 +102,14 @@ func TestSizeGenerateConstArg(t *testing.T) {
target, rs, iters := initRandomTargetTest(t, "test", "64")
r := newRand(target, rs)
for _, c := range target.Syscalls {
- ForeachType(c, func(typ Type) {
+ foreachType(c, func(typ Type, ctx typeCtx) {
if _, ok := typ.(*IntType); !ok {
return
}
bits := typ.TypeBitSize()
limit := uint64(1<<bits - 1)
for i := 0; i < iters; i++ {
- newArg, _ := typ.generate(r, nil)
+ newArg, _ := typ.generate(r, nil, ctx.Dir)
newVal := newArg.(*ConstArg).Val
if newVal > limit {
t.Fatalf("invalid generated value: %d. (arg bitsize: %d; max value: %d)", newVal, bits, limit)
diff --git a/prog/resources.go b/prog/resources.go
index fa03e6cd6..b7bcecf95 100644
--- a/prog/resources.go
+++ b/prog/resources.go
@@ -46,10 +46,10 @@ func (target *Target) populateResourceCtors() {
// Find resources that are created by each call.
callsResources := make([][]*ResourceDesc, len(target.Syscalls))
for call, meta := range target.Syscalls {
- ForeachType(meta, func(typ Type) {
+ foreachType(meta, func(typ Type, ctx typeCtx) {
switch typ1 := typ.(type) {
case *ResourceType:
- if typ1.Dir() != DirIn {
+ if ctx.Dir != DirIn {
callsResources[call] = append(callsResources[call], typ1.Desc)
}
}
@@ -84,21 +84,19 @@ func (target *Target) populateResourceCtors() {
// isCompatibleResource returns true if resource of kind src can be passed as an argument of kind dst.
func (target *Target) isCompatibleResource(dst, src string) bool {
- if dst == target.any.res16.TypeName ||
- dst == target.any.res32.TypeName ||
- dst == target.any.res64.TypeName ||
- dst == target.any.resdec.TypeName ||
- dst == target.any.reshex.TypeName ||
- dst == target.any.resoct.TypeName {
+ if target.isAnyRes(dst) {
return true
}
+ if target.isAnyRes(src) {
+ return false
+ }
dstRes := target.resourceMap[dst]
if dstRes == nil {
- panic(fmt.Sprintf("unknown resource '%v'", dst))
+ panic(fmt.Sprintf("unknown resource %q", dst))
}
srcRes := target.resourceMap[src]
if srcRes == nil {
- panic(fmt.Sprintf("unknown resource '%v'", src))
+ panic(fmt.Sprintf("unknown resource %q", src))
}
return isCompatibleResourceImpl(dstRes.Kind, srcRes.Kind, false)
}
@@ -128,8 +126,8 @@ func isCompatibleResourceImpl(dst, src []string, precise bool) bool {
func (target *Target) getInputResources(c *Syscall) []*ResourceDesc {
var resources []*ResourceDesc
- ForeachType(c, func(typ Type) {
- if typ.Dir() == DirOut {
+ foreachType(c, func(typ Type, ctx typeCtx) {
+ if ctx.Dir == DirOut {
return
}
switch typ1 := typ.(type) {
@@ -148,10 +146,10 @@ func (target *Target) getInputResources(c *Syscall) []*ResourceDesc {
func (target *Target) getOutputResources(c *Syscall) []*ResourceDesc {
var resources []*ResourceDesc
- ForeachType(c, func(typ Type) {
+ foreachType(c, func(typ Type, ctx typeCtx) {
switch typ1 := typ.(type) {
case *ResourceType:
- if typ1.Dir() != DirIn {
+ if ctx.Dir != DirIn {
resources = append(resources, typ1.Desc)
}
}
diff --git a/prog/rotation.go b/prog/rotation.go
index f95ffa03d..47ee2ca81 100644
--- a/prog/rotation.go
+++ b/prog/rotation.go
@@ -50,7 +50,7 @@ func MakeRotator(target *Target, calls map[*Syscall]bool, rnd *rand.Rand) *Rotat
}
// VMAs and filenames are effectively resources for our purposes
// (but they don't have ctors).
- ForeachType(call, func(t Type) {
+ foreachType(call, func(t Type, _ typeCtx) {
switch a := t.(type) {
case *BufferType:
switch a.Kind {
diff --git a/prog/target.go b/prog/target.go
index f9e10b6f4..692d0b877 100644
--- a/prog/target.go
+++ b/prog/target.go
@@ -45,7 +45,7 @@ type Target struct {
// allocate memory, etc. typ is the struct/union type. old is the old value of the struct/union
// for mutation, or nil for generation. The function returns a new value of the struct/union,
// and optionally any calls that need to be inserted before the arg reference.
- SpecialTypes map[string]func(g *Gen, typ Type, old Arg) (Arg, []*Call)
+ SpecialTypes map[string]func(g *Gen, typ Type, dir Dir, old Arg) (Arg, []*Call)
// Special strings that can matter for the target.
// Used as fallback when string type does not have own dictionary.
@@ -197,7 +197,7 @@ func restoreLinks(syscalls []*Syscall, resources []*ResourceDesc, structs []*Key
if c.Ret != nil {
unref(&c.Ret, types)
}
- ForeachType(c, func(t0 Type) {
+ foreachType(c, func(t0 Type, _ typeCtx) {
switch t := t0.(type) {
case *PtrType:
unref(&t.Type, types)
@@ -247,20 +247,20 @@ func (g *Gen) NOutOf(n, outOf int) bool {
return g.r.nOutOf(n, outOf)
}
-func (g *Gen) Alloc(ptrType Type, data Arg) (Arg, []*Call) {
- return g.r.allocAddr(g.s, ptrType, data.Size(), data), nil
+func (g *Gen) Alloc(ptrType Type, dir Dir, data Arg) (Arg, []*Call) {
+ return g.r.allocAddr(g.s, ptrType, dir, data.Size(), data), nil
}
-func (g *Gen) GenerateArg(typ Type, pcalls *[]*Call) Arg {
- return g.generateArg(typ, pcalls, false)
+func (g *Gen) GenerateArg(typ Type, dir Dir, pcalls *[]*Call) Arg {
+ return g.generateArg(typ, dir, pcalls, false)
}
-func (g *Gen) GenerateSpecialArg(typ Type, pcalls *[]*Call) Arg {
- return g.generateArg(typ, pcalls, true)
+func (g *Gen) GenerateSpecialArg(typ Type, dir Dir, pcalls *[]*Call) Arg {
+ return g.generateArg(typ, dir, pcalls, true)
}
-func (g *Gen) generateArg(typ Type, pcalls *[]*Call, ignoreSpecial bool) Arg {
- arg, calls := g.r.generateArgImpl(g.s, typ, ignoreSpecial)
+func (g *Gen) generateArg(typ Type, dir Dir, pcalls *[]*Call, ignoreSpecial bool) Arg {
+ arg, calls := g.r.generateArgImpl(g.s, typ, dir, ignoreSpecial)
*pcalls = append(*pcalls, calls...)
g.r.target.assignSizesArray([]Arg{arg}, nil)
return arg
diff --git a/prog/types.go b/prog/types.go
index ac272c3c1..a3b3c9709 100644
--- a/prog/types.go
+++ b/prog/types.go
@@ -44,7 +44,7 @@ type SyscallAttrs struct {
// Executor also knows about this value.
const MaxArgs = 9
-type Dir int
+type Dir uint8
const (
DirIn Dir = iota
@@ -80,7 +80,6 @@ type Type interface {
Name() string
FieldName() string
TemplateName() string // for template structs name without arguments
- Dir() Dir
Optional() bool
Varlen() bool
Size() uint64
@@ -95,9 +94,9 @@ type Type interface {
UnitSize() uint64
UnitOffset() uint64
- DefaultArg() Arg
+ DefaultArg(dir Dir) Arg
isDefaultArg(arg Arg) bool
- generate(r *randGen, s *state) (arg Arg, calls []*Call)
+ generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call)
mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*Call, retry, preserve bool)
getMutationPrio(target *Target, arg Arg, ignoreSpecial bool) (prio float64, stopRecursion bool)
minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool
@@ -105,25 +104,25 @@ type Type interface {
type Ref uint32
-func (ti Ref) String() string { panic("prog.Ref method called") }
-func (ti Ref) Name() string { panic("prog.Ref method called") }
-func (ti Ref) FieldName() string { panic("prog.Ref method called") }
-func (ti Ref) TemplateName() string { panic("prog.Ref method called") }
-func (ti Ref) Dir() Dir { panic("prog.Ref method called") }
-func (ti Ref) Optional() bool { panic("prog.Ref method called") }
-func (ti Ref) Varlen() bool { panic("prog.Ref method called") }
-func (ti Ref) Size() uint64 { panic("prog.Ref method called") }
-func (ti Ref) TypeBitSize() uint64 { panic("prog.Ref method called") }
-func (ti Ref) Format() BinaryFormat { panic("prog.Ref method called") }
-func (ti Ref) BitfieldOffset() uint64 { panic("prog.Ref method called") }
-func (ti Ref) BitfieldLength() uint64 { panic("prog.Ref method called") }
-func (ti Ref) IsBitfield() bool { panic("prog.Ref method called") }
-func (ti Ref) UnitSize() uint64 { panic("prog.Ref method called") }
-func (ti Ref) UnitOffset() uint64 { panic("prog.Ref method called") }
-func (ti Ref) DefaultArg() Arg { panic("prog.Ref method called") }
-func (ti Ref) Clone() Type { panic("prog.Ref method called") }
-func (ti Ref) isDefaultArg(arg Arg) bool { panic("prog.Ref method called") }
-func (ti Ref) generate(r *randGen, s *state) (Arg, []*Call) { panic("prog.Ref method called") }
+func (ti Ref) String() string { panic("prog.Ref method called") }
+func (ti Ref) Name() string { panic("prog.Ref method called") }
+func (ti Ref) FieldName() string { panic("prog.Ref method called") }
+func (ti Ref) TemplateName() string { panic("prog.Ref method called") }
+
+func (ti Ref) Optional() bool { panic("prog.Ref method called") }
+func (ti Ref) Varlen() bool { panic("prog.Ref method called") }
+func (ti Ref) Size() uint64 { panic("prog.Ref method called") }
+func (ti Ref) TypeBitSize() uint64 { panic("prog.Ref method called") }
+func (ti Ref) Format() BinaryFormat { panic("prog.Ref method called") }
+func (ti Ref) BitfieldOffset() uint64 { panic("prog.Ref method called") }
+func (ti Ref) BitfieldLength() uint64 { panic("prog.Ref method called") }
+func (ti Ref) IsBitfield() bool { panic("prog.Ref method called") }
+func (ti Ref) UnitSize() uint64 { panic("prog.Ref method called") }
+func (ti Ref) UnitOffset() uint64 { panic("prog.Ref method called") }
+func (ti Ref) DefaultArg(dir Dir) Arg { panic("prog.Ref method called") }
+func (ti Ref) Clone() Type { panic("prog.Ref method called") }
+func (ti Ref) isDefaultArg(arg Arg) bool { panic("prog.Ref method called") }
+func (ti Ref) generate(r *randGen, s *state, dir Dir) (Arg, []*Call) { panic("prog.Ref method called") }
func (ti Ref) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) ([]*Call, bool, bool) {
panic("prog.Ref method called")
}
@@ -146,7 +145,6 @@ type TypeCommon struct {
FldName string // for struct fields and named args
// Static size of the type, or 0 for variable size types and all but last bitfields in the group.
TypeSize uint64
- ArgDir Dir
IsOptional bool
IsVarlen bool
}
@@ -210,10 +208,6 @@ func (t *TypeCommon) IsBitfield() bool {
return false
}
-func (t TypeCommon) Dir() Dir {
- return t.ArgDir
-}
-
type ResourceDesc struct {
Name string
Kind []string
@@ -236,8 +230,8 @@ func (t *ResourceType) String() string {
return t.Name()
}
-func (t *ResourceType) DefaultArg() Arg {
- return MakeResultArg(t, nil, t.Default())
+func (t *ResourceType) DefaultArg(dir Dir) Arg {
+ return MakeResultArg(t, dir, nil, t.Default())
}
func (t *ResourceType) isDefaultArg(arg Arg) bool {
@@ -317,8 +311,8 @@ type ConstType struct {
IsPad bool
}
-func (t *ConstType) DefaultArg() Arg {
- return MakeConstArg(t, t.Val)
+func (t *ConstType) DefaultArg(dir Dir) Arg {
+ return MakeConstArg(t, dir, t.Val)
}
func (t *ConstType) isDefaultArg(arg Arg) bool {
@@ -347,8 +341,8 @@ type IntType struct {
Align uint64
}
-func (t *IntType) DefaultArg() Arg {
- return MakeConstArg(t, 0)
+func (t *IntType) DefaultArg(dir Dir) Arg {
+ return MakeConstArg(t, dir, 0)
}
func (t *IntType) isDefaultArg(arg Arg) bool {
@@ -361,8 +355,8 @@ type FlagsType struct {
BitMask bool
}
-func (t *FlagsType) DefaultArg() Arg {
- return MakeConstArg(t, 0)
+func (t *FlagsType) DefaultArg(dir Dir) Arg {
+ return MakeConstArg(t, dir, 0)
}
func (t *FlagsType) isDefaultArg(arg Arg) bool {
@@ -376,8 +370,8 @@ type LenType struct {
Path []string
}
-func (t *LenType) DefaultArg() Arg {
- return MakeConstArg(t, 0)
+func (t *LenType) DefaultArg(dir Dir) Arg {
+ return MakeConstArg(t, dir, 0)
}
func (t *LenType) isDefaultArg(arg Arg) bool {
@@ -395,8 +389,8 @@ const (
procDefaultValue = 0xffffffffffffffff // special value denoting 0 for all procs
)
-func (t *ProcType) DefaultArg() Arg {
- return MakeConstArg(t, procDefaultValue)
+func (t *ProcType) DefaultArg(dir Dir) Arg {
+ return MakeConstArg(t, dir, procDefaultValue)
}
func (t *ProcType) isDefaultArg(arg Arg) bool {
@@ -421,8 +415,8 @@ func (t *CsumType) String() string {
return "csum"
}
-func (t *CsumType) DefaultArg() Arg {
- return MakeConstArg(t, 0)
+func (t *CsumType) DefaultArg(dir Dir) Arg {
+ return MakeConstArg(t, dir, 0)
}
func (t *CsumType) isDefaultArg(arg Arg) bool {
@@ -439,8 +433,8 @@ func (t *VmaType) String() string {
return "vma"
}
-func (t *VmaType) DefaultArg() Arg {
- return MakeSpecialPointerArg(t, 0)
+func (t *VmaType) DefaultArg(dir Dir) Arg {
+ return MakeSpecialPointerArg(t, dir, 0)
}
func (t *VmaType) isDefaultArg(arg Arg) bool {
@@ -484,19 +478,19 @@ func (t *BufferType) String() string {
return "buffer"
}
-func (t *BufferType) DefaultArg() Arg {
- if t.Dir() == DirOut {
+func (t *BufferType) DefaultArg(dir Dir) Arg {
+ if dir == DirOut {
var sz uint64
if !t.Varlen() {
sz = t.Size()
}
- return MakeOutDataArg(t, sz)
+ return MakeOutDataArg(t, dir, sz)
}
var data []byte
if !t.Varlen() {
data = make([]byte, t.Size())
}
- return MakeDataArg(t, data)
+ return MakeDataArg(t, dir, data)
}
func (t *BufferType) isDefaultArg(arg Arg) bool {
@@ -507,7 +501,7 @@ func (t *BufferType) isDefaultArg(arg Arg) bool {
if a.Type().Varlen() {
return false
}
- if a.Type().Dir() == DirOut {
+ if a.Dir() == DirOut {
return true
}
for _, v := range a.Data() {
@@ -537,14 +531,14 @@ func (t *ArrayType) String() string {
return fmt.Sprintf("array[%v]", t.Type.String())
}
-func (t *ArrayType) DefaultArg() Arg {
+func (t *ArrayType) DefaultArg(dir Dir) Arg {
var elems []Arg
if t.Kind == ArrayRangeLen && t.RangeBegin == t.RangeEnd {
for i := uint64(0); i < t.RangeBegin; i++ {
- elems = append(elems, t.Type.DefaultArg())
+ elems = append(elems, t.Type.DefaultArg(dir))
}
}
- return MakeGroupArg(t, elems)
+ return MakeGroupArg(t, dir, elems)
}
func (t *ArrayType) isDefaultArg(arg Arg) bool {
@@ -562,18 +556,19 @@ func (t *ArrayType) isDefaultArg(arg Arg) bool {
type PtrType struct {
TypeCommon
- Type Type
+ Type Type
+ ElemDir Dir
}
func (t *PtrType) String() string {
- return fmt.Sprintf("ptr[%v, %v]", t.Dir(), t.Type.String())
+ return fmt.Sprintf("ptr[%v, %v]", t.ElemDir, t.Type.String())
}
-func (t *PtrType) DefaultArg() Arg {
+func (t *PtrType) DefaultArg(dir Dir) Arg {
if t.Optional() {
- return MakeSpecialPointerArg(t, 0)
+ return MakeSpecialPointerArg(t, dir, 0)
}
- return MakePointerArg(t, 0, t.Type.DefaultArg())
+ return MakePointerArg(t, dir, 0, t.Type.DefaultArg(t.ElemDir))
}
func (t *PtrType) isDefaultArg(arg Arg) bool {
@@ -598,12 +593,12 @@ func (t *StructType) FieldName() string {
return t.FldName
}
-func (t *StructType) DefaultArg() Arg {
+func (t *StructType) DefaultArg(dir Dir) Arg {
inner := make([]Arg, len(t.Fields))
for i, field := range t.Fields {
- inner[i] = field.DefaultArg()
+ inner[i] = field.DefaultArg(dir)
}
- return MakeGroupArg(t, inner)
+ return MakeGroupArg(t, dir, inner)
}
func (t *StructType) isDefaultArg(arg Arg) bool {
@@ -630,8 +625,8 @@ func (t *UnionType) FieldName() string {
return t.FldName
}
-func (t *UnionType) DefaultArg() Arg {
- return MakeUnionArg(t, t.Fields[0].DefaultArg())
+func (t *UnionType) DefaultArg(dir Dir) Arg {
+ return MakeUnionArg(t, dir, t.Fields[0].DefaultArg(dir))
}
func (t *UnionType) isDefaultArg(arg Arg) bool {
@@ -651,7 +646,6 @@ func (t *StructDesc) FieldName() string {
type StructKey struct {
Name string
- Dir Dir
}
type KeyedStruct struct {
@@ -664,29 +658,33 @@ type ConstValue struct {
Value uint64
}
-func ForeachType(meta *Syscall, f func(Type)) {
- var rec func(t Type)
+type typeCtx struct {
+ Dir Dir
+}
+
+func foreachType(meta *Syscall, f func(t Type, ctx typeCtx)) {
+ var rec func(t Type, dir Dir)
seen := make(map[*StructDesc]bool)
- recStruct := func(desc *StructDesc) {
+ recStruct := func(desc *StructDesc, dir Dir) {
if seen[desc] {
return // prune recursion via pointers to structs/unions
}
seen[desc] = true
for _, f := range desc.Fields {
- rec(f)
+ rec(f, dir)
}
}
- rec = func(t Type) {
- f(t)
+ rec = func(t Type, dir Dir) {
+ f(t, typeCtx{Dir: dir})
switch a := t.(type) {
case *PtrType:
- rec(a.Type)
+ rec(a.Type, a.ElemDir)
case *ArrayType:
- rec(a.Type)
+ rec(a.Type, dir)
case *StructType:
- recStruct(a.StructDesc)
+ recStruct(a.StructDesc, dir)
case *UnionType:
- recStruct(a.StructDesc)
+ recStruct(a.StructDesc, dir)
case *ResourceType, *BufferType, *VmaType, *LenType,
*FlagsType, *ConstType, *IntType, *ProcType, *CsumType:
default:
@@ -694,10 +692,10 @@ func ForeachType(meta *Syscall, f func(Type)) {
}
}
for _, t := range meta.Args {
- rec(t)
+ rec(t, DirIn)
}
if meta.Ret != nil {
- rec(meta.Ret)
+ rec(meta.Ret, DirOut)
}
}
diff --git a/prog/validation.go b/prog/validation.go
index c8c030aa7..8ec0615da 100644
--- a/prog/validation.go
+++ b/prog/validation.go
@@ -55,7 +55,7 @@ func (ctx *validCtx) validateCall(c *Call) error {
len(c.Meta.Args), len(c.Args))
}
for i, arg := range c.Args {
- if err := ctx.validateArg(arg, c.Meta.Args[i]); err != nil {
+ if err := ctx.validateArg(arg, c.Meta.Args[i], DirIn); err != nil {
return err
}
}
@@ -72,16 +72,13 @@ func (ctx *validCtx) validateRet(c *Call) error {
if c.Ret == nil {
return fmt.Errorf("return value is absent")
}
- if c.Ret.Type().Dir() != DirOut {
- return fmt.Errorf("return value %v is not output", c.Ret)
- }
if c.Ret.Res != nil || c.Ret.Val != 0 || c.Ret.OpDiv != 0 || c.Ret.OpAdd != 0 {
return fmt.Errorf("return value %v is not empty", c.Ret)
}
- return ctx.validateArg(c.Ret, c.Meta.Ret)
+ return ctx.validateArg(c.Ret, c.Meta.Ret, DirOut)
}
-func (ctx *validCtx) validateArg(arg Arg, typ Type) error {
+func (ctx *validCtx) validateArg(arg Arg, typ Type, dir Dir) error {
if arg == nil {
return fmt.Errorf("nil arg")
}
@@ -91,6 +88,12 @@ func (ctx *validCtx) validateArg(arg Arg, typ Type) error {
if arg.Type() == nil {
return fmt.Errorf("no arg type")
}
+ if _, ok := typ.(*PtrType); ok {
+ dir = DirIn // pointers are always in
+ }
+ if arg.Dir() != dir {
+ return fmt.Errorf("arg %#v type %v has wrong dir %v, expect %v", arg, arg.Type(), arg.Dir(), dir)
+ }
if !ctx.target.isAnyPtr(arg.Type()) && arg.Type() != typ {
return fmt.Errorf("bad arg type %#v, expect %#v", arg.Type(), typ)
}
@@ -101,7 +104,7 @@ func (ctx *validCtx) validateArg(arg Arg, typ Type) error {
func (arg *ConstArg) validate(ctx *validCtx) error {
switch typ := arg.Type().(type) {
case *IntType:
- if typ.Dir() == DirOut && !isDefault(arg) {
+ if arg.Dir() == DirOut && !isDefault(arg) {
return fmt.Errorf("out int arg '%v' has bad const value %v", typ.Name(), arg.Val)
}
case *ProcType:
@@ -116,9 +119,10 @@ func (arg *ConstArg) validate(ctx *validCtx) error {
default:
return fmt.Errorf("const arg %v has bad type %v", arg, typ.Name())
}
- if typ := arg.Type(); typ.Dir() == DirOut {
+ if arg.Dir() == DirOut {
// We generate output len arguments, which makes sense since it can be
// a length of a variable-length array which is not known otherwise.
+ typ := arg.Type()
if _, isLen := typ.(*LenType); !isLen {
if !typ.isDefaultArg(arg) {
return fmt.Errorf("output arg '%v'/'%v' has non default value '%+v'",
@@ -143,7 +147,7 @@ func (arg *ResultArg) validate(ctx *validCtx) error {
}
ctx.uses[u] = arg
}
- if typ.Dir() == DirOut && arg.Val != 0 && arg.Val != typ.Default() {
+ if arg.Dir() == DirOut && arg.Val != 0 && arg.Val != typ.Default() {
return fmt.Errorf("out resource arg '%v' has bad const value %v", typ.Name(), arg.Val)
}
if arg.Res != nil {
@@ -163,7 +167,7 @@ func (arg *DataArg) validate(ctx *validCtx) error {
if !ok {
return fmt.Errorf("data arg %v has bad type %v", arg, arg.Type().Name())
}
- if typ.Dir() == DirOut && len(arg.data) != 0 {
+ if arg.Dir() == DirOut && len(arg.data) != 0 {
return fmt.Errorf("output arg '%v' has data", typ.Name())
}
if !typ.Varlen() && typ.Size() != arg.Size() {
@@ -188,7 +192,7 @@ func (arg *GroupArg) validate(ctx *validCtx) error {
typ.Name(), len(typ.Fields), len(arg.Inner))
}
for i, field := range arg.Inner {
- if err := ctx.validateArg(field, typ.Fields[i]); err != nil {
+ if err := ctx.validateArg(field, typ.Fields[i], arg.Dir()); err != nil {
return err
}
}
@@ -199,7 +203,7 @@ func (arg *GroupArg) validate(ctx *validCtx) error {
typ.Name(), len(arg.Inner), typ.RangeBegin)
}
for _, elem := range arg.Inner {
- if err := ctx.validateArg(elem, typ.Type); err != nil {
+ if err := ctx.validateArg(elem, typ.Type, arg.Dir()); err != nil {
return err
}
}
@@ -224,7 +228,7 @@ func (arg *UnionArg) validate(ctx *validCtx) error {
if optType == nil {
return fmt.Errorf("union arg '%v' has bad option", typ.Name())
}
- return ctx.validateArg(arg.Option, optType)
+ return ctx.validateArg(arg.Option, optType, arg.Dir())
}
func (arg *PointerArg) validate(ctx *validCtx) error {
@@ -235,14 +239,14 @@ func (arg *PointerArg) validate(ctx *validCtx) error {
}
case *PtrType:
if arg.Res != nil {
- if err := ctx.validateArg(arg.Res, typ.Type); err != nil {
+ if err := ctx.validateArg(arg.Res, typ.Type, typ.ElemDir); err != nil {
return err
}
}
if arg.VmaSize != 0 {
return fmt.Errorf("pointer arg '%v' has nonzero size", typ.Name())
}
- if typ.Dir() == DirOut {
+ if arg.Dir() == DirOut {
return fmt.Errorf("pointer arg '%v' has output direction", typ.Name())
}
default: