diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2020-05-01 17:19:27 +0200 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2020-05-02 12:16:06 +0200 |
| commit | 58da4c35b15200b7279f18ea15bc8644618aae78 (patch) | |
| tree | 412d59572c980c4eb582d6d0e187eb6ec32345c9 /prog | |
| parent | bc734e7ada413654f1b7d948b2a857260a52dd9c (diff) | |
prog: introduce Field type
Remvoe FieldName from Type and add a separate Field type
that holds field name. Use Field for struct fields, union options
and syscalls arguments, only these really have names.
Reduces size of sys/linux/gen/amd64.go from 5665583 to 5201321 (-8.2%).
Allows to not create new type for squashed any pointer.
But main advantages will follow, e.g. removing StructDesc,
using TypeRef in Arg, etc.
Update #1580
Diffstat (limited to 'prog')
| -rw-r--r-- | prog/analysis.go | 5 | ||||
| -rw-r--r-- | prog/any.go | 70 | ||||
| -rw-r--r-- | prog/any_test.go | 4 | ||||
| -rw-r--r-- | prog/checksum.go | 25 | ||||
| -rw-r--r-- | prog/encoding.go | 36 | ||||
| -rw-r--r-- | prog/hints_test.go | 2 | ||||
| -rw-r--r-- | prog/minimization.go | 23 | ||||
| -rw-r--r-- | prog/mutation.go | 24 | ||||
| -rw-r--r-- | prog/prog.go | 5 | ||||
| -rw-r--r-- | prog/rand.go | 15 | ||||
| -rw-r--r-- | prog/size.go | 70 | ||||
| -rw-r--r-- | prog/target.go | 6 | ||||
| -rw-r--r-- | prog/types.go | 43 | ||||
| -rw-r--r-- | prog/validation.go | 19 |
14 files changed, 156 insertions, 191 deletions
diff --git a/prog/analysis.go b/prog/analysis.go index e1fbaa557..9300240b5 100644 --- a/prog/analysis.go +++ b/prog/analysis.go @@ -101,6 +101,7 @@ func (s *state) analyzeImpl(c *Call, resources bool) { type ArgCtx struct { Parent *[]Arg // GroupArg.Inner (for structs) or Call.Args containing this arg + Fields []Field // Fields of the parent struct/syscall Base *PointerArg // pointer to the base of the heap object containing this arg Offset uint64 // offset of this arg from the base Stop bool // if set by the callback, subargs of this arg are not visited @@ -116,6 +117,7 @@ func ForeachArg(c *Call, f func(Arg, *ArgCtx)) { foreachArgImpl(c.Ret, ctx, f) } ctx.Parent = &c.Args + ctx.Fields = c.Meta.Args for _, arg := range c.Args { foreachArgImpl(arg, ctx, f) } @@ -128,8 +130,9 @@ func foreachArgImpl(arg Arg, ctx ArgCtx, f func(Arg, *ArgCtx)) { } switch a := arg.(type) { case *GroupArg: - if _, ok := a.Type().(*StructType); ok { + if typ, ok := a.Type().(*StructType); ok { ctx.Parent = &a.Inner + ctx.Fields = typ.Fields } var totalSize uint64 for _, arg1 := range a.Inner { diff --git a/prog/any.go b/prog/any.go index 6e90698b9..1ce444ce2 100644 --- a/prog/any.go +++ b/prog/any.go @@ -36,13 +36,10 @@ type anyTypes struct { // resoct fmt[oct, ANYRES64] // ] [varlen] func initAnyTypes(target *Target) { - target.any.union = &UnionType{ - FldName: "ANYUNION", - } + target.any.union = &UnionType{} target.any.array = &ArrayType{ TypeCommon: TypeCommon{ TypeName: "ANYARRAY", - FldName: "ANYARRAY", IsVarlen: true, }, Elem: target.any.union, @@ -50,7 +47,6 @@ func initAnyTypes(target *Target) { target.any.ptrPtr = &PtrType{ TypeCommon: TypeCommon{ TypeName: "ANYPTR", - FldName: "ANYPTR", TypeSize: target.PtrSize, IsOptional: true, }, @@ -60,7 +56,6 @@ func initAnyTypes(target *Target) { target.any.ptr64 = &PtrType{ TypeCommon: TypeCommon{ TypeName: "ANYPTR64", - FldName: "ANYPTR64", TypeSize: 8, IsOptional: true, }, @@ -70,7 +65,6 @@ func initAnyTypes(target *Target) { target.any.blob = &BufferType{ TypeCommon: TypeCommon{ TypeName: "ANYBLOB", - FldName: "ANYBLOB", IsVarlen: true, }, } @@ -78,7 +72,6 @@ func initAnyTypes(target *Target) { return &ResourceType{ TypeCommon: TypeCommon{ TypeName: name, - FldName: name, TypeSize: size, IsOptional: true, }, @@ -99,37 +92,27 @@ func initAnyTypes(target *Target) { target.any.union.StructDesc = &StructDesc{ TypeCommon: TypeCommon{ TypeName: "ANYUNION", - FldName: "ANYUNION", IsVarlen: true, }, - Fields: []Type{ - target.any.blob, - target.any.res16, - target.any.res32, - target.any.res64, - target.any.resdec, - target.any.reshex, - target.any.resoct, + Fields: []Field{ + {Name: "ANYBLOB", Type: target.any.blob}, + {Name: "ANYRES16", Type: target.any.res16}, + {Name: "ANYRES32", Type: target.any.res32}, + {Name: "ANYRES64", Type: target.any.res64}, + {Name: "ANYRESDEC", Type: target.any.resdec}, + {Name: "ANYRESHEX", Type: target.any.reshex}, + {Name: "ANYRESOCT", Type: target.any.resoct}, }, } } -func (target *Target) makeAnyPtrType(size uint64, field string) *PtrType { - // We need to make a copy because type holds field name, - // and field names are used as len target. - var typ PtrType +func (target *Target) getAnyPtrType(size uint64) *PtrType { if size == target.PtrSize { - typ = *target.any.ptrPtr + return target.any.ptrPtr } else if size == 8 { - typ = *target.any.ptr64 - } else { - panic(fmt.Sprintf("bad pointer size %v", size)) - } - typ.TypeSize = size - if field != "" { - typ.FldName = field + return target.any.ptr64 } - return &typ + panic(fmt.Sprintf("bad pointer size %v", size)) } func (target *Target) isAnyPtr(typ Type) bool { @@ -204,7 +187,7 @@ func (target *Target) ArgContainsAny(arg0 Arg) (res bool) { return } -func (target *Target) squashPtr(arg *PointerArg, preserveField bool) { +func (target *Target) squashPtr(arg *PointerArg) { if arg.Res == nil || arg.VmaSize != 0 { panic("bad ptr arg") } @@ -212,11 +195,7 @@ func (target *Target) squashPtr(arg *PointerArg, preserveField bool) { size0 := res0.Size() var elems []Arg target.squashPtrImpl(arg.Res, &elems) - field := "" - if preserveField { - field = arg.Type().FieldName() - } - arg.typ = target.makeAnyPtrType(arg.Type().Size(), field) + arg.typ = target.getAnyPtrType(arg.Type().Size()) arg.Res = MakeGroupArg(arg.typ.(*PtrType).Elem, DirIn, elems) if size := arg.Res.Size(); size != size0 { panic(fmt.Sprintf("squash changed size %v->%v for %v", size0, size, res0.Type())) @@ -287,29 +266,30 @@ func (target *Target) squashConst(arg *ConstArg, elems *[]Arg) { } func (target *Target) squashResult(arg *ResultArg, elems *[]Arg) { + index := -1 switch arg.Type().Format() { case FormatNative, FormatBigEndian: switch arg.Size() { case 2: - arg.typ = target.any.res16 + arg.typ, index = target.any.res16, 1 case 4: - arg.typ = target.any.res32 + arg.typ, index = target.any.res32, 2 case 8: - arg.typ = target.any.res64 + arg.typ, index = target.any.res64, 3 default: panic("bad size") } case FormatStrDec: - arg.typ = target.any.resdec + arg.typ, index = target.any.resdec, 4 case FormatStrHex: - arg.typ = target.any.reshex + arg.typ, index = target.any.reshex, 5 case FormatStrOct: - arg.typ = target.any.resoct + arg.typ, index = target.any.resoct, 6 default: panic("bad") } arg.dir = DirIn - *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, arg)) + *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, arg, index)) } func (target *Target) squashGroup(arg *GroupArg, elems *[]Arg) { @@ -386,13 +366,13 @@ func (target *Target) squashedValue(arg *ConstArg) (uint64, BinaryFormat) { func (target *Target) ensureDataElem(elems *[]Arg) *DataArg { if len(*elems) == 0 { res := MakeDataArg(target.any.blob, DirIn, nil) - *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, res)) + *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, res, 0)) return res } res, ok := (*elems)[len(*elems)-1].(*UnionArg).Option.(*DataArg) if !ok { res = MakeDataArg(target.any.blob, DirIn, nil) - *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, res)) + *elems = append(*elems, MakeUnionArg(target.any.union, DirIn, res, 0)) } return res } diff --git a/prog/any_test.go b/prog/any_test.go index 968d1624a..e481ef1c1 100644 --- a/prog/any_test.go +++ b/prog/any_test.go @@ -61,12 +61,12 @@ func TestSquash(t *testing.T) { if target.ArgContainsAny(ptrArg) { t.Fatalf("arg is already squashed") } - target.squashPtr(ptrArg, true) + target.squashPtr(ptrArg) if !target.ArgContainsAny(ptrArg) { t.Fatalf("arg is not squashed") } p1 := strings.TrimSpace(string(p.Serialize())) - target.squashPtr(ptrArg, true) + target.squashPtr(ptrArg) p2 := strings.TrimSpace(string(p.Serialize())) if p1 != p2 { t.Fatalf("double squash changed program:\n%v\nvs:\n%v", p1, p2) diff --git a/prog/checksum.go b/prog/checksum.go index 2db825f1e..ddad5411f 100644 --- a/prog/checksum.go +++ b/prog/checksum.go @@ -115,19 +115,19 @@ func calcChecksumsCall(c *Call) (map[Arg]CsumInfo, map[Arg]struct{}) { func findCsummedArg(arg Arg, typ *CsumType, parentsMap map[Arg]Arg) Arg { if typ.Buf == ParentRef { - if csummedArg, ok := parentsMap[arg]; ok { - return csummedArg + csummedArg := parentsMap[arg] + if csummedArg == nil { + panic(fmt.Sprintf("%q for %q is not in parents map", ParentRef, typ.Name())) } - panic(fmt.Sprintf("%v for %v is not in parents map", ParentRef, typ.Name())) - } else { - for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] { - // TODO(dvyukov): support template argument names as in size calculation. - if typ.Buf == parent.Type().Name() { - return parent - } + return csummedArg + } + for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] { + // TODO(dvyukov): support template argument names as in size calculation. + if typ.Buf == parent.Type().Name() { + return parent } } - panic(fmt.Sprintf("csum field '%v' references non existent field '%v'", typ.FieldName(), typ.Buf)) + panic(fmt.Sprintf("csum field %q references non existent field %q", typ.Name(), typ.Buf)) } func composePseudoCsumIPv4(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo { @@ -160,8 +160,9 @@ func extractHeaderParams(arg *GroupArg, size uint64) (Arg, Arg) { } func getFieldByName(arg *GroupArg, name string) Arg { - for _, field := range arg.Inner { - if field.Type().FieldName() == name { + typ := arg.Type().(*StructType) + for i, field := range arg.Inner { + if typ.Fields[i].Name == name { return field } } diff --git a/prog/encoding.go b/prog/encoding.go index 38935b302..f1fc97da4 100644 --- a/prog/encoding.go +++ b/prog/encoding.go @@ -167,7 +167,8 @@ func (a *GroupArg) serialize(ctx *serializer) { } func (a *UnionArg) serialize(ctx *serializer) { - ctx.printf("@%v", a.Option.Type().FieldName()) + typ := a.Type().(*UnionType) + ctx.printf("@%v", typ.Fields[a.Index].Name) if !ctx.verbose && isDefault(a.Option) { return } @@ -276,11 +277,11 @@ func (p *parser) parseProg() (*Prog, error) { p.eatExcessive(false, "excessive syscall arguments") break } - typ := meta.Args[i] - if IsPad(typ) { + field := meta.Args[i] + if IsPad(field.Type) { return nil, fmt.Errorf("padding in syscall %v arguments", name) } - arg, err := p.parseArg(typ, DirIn) + arg, err := p.parseArg(field.Type, DirIn) if err != nil { return nil, err } @@ -488,7 +489,7 @@ func (p *parser) parseArgAddr(typ Type, dir Dir) (Arg, error) { p.Parse('N') p.Parse('Y') p.Parse('=') - anyPtr := p.target.makeAnyPtrType(typ.Size(), typ.FieldName()) + anyPtr := p.target.getAnyPtrType(typ.Size()) typ, elem, elemDir = anyPtr, anyPtr.Elem, anyPtr.ElemDir } var err error @@ -582,11 +583,11 @@ func (p *parser) parseArgStruct(typ Type, dir Dir) (Arg, error) { p.eatExcessive(false, "excessive struct %v fields", typ.Name()) break } - fld := t1.Fields[i] - if IsPad(fld) { - inner = append(inner, MakeConstArg(fld, dir, 0)) + field := t1.Fields[i] + if IsPad(field.Type) { + inner = append(inner, MakeConstArg(field.Type, dir, 0)) } else { - arg, err := p.parseArg(fld, dir) + arg, err := p.parseArg(field.Type, dir) if err != nil { return nil, err } @@ -598,11 +599,11 @@ func (p *parser) parseArgStruct(typ Type, dir Dir) (Arg, error) { } p.Parse('}') for len(inner) < len(t1.Fields) { - fld := t1.Fields[len(inner)] - if !IsPad(fld) { + field := t1.Fields[len(inner)] + if !IsPad(field.Type) { p.strictFailf("missing struct %v fields %v/%v", typ.Name(), len(inner), len(t1.Fields)) } - inner = append(inner, fld.DefaultArg(dir)) + inner = append(inner, field.Type.DefaultArg(dir)) } return MakeGroupArg(typ, dir, inner), nil } @@ -646,9 +647,10 @@ func (p *parser) parseArgUnion(typ Type, dir Dir) (Arg, error) { p.Parse('@') name := p.Ident() var optType Type - for _, t2 := range t1.Fields { - if name == t2.FieldName() { - optType = t2 + index := -1 + for i, field := range t1.Fields { + if name == field.Name { + optType, index = field.Type, i break } } @@ -667,7 +669,7 @@ func (p *parser) parseArgUnion(typ Type, dir Dir) (Arg, error) { } else { opt = optType.DefaultArg(dir) } - return MakeUnionArg(typ, dir, opt), nil + return MakeUnionArg(typ, dir, opt, index), nil } // Eats excessive call arguments and struct fields to recover after description changes. @@ -1004,7 +1006,7 @@ func (p *parser) auto(arg Arg) Arg { func (p *parser) fixupAutos(prog *Prog) { s := analyze(nil, nil, prog, nil) for _, c := range prog.Calls { - p.target.assignSizesArray(c.Args, p.autos) + p.target.assignSizesArray(c.Args, c.Meta.Args, p.autos) ForeachArg(c, func(arg Arg, _ *ArgCtx) { if !p.autos[arg] { return diff --git a/prog/hints_test.go b/prog/hints_test.go index caf84e715..852f7bd86 100644 --- a/prog/hints_test.go +++ b/prog/hints_test.go @@ -295,7 +295,7 @@ func TestHintsCheckDataArg(t *testing.T) { res := make(map[string]bool) // Whatever type here. It's just needed to pass the // dataArg.Type().Dir() == DirIn check. - typ := &ArrayType{TypeCommon{"", "", 0, false, true}, nil, 0, 0, 0} + typ := &ArrayType{TypeCommon{"", 0, false, true}, nil, 0, 0, 0} dataArg := MakeDataArg(typ, DirIn, []byte(test.in)) checkDataArg(dataArg, test.comps, func() { res[string(dataArg.Data())] = true diff --git a/prog/minimization.go b/prog/minimization.go index 0b5bb05b2..88c401b2c 100644 --- a/prog/minimization.go +++ b/prog/minimization.go @@ -41,8 +41,8 @@ func Minimize(p0 *Prog, callIndex0 int, crash bool, pred0 func(*Prog, int) bool) again: ctx.p = p0.Clone() ctx.call = ctx.p.Calls[i] - for j, arg := range ctx.call.Args { - if ctx.do(arg, fmt.Sprintf("%v", j)) { + for j, field := range ctx.call.Meta.Args { + if ctx.do(ctx.call.Args[j], field.Name, "") { goto again } } @@ -88,8 +88,8 @@ type minimizeArgsCtx struct { triedPaths map[string]bool } -func (ctx *minimizeArgsCtx) do(arg Arg, path string) bool { - path += fmt.Sprintf("-%v", arg.Type().FieldName()) +func (ctx *minimizeArgsCtx) do(arg Arg, field, path string) bool { + path += fmt.Sprintf("-%v", field) if ctx.triedPaths[path] { return false } @@ -118,8 +118,8 @@ func (typ *TypeCommon) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool func (typ *StructType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool { a := arg.(*GroupArg) - for _, innerArg := range a.Inner { - if ctx.do(innerArg, path) { + for i, innerArg := range a.Inner { + if ctx.do(innerArg, typ.Fields[i].Name, path) { return true } } @@ -127,7 +127,8 @@ func (typ *StructType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool } func (typ *UnionType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool { - return ctx.do(arg.(*UnionArg).Option, path) + a := arg.(*UnionArg) + return ctx.do(a.Option, typ.Fields[a.Index].Name, path) } func (typ *PtrType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool { @@ -135,17 +136,17 @@ func (typ *PtrType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool { if a.Res == nil { return false } - if !ctx.triedPaths[path+"->"] { + if path1 := path + ">"; !ctx.triedPaths[path1] { removeArg(a.Res) replaceArg(a, MakeSpecialPointerArg(a.Type(), a.Dir(), 0)) ctx.target.assignSizesCall(ctx.call) if ctx.pred(ctx.p, ctx.callIndex0) { *ctx.p0 = ctx.p } - ctx.triedPaths[path+"->"] = true + ctx.triedPaths[path1] = true return true } - return ctx.do(a.Res, path) + return ctx.do(a.Res, "", path) } func (typ *ArrayType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool { @@ -167,7 +168,7 @@ func (typ *ArrayType) minimize(ctx *minimizeArgsCtx, arg Arg, path string) bool } return true } - if ctx.do(elem, elemPath) { + if ctx.do(elem, "", elemPath) { return true } } diff --git a/prog/mutation.go b/prog/mutation.go index ada2febdb..8758b94f2 100644 --- a/prog/mutation.go +++ b/prog/mutation.go @@ -94,7 +94,7 @@ func (ctx *mutator) squashAny() bool { } ptr := complexPtrs[r.Intn(len(complexPtrs))] if !p.Target.isAnyPtr(ptr.Type()) { - p.Target.squashPtr(ptr, true) + p.Target.squashPtr(ptr) } var blobs []*DataArg var bases []*PointerArg @@ -320,7 +320,7 @@ func (t *FlagsType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []* } func (t *LenType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []*Call, retry, preserve bool) { - if !r.mutateSize(arg.(*ConstArg), *ctx.Parent) { + if !r.mutateSize(arg.(*ConstArg), *ctx.Parent, ctx.Fields) { retry = true return } @@ -464,25 +464,15 @@ func (t *UnionType) mutate(r *randGen, s *state, arg Arg, ctx ArgCtx) (calls []* return } a := arg.(*UnionArg) - current := -1 - for i, option := range t.Fields { - if a.Option.Type().FieldName() == option.FieldName() { - current = i - break - } - } - if current == -1 { - panic("can't find current option in union") - } - newIdx := r.Intn(len(t.Fields) - 1) - if newIdx >= current { - newIdx++ + index := r.Intn(len(t.Fields) - 1) + if index >= a.Index { + index++ } - optType := t.Fields[newIdx] + optType := t.Fields[index].Type removeArg(a.Option) var newOpt Arg newOpt, calls = r.generateArg(s, optType, a.Dir()) - replaceArg(arg, MakeUnionArg(t, a.Dir(), newOpt)) + replaceArg(arg, MakeUnionArg(t, a.Dir(), newOpt, index)) return } diff --git a/prog/prog.go b/prog/prog.go index 017a0dbbb..dc1c8dd6c 100644 --- a/prog/prog.go +++ b/prog/prog.go @@ -234,10 +234,11 @@ func (arg *GroupArg) fixedInnerSize() bool { type UnionArg struct { ArgCommon Option Arg + Index int // Index of the selected option in the union type. } -func MakeUnionArg(t Type, dir Dir, opt Arg) *UnionArg { - return &UnionArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Option: opt} +func MakeUnionArg(t Type, dir Dir, opt Arg, index int) *UnionArg { + return &UnionArg{ArgCommon: ArgCommon{typ: t, dir: dir}, Option: opt, Index: index} } func (arg *UnionArg) Size() uint64 { diff --git a/prog/rand.go b/prog/rand.go index 603996114..135469fe7 100644 --- a/prog/rand.go +++ b/prog/rand.go @@ -601,15 +601,15 @@ func (target *Target) DataMmapProg() *Prog { } } -func (r *randGen) generateArgs(s *state, types []Type, dir Dir) ([]Arg, []*Call) { +func (r *randGen) generateArgs(s *state, fields []Field, dir Dir) ([]Arg, []*Call) { var calls []*Call - args := make([]Arg, len(types)) + args := make([]Arg, len(fields)) // Generate all args. Size args have the default value 0 for now. - for i, typ := range types { - arg, calls1 := r.generateArg(s, typ, dir) + for i, field := range fields { + arg, calls1 := r.generateArg(s, field.Type, dir) if arg == nil { - panic(fmt.Sprintf("generated arg is nil for type '%v', types: %+v", typ.Name(), types)) + panic(fmt.Sprintf("generated arg is nil for field '%v', fields: %+v", field.Type.Name(), fields)) } args[i] = arg calls = append(calls, calls1...) @@ -797,9 +797,10 @@ func (a *StructType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []* } func (a *UnionType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { - optType := a.Fields[r.Intn(len(a.Fields))] + index := r.Intn(len(a.Fields)) + optType := a.Fields[index].Type opt, calls := r.generateArg(s, optType, dir) - return MakeUnionArg(a, dir, opt), calls + return MakeUnionArg(a, dir, opt, index), calls } func (a *PtrType) generate(r *randGen, s *state, dir Dir) (arg Arg, calls []*Call) { diff --git a/prog/size.go b/prog/size.go index a7fdb093e..61d382723 100644 --- a/prog/size.go +++ b/prog/size.go @@ -14,7 +14,8 @@ const ( SyscallRef = "syscall" ) -func (target *Target) assignSizes(args []Arg, parentsMap map[Arg]Arg, syscallArgs []Arg, autos map[Arg]bool) { +func (target *Target) assignSizes(args []Arg, fields []Field, parentsMap map[Arg]Arg, + syscallArgs []Arg, syscallFields []Field, autos map[Arg]bool) { for _, arg := range args { if arg = InnerArg(arg); arg == nil { continue // Pointer to optional len field, no need to fill in value. @@ -31,19 +32,26 @@ func (target *Target) assignSizes(args []Arg, parentsMap map[Arg]Arg, syscallArg } a := arg.(*ConstArg) if typ.Path[0] == SyscallRef { - target.assignSize(a, nil, typ.Path[1:], syscallArgs, parentsMap) + target.assignSize(a, nil, typ.Path[1:], syscallArgs, syscallFields, parentsMap) } else { - target.assignSize(a, a, typ.Path, args, parentsMap) + target.assignSize(a, a, typ.Path, args, fields, parentsMap) } } } -func (target *Target) assignSize(dst *ConstArg, pos Arg, path []string, args []Arg, parentsMap map[Arg]Arg) { +func (target *Target) assignSizeStruct(dst *ConstArg, buf Arg, path []string, parentsMap map[Arg]Arg) { + arg := buf.(*GroupArg) + typ := arg.Type().(*StructType) + target.assignSize(dst, buf, path, arg.Inner, typ.Fields, parentsMap) +} + +func (target *Target) assignSize(dst *ConstArg, pos Arg, path []string, args []Arg, + fields []Field, parentsMap map[Arg]Arg) { elem := path[0] path = path[1:] var offset uint64 - for _, buf := range args { - if elem != buf.Type().FieldName() { + for i, buf := range args { + if elem != fields[i].Name { offset += buf.Size() continue } @@ -58,39 +66,39 @@ func (target *Target) assignSize(dst *ConstArg, pos Arg, path []string, args []A dst.Val = 0 // target is an optional pointer return } - if len(path) == 0 { - dst.Val = target.computeSize(buf, offset, dst.Type().(*LenType)) - } else { - target.assignSize(dst, buf, path, buf.(*GroupArg).Inner, parentsMap) + if len(path) != 0 { + target.assignSizeStruct(dst, buf, path, parentsMap) + return } + dst.Val = target.computeSize(buf, offset, dst.Type().(*LenType)) return } if elem == ParentRef { buf := parentsMap[pos] - if len(path) == 0 { - dst.Val = target.computeSize(buf, noOffset, dst.Type().(*LenType)) - } else { - target.assignSize(dst, buf, path, buf.(*GroupArg).Inner, parentsMap) + if len(path) != 0 { + target.assignSizeStruct(dst, buf, path, parentsMap) + return } + dst.Val = target.computeSize(buf, noOffset, dst.Type().(*LenType)) return } for buf := parentsMap[pos]; buf != nil; buf = parentsMap[buf] { if elem != buf.Type().TemplateName() { continue } - if len(path) == 0 { - dst.Val = target.computeSize(buf, noOffset, dst.Type().(*LenType)) - } else { - target.assignSize(dst, buf, path, buf.(*GroupArg).Inner, parentsMap) + if len(path) != 0 { + target.assignSizeStruct(dst, buf, path, parentsMap) + return } + dst.Val = target.computeSize(buf, noOffset, dst.Type().(*LenType)) return } - var argNames []string - for _, arg := range args { - argNames = append(argNames, arg.Type().FieldName()) + var fieldNames []string + for _, field := range fields { + fieldNames = append(fieldNames, field.Name) } - panic(fmt.Sprintf("len field %q references non existent field %q, pos=%q/%q, argsMap: %+v", - dst.Type().FieldName(), elem, pos.Type().Name(), pos.Type().FieldName(), argNames)) + panic(fmt.Sprintf("len field %q references non existent field %q, pos=%q, argsMap: %v, path: %v", + dst.Type().Name(), elem, pos.Type().Name(), fieldNames, path)) } const noOffset = ^uint64(0) @@ -121,7 +129,7 @@ func (target *Target) computeSize(arg Arg, offset uint64, lenType *LenType) uint } } -func (target *Target) assignSizesArray(args []Arg, autos map[Arg]bool) { +func (target *Target) assignSizesArray(args []Arg, fields []Field, autos map[Arg]bool) { parentsMap := make(map[Arg]Arg) for _, arg := range args { ForeachSubArg(arg, func(arg Arg, _ *ArgCtx) { @@ -132,29 +140,29 @@ func (target *Target) assignSizesArray(args []Arg, autos map[Arg]bool) { } }) } - target.assignSizes(args, parentsMap, args, autos) + target.assignSizes(args, fields, parentsMap, args, fields, autos) for _, arg := range args { ForeachSubArg(arg, func(arg Arg, _ *ArgCtx) { - if _, ok := arg.Type().(*StructType); ok { - target.assignSizes(arg.(*GroupArg).Inner, parentsMap, args, autos) + if typ, ok := arg.Type().(*StructType); ok { + target.assignSizes(arg.(*GroupArg).Inner, typ.Fields, parentsMap, args, fields, autos) } }) } } func (target *Target) assignSizesCall(c *Call) { - target.assignSizesArray(c.Args, nil) + target.assignSizesArray(c.Args, c.Meta.Args, nil) } -func (r *randGen) mutateSize(arg *ConstArg, parent []Arg) bool { +func (r *randGen) mutateSize(arg *ConstArg, parent []Arg, fields []Field) bool { typ := arg.Type().(*LenType) elemSize := typ.BitSize / 8 if elemSize == 0 { elemSize = 1 // TODO(dvyukov): implement path support for size mutation. if len(typ.Path) == 1 { - for _, field := range parent { - if typ.Path[0] != field.Type().FieldName() { + for i, field := range parent { + if typ.Path[0] != fields[i].Name { continue } if inner := InnerArg(field); inner != nil { diff --git a/prog/target.go b/prog/target.go index f71730b1b..89940a61a 100644 --- a/prog/target.go +++ b/prog/target.go @@ -187,12 +187,12 @@ func restoreLinks(syscalls []*Syscall, resources []*ResourceDesc, structs []*Key for _, desc := range structs { keyedStructs[desc.Key] = desc.Desc for i := range desc.Desc.Fields { - unref(&desc.Desc.Fields[i], types) + unref(&desc.Desc.Fields[i].Type, types) } } for _, c := range syscalls { for i := range c.Args { - unref(&c.Args[i], types) + unref(&c.Args[i].Type, types) } if c.Ret != nil { unref(&c.Ret, types) @@ -262,7 +262,7 @@ func (g *Gen) GenerateSpecialArg(typ Type, dir Dir, pcalls *[]*Call) Arg { func (g *Gen) generateArg(typ Type, dir Dir, pcalls *[]*Call, ignoreSpecial bool) Arg { arg, calls := g.r.generateArgImpl(g.s, typ, dir, ignoreSpecial) *pcalls = append(*pcalls, calls...) - g.r.target.assignSizesArray([]Arg{arg}, nil) + g.r.target.assignSizesArray([]Arg{arg}, []Field{{Name: "", Type: arg.Type()}}, nil) return arg } diff --git a/prog/types.go b/prog/types.go index 2d9e8659c..4412239d3 100644 --- a/prog/types.go +++ b/prog/types.go @@ -15,7 +15,7 @@ type Syscall struct { Name string CallName string MissingArgs int // number of trailing args that should be zero-filled - Args []Type + Args []Field Ret Type Attrs SyscallAttrs @@ -65,6 +65,11 @@ func (dir Dir) String() string { } } +type Field struct { + Name string + Type +} + type BinaryFormat int const ( @@ -78,7 +83,6 @@ const ( type Type interface { String() string Name() string - FieldName() string TemplateName() string // for template structs name without arguments Optional() bool Varlen() bool @@ -142,7 +146,6 @@ func IsPad(t Type) bool { type TypeCommon struct { TypeName string - FldName string // for struct fields and named args // Static size of the type, or 0 for variable size types and all but last bitfields in the group. TypeSize uint64 IsOptional bool @@ -153,10 +156,6 @@ func (t *TypeCommon) Name() string { return t.TypeName } -func (t *TypeCommon) FieldName() string { - return t.FldName -} - func (t *TypeCommon) TemplateName() string { name := t.TypeName if pos := strings.IndexByte(name, '['); pos != -1 { @@ -580,8 +579,7 @@ func (t *PtrType) isDefaultArg(arg Arg) bool { } type StructType struct { - Key StructKey - FldName string + Key StructKey *StructDesc } @@ -589,10 +587,6 @@ func (t *StructType) String() string { return t.Name() } -func (t *StructType) FieldName() string { - return t.FldName -} - func (t *StructType) DefaultArg(dir Dir) Arg { inner := make([]Arg, len(t.Fields)) for i, field := range t.Fields { @@ -612,8 +606,7 @@ func (t *StructType) isDefaultArg(arg Arg) bool { } type UnionType struct { - Key StructKey - FldName string + Key StructKey *StructDesc } @@ -621,29 +614,21 @@ func (t *UnionType) String() string { return t.Name() } -func (t *UnionType) FieldName() string { - return t.FldName -} - func (t *UnionType) DefaultArg(dir Dir) Arg { - return MakeUnionArg(t, dir, t.Fields[0].DefaultArg(dir)) + return MakeUnionArg(t, dir, t.Fields[0].DefaultArg(dir), 0) } func (t *UnionType) isDefaultArg(arg Arg) bool { a := arg.(*UnionArg) - return a.Option.Type().FieldName() == t.Fields[0].FieldName() && isDefault(a.Option) + return a.Index == 0 && isDefault(a.Option) } type StructDesc struct { TypeCommon - Fields []Type + Fields []Field AlignAttr uint64 } -func (t *StructDesc) FieldName() string { - panic("must not be called") -} - type StructKey struct { Name string } @@ -671,7 +656,7 @@ func foreachType(meta *Syscall, f func(t Type, ctx typeCtx)) { } seen[desc] = true for _, f := range desc.Fields { - rec(f, dir) + rec(f.Type, dir) } } rec = func(t Type, dir Dir) { @@ -691,8 +676,8 @@ func foreachType(meta *Syscall, f func(t Type, ctx typeCtx)) { panic("unknown type") } } - for _, t := range meta.Args { - rec(t, DirIn) + for _, field := range meta.Args { + rec(field.Type, DirIn) } if meta.Ret != nil { rec(meta.Ret, DirOut) diff --git a/prog/validation.go b/prog/validation.go index 54b288905..e4613600c 100644 --- a/prog/validation.go +++ b/prog/validation.go @@ -55,7 +55,7 @@ func (ctx *validCtx) validateCall(c *Call) error { len(c.Meta.Args), len(c.Args)) } for i, arg := range c.Args { - if err := ctx.validateArg(arg, c.Meta.Args[i], DirIn); err != nil { + if err := ctx.validateArg(arg, c.Meta.Args[i].Type, DirIn); err != nil { return err } } @@ -125,8 +125,7 @@ func (arg *ConstArg) validate(ctx *validCtx) error { typ := arg.Type() if _, isLen := typ.(*LenType); !isLen { if !typ.isDefaultArg(arg) { - return fmt.Errorf("output arg '%v'/'%v' has non default value '%+v'", - typ.FieldName(), typ.Name(), arg) + return fmt.Errorf("output arg %q has non default value %+v", typ.Name(), arg) } } } @@ -192,7 +191,7 @@ func (arg *GroupArg) validate(ctx *validCtx) error { typ.Name(), len(typ.Fields), len(arg.Inner)) } for i, field := range arg.Inner { - if err := ctx.validateArg(field, typ.Fields[i], arg.Dir()); err != nil { + if err := ctx.validateArg(field, typ.Fields[i].Type, arg.Dir()); err != nil { return err } } @@ -218,16 +217,10 @@ func (arg *UnionArg) validate(ctx *validCtx) error { if !ok { return fmt.Errorf("union arg %v has bad type %v", arg, arg.Type().Name()) } - var optType Type - for _, typ1 := range typ.Fields { - if arg.Option.Type().FieldName() == typ1.FieldName() { - optType = typ1 - break - } - } - if optType == nil { - return fmt.Errorf("union arg '%v' has bad option", typ.Name()) + if arg.Index < 0 || arg.Index >= len(typ.Fields) { + return fmt.Errorf("union arg %v has bad index %v/%v", arg, arg.Index, len(typ.Fields)) } + optType := typ.Fields[arg.Index].Type return ctx.validateArg(arg.Option, optType, arg.Dir()) } |
