From 58ae5e18624eaaac79cab00e63d6f32c9bd64ee0 Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Sun, 3 May 2020 11:29:12 +0200 Subject: prog: remove StructDesc Remove StructDesc, KeyedStruct, StructKey and all associated logic/complexity in prog and pkg/compiler. We can now handle recursion more generically with the Ref type, and Dir/FieldName are not a part of the type anymore. This makes StructType/UnionType simpler and more natural. Reduces size of sys/linux/gen/amd64.go from 5201321 to 4180861 (-20%). Update #1580 --- pkg/compiler/compiler.go | 18 ++- pkg/compiler/compiler_test.go | 4 - pkg/compiler/gen.go | 250 ++++++++++++++---------------------------- pkg/compiler/types.go | 38 ++++--- prog/any.go | 29 +++-- prog/prio.go | 91 ++++++++------- prog/prog_test.go | 18 ++- prog/rand_test.go | 28 +++-- prog/resources.go | 20 ++-- prog/rotation.go | 2 +- prog/target.go | 63 +++-------- prog/types.go | 100 ++++++++++------- sys/syz-sysgen/sysgen.go | 8 +- tools/syz-check/check.go | 62 +++++------ 14 files changed, 300 insertions(+), 431 deletions(-) diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index f61d02036..55ddc1ba4 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -34,10 +34,9 @@ import ( // Prog is description compilation result. type Prog struct { - Resources []*prog.ResourceDesc - Syscalls []*prog.Syscall - StructDescs []*prog.KeyedStruct - Types []prog.Type + Resources []*prog.ResourceDesc + Syscalls []*prog.Syscall + Types []prog.Type // Set of unsupported syscalls/flags. Unsupported map[string]bool // Returned if consts was nil. @@ -61,9 +60,8 @@ func createCompiler(desc *ast.Description, target *targets.Target, eh ast.ErrorH strFlags: make(map[string]*ast.StrFlags), used: make(map[string]bool), usedTypedefs: make(map[string]bool), - structDescs: make(map[prog.StructKey]*prog.StructDesc), - structNodes: make(map[*prog.StructDesc]*ast.Struct), structVarlen: make(map[string]bool), + structTypes: make(map[string]prog.Type), builtinConsts: map[string]uint64{ "PTR_SIZE": target.PtrSize, }, @@ -104,12 +102,11 @@ func Compile(desc *ast.Description, consts map[string]uint64, target *targets.Ta return nil } syscalls := comp.genSyscalls() - structs := comp.genStructDescs(syscalls) - types := comp.generateTypes(syscalls, structs) + comp.layoutTypes(syscalls) + types := comp.generateTypes(syscalls) prg := &Prog{ Resources: comp.genResources(), Syscalls: syscalls, - StructDescs: structs, Types: types, Unsupported: comp.unsupported, } @@ -139,9 +136,8 @@ type compiler struct { used map[string]bool // contains used structs/resources usedTypedefs map[string]bool - structDescs map[prog.StructKey]*prog.StructDesc - structNodes map[*prog.StructDesc]*ast.Struct structVarlen map[string]bool + structTypes map[string]prog.Type builtinConsts map[string]uint64 } diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go index ae3b3d357..3d4ee3e64 100644 --- a/pkg/compiler/compiler_test.go +++ b/pkg/compiler/compiler_test.go @@ -122,8 +122,6 @@ func TestData(t *testing.T) { out := new(bytes.Buffer) fmt.Fprintf(out, "\n\nRESOURCES:\n") serializer.Write(out, desc.Resources) - fmt.Fprintf(out, "\n\nSTRUCTS:\n") - serializer.Write(out, desc.StructDescs) fmt.Fprintf(out, "\n\nSYSCALLS:\n") serializer.Write(out, desc.Syscalls) if false { @@ -188,8 +186,6 @@ s2 { if p == nil { t.Fatal("failed to compile") } - got := p.StructDescs[0].Desc - t.Logf("got: %#v", got) } func TestCollectUnusedError(t *testing.T) { diff --git a/pkg/compiler/gen.go b/pkg/compiler/gen.go index 748941bae..a86190863 100644 --- a/pkg/compiler/gen.go +++ b/pkg/compiler/gen.go @@ -138,26 +138,42 @@ func (comp *compiler) genSyscall(n *ast.Call, argSizes []uint64) *prog.Syscall { type typeProxy struct { typ prog.Type id string + ref prog.Ref locations []*prog.Type } -func (comp *compiler) generateTypes(syscalls []*prog.Syscall, structs []*prog.KeyedStruct) []prog.Type { +func (comp *compiler) generateTypes(syscalls []*prog.Syscall) []prog.Type { // Replace all Type's in the descriptions with Ref's // and prepare a sorted array of corresponding real types. proxies := make(map[string]*typeProxy) - for _, call := range syscalls { - for i := range call.Args { - comp.collectTypes(proxies, &call.Args[i].Type) + prog.ForeachTypePost(syscalls, func(typ prog.Type, ctx prog.TypeCtx) { + if _, ok := typ.(prog.Ref); ok { + return } - if call.Ret != nil { - comp.collectTypes(proxies, &call.Ret) + if !typ.Varlen() && typ.Size() == sizeUnassigned { + panic("unassigned size") } - } - for _, str := range structs { - for i := range str.Desc.Fields { - comp.collectTypes(proxies, &str.Desc.Fields[i].Type) + id := typ.Name() + switch typ.(type) { + case *prog.StructType, *prog.UnionType: + // There types can be uniquely identified with the name. + default: + buf := new(bytes.Buffer) + serializer.Write(buf, typ) + id = buf.String() } - } + proxy := proxies[id] + if proxy == nil { + proxy = &typeProxy{ + typ: typ, + id: id, + ref: prog.Ref(len(proxies)), + } + proxies[id] = proxy + } + *ctx.Ptr = proxy.ref + proxy.locations = append(proxy.locations, ctx.Ptr) + }) array := make([]*typeProxy, 0, len(proxies)) for _, proxy := range proxies { array = append(array, proxy) @@ -175,142 +191,75 @@ func (comp *compiler) generateTypes(syscalls []*prog.Syscall, structs []*prog.Ke return types } -func (comp *compiler) collectTypes(proxies map[string]*typeProxy, tptr *prog.Type) { - typ := *tptr - switch t := typ.(type) { - case *prog.PtrType: - comp.collectTypes(proxies, &t.Elem) - case *prog.ArrayType: - comp.collectTypes(proxies, &t.Elem) - case *prog.ResourceType, *prog.BufferType, *prog.VmaType, *prog.LenType, - *prog.FlagsType, *prog.ConstType, *prog.IntType, *prog.ProcType, - *prog.CsumType, *prog.StructType, *prog.UnionType: - default: - panic("unknown type") - } - buf := new(bytes.Buffer) - serializer.Write(buf, typ) - id := buf.String() - proxy := proxies[id] - if proxy == nil { - proxy = &typeProxy{ - typ: typ, - id: id, - } - proxies[id] = proxy - } - proxy.locations = append(proxy.locations, tptr) -} - -func (comp *compiler) genStructDescs(syscalls []*prog.Syscall) []*prog.KeyedStruct { - // Calculate struct/union/array sizes, add padding to structs and detach - // StructDesc's from StructType's. StructType's can be recursive so it's - // not possible to write them out inline as other types. To break the - // recursion detach them, and write StructDesc's out as separate array - // of KeyedStruct's. prog package will reattach them during init. - ctx := &structGen{ - comp: comp, - padded: make(map[interface{}]bool), - detach: make(map[**prog.StructDesc]bool), - } - // We have to do this in the loop until we pad nothing new - // due to recursive structs. - for { - start := len(ctx.padded) - for _, c := range syscalls { - for _, a := range c.Args { - ctx.walk(a.Type) - } - if c.Ret != nil { - ctx.walk(c.Ret) - } - } - if start == len(ctx.padded) { - break - } - } - - // Detach StructDesc's from StructType's. prog will reattach them again. - for descp := range ctx.detach { - *descp = nil - } - - sort.Slice(ctx.structs, func(i, j int) bool { - si, sj := ctx.structs[i].Key, ctx.structs[j].Key - return si.Name < sj.Name +func (comp *compiler) layoutTypes(syscalls []*prog.Syscall) { + // Calculate struct/union/array sizes, add padding to structs, mark bitfields. + padded := make(map[prog.Type]bool) + prog.ForeachTypePost(syscalls, func(typ prog.Type, _ prog.TypeCtx) { + comp.layoutType(typ, padded) }) - return ctx.structs -} - -type structGen struct { - comp *compiler - padded map[interface{}]bool - detach map[**prog.StructDesc]bool - structs []*prog.KeyedStruct } -func (ctx *structGen) check(key prog.StructKey, descp **prog.StructDesc) bool { - ctx.detach[descp] = true - desc := *descp - if ctx.padded[desc] { - return false - } - ctx.padded[desc] = true - for _, f := range desc.Fields { - ctx.walk(f.Type) - if !f.Varlen() && f.Size() == sizeUnassigned { - // An inner struct is not padded yet. - // Leave this struct for next iteration. - delete(ctx.padded, desc) - return false - } - } - if ctx.comp.used[key.Name] { - ctx.structs = append(ctx.structs, &prog.KeyedStruct{ - Key: key, - Desc: desc, - }) +func (comp *compiler) layoutType(typ prog.Type, padded map[prog.Type]bool) { + if padded[typ] { + return } - return true -} - -func (ctx *structGen) walk(t0 prog.Type) { - switch t := t0.(type) { - case *prog.PtrType: - ctx.walk(t.Elem) + switch t := typ.(type) { case *prog.ArrayType: - ctx.walkArray(t) + comp.layoutType(t.Elem, padded) + comp.layoutArray(t) case *prog.StructType: - ctx.walkStruct(t) + for _, f := range t.Fields { + comp.layoutType(f.Type, padded) + } + comp.layoutStruct(t) case *prog.UnionType: - ctx.walkUnion(t) - } -} - -func (ctx *structGen) walkArray(t *prog.ArrayType) { - if ctx.padded[t] { + for _, f := range t.Fields { + comp.layoutType(f.Type, padded) + } + comp.layoutUnion(t) + default: return } - ctx.walk(t.Elem) - if !t.Elem.Varlen() && t.Elem.Size() == sizeUnassigned { - // An inner struct is not padded yet. - // Leave this array for next iteration. - return + if !typ.Varlen() && typ.Size() == sizeUnassigned { + panic("size unassigned") } - ctx.padded[t] = true + padded[typ] = true +} + +func (comp *compiler) layoutArray(t *prog.ArrayType) { t.TypeSize = 0 if t.Kind == prog.ArrayRangeLen && t.RangeBegin == t.RangeEnd && !t.Elem.Varlen() { t.TypeSize = t.RangeBegin * t.Elem.Size() } } -func (ctx *structGen) walkStruct(t *prog.StructType) { - if !ctx.check(t.Key, &t.StructDesc) { +func (comp *compiler) layoutUnion(t *prog.UnionType) { + structNode := comp.structs[t.TypeName] + attrs := comp.parseAttrs(unionAttrs, structNode, structNode.Attrs) + t.TypeSize = 0 + if attrs[attrVarlen] != 0 { return } - comp := ctx.comp - structNode := comp.structNodes[t.StructDesc] + sizeAttr, hasSize := attrs[attrSize] + for i, fld := range t.Fields { + sz := fld.Size() + if hasSize && sz > sizeAttr { + comp.error(structNode.Fields[i].Pos, "union %v has size attribute %v"+ + " which is less than field %v size %v", + structNode.Name.Name, sizeAttr, fld.Type.Name(), sz) + } + if t.TypeSize < sz { + t.TypeSize = sz + } + } + if hasSize { + t.TypeSize = sizeAttr + } +} + +func (comp *compiler) layoutStruct(t *prog.StructType) { // Add paddings, calculate size, mark bitfields. + structNode := comp.structs[t.TypeName] varlen := false for _, f := range t.Fields { if f.Varlen() { @@ -319,7 +268,7 @@ func (ctx *structGen) walkStruct(t *prog.StructType) { } attrs := comp.parseAttrs(structAttrs, structNode, structNode.Attrs) t.AlignAttr = attrs[attrAlign] - comp.layoutStruct(t, varlen, attrs[attrPacked] != 0) + comp.layoutStructFields(t, varlen, attrs[attrPacked] != 0) t.TypeSize = 0 if !varlen { for _, f := range t.Fields { @@ -340,46 +289,7 @@ func (ctx *structGen) walkStruct(t *prog.StructType) { } } -func (ctx *structGen) walkUnion(t *prog.UnionType) { - if !ctx.check(t.Key, &t.StructDesc) { - return - } - comp := ctx.comp - structNode := comp.structNodes[t.StructDesc] - attrs := comp.parseAttrs(unionAttrs, structNode, structNode.Attrs) - t.TypeSize = 0 - if attrs[attrVarlen] != 0 { - return - } - sizeAttr, hasSize := attrs[attrSize] - for i, fld := range t.Fields { - sz := fld.Size() - if hasSize && sz > sizeAttr { - comp.error(structNode.Fields[i].Pos, "union %v has size attribute %v"+ - " which is less than field %v size %v", - structNode.Name.Name, sizeAttr, fld.Type.Name(), sz) - } - if t.TypeSize < sz { - t.TypeSize = sz - } - } - if hasSize { - t.TypeSize = sizeAttr - } -} - -func (comp *compiler) genStructDesc(res *prog.StructDesc, n *ast.Struct, varlen bool) { - // Leave node for genStructDescs to calculate size/padding. - comp.structNodes[res] = n - common := genCommon(n.Name.Name, sizeUnassigned, false) - common.IsVarlen = varlen - *res = prog.StructDesc{ - TypeCommon: common, - Fields: comp.genFieldArray(n.Fields, make([]uint64, len(n.Fields))), - } -} - -func (comp *compiler) layoutStruct(t *prog.StructType, varlen, packed bool) { +func (comp *compiler) layoutStructFields(t *prog.StructType, varlen, packed bool) { var newFields []prog.Field var structAlign, byteOffset, bitOffset uint64 for i, field := range t.Fields { @@ -538,7 +448,7 @@ func (comp *compiler) typeAlign(t0 prog.Type) uint64 { case *prog.ArrayType: return comp.typeAlign(t.Elem) case *prog.StructType: - n := comp.structNodes[t.StructDesc] + n := comp.structs[t.TypeName] attrs := comp.parseAttrs(structAttrs, n, n.Attrs) if align := attrs[attrAlign]; align != 0 { return align // overrided by user attribute diff --git a/pkg/compiler/types.go b/pkg/compiler/types.go index 7ed12afd2..44ccb1f1d 100644 --- a/pkg/compiler/types.go +++ b/pkg/compiler/types.go @@ -237,7 +237,7 @@ var typeArray = &typeDesc{ RangeEnd: end, } } - // TypeSize is assigned later in genStructDescs. + // TypeSize is assigned later in layoutArray. return &prog.ArrayType{ TypeCommon: base.TypeCommon, Elem: elemType, @@ -815,27 +815,31 @@ func init() { return true } typeStruct.Gen = func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type { - s := comp.structs[t.Ident] - key := prog.StructKey{ - Name: t.Ident, - } - desc := comp.structDescs[key] - if desc == nil { - // Need to assign to structDescs before calling genStructDesc to break recursion. - desc = new(prog.StructDesc) - comp.structDescs[key] = desc - comp.genStructDesc(desc, s, typeStruct.Varlen(comp, t, args)) + if typ := comp.structTypes[t.Ident]; typ != nil { + return typ } + s := comp.structs[t.Ident] + common := genCommon(t.Ident, sizeUnassigned, false) + common.IsVarlen = typeStruct.Varlen(comp, t, args) + var typ prog.Type if s.IsUnion { - return &prog.UnionType{ - Key: key, - StructDesc: desc, + typ = &prog.UnionType{ + TypeCommon: common, + } + } else { + typ = &prog.StructType{ + TypeCommon: common, } } - return &prog.StructType{ - Key: key, - StructDesc: desc, + // Need to cache type in structTypes before generating fields to break recursion. + comp.structTypes[t.Ident] = typ + fields := comp.genFieldArray(s.Fields, make([]uint64, len(s.Fields))) + if s.IsUnion { + typ.(*prog.UnionType).Fields = fields + } else { + typ.(*prog.StructType).Fields = fields } + return typ } } diff --git a/prog/any.go b/prog/any.go index 1ce444ce2..ce3736a4b 100644 --- a/prog/any.go +++ b/prog/any.go @@ -36,7 +36,12 @@ type anyTypes struct { // resoct fmt[oct, ANYRES64] // ] [varlen] func initAnyTypes(target *Target) { - target.any.union = &UnionType{} + target.any.union = &UnionType{ + TypeCommon: TypeCommon{ + TypeName: "ANYUNION", + IsVarlen: true, + }, + } target.any.array = &ArrayType{ TypeCommon: TypeCommon{ TypeName: "ANYARRAY", @@ -89,20 +94,14 @@ func initAnyTypes(target *Target) { target.any.resdec = createResource("ANYRESDEC", "int64", FormatStrDec, 20) target.any.reshex = createResource("ANYRESHEX", "int64", FormatStrHex, 18) target.any.resoct = createResource("ANYRESOCT", "int64", FormatStrOct, 23) - target.any.union.StructDesc = &StructDesc{ - TypeCommon: TypeCommon{ - TypeName: "ANYUNION", - IsVarlen: true, - }, - Fields: []Field{ - {Name: "ANYBLOB", Type: target.any.blob}, - {Name: "ANYRES16", Type: target.any.res16}, - {Name: "ANYRES32", Type: target.any.res32}, - {Name: "ANYRES64", Type: target.any.res64}, - {Name: "ANYRESDEC", Type: target.any.resdec}, - {Name: "ANYRESHEX", Type: target.any.reshex}, - {Name: "ANYRESOCT", Type: target.any.resoct}, - }, + target.any.union.Fields = []Field{ + {Name: "ANYBLOB", Type: target.any.blob}, + {Name: "ANYRES16", Type: target.any.res16}, + {Name: "ANYRES32", Type: target.any.res32}, + {Name: "ANYRES64", Type: target.any.res64}, + {Name: "ANYRESDEC", Type: target.any.resdec}, + {Name: "ANYRESHEX", Type: target.any.reshex}, + {Name: "ANYRESOCT", Type: target.any.resoct}, } } diff --git a/prog/prio.go b/prog/prio.go index 43668d48d..648af0422 100644 --- a/prog/prio.go +++ b/prog/prio.go @@ -64,56 +64,55 @@ func (target *Target) calcStaticPriorities() [][]float32 { func (target *Target) calcResourceUsage() map[string]map[int]weights { uses := make(map[string]map[int]weights) - for _, c := range target.Syscalls { - foreachType(c, func(t Type, ctx typeCtx) { - switch a := t.(type) { - case *ResourceType: - if target.AuxResources[a.Desc.Name] { - noteUsage(uses, c, 0.1, ctx.Dir, "res%v", a.Desc.Name) - } else { - str := "res" - for i, k := range a.Desc.Kind { - str += "-" + k - w := 1.0 - if i < len(a.Desc.Kind)-1 { - w = 0.2 - } - noteUsage(uses, c, float32(w), ctx.Dir, str) + ForeachType(target.Syscalls, func(t Type, ctx TypeCtx) { + c := ctx.Meta + switch a := t.(type) { + case *ResourceType: + if target.AuxResources[a.Desc.Name] { + noteUsage(uses, c, 0.1, ctx.Dir, "res%v", a.Desc.Name) + } else { + str := "res" + for i, k := range a.Desc.Kind { + str += "-" + k + w := 1.0 + if i < len(a.Desc.Kind)-1 { + w = 0.2 } + noteUsage(uses, c, float32(w), ctx.Dir, str) } - case *PtrType: - if _, ok := a.Elem.(*StructType); ok { - noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Elem.Name()) - } - if _, ok := a.Elem.(*UnionType); ok { - noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Elem.Name()) - } - if arr, ok := a.Elem.(*ArrayType); ok { - noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", arr.Elem.Name()) - } - case *BufferType: - switch a.Kind { - case BufferBlobRand, BufferBlobRange, BufferText: - case BufferString: - if a.SubKind != "" { - noteUsage(uses, c, 0.2, ctx.Dir, fmt.Sprintf("str-%v", a.SubKind)) - } - case BufferFilename: - noteUsage(uses, c, 1.0, DirIn, "filename") - default: - panic("unknown buffer kind") - } - case *VmaType: - noteUsage(uses, c, 0.5, ctx.Dir, "vma") - case *IntType: - switch a.Kind { - case IntPlain, IntRange: - default: - panic("unknown int kind") + } + case *PtrType: + if _, ok := a.Elem.(*StructType); ok { + noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Elem.Name()) + } + if _, ok := a.Elem.(*UnionType); ok { + noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Elem.Name()) + } + if arr, ok := a.Elem.(*ArrayType); ok { + noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", arr.Elem.Name()) + } + case *BufferType: + switch a.Kind { + case BufferBlobRand, BufferBlobRange, BufferText: + case BufferString: + if a.SubKind != "" { + noteUsage(uses, c, 0.2, ctx.Dir, fmt.Sprintf("str-%v", a.SubKind)) } + case BufferFilename: + noteUsage(uses, c, 1.0, DirIn, "filename") + default: + panic("unknown buffer kind") } - }) - } + case *VmaType: + noteUsage(uses, c, 0.5, ctx.Dir, "vma") + case *IntType: + switch a.Kind { + case IntPlain, IntRange: + default: + panic("unknown int kind") + } + } + }) return uses } diff --git a/prog/prog_test.go b/prog/prog_test.go index 9b8f442e3..6374b6d25 100644 --- a/prog/prog_test.go +++ b/prog/prog_test.go @@ -20,15 +20,13 @@ func TestGeneration(t *testing.T) { func TestDefault(t *testing.T) { target, _, _ := initTest(t) - for _, meta := range target.Syscalls { - foreachType(meta, func(typ Type, ctx typeCtx) { - arg := typ.DefaultArg(ctx.Dir) - if !isDefault(arg) { - t.Errorf("default arg is not default: %s\ntype: %#v\narg: %#v", - typ, typ, arg) - } - }) - } + ForeachType(target.Syscalls, func(typ Type, ctx TypeCtx) { + arg := typ.DefaultArg(ctx.Dir) + if !isDefault(arg) { + t.Errorf("default arg is not default: %s\ntype: %#v\narg: %#v", + typ, typ, arg) + } + }) } func TestDefaultCallArgs(t *testing.T) { @@ -203,7 +201,7 @@ func TestSpecialStructs(t *testing.T) { t.Run(special, func(t *testing.T) { var typ Type for i := 0; i < len(target.Syscalls) && typ == nil; i++ { - foreachType(target.Syscalls[i], func(t Type, ctx typeCtx) { + ForeachCallType(target.Syscalls[i], func(t Type, ctx TypeCtx) { if ctx.Dir == DirOut { return } diff --git a/prog/rand_test.go b/prog/rand_test.go index 6f251cb7c..da0871505 100644 --- a/prog/rand_test.go +++ b/prog/rand_test.go @@ -101,22 +101,20 @@ func TestEnabledCalls(t *testing.T) { func TestSizeGenerateConstArg(t *testing.T) { target, rs, iters := initRandomTargetTest(t, "test", "64") r := newRand(target, rs) - for _, c := range target.Syscalls { - foreachType(c, func(typ Type, ctx typeCtx) { - if _, ok := typ.(*IntType); !ok { - return - } - bits := typ.TypeBitSize() - limit := uint64(1< limit { - t.Fatalf("invalid generated value: %d. (arg bitsize: %d; max value: %d)", newVal, bits, limit) - } + ForeachType(target.Syscalls, func(typ Type, ctx TypeCtx) { + if _, ok := typ.(*IntType); !ok { + return + } + bits := typ.TypeBitSize() + limit := uint64(1< limit { + t.Fatalf("invalid generated value: %d. (arg bitsize: %d; max value: %d)", newVal, bits, limit) } - }) - } + } + }) } func TestFlags(t *testing.T) { diff --git a/prog/resources.go b/prog/resources.go index b7bcecf95..311eb72dd 100644 --- a/prog/resources.go +++ b/prog/resources.go @@ -45,16 +45,14 @@ func (target *Target) calcResourceCtors(res *ResourceDesc, precise bool) []*Sysc func (target *Target) populateResourceCtors() { // Find resources that are created by each call. callsResources := make([][]*ResourceDesc, len(target.Syscalls)) - for call, meta := range target.Syscalls { - foreachType(meta, func(typ Type, ctx typeCtx) { - switch typ1 := typ.(type) { - case *ResourceType: - if ctx.Dir != DirIn { - callsResources[call] = append(callsResources[call], typ1.Desc) - } + ForeachType(target.Syscalls, func(typ Type, ctx TypeCtx) { + switch typ1 := typ.(type) { + case *ResourceType: + if ctx.Dir != DirIn { + callsResources[ctx.Meta.ID] = append(callsResources[ctx.Meta.ID], typ1.Desc) } - }) - } + } + }) // Populate resource ctors accounting for resource compatibility. for _, res := range target.Resources { @@ -126,7 +124,7 @@ func isCompatibleResourceImpl(dst, src []string, precise bool) bool { func (target *Target) getInputResources(c *Syscall) []*ResourceDesc { var resources []*ResourceDesc - foreachType(c, func(typ Type, ctx typeCtx) { + ForeachCallType(c, func(typ Type, ctx TypeCtx) { if ctx.Dir == DirOut { return } @@ -146,7 +144,7 @@ func (target *Target) getInputResources(c *Syscall) []*ResourceDesc { func (target *Target) getOutputResources(c *Syscall) []*ResourceDesc { var resources []*ResourceDesc - foreachType(c, func(typ Type, ctx typeCtx) { + ForeachCallType(c, func(typ Type, ctx TypeCtx) { switch typ1 := typ.(type) { case *ResourceType: if ctx.Dir != DirIn { diff --git a/prog/rotation.go b/prog/rotation.go index 47ee2ca81..0171e6ace 100644 --- a/prog/rotation.go +++ b/prog/rotation.go @@ -50,7 +50,7 @@ func MakeRotator(target *Target, calls map[*Syscall]bool, rnd *rand.Rand) *Rotat } // VMAs and filenames are effectively resources for our purposes // (but they don't have ctors). - foreachType(call, func(t Type, _ typeCtx) { + ForeachCallType(call, func(t Type, _ TypeCtx) { switch a := t.(type) { case *BufferType: switch a.Kind { diff --git a/prog/target.go b/prog/target.go index 89940a61a..8999f25ac 100644 --- a/prog/target.go +++ b/prog/target.go @@ -22,7 +22,6 @@ type Target struct { Syscalls []*Syscall Resources []*ResourceDesc - Structs []*KeyedStruct Consts []ConstValue Types []Type @@ -137,8 +136,7 @@ func (target *Target) initTarget() { target.ConstMap[c.Name] = c.Value } - target.resourceMap = restoreLinks(target.Syscalls, target.Resources, target.Structs, target.Types) - target.Structs = nil + target.resourceMap = restoreLinks(target.Syscalls, target.Resources, target.Types) target.Types = nil target.SyscallMap = make(map[string]*Syscall) @@ -173,63 +171,30 @@ func (target *Target) sanitize(c *Call, fix bool) error { return nil } -func RestoreLinks(syscalls []*Syscall, resources []*ResourceDesc, structs []*KeyedStruct, types []Type) { - restoreLinks(syscalls, resources, structs, types) +func RestoreLinks(syscalls []*Syscall, resources []*ResourceDesc, types []Type) { + restoreLinks(syscalls, resources, types) } -func restoreLinks(syscalls []*Syscall, resources []*ResourceDesc, structs []*KeyedStruct, - types []Type) map[string]*ResourceDesc { +func restoreLinks(syscalls []*Syscall, resources []*ResourceDesc, types []Type) map[string]*ResourceDesc { resourceMap := make(map[string]*ResourceDesc) for _, res := range resources { resourceMap[res.Name] = res } - keyedStructs := make(map[StructKey]*StructDesc) - for _, desc := range structs { - keyedStructs[desc.Key] = desc.Desc - for i := range desc.Desc.Fields { - unref(&desc.Desc.Fields[i].Type, types) + ForeachType(syscalls, func(_ Type, ctx TypeCtx) { + if ref, ok := (*ctx.Ptr).(Ref); ok { + *ctx.Ptr = types[ref] } - } - for _, c := range syscalls { - for i := range c.Args { - unref(&c.Args[i].Type, types) - } - if c.Ret != nil { - unref(&c.Ret, types) - } - foreachType(c, func(t0 Type, _ typeCtx) { - switch t := t0.(type) { - case *PtrType: - unref(&t.Elem, types) - case *ArrayType: - unref(&t.Elem, types) - case *ResourceType: - t.Desc = resourceMap[t.TypeName] - if t.Desc == nil { - panic("no resource desc") - } - case *StructType: - t.StructDesc = keyedStructs[t.Key] - if t.StructDesc == nil { - panic("no struct desc") - } - case *UnionType: - t.StructDesc = keyedStructs[t.Key] - if t.StructDesc == nil { - panic("no union desc") - } + switch t := (*ctx.Ptr).(type) { + case *ResourceType: + t.Desc = resourceMap[t.TypeName] + if t.Desc == nil { + panic("no resource desc") } - }) - } + } + }) return resourceMap } -func unref(tp *Type, types []Type) { - if ref, ok := (*tp).(Ref); ok { - *tp = types[ref] - } -} - type Gen struct { r *randGen s *state diff --git a/prog/types.go b/prog/types.go index 4412239d3..33c7bd356 100644 --- a/prog/types.go +++ b/prog/types.go @@ -110,7 +110,6 @@ type Ref uint32 func (ti Ref) String() string { panic("prog.Ref method called") } func (ti Ref) Name() string { panic("prog.Ref method called") } -func (ti Ref) FieldName() string { panic("prog.Ref method called") } func (ti Ref) TemplateName() string { panic("prog.Ref method called") } func (ti Ref) Optional() bool { panic("prog.Ref method called") } @@ -579,8 +578,9 @@ func (t *PtrType) isDefaultArg(arg Arg) bool { } type StructType struct { - Key StructKey - *StructDesc + TypeCommon + Fields []Field + AlignAttr uint64 } func (t *StructType) String() string { @@ -606,8 +606,8 @@ func (t *StructType) isDefaultArg(arg Arg) bool { } type UnionType struct { - Key StructKey - *StructDesc + TypeCommon + Fields []Field } func (t *UnionType) String() string { @@ -623,64 +623,80 @@ func (t *UnionType) isDefaultArg(arg Arg) bool { return a.Index == 0 && isDefault(a.Option) } -type StructDesc struct { - TypeCommon - Fields []Field - AlignAttr uint64 +type ConstValue struct { + Name string + Value uint64 } -type StructKey struct { - Name string +type TypeCtx struct { + Meta *Syscall + Dir Dir + Ptr *Type } -type KeyedStruct struct { - Key StructKey - Desc *StructDesc +func ForeachType(syscalls []*Syscall, f func(t Type, ctx TypeCtx)) { + for _, meta := range syscalls { + foreachTypeImpl(meta, true, f) + } } -type ConstValue struct { - Name string - Value uint64 +func ForeachTypePost(syscalls []*Syscall, f func(t Type, ctx TypeCtx)) { + for _, meta := range syscalls { + foreachTypeImpl(meta, false, f) + } } -type typeCtx struct { - Dir Dir +func ForeachCallType(meta *Syscall, f func(t Type, ctx TypeCtx)) { + foreachTypeImpl(meta, true, f) } -func foreachType(meta *Syscall, f func(t Type, ctx typeCtx)) { - var rec func(t Type, dir Dir) - seen := make(map[*StructDesc]bool) - recStruct := func(desc *StructDesc, dir Dir) { - if seen[desc] { - return // prune recursion via pointers to structs/unions - } - seen[desc] = true - for _, f := range desc.Fields { - rec(f.Type, dir) +func foreachTypeImpl(meta *Syscall, preorder bool, f func(t Type, ctx TypeCtx)) { + // Note: we specifically don't create seen in ForeachType. + // It would prune recursion more (across syscalls), but lots of users need to + // visit each struct per-syscall (e.g. prio, used resources). + seen := make(map[Type]bool) + var rec func(*Type, Dir) + rec = func(ptr *Type, dir Dir) { + if preorder { + f(*ptr, TypeCtx{Meta: meta, Dir: dir, Ptr: ptr}) } - } - rec = func(t Type, dir Dir) { - f(t, typeCtx{Dir: dir}) - switch a := t.(type) { + switch a := (*ptr).(type) { case *PtrType: - rec(a.Elem, a.ElemDir) + rec(&a.Elem, a.ElemDir) case *ArrayType: - rec(a.Elem, dir) + rec(&a.Elem, dir) case *StructType: - recStruct(a.StructDesc, dir) + if seen[a] { + break // prune recursion via pointers to structs/unions + } + seen[a] = true + for i := range a.Fields { + rec(&a.Fields[i].Type, dir) + } case *UnionType: - recStruct(a.StructDesc, dir) - case *ResourceType, *BufferType, *VmaType, *LenType, - *FlagsType, *ConstType, *IntType, *ProcType, *CsumType: + if seen[a] { + break // prune recursion via pointers to structs/unions + } + seen[a] = true + for i := range a.Fields { + rec(&a.Fields[i].Type, dir) + } + case *ResourceType, *BufferType, *VmaType, *LenType, *FlagsType, + *ConstType, *IntType, *ProcType, *CsumType: + case Ref: + // This is only needed for pkg/compiler. default: panic("unknown type") } + if !preorder { + f(*ptr, TypeCtx{Meta: meta, Dir: dir, Ptr: ptr}) + } } - for _, field := range meta.Args { - rec(field.Type, DirIn) + for i := range meta.Args { + rec(&meta.Args[i].Type, DirIn) } if meta.Ret != nil { - rec(meta.Ret, DirOut) + rec(&meta.Ret, DirOut) } } diff --git a/sys/syz-sysgen/sysgen.go b/sys/syz-sysgen/sysgen.go index 1197a8cce..bf39ba9ee 100644 --- a/sys/syz-sysgen/sysgen.go +++ b/sys/syz-sysgen/sysgen.go @@ -189,20 +189,16 @@ func generate(target *targets.Target, prg *compiler.Prog, consts map[string]uint fmt.Fprintf(out, "\tRegisterTarget(&Target{"+ "OS: %q, Arch: %q, Revision: revision_%v, PtrSize: %v, "+ "PageSize: %v, NumPages: %v, DataOffset: %v, Syscalls: syscalls_%v, "+ - "Resources: resources_%v, Structs: structDescs_%v, Types: types_%v, Consts: consts_%v}, "+ + "Resources: resources_%v, Types: types_%v, Consts: consts_%v}, "+ "InitTarget)\n}\n\n", target.OS, target.Arch, target.Arch, target.PtrSize, target.PageSize, target.NumPages, target.DataOffset, - target.Arch, target.Arch, target.Arch, target.Arch, target.Arch) + target.Arch, target.Arch, target.Arch, target.Arch) fmt.Fprintf(out, "var resources_%v = ", target.Arch) serializer.Write(out, prg.Resources) fmt.Fprintf(out, "\n\n") - fmt.Fprintf(out, "var structDescs_%v = ", target.Arch) - serializer.Write(out, prg.StructDescs) - fmt.Fprintf(out, "\n\n") - fmt.Fprintf(out, "var syscalls_%v = ", target.Arch) serializer.Write(out, prg.Syscalls) fmt.Fprintf(out, "\n\n") diff --git a/tools/syz-check/check.go b/tools/syz-check/check.go index 1d74584af..49152f481 100644 --- a/tools/syz-check/check.go +++ b/tools/syz-check/check.go @@ -89,7 +89,7 @@ func check(OS, arch, obj string, dwarf, netlink bool) ([]Warn, error) { if obj == "" { return nil, fmt.Errorf("no object file in -obj-%v flag", arch) } - structDescs, locs, warnings1, err := parseDescriptions(OS, arch) + structTypes, locs, warnings1, err := parseDescriptions(OS, arch) if err != nil { return nil, err } @@ -99,14 +99,14 @@ func check(OS, arch, obj string, dwarf, netlink bool) ([]Warn, error) { if err != nil { return nil, err } - warnings2, err := checkImpl(structs, structDescs, locs) + warnings2, err := checkImpl(structs, structTypes, locs) if err != nil { return nil, err } warnings = append(warnings, warnings2...) } if netlink { - warnings3, err := checkNetlink(OS, arch, obj, structDescs, locs) + warnings3, err := checkNetlink(OS, arch, obj, structTypes, locs) if err != nil { return nil, err } @@ -197,16 +197,10 @@ func writeWarnings(OS string, narches int, warnings []Warn) error { return nil } -func checkImpl(structs map[string]*dwarf.StructType, structDescs []*prog.KeyedStruct, +func checkImpl(structs map[string]*dwarf.StructType, structTypes []*prog.StructType, locs map[string]*ast.Struct) ([]Warn, error) { var warnings []Warn - checked := make(map[string]bool) - for _, str := range structDescs { - typ := str.Desc - if checked[typ.Name()] { - continue - } - checked[typ.Name()] = true + for _, typ := range structTypes { name := typ.TemplateName() astStruct := locs[name] if astStruct == nil { @@ -221,7 +215,7 @@ func checkImpl(structs map[string]*dwarf.StructType, structDescs []*prog.KeyedSt return warnings, nil } -func checkStruct(typ *prog.StructDesc, astStruct *ast.Struct, str *dwarf.StructType) ([]Warn, error) { +func checkStruct(typ *prog.StructType, astStruct *ast.Struct, str *dwarf.StructType) ([]Warn, error) { var warnings []Warn warn := func(pos ast.Pos, typ, msg string, args ...interface{}) { warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)}) @@ -307,7 +301,7 @@ func checkStruct(typ *prog.StructDesc, astStruct *ast.Struct, str *dwarf.StructT return warnings, nil } -func parseDescriptions(OS, arch string) ([]*prog.KeyedStruct, map[string]*ast.Struct, []Warn, error) { +func parseDescriptions(OS, arch string) ([]*prog.StructType, map[string]*ast.Struct, []Warn, error) { errorBuf := new(bytes.Buffer) var warnings []Warn eh := func(pos ast.Pos, msg string) { @@ -326,7 +320,7 @@ func parseDescriptions(OS, arch string) ([]*prog.KeyedStruct, map[string]*ast.St if prg == nil { return nil, nil, nil, fmt.Errorf("failed to compile descriptions:\n%s", errorBuf.Bytes()) } - prog.RestoreLinks(prg.Syscalls, prg.Resources, prg.StructDescs, prg.Types) + prog.RestoreLinks(prg.Syscalls, prg.Resources, prg.Types) locs := make(map[string]*ast.Struct) for _, decl := range top.Nodes { switch n := decl.(type) { @@ -338,7 +332,13 @@ func parseDescriptions(OS, arch string) ([]*prog.KeyedStruct, map[string]*ast.St } } } - return prg.StructDescs, locs, warnings, nil + var structs []*prog.StructType + for _, typ := range prg.Types { + if t, ok := typ.(*prog.StructType); ok { + structs = append(structs, t) + } + } + return structs, locs, warnings, nil } // Overall idea of netlink checking. @@ -350,7 +350,7 @@ func parseDescriptions(OS, arch string) ([]*prog.KeyedStruct, map[string]*ast.St // Then read in the symbol data, which is an array of nla_policy structs. // These structs allow to easily figure out type/size of attributes. // Finally we compare our descriptions with the kernel policy description. -func checkNetlink(OS, arch, obj string, structDescs []*prog.KeyedStruct, +func checkNetlink(OS, arch, obj string, structTypes []*prog.StructType, locs map[string]*ast.Struct) ([]Warn, error) { if arch != "amd64" { // Netlink policies are arch-independent (?), @@ -374,24 +374,18 @@ func checkNetlink(OS, arch, obj string, structDescs []*prog.KeyedStruct, warn := func(pos ast.Pos, typ, msg string, args ...interface{}) { warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)}) } - structMap := make(map[string]*prog.StructDesc) - for _, str := range structDescs { - structMap[str.Desc.Name()] = str.Desc + structMap := make(map[string]*prog.StructType) + for _, typ := range structTypes { + structMap[typ.Name()] = typ } - checked := make(map[string]bool) checkedAttrs := make(map[string]*checkAttr) - for _, str := range structDescs { - typ := str.Desc - if checked[typ.Name()] { - continue - } - checked[typ.Name()] = true + for _, typ := range structTypes { name := typ.TemplateName() astStruct := locs[name] if astStruct == nil { continue } - if !isNetlinkPolicy(typ) { + if !isNetlinkPolicy(typ.Fields) { continue } kernelName := name @@ -499,9 +493,9 @@ func checkMissingAttrs(checkedAttrs map[string]*checkAttr) []Warn { return warnings } -func isNetlinkPolicy(typ *prog.StructDesc) bool { +func isNetlinkPolicy(fields []prog.Field) bool { haveAttr := false - for _, fld := range typ.Fields { + for _, fld := range fields { field := fld.Type if prog.IsPad(field) { continue @@ -514,12 +508,12 @@ func isNetlinkPolicy(typ *prog.StructDesc) bool { field = arr.Elem } if field1, ok := field.(*prog.StructType); ok { - if isNetlinkPolicy(field1.StructDesc) { + if isNetlinkPolicy(field1.Fields) { continue } } if field1, ok := field.(*prog.UnionType); ok { - if isNetlinkPolicy(field1.StructDesc) { + if isNetlinkPolicy(field1.Fields) { continue } } @@ -533,7 +527,7 @@ func isNlattr(typ prog.Type) bool { return name == "nlattr_t" || name == "nlattr_tt" } -func checkNetlinkPolicy(structMap map[string]*prog.StructDesc, typ *prog.StructDesc, +func checkNetlinkPolicy(structMap map[string]*prog.StructType, typ *prog.StructType, astStruct *ast.Struct, policy []nlaPolicy) ([]Warn, map[int]bool, error) { var warnings []Warn warn := func(pos ast.Pos, typ, msg string, args ...interface{}) { @@ -571,7 +565,7 @@ func checkNetlinkPolicy(structMap map[string]*prog.StructDesc, typ *prog.StructD return warnings, checked, nil } -func checkNetlinkAttr(typ *prog.StructDesc, policy nlaPolicy) string { +func checkNetlinkAttr(typ *prog.StructType, policy nlaPolicy) string { payload := typ.Fields[2].Type if typ.TemplateName() == "nlattr_tt" { payload = typ.Fields[4].Type @@ -637,7 +631,7 @@ func minTypeSize(typ prog.Type) int { return -1 } -func checkAttrType(typ *prog.StructDesc, payload prog.Type, policy nlaPolicy) string { +func checkAttrType(typ *prog.StructType, payload prog.Type, policy nlaPolicy) string { switch policy.typ { case NLA_STRING, NLA_NUL_STRING: if _, ok := payload.(*prog.BufferType); !ok { -- cgit mrf-deployment