aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2020-05-03 11:29:12 +0200
committerDmitry Vyukov <dvyukov@google.com>2020-05-03 12:55:42 +0200
commit58ae5e18624eaaac79cab00e63d6f32c9bd64ee0 (patch)
tree00515dd9b2e461102e898930df00bc80400bf996
parent5457883a514281287bbd81364c4e26e25828563d (diff)
prog: remove StructDesc
Remove StructDesc, KeyedStruct, StructKey and all associated logic/complexity in prog and pkg/compiler. We can now handle recursion more generically with the Ref type, and Dir/FieldName are not a part of the type anymore. This makes StructType/UnionType simpler and more natural. Reduces size of sys/linux/gen/amd64.go from 5201321 to 4180861 (-20%). Update #1580
-rw-r--r--pkg/compiler/compiler.go18
-rw-r--r--pkg/compiler/compiler_test.go4
-rw-r--r--pkg/compiler/gen.go250
-rw-r--r--pkg/compiler/types.go38
-rw-r--r--prog/any.go29
-rw-r--r--prog/prio.go91
-rw-r--r--prog/prog_test.go18
-rw-r--r--prog/rand_test.go28
-rw-r--r--prog/resources.go20
-rw-r--r--prog/rotation.go2
-rw-r--r--prog/target.go63
-rw-r--r--prog/types.go100
-rw-r--r--sys/syz-sysgen/sysgen.go8
-rw-r--r--tools/syz-check/check.go62
14 files changed, 300 insertions, 431 deletions
diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go
index f61d02036..55ddc1ba4 100644
--- a/pkg/compiler/compiler.go
+++ b/pkg/compiler/compiler.go
@@ -34,10 +34,9 @@ import (
// Prog is description compilation result.
type Prog struct {
- Resources []*prog.ResourceDesc
- Syscalls []*prog.Syscall
- StructDescs []*prog.KeyedStruct
- Types []prog.Type
+ Resources []*prog.ResourceDesc
+ Syscalls []*prog.Syscall
+ Types []prog.Type
// Set of unsupported syscalls/flags.
Unsupported map[string]bool
// Returned if consts was nil.
@@ -61,9 +60,8 @@ func createCompiler(desc *ast.Description, target *targets.Target, eh ast.ErrorH
strFlags: make(map[string]*ast.StrFlags),
used: make(map[string]bool),
usedTypedefs: make(map[string]bool),
- structDescs: make(map[prog.StructKey]*prog.StructDesc),
- structNodes: make(map[*prog.StructDesc]*ast.Struct),
structVarlen: make(map[string]bool),
+ structTypes: make(map[string]prog.Type),
builtinConsts: map[string]uint64{
"PTR_SIZE": target.PtrSize,
},
@@ -104,12 +102,11 @@ func Compile(desc *ast.Description, consts map[string]uint64, target *targets.Ta
return nil
}
syscalls := comp.genSyscalls()
- structs := comp.genStructDescs(syscalls)
- types := comp.generateTypes(syscalls, structs)
+ comp.layoutTypes(syscalls)
+ types := comp.generateTypes(syscalls)
prg := &Prog{
Resources: comp.genResources(),
Syscalls: syscalls,
- StructDescs: structs,
Types: types,
Unsupported: comp.unsupported,
}
@@ -139,9 +136,8 @@ type compiler struct {
used map[string]bool // contains used structs/resources
usedTypedefs map[string]bool
- structDescs map[prog.StructKey]*prog.StructDesc
- structNodes map[*prog.StructDesc]*ast.Struct
structVarlen map[string]bool
+ structTypes map[string]prog.Type
builtinConsts map[string]uint64
}
diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go
index ae3b3d357..3d4ee3e64 100644
--- a/pkg/compiler/compiler_test.go
+++ b/pkg/compiler/compiler_test.go
@@ -122,8 +122,6 @@ func TestData(t *testing.T) {
out := new(bytes.Buffer)
fmt.Fprintf(out, "\n\nRESOURCES:\n")
serializer.Write(out, desc.Resources)
- fmt.Fprintf(out, "\n\nSTRUCTS:\n")
- serializer.Write(out, desc.StructDescs)
fmt.Fprintf(out, "\n\nSYSCALLS:\n")
serializer.Write(out, desc.Syscalls)
if false {
@@ -188,8 +186,6 @@ s2 {
if p == nil {
t.Fatal("failed to compile")
}
- got := p.StructDescs[0].Desc
- t.Logf("got: %#v", got)
}
func TestCollectUnusedError(t *testing.T) {
diff --git a/pkg/compiler/gen.go b/pkg/compiler/gen.go
index 748941bae..a86190863 100644
--- a/pkg/compiler/gen.go
+++ b/pkg/compiler/gen.go
@@ -138,26 +138,42 @@ func (comp *compiler) genSyscall(n *ast.Call, argSizes []uint64) *prog.Syscall {
type typeProxy struct {
typ prog.Type
id string
+ ref prog.Ref
locations []*prog.Type
}
-func (comp *compiler) generateTypes(syscalls []*prog.Syscall, structs []*prog.KeyedStruct) []prog.Type {
+func (comp *compiler) generateTypes(syscalls []*prog.Syscall) []prog.Type {
// Replace all Type's in the descriptions with Ref's
// and prepare a sorted array of corresponding real types.
proxies := make(map[string]*typeProxy)
- for _, call := range syscalls {
- for i := range call.Args {
- comp.collectTypes(proxies, &call.Args[i].Type)
+ prog.ForeachTypePost(syscalls, func(typ prog.Type, ctx prog.TypeCtx) {
+ if _, ok := typ.(prog.Ref); ok {
+ return
}
- if call.Ret != nil {
- comp.collectTypes(proxies, &call.Ret)
+ if !typ.Varlen() && typ.Size() == sizeUnassigned {
+ panic("unassigned size")
}
- }
- for _, str := range structs {
- for i := range str.Desc.Fields {
- comp.collectTypes(proxies, &str.Desc.Fields[i].Type)
+ id := typ.Name()
+ switch typ.(type) {
+ case *prog.StructType, *prog.UnionType:
+ // There types can be uniquely identified with the name.
+ default:
+ buf := new(bytes.Buffer)
+ serializer.Write(buf, typ)
+ id = buf.String()
}
- }
+ proxy := proxies[id]
+ if proxy == nil {
+ proxy = &typeProxy{
+ typ: typ,
+ id: id,
+ ref: prog.Ref(len(proxies)),
+ }
+ proxies[id] = proxy
+ }
+ *ctx.Ptr = proxy.ref
+ proxy.locations = append(proxy.locations, ctx.Ptr)
+ })
array := make([]*typeProxy, 0, len(proxies))
for _, proxy := range proxies {
array = append(array, proxy)
@@ -175,142 +191,75 @@ func (comp *compiler) generateTypes(syscalls []*prog.Syscall, structs []*prog.Ke
return types
}
-func (comp *compiler) collectTypes(proxies map[string]*typeProxy, tptr *prog.Type) {
- typ := *tptr
- switch t := typ.(type) {
- case *prog.PtrType:
- comp.collectTypes(proxies, &t.Elem)
- case *prog.ArrayType:
- comp.collectTypes(proxies, &t.Elem)
- case *prog.ResourceType, *prog.BufferType, *prog.VmaType, *prog.LenType,
- *prog.FlagsType, *prog.ConstType, *prog.IntType, *prog.ProcType,
- *prog.CsumType, *prog.StructType, *prog.UnionType:
- default:
- panic("unknown type")
- }
- buf := new(bytes.Buffer)
- serializer.Write(buf, typ)
- id := buf.String()
- proxy := proxies[id]
- if proxy == nil {
- proxy = &typeProxy{
- typ: typ,
- id: id,
- }
- proxies[id] = proxy
- }
- proxy.locations = append(proxy.locations, tptr)
-}
-
-func (comp *compiler) genStructDescs(syscalls []*prog.Syscall) []*prog.KeyedStruct {
- // Calculate struct/union/array sizes, add padding to structs and detach
- // StructDesc's from StructType's. StructType's can be recursive so it's
- // not possible to write them out inline as other types. To break the
- // recursion detach them, and write StructDesc's out as separate array
- // of KeyedStruct's. prog package will reattach them during init.
- ctx := &structGen{
- comp: comp,
- padded: make(map[interface{}]bool),
- detach: make(map[**prog.StructDesc]bool),
- }
- // We have to do this in the loop until we pad nothing new
- // due to recursive structs.
- for {
- start := len(ctx.padded)
- for _, c := range syscalls {
- for _, a := range c.Args {
- ctx.walk(a.Type)
- }
- if c.Ret != nil {
- ctx.walk(c.Ret)
- }
- }
- if start == len(ctx.padded) {
- break
- }
- }
-
- // Detach StructDesc's from StructType's. prog will reattach them again.
- for descp := range ctx.detach {
- *descp = nil
- }
-
- sort.Slice(ctx.structs, func(i, j int) bool {
- si, sj := ctx.structs[i].Key, ctx.structs[j].Key
- return si.Name < sj.Name
+func (comp *compiler) layoutTypes(syscalls []*prog.Syscall) {
+ // Calculate struct/union/array sizes, add padding to structs, mark bitfields.
+ padded := make(map[prog.Type]bool)
+ prog.ForeachTypePost(syscalls, func(typ prog.Type, _ prog.TypeCtx) {
+ comp.layoutType(typ, padded)
})
- return ctx.structs
-}
-
-type structGen struct {
- comp *compiler
- padded map[interface{}]bool
- detach map[**prog.StructDesc]bool
- structs []*prog.KeyedStruct
}
-func (ctx *structGen) check(key prog.StructKey, descp **prog.StructDesc) bool {
- ctx.detach[descp] = true
- desc := *descp
- if ctx.padded[desc] {
- return false
- }
- ctx.padded[desc] = true
- for _, f := range desc.Fields {
- ctx.walk(f.Type)
- if !f.Varlen() && f.Size() == sizeUnassigned {
- // An inner struct is not padded yet.
- // Leave this struct for next iteration.
- delete(ctx.padded, desc)
- return false
- }
- }
- if ctx.comp.used[key.Name] {
- ctx.structs = append(ctx.structs, &prog.KeyedStruct{
- Key: key,
- Desc: desc,
- })
+func (comp *compiler) layoutType(typ prog.Type, padded map[prog.Type]bool) {
+ if padded[typ] {
+ return
}
- return true
-}
-
-func (ctx *structGen) walk(t0 prog.Type) {
- switch t := t0.(type) {
- case *prog.PtrType:
- ctx.walk(t.Elem)
+ switch t := typ.(type) {
case *prog.ArrayType:
- ctx.walkArray(t)
+ comp.layoutType(t.Elem, padded)
+ comp.layoutArray(t)
case *prog.StructType:
- ctx.walkStruct(t)
+ for _, f := range t.Fields {
+ comp.layoutType(f.Type, padded)
+ }
+ comp.layoutStruct(t)
case *prog.UnionType:
- ctx.walkUnion(t)
- }
-}
-
-func (ctx *structGen) walkArray(t *prog.ArrayType) {
- if ctx.padded[t] {
+ for _, f := range t.Fields {
+ comp.layoutType(f.Type, padded)
+ }
+ comp.layoutUnion(t)
+ default:
return
}
- ctx.walk(t.Elem)
- if !t.Elem.Varlen() && t.Elem.Size() == sizeUnassigned {
- // An inner struct is not padded yet.
- // Leave this array for next iteration.
- return
+ if !typ.Varlen() && typ.Size() == sizeUnassigned {
+ panic("size unassigned")
}
- ctx.padded[t] = true
+ padded[typ] = true
+}
+
+func (comp *compiler) layoutArray(t *prog.ArrayType) {
t.TypeSize = 0
if t.Kind == prog.ArrayRangeLen && t.RangeBegin == t.RangeEnd && !t.Elem.Varlen() {
t.TypeSize = t.RangeBegin * t.Elem.Size()
}
}
-func (ctx *structGen) walkStruct(t *prog.StructType) {
- if !ctx.check(t.Key, &t.StructDesc) {
+func (comp *compiler) layoutUnion(t *prog.UnionType) {
+ structNode := comp.structs[t.TypeName]
+ attrs := comp.parseAttrs(unionAttrs, structNode, structNode.Attrs)
+ t.TypeSize = 0
+ if attrs[attrVarlen] != 0 {
return
}
- comp := ctx.comp
- structNode := comp.structNodes[t.StructDesc]
+ sizeAttr, hasSize := attrs[attrSize]
+ for i, fld := range t.Fields {
+ sz := fld.Size()
+ if hasSize && sz > sizeAttr {
+ comp.error(structNode.Fields[i].Pos, "union %v has size attribute %v"+
+ " which is less than field %v size %v",
+ structNode.Name.Name, sizeAttr, fld.Type.Name(), sz)
+ }
+ if t.TypeSize < sz {
+ t.TypeSize = sz
+ }
+ }
+ if hasSize {
+ t.TypeSize = sizeAttr
+ }
+}
+
+func (comp *compiler) layoutStruct(t *prog.StructType) {
// Add paddings, calculate size, mark bitfields.
+ structNode := comp.structs[t.TypeName]
varlen := false
for _, f := range t.Fields {
if f.Varlen() {
@@ -319,7 +268,7 @@ func (ctx *structGen) walkStruct(t *prog.StructType) {
}
attrs := comp.parseAttrs(structAttrs, structNode, structNode.Attrs)
t.AlignAttr = attrs[attrAlign]
- comp.layoutStruct(t, varlen, attrs[attrPacked] != 0)
+ comp.layoutStructFields(t, varlen, attrs[attrPacked] != 0)
t.TypeSize = 0
if !varlen {
for _, f := range t.Fields {
@@ -340,46 +289,7 @@ func (ctx *structGen) walkStruct(t *prog.StructType) {
}
}
-func (ctx *structGen) walkUnion(t *prog.UnionType) {
- if !ctx.check(t.Key, &t.StructDesc) {
- return
- }
- comp := ctx.comp
- structNode := comp.structNodes[t.StructDesc]
- attrs := comp.parseAttrs(unionAttrs, structNode, structNode.Attrs)
- t.TypeSize = 0
- if attrs[attrVarlen] != 0 {
- return
- }
- sizeAttr, hasSize := attrs[attrSize]
- for i, fld := range t.Fields {
- sz := fld.Size()
- if hasSize && sz > sizeAttr {
- comp.error(structNode.Fields[i].Pos, "union %v has size attribute %v"+
- " which is less than field %v size %v",
- structNode.Name.Name, sizeAttr, fld.Type.Name(), sz)
- }
- if t.TypeSize < sz {
- t.TypeSize = sz
- }
- }
- if hasSize {
- t.TypeSize = sizeAttr
- }
-}
-
-func (comp *compiler) genStructDesc(res *prog.StructDesc, n *ast.Struct, varlen bool) {
- // Leave node for genStructDescs to calculate size/padding.
- comp.structNodes[res] = n
- common := genCommon(n.Name.Name, sizeUnassigned, false)
- common.IsVarlen = varlen
- *res = prog.StructDesc{
- TypeCommon: common,
- Fields: comp.genFieldArray(n.Fields, make([]uint64, len(n.Fields))),
- }
-}
-
-func (comp *compiler) layoutStruct(t *prog.StructType, varlen, packed bool) {
+func (comp *compiler) layoutStructFields(t *prog.StructType, varlen, packed bool) {
var newFields []prog.Field
var structAlign, byteOffset, bitOffset uint64
for i, field := range t.Fields {
@@ -538,7 +448,7 @@ func (comp *compiler) typeAlign(t0 prog.Type) uint64 {
case *prog.ArrayType:
return comp.typeAlign(t.Elem)
case *prog.StructType:
- n := comp.structNodes[t.StructDesc]
+ n := comp.structs[t.TypeName]
attrs := comp.parseAttrs(structAttrs, n, n.Attrs)
if align := attrs[attrAlign]; align != 0 {
return align // overrided by user attribute
diff --git a/pkg/compiler/types.go b/pkg/compiler/types.go
index 7ed12afd2..44ccb1f1d 100644
--- a/pkg/compiler/types.go
+++ b/pkg/compiler/types.go
@@ -237,7 +237,7 @@ var typeArray = &typeDesc{
RangeEnd: end,
}
}
- // TypeSize is assigned later in genStructDescs.
+ // TypeSize is assigned later in layoutArray.
return &prog.ArrayType{
TypeCommon: base.TypeCommon,
Elem: elemType,
@@ -815,27 +815,31 @@ func init() {
return true
}
typeStruct.Gen = func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type {
- s := comp.structs[t.Ident]
- key := prog.StructKey{
- Name: t.Ident,
- }
- desc := comp.structDescs[key]
- if desc == nil {
- // Need to assign to structDescs before calling genStructDesc to break recursion.
- desc = new(prog.StructDesc)
- comp.structDescs[key] = desc
- comp.genStructDesc(desc, s, typeStruct.Varlen(comp, t, args))
+ if typ := comp.structTypes[t.Ident]; typ != nil {
+ return typ
}
+ s := comp.structs[t.Ident]
+ common := genCommon(t.Ident, sizeUnassigned, false)
+ common.IsVarlen = typeStruct.Varlen(comp, t, args)
+ var typ prog.Type
if s.IsUnion {
- return &prog.UnionType{
- Key: key,
- StructDesc: desc,
+ typ = &prog.UnionType{
+ TypeCommon: common,
+ }
+ } else {
+ typ = &prog.StructType{
+ TypeCommon: common,
}
}
- return &prog.StructType{
- Key: key,
- StructDesc: desc,
+ // Need to cache type in structTypes before generating fields to break recursion.
+ comp.structTypes[t.Ident] = typ
+ fields := comp.genFieldArray(s.Fields, make([]uint64, len(s.Fields)))
+ if s.IsUnion {
+ typ.(*prog.UnionType).Fields = fields
+ } else {
+ typ.(*prog.StructType).Fields = fields
}
+ return typ
}
}
diff --git a/prog/any.go b/prog/any.go
index 1ce444ce2..ce3736a4b 100644
--- a/prog/any.go
+++ b/prog/any.go
@@ -36,7 +36,12 @@ type anyTypes struct {
// resoct fmt[oct, ANYRES64]
// ] [varlen]
func initAnyTypes(target *Target) {
- target.any.union = &UnionType{}
+ target.any.union = &UnionType{
+ TypeCommon: TypeCommon{
+ TypeName: "ANYUNION",
+ IsVarlen: true,
+ },
+ }
target.any.array = &ArrayType{
TypeCommon: TypeCommon{
TypeName: "ANYARRAY",
@@ -89,20 +94,14 @@ func initAnyTypes(target *Target) {
target.any.resdec = createResource("ANYRESDEC", "int64", FormatStrDec, 20)
target.any.reshex = createResource("ANYRESHEX", "int64", FormatStrHex, 18)
target.any.resoct = createResource("ANYRESOCT", "int64", FormatStrOct, 23)
- target.any.union.StructDesc = &StructDesc{
- TypeCommon: TypeCommon{
- TypeName: "ANYUNION",
- IsVarlen: true,
- },
- Fields: []Field{
- {Name: "ANYBLOB", Type: target.any.blob},
- {Name: "ANYRES16", Type: target.any.res16},
- {Name: "ANYRES32", Type: target.any.res32},
- {Name: "ANYRES64", Type: target.any.res64},
- {Name: "ANYRESDEC", Type: target.any.resdec},
- {Name: "ANYRESHEX", Type: target.any.reshex},
- {Name: "ANYRESOCT", Type: target.any.resoct},
- },
+ target.any.union.Fields = []Field{
+ {Name: "ANYBLOB", Type: target.any.blob},
+ {Name: "ANYRES16", Type: target.any.res16},
+ {Name: "ANYRES32", Type: target.any.res32},
+ {Name: "ANYRES64", Type: target.any.res64},
+ {Name: "ANYRESDEC", Type: target.any.resdec},
+ {Name: "ANYRESHEX", Type: target.any.reshex},
+ {Name: "ANYRESOCT", Type: target.any.resoct},
}
}
diff --git a/prog/prio.go b/prog/prio.go
index 43668d48d..648af0422 100644
--- a/prog/prio.go
+++ b/prog/prio.go
@@ -64,56 +64,55 @@ func (target *Target) calcStaticPriorities() [][]float32 {
func (target *Target) calcResourceUsage() map[string]map[int]weights {
uses := make(map[string]map[int]weights)
- for _, c := range target.Syscalls {
- foreachType(c, func(t Type, ctx typeCtx) {
- switch a := t.(type) {
- case *ResourceType:
- if target.AuxResources[a.Desc.Name] {
- noteUsage(uses, c, 0.1, ctx.Dir, "res%v", a.Desc.Name)
- } else {
- str := "res"
- for i, k := range a.Desc.Kind {
- str += "-" + k
- w := 1.0
- if i < len(a.Desc.Kind)-1 {
- w = 0.2
- }
- noteUsage(uses, c, float32(w), ctx.Dir, str)
+ ForeachType(target.Syscalls, func(t Type, ctx TypeCtx) {
+ c := ctx.Meta
+ switch a := t.(type) {
+ case *ResourceType:
+ if target.AuxResources[a.Desc.Name] {
+ noteUsage(uses, c, 0.1, ctx.Dir, "res%v", a.Desc.Name)
+ } else {
+ str := "res"
+ for i, k := range a.Desc.Kind {
+ str += "-" + k
+ w := 1.0
+ if i < len(a.Desc.Kind)-1 {
+ w = 0.2
}
+ noteUsage(uses, c, float32(w), ctx.Dir, str)
}
- case *PtrType:
- if _, ok := a.Elem.(*StructType); ok {
- noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Elem.Name())
- }
- if _, ok := a.Elem.(*UnionType); ok {
- noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Elem.Name())
- }
- if arr, ok := a.Elem.(*ArrayType); ok {
- noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", arr.Elem.Name())
- }
- case *BufferType:
- switch a.Kind {
- case BufferBlobRand, BufferBlobRange, BufferText:
- case BufferString:
- if a.SubKind != "" {
- noteUsage(uses, c, 0.2, ctx.Dir, fmt.Sprintf("str-%v", a.SubKind))
- }
- case BufferFilename:
- noteUsage(uses, c, 1.0, DirIn, "filename")
- default:
- panic("unknown buffer kind")
- }
- case *VmaType:
- noteUsage(uses, c, 0.5, ctx.Dir, "vma")
- case *IntType:
- switch a.Kind {
- case IntPlain, IntRange:
- default:
- panic("unknown int kind")
+ }
+ case *PtrType:
+ if _, ok := a.Elem.(*StructType); ok {
+ noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Elem.Name())
+ }
+ if _, ok := a.Elem.(*UnionType); ok {
+ noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", a.Elem.Name())
+ }
+ if arr, ok := a.Elem.(*ArrayType); ok {
+ noteUsage(uses, c, 1.0, ctx.Dir, "ptrto-%v", arr.Elem.Name())
+ }
+ case *BufferType:
+ switch a.Kind {
+ case BufferBlobRand, BufferBlobRange, BufferText:
+ case BufferString:
+ if a.SubKind != "" {
+ noteUsage(uses, c, 0.2, ctx.Dir, fmt.Sprintf("str-%v", a.SubKind))
}
+ case BufferFilename:
+ noteUsage(uses, c, 1.0, DirIn, "filename")
+ default:
+ panic("unknown buffer kind")
}
- })
- }
+ case *VmaType:
+ noteUsage(uses, c, 0.5, ctx.Dir, "vma")
+ case *IntType:
+ switch a.Kind {
+ case IntPlain, IntRange:
+ default:
+ panic("unknown int kind")
+ }
+ }
+ })
return uses
}
diff --git a/prog/prog_test.go b/prog/prog_test.go
index 9b8f442e3..6374b6d25 100644
--- a/prog/prog_test.go
+++ b/prog/prog_test.go
@@ -20,15 +20,13 @@ func TestGeneration(t *testing.T) {
func TestDefault(t *testing.T) {
target, _, _ := initTest(t)
- for _, meta := range target.Syscalls {
- foreachType(meta, func(typ Type, ctx typeCtx) {
- arg := typ.DefaultArg(ctx.Dir)
- if !isDefault(arg) {
- t.Errorf("default arg is not default: %s\ntype: %#v\narg: %#v",
- typ, typ, arg)
- }
- })
- }
+ ForeachType(target.Syscalls, func(typ Type, ctx TypeCtx) {
+ arg := typ.DefaultArg(ctx.Dir)
+ if !isDefault(arg) {
+ t.Errorf("default arg is not default: %s\ntype: %#v\narg: %#v",
+ typ, typ, arg)
+ }
+ })
}
func TestDefaultCallArgs(t *testing.T) {
@@ -203,7 +201,7 @@ func TestSpecialStructs(t *testing.T) {
t.Run(special, func(t *testing.T) {
var typ Type
for i := 0; i < len(target.Syscalls) && typ == nil; i++ {
- foreachType(target.Syscalls[i], func(t Type, ctx typeCtx) {
+ ForeachCallType(target.Syscalls[i], func(t Type, ctx TypeCtx) {
if ctx.Dir == DirOut {
return
}
diff --git a/prog/rand_test.go b/prog/rand_test.go
index 6f251cb7c..da0871505 100644
--- a/prog/rand_test.go
+++ b/prog/rand_test.go
@@ -101,22 +101,20 @@ func TestEnabledCalls(t *testing.T) {
func TestSizeGenerateConstArg(t *testing.T) {
target, rs, iters := initRandomTargetTest(t, "test", "64")
r := newRand(target, rs)
- for _, c := range target.Syscalls {
- foreachType(c, func(typ Type, ctx typeCtx) {
- if _, ok := typ.(*IntType); !ok {
- return
- }
- bits := typ.TypeBitSize()
- limit := uint64(1<<bits - 1)
- for i := 0; i < iters; i++ {
- newArg, _ := typ.generate(r, nil, ctx.Dir)
- newVal := newArg.(*ConstArg).Val
- if newVal > limit {
- t.Fatalf("invalid generated value: %d. (arg bitsize: %d; max value: %d)", newVal, bits, limit)
- }
+ ForeachType(target.Syscalls, func(typ Type, ctx TypeCtx) {
+ if _, ok := typ.(*IntType); !ok {
+ return
+ }
+ bits := typ.TypeBitSize()
+ limit := uint64(1<<bits - 1)
+ for i := 0; i < iters; i++ {
+ newArg, _ := typ.generate(r, nil, ctx.Dir)
+ newVal := newArg.(*ConstArg).Val
+ if newVal > limit {
+ t.Fatalf("invalid generated value: %d. (arg bitsize: %d; max value: %d)", newVal, bits, limit)
}
- })
- }
+ }
+ })
}
func TestFlags(t *testing.T) {
diff --git a/prog/resources.go b/prog/resources.go
index b7bcecf95..311eb72dd 100644
--- a/prog/resources.go
+++ b/prog/resources.go
@@ -45,16 +45,14 @@ func (target *Target) calcResourceCtors(res *ResourceDesc, precise bool) []*Sysc
func (target *Target) populateResourceCtors() {
// Find resources that are created by each call.
callsResources := make([][]*ResourceDesc, len(target.Syscalls))
- for call, meta := range target.Syscalls {
- foreachType(meta, func(typ Type, ctx typeCtx) {
- switch typ1 := typ.(type) {
- case *ResourceType:
- if ctx.Dir != DirIn {
- callsResources[call] = append(callsResources[call], typ1.Desc)
- }
+ ForeachType(target.Syscalls, func(typ Type, ctx TypeCtx) {
+ switch typ1 := typ.(type) {
+ case *ResourceType:
+ if ctx.Dir != DirIn {
+ callsResources[ctx.Meta.ID] = append(callsResources[ctx.Meta.ID], typ1.Desc)
}
- })
- }
+ }
+ })
// Populate resource ctors accounting for resource compatibility.
for _, res := range target.Resources {
@@ -126,7 +124,7 @@ func isCompatibleResourceImpl(dst, src []string, precise bool) bool {
func (target *Target) getInputResources(c *Syscall) []*ResourceDesc {
var resources []*ResourceDesc
- foreachType(c, func(typ Type, ctx typeCtx) {
+ ForeachCallType(c, func(typ Type, ctx TypeCtx) {
if ctx.Dir == DirOut {
return
}
@@ -146,7 +144,7 @@ func (target *Target) getInputResources(c *Syscall) []*ResourceDesc {
func (target *Target) getOutputResources(c *Syscall) []*ResourceDesc {
var resources []*ResourceDesc
- foreachType(c, func(typ Type, ctx typeCtx) {
+ ForeachCallType(c, func(typ Type, ctx TypeCtx) {
switch typ1 := typ.(type) {
case *ResourceType:
if ctx.Dir != DirIn {
diff --git a/prog/rotation.go b/prog/rotation.go
index 47ee2ca81..0171e6ace 100644
--- a/prog/rotation.go
+++ b/prog/rotation.go
@@ -50,7 +50,7 @@ func MakeRotator(target *Target, calls map[*Syscall]bool, rnd *rand.Rand) *Rotat
}
// VMAs and filenames are effectively resources for our purposes
// (but they don't have ctors).
- foreachType(call, func(t Type, _ typeCtx) {
+ ForeachCallType(call, func(t Type, _ TypeCtx) {
switch a := t.(type) {
case *BufferType:
switch a.Kind {
diff --git a/prog/target.go b/prog/target.go
index 89940a61a..8999f25ac 100644
--- a/prog/target.go
+++ b/prog/target.go
@@ -22,7 +22,6 @@ type Target struct {
Syscalls []*Syscall
Resources []*ResourceDesc
- Structs []*KeyedStruct
Consts []ConstValue
Types []Type
@@ -137,8 +136,7 @@ func (target *Target) initTarget() {
target.ConstMap[c.Name] = c.Value
}
- target.resourceMap = restoreLinks(target.Syscalls, target.Resources, target.Structs, target.Types)
- target.Structs = nil
+ target.resourceMap = restoreLinks(target.Syscalls, target.Resources, target.Types)
target.Types = nil
target.SyscallMap = make(map[string]*Syscall)
@@ -173,63 +171,30 @@ func (target *Target) sanitize(c *Call, fix bool) error {
return nil
}
-func RestoreLinks(syscalls []*Syscall, resources []*ResourceDesc, structs []*KeyedStruct, types []Type) {
- restoreLinks(syscalls, resources, structs, types)
+func RestoreLinks(syscalls []*Syscall, resources []*ResourceDesc, types []Type) {
+ restoreLinks(syscalls, resources, types)
}
-func restoreLinks(syscalls []*Syscall, resources []*ResourceDesc, structs []*KeyedStruct,
- types []Type) map[string]*ResourceDesc {
+func restoreLinks(syscalls []*Syscall, resources []*ResourceDesc, types []Type) map[string]*ResourceDesc {
resourceMap := make(map[string]*ResourceDesc)
for _, res := range resources {
resourceMap[res.Name] = res
}
- keyedStructs := make(map[StructKey]*StructDesc)
- for _, desc := range structs {
- keyedStructs[desc.Key] = desc.Desc
- for i := range desc.Desc.Fields {
- unref(&desc.Desc.Fields[i].Type, types)
+ ForeachType(syscalls, func(_ Type, ctx TypeCtx) {
+ if ref, ok := (*ctx.Ptr).(Ref); ok {
+ *ctx.Ptr = types[ref]
}
- }
- for _, c := range syscalls {
- for i := range c.Args {
- unref(&c.Args[i].Type, types)
- }
- if c.Ret != nil {
- unref(&c.Ret, types)
- }
- foreachType(c, func(t0 Type, _ typeCtx) {
- switch t := t0.(type) {
- case *PtrType:
- unref(&t.Elem, types)
- case *ArrayType:
- unref(&t.Elem, types)
- case *ResourceType:
- t.Desc = resourceMap[t.TypeName]
- if t.Desc == nil {
- panic("no resource desc")
- }
- case *StructType:
- t.StructDesc = keyedStructs[t.Key]
- if t.StructDesc == nil {
- panic("no struct desc")
- }
- case *UnionType:
- t.StructDesc = keyedStructs[t.Key]
- if t.StructDesc == nil {
- panic("no union desc")
- }
+ switch t := (*ctx.Ptr).(type) {
+ case *ResourceType:
+ t.Desc = resourceMap[t.TypeName]
+ if t.Desc == nil {
+ panic("no resource desc")
}
- })
- }
+ }
+ })
return resourceMap
}
-func unref(tp *Type, types []Type) {
- if ref, ok := (*tp).(Ref); ok {
- *tp = types[ref]
- }
-}
-
type Gen struct {
r *randGen
s *state
diff --git a/prog/types.go b/prog/types.go
index 4412239d3..33c7bd356 100644
--- a/prog/types.go
+++ b/prog/types.go
@@ -110,7 +110,6 @@ type Ref uint32
func (ti Ref) String() string { panic("prog.Ref method called") }
func (ti Ref) Name() string { panic("prog.Ref method called") }
-func (ti Ref) FieldName() string { panic("prog.Ref method called") }
func (ti Ref) TemplateName() string { panic("prog.Ref method called") }
func (ti Ref) Optional() bool { panic("prog.Ref method called") }
@@ -579,8 +578,9 @@ func (t *PtrType) isDefaultArg(arg Arg) bool {
}
type StructType struct {
- Key StructKey
- *StructDesc
+ TypeCommon
+ Fields []Field
+ AlignAttr uint64
}
func (t *StructType) String() string {
@@ -606,8 +606,8 @@ func (t *StructType) isDefaultArg(arg Arg) bool {
}
type UnionType struct {
- Key StructKey
- *StructDesc
+ TypeCommon
+ Fields []Field
}
func (t *UnionType) String() string {
@@ -623,64 +623,80 @@ func (t *UnionType) isDefaultArg(arg Arg) bool {
return a.Index == 0 && isDefault(a.Option)
}
-type StructDesc struct {
- TypeCommon
- Fields []Field
- AlignAttr uint64
+type ConstValue struct {
+ Name string
+ Value uint64
}
-type StructKey struct {
- Name string
+type TypeCtx struct {
+ Meta *Syscall
+ Dir Dir
+ Ptr *Type
}
-type KeyedStruct struct {
- Key StructKey
- Desc *StructDesc
+func ForeachType(syscalls []*Syscall, f func(t Type, ctx TypeCtx)) {
+ for _, meta := range syscalls {
+ foreachTypeImpl(meta, true, f)
+ }
}
-type ConstValue struct {
- Name string
- Value uint64
+func ForeachTypePost(syscalls []*Syscall, f func(t Type, ctx TypeCtx)) {
+ for _, meta := range syscalls {
+ foreachTypeImpl(meta, false, f)
+ }
}
-type typeCtx struct {
- Dir Dir
+func ForeachCallType(meta *Syscall, f func(t Type, ctx TypeCtx)) {
+ foreachTypeImpl(meta, true, f)
}
-func foreachType(meta *Syscall, f func(t Type, ctx typeCtx)) {
- var rec func(t Type, dir Dir)
- seen := make(map[*StructDesc]bool)
- recStruct := func(desc *StructDesc, dir Dir) {
- if seen[desc] {
- return // prune recursion via pointers to structs/unions
- }
- seen[desc] = true
- for _, f := range desc.Fields {
- rec(f.Type, dir)
+func foreachTypeImpl(meta *Syscall, preorder bool, f func(t Type, ctx TypeCtx)) {
+ // Note: we specifically don't create seen in ForeachType.
+ // It would prune recursion more (across syscalls), but lots of users need to
+ // visit each struct per-syscall (e.g. prio, used resources).
+ seen := make(map[Type]bool)
+ var rec func(*Type, Dir)
+ rec = func(ptr *Type, dir Dir) {
+ if preorder {
+ f(*ptr, TypeCtx{Meta: meta, Dir: dir, Ptr: ptr})
}
- }
- rec = func(t Type, dir Dir) {
- f(t, typeCtx{Dir: dir})
- switch a := t.(type) {
+ switch a := (*ptr).(type) {
case *PtrType:
- rec(a.Elem, a.ElemDir)
+ rec(&a.Elem, a.ElemDir)
case *ArrayType:
- rec(a.Elem, dir)
+ rec(&a.Elem, dir)
case *StructType:
- recStruct(a.StructDesc, dir)
+ if seen[a] {
+ break // prune recursion via pointers to structs/unions
+ }
+ seen[a] = true
+ for i := range a.Fields {
+ rec(&a.Fields[i].Type, dir)
+ }
case *UnionType:
- recStruct(a.StructDesc, dir)
- case *ResourceType, *BufferType, *VmaType, *LenType,
- *FlagsType, *ConstType, *IntType, *ProcType, *CsumType:
+ if seen[a] {
+ break // prune recursion via pointers to structs/unions
+ }
+ seen[a] = true
+ for i := range a.Fields {
+ rec(&a.Fields[i].Type, dir)
+ }
+ case *ResourceType, *BufferType, *VmaType, *LenType, *FlagsType,
+ *ConstType, *IntType, *ProcType, *CsumType:
+ case Ref:
+ // This is only needed for pkg/compiler.
default:
panic("unknown type")
}
+ if !preorder {
+ f(*ptr, TypeCtx{Meta: meta, Dir: dir, Ptr: ptr})
+ }
}
- for _, field := range meta.Args {
- rec(field.Type, DirIn)
+ for i := range meta.Args {
+ rec(&meta.Args[i].Type, DirIn)
}
if meta.Ret != nil {
- rec(meta.Ret, DirOut)
+ rec(&meta.Ret, DirOut)
}
}
diff --git a/sys/syz-sysgen/sysgen.go b/sys/syz-sysgen/sysgen.go
index 1197a8cce..bf39ba9ee 100644
--- a/sys/syz-sysgen/sysgen.go
+++ b/sys/syz-sysgen/sysgen.go
@@ -189,20 +189,16 @@ func generate(target *targets.Target, prg *compiler.Prog, consts map[string]uint
fmt.Fprintf(out, "\tRegisterTarget(&Target{"+
"OS: %q, Arch: %q, Revision: revision_%v, PtrSize: %v, "+
"PageSize: %v, NumPages: %v, DataOffset: %v, Syscalls: syscalls_%v, "+
- "Resources: resources_%v, Structs: structDescs_%v, Types: types_%v, Consts: consts_%v}, "+
+ "Resources: resources_%v, Types: types_%v, Consts: consts_%v}, "+
"InitTarget)\n}\n\n",
target.OS, target.Arch, target.Arch, target.PtrSize,
target.PageSize, target.NumPages, target.DataOffset,
- target.Arch, target.Arch, target.Arch, target.Arch, target.Arch)
+ target.Arch, target.Arch, target.Arch, target.Arch)
fmt.Fprintf(out, "var resources_%v = ", target.Arch)
serializer.Write(out, prg.Resources)
fmt.Fprintf(out, "\n\n")
- fmt.Fprintf(out, "var structDescs_%v = ", target.Arch)
- serializer.Write(out, prg.StructDescs)
- fmt.Fprintf(out, "\n\n")
-
fmt.Fprintf(out, "var syscalls_%v = ", target.Arch)
serializer.Write(out, prg.Syscalls)
fmt.Fprintf(out, "\n\n")
diff --git a/tools/syz-check/check.go b/tools/syz-check/check.go
index 1d74584af..49152f481 100644
--- a/tools/syz-check/check.go
+++ b/tools/syz-check/check.go
@@ -89,7 +89,7 @@ func check(OS, arch, obj string, dwarf, netlink bool) ([]Warn, error) {
if obj == "" {
return nil, fmt.Errorf("no object file in -obj-%v flag", arch)
}
- structDescs, locs, warnings1, err := parseDescriptions(OS, arch)
+ structTypes, locs, warnings1, err := parseDescriptions(OS, arch)
if err != nil {
return nil, err
}
@@ -99,14 +99,14 @@ func check(OS, arch, obj string, dwarf, netlink bool) ([]Warn, error) {
if err != nil {
return nil, err
}
- warnings2, err := checkImpl(structs, structDescs, locs)
+ warnings2, err := checkImpl(structs, structTypes, locs)
if err != nil {
return nil, err
}
warnings = append(warnings, warnings2...)
}
if netlink {
- warnings3, err := checkNetlink(OS, arch, obj, structDescs, locs)
+ warnings3, err := checkNetlink(OS, arch, obj, structTypes, locs)
if err != nil {
return nil, err
}
@@ -197,16 +197,10 @@ func writeWarnings(OS string, narches int, warnings []Warn) error {
return nil
}
-func checkImpl(structs map[string]*dwarf.StructType, structDescs []*prog.KeyedStruct,
+func checkImpl(structs map[string]*dwarf.StructType, structTypes []*prog.StructType,
locs map[string]*ast.Struct) ([]Warn, error) {
var warnings []Warn
- checked := make(map[string]bool)
- for _, str := range structDescs {
- typ := str.Desc
- if checked[typ.Name()] {
- continue
- }
- checked[typ.Name()] = true
+ for _, typ := range structTypes {
name := typ.TemplateName()
astStruct := locs[name]
if astStruct == nil {
@@ -221,7 +215,7 @@ func checkImpl(structs map[string]*dwarf.StructType, structDescs []*prog.KeyedSt
return warnings, nil
}
-func checkStruct(typ *prog.StructDesc, astStruct *ast.Struct, str *dwarf.StructType) ([]Warn, error) {
+func checkStruct(typ *prog.StructType, astStruct *ast.Struct, str *dwarf.StructType) ([]Warn, error) {
var warnings []Warn
warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)})
@@ -307,7 +301,7 @@ func checkStruct(typ *prog.StructDesc, astStruct *ast.Struct, str *dwarf.StructT
return warnings, nil
}
-func parseDescriptions(OS, arch string) ([]*prog.KeyedStruct, map[string]*ast.Struct, []Warn, error) {
+func parseDescriptions(OS, arch string) ([]*prog.StructType, map[string]*ast.Struct, []Warn, error) {
errorBuf := new(bytes.Buffer)
var warnings []Warn
eh := func(pos ast.Pos, msg string) {
@@ -326,7 +320,7 @@ func parseDescriptions(OS, arch string) ([]*prog.KeyedStruct, map[string]*ast.St
if prg == nil {
return nil, nil, nil, fmt.Errorf("failed to compile descriptions:\n%s", errorBuf.Bytes())
}
- prog.RestoreLinks(prg.Syscalls, prg.Resources, prg.StructDescs, prg.Types)
+ prog.RestoreLinks(prg.Syscalls, prg.Resources, prg.Types)
locs := make(map[string]*ast.Struct)
for _, decl := range top.Nodes {
switch n := decl.(type) {
@@ -338,7 +332,13 @@ func parseDescriptions(OS, arch string) ([]*prog.KeyedStruct, map[string]*ast.St
}
}
}
- return prg.StructDescs, locs, warnings, nil
+ var structs []*prog.StructType
+ for _, typ := range prg.Types {
+ if t, ok := typ.(*prog.StructType); ok {
+ structs = append(structs, t)
+ }
+ }
+ return structs, locs, warnings, nil
}
// Overall idea of netlink checking.
@@ -350,7 +350,7 @@ func parseDescriptions(OS, arch string) ([]*prog.KeyedStruct, map[string]*ast.St
// Then read in the symbol data, which is an array of nla_policy structs.
// These structs allow to easily figure out type/size of attributes.
// Finally we compare our descriptions with the kernel policy description.
-func checkNetlink(OS, arch, obj string, structDescs []*prog.KeyedStruct,
+func checkNetlink(OS, arch, obj string, structTypes []*prog.StructType,
locs map[string]*ast.Struct) ([]Warn, error) {
if arch != "amd64" {
// Netlink policies are arch-independent (?),
@@ -374,24 +374,18 @@ func checkNetlink(OS, arch, obj string, structDescs []*prog.KeyedStruct,
warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)})
}
- structMap := make(map[string]*prog.StructDesc)
- for _, str := range structDescs {
- structMap[str.Desc.Name()] = str.Desc
+ structMap := make(map[string]*prog.StructType)
+ for _, typ := range structTypes {
+ structMap[typ.Name()] = typ
}
- checked := make(map[string]bool)
checkedAttrs := make(map[string]*checkAttr)
- for _, str := range structDescs {
- typ := str.Desc
- if checked[typ.Name()] {
- continue
- }
- checked[typ.Name()] = true
+ for _, typ := range structTypes {
name := typ.TemplateName()
astStruct := locs[name]
if astStruct == nil {
continue
}
- if !isNetlinkPolicy(typ) {
+ if !isNetlinkPolicy(typ.Fields) {
continue
}
kernelName := name
@@ -499,9 +493,9 @@ func checkMissingAttrs(checkedAttrs map[string]*checkAttr) []Warn {
return warnings
}
-func isNetlinkPolicy(typ *prog.StructDesc) bool {
+func isNetlinkPolicy(fields []prog.Field) bool {
haveAttr := false
- for _, fld := range typ.Fields {
+ for _, fld := range fields {
field := fld.Type
if prog.IsPad(field) {
continue
@@ -514,12 +508,12 @@ func isNetlinkPolicy(typ *prog.StructDesc) bool {
field = arr.Elem
}
if field1, ok := field.(*prog.StructType); ok {
- if isNetlinkPolicy(field1.StructDesc) {
+ if isNetlinkPolicy(field1.Fields) {
continue
}
}
if field1, ok := field.(*prog.UnionType); ok {
- if isNetlinkPolicy(field1.StructDesc) {
+ if isNetlinkPolicy(field1.Fields) {
continue
}
}
@@ -533,7 +527,7 @@ func isNlattr(typ prog.Type) bool {
return name == "nlattr_t" || name == "nlattr_tt"
}
-func checkNetlinkPolicy(structMap map[string]*prog.StructDesc, typ *prog.StructDesc,
+func checkNetlinkPolicy(structMap map[string]*prog.StructType, typ *prog.StructType,
astStruct *ast.Struct, policy []nlaPolicy) ([]Warn, map[int]bool, error) {
var warnings []Warn
warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
@@ -571,7 +565,7 @@ func checkNetlinkPolicy(structMap map[string]*prog.StructDesc, typ *prog.StructD
return warnings, checked, nil
}
-func checkNetlinkAttr(typ *prog.StructDesc, policy nlaPolicy) string {
+func checkNetlinkAttr(typ *prog.StructType, policy nlaPolicy) string {
payload := typ.Fields[2].Type
if typ.TemplateName() == "nlattr_tt" {
payload = typ.Fields[4].Type
@@ -637,7 +631,7 @@ func minTypeSize(typ prog.Type) int {
return -1
}
-func checkAttrType(typ *prog.StructDesc, payload prog.Type, policy nlaPolicy) string {
+func checkAttrType(typ *prog.StructType, payload prog.Type, policy nlaPolicy) string {
switch policy.typ {
case NLA_STRING, NLA_NUL_STRING:
if _, ok := payload.(*prog.BufferType); !ok {