From 58da4c35b15200b7279f18ea15bc8644618aae78 Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Fri, 1 May 2020 17:19:27 +0200 Subject: prog: introduce Field type Remvoe FieldName from Type and add a separate Field type that holds field name. Use Field for struct fields, union options and syscalls arguments, only these really have names. Reduces size of sys/linux/gen/amd64.go from 5665583 to 5201321 (-8.2%). Allows to not create new type for squashed any pointer. But main advantages will follow, e.g. removing StructDesc, using TypeRef in Arg, etc. Update #1580 --- pkg/compiler/check.go | 14 +++++------ pkg/compiler/compiler.go | 6 ++--- pkg/compiler/gen.go | 63 +++++++++++++++++++++++++--------------------- pkg/compiler/types.go | 12 ++++----- pkg/host/syscalls_linux.go | 20 +++++++-------- 5 files changed, 59 insertions(+), 56 deletions(-) (limited to 'pkg') diff --git a/pkg/compiler/check.go b/pkg/compiler/check.go index 2c0438086..0147fce50 100644 --- a/pkg/compiler/check.go +++ b/pkg/compiler/check.go @@ -323,7 +323,7 @@ func (comp *compiler) checkLenType(t0, t *ast.Type, parents []parentDesc, warned[parentName] = true return } - _, args, _ := comp.getArgsBase(t, "", isArg) + _, args, _ := comp.getArgsBase(t, isArg) for i, arg := range args { argDesc := desc.Args[i] if argDesc.Type == typeArgLenTarget { @@ -522,7 +522,7 @@ func (comp *compiler) collectUsedType(structs, flags, strflags map[string]bool, } return } - _, args, _ := comp.getArgsBase(t, "", isArg) + _, args, _ := comp.getArgsBase(t, isArg) for i, arg := range args { if desc.Args[i].Type == typeArgType { comp.collectUsedType(structs, flags, strflags, arg, desc.Args[i].IsArg) @@ -603,7 +603,7 @@ func (comp *compiler) checkTypeCtors(t *ast.Type, dir prog.Dir, isArg bool, if desc == typePtr { dir = genDir(t.Args[0]) } - _, args, _ := comp.getArgsBase(t, "", isArg) + _, args, _ := comp.getArgsBase(t, isArg) for i, arg := range args { if desc.Args[i].Type == typeArgType { comp.checkTypeCtors(arg, dir, desc.Args[i].IsArg, ctors, checked) @@ -684,7 +684,7 @@ func (comp *compiler) recurseField(checked map[string]bool, t *ast.Type, path [] comp.checkStructRecursion(checked, comp.structs[t.Ident], path) return } - _, args, base := comp.getArgsBase(t, "", false) + _, args, base := comp.getArgsBase(t, false) if desc == typePtr && base.IsOptional { return // optional pointers prune recursion } @@ -774,7 +774,7 @@ func (comp *compiler) checkType(ctx checkCtx, t *ast.Type, flags checkFlags) { return } if desc.Check != nil { - _, args, base := comp.getArgsBase(t, "", flags&checkIsArg != 0) + _, args, base := comp.getArgsBase(t, flags&checkIsArg != 0) desc.Check(comp, t, args, base) } } @@ -1098,12 +1098,12 @@ func (comp *compiler) checkVarlens() { } func (comp *compiler) isVarlen(t *ast.Type) bool { - desc, args, _ := comp.getArgsBase(t, "", false) + desc, args, _ := comp.getArgsBase(t, false) return desc.Varlen != nil && desc.Varlen(comp, t, args) } func (comp *compiler) isZeroSize(t *ast.Type) bool { - desc, args, _ := comp.getArgsBase(t, "", false) + desc, args, _ := comp.getArgsBase(t, false) return desc.ZeroSize != nil && desc.ZeroSize(comp, t, args) } diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 8e5afddde..f61d02036 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -249,13 +249,13 @@ func (comp *compiler) getTypeDesc(t *ast.Type) *typeDesc { return nil } -func (comp *compiler) getArgsBase(t *ast.Type, field string, isArg bool) (*typeDesc, []*ast.Type, prog.IntTypeCommon) { +func (comp *compiler) getArgsBase(t *ast.Type, isArg bool) (*typeDesc, []*ast.Type, prog.IntTypeCommon) { desc := comp.getTypeDesc(t) if desc == nil { panic(fmt.Sprintf("no type desc for %#v", *t)) } args, opt := removeOpt(t) - com := genCommon(t.Ident, field, sizeUnassigned, opt != nil) + com := genCommon(t.Ident, sizeUnassigned, opt != nil) base := genIntCommon(com, 0, false) if desc.NeedBase { base.TypeSize = comp.ptrSize @@ -305,7 +305,7 @@ func (comp *compiler) foreachType(n0 ast.Node, func (comp *compiler) foreachSubType(t *ast.Type, isArg bool, cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) { - desc, args, base := comp.getArgsBase(t, "", isArg) + desc, args, base := comp.getArgsBase(t, isArg) cb(t, desc, args, base) for i, arg := range args { if desc.Args[i].Type == typeArgType { diff --git a/pkg/compiler/gen.go b/pkg/compiler/gen.go index 07f35d71f..748941bae 100644 --- a/pkg/compiler/gen.go +++ b/pkg/compiler/gen.go @@ -65,7 +65,7 @@ func (comp *compiler) collectCallArgSizes() map[string][]uint64 { if len(argSizes) <= i { argSizes = append(argSizes, comp.ptrSize) } - desc, _, _ := comp.getArgsBase(arg.Type, arg.Name.Name, true) + desc, _, _ := comp.getArgsBase(arg.Type, true) typ := comp.genField(arg, comp.ptrSize) // Ignore all types with base (const, flags). We don't have base in syscall args. // Also ignore resources and pointers because fd can be 32-bits and pointer 64-bits, @@ -112,7 +112,7 @@ func (comp *compiler) genSyscalls() []*prog.Syscall { func (comp *compiler) genSyscall(n *ast.Call, argSizes []uint64) *prog.Syscall { var ret prog.Type if n.Ret != nil { - ret = comp.genType(n.Ret, "ret", comp.ptrSize) + ret = comp.genType(n.Ret, comp.ptrSize) } var attrs prog.SyscallAttrs descAttrs := comp.parseAttrs(callAttrs, n, n.Attrs) @@ -147,7 +147,7 @@ func (comp *compiler) generateTypes(syscalls []*prog.Syscall, structs []*prog.Ke proxies := make(map[string]*typeProxy) for _, call := range syscalls { for i := range call.Args { - comp.collectTypes(proxies, &call.Args[i]) + comp.collectTypes(proxies, &call.Args[i].Type) } if call.Ret != nil { comp.collectTypes(proxies, &call.Ret) @@ -155,7 +155,7 @@ func (comp *compiler) generateTypes(syscalls []*prog.Syscall, structs []*prog.Ke } for _, str := range structs { for i := range str.Desc.Fields { - comp.collectTypes(proxies, &str.Desc.Fields[i]) + comp.collectTypes(proxies, &str.Desc.Fields[i].Type) } } array := make([]*typeProxy, 0, len(proxies)) @@ -219,7 +219,7 @@ func (comp *compiler) genStructDescs(syscalls []*prog.Syscall) []*prog.KeyedStru start := len(ctx.padded) for _, c := range syscalls { for _, a := range c.Args { - ctx.walk(a) + ctx.walk(a.Type) } if c.Ret != nil { ctx.walk(c.Ret) @@ -257,7 +257,7 @@ func (ctx *structGen) check(key prog.StructKey, descp **prog.StructDesc) bool { } ctx.padded[desc] = true for _, f := range desc.Fields { - ctx.walk(f) + ctx.walk(f.Type) if !f.Varlen() && f.Size() == sizeUnassigned { // An inner struct is not padded yet. // Leave this struct for next iteration. @@ -357,7 +357,7 @@ func (ctx *structGen) walkUnion(t *prog.UnionType) { if hasSize && sz > sizeAttr { comp.error(structNode.Fields[i].Pos, "union %v has size attribute %v"+ " which is less than field %v size %v", - structNode.Name.Name, sizeAttr, fld.Name(), sz) + structNode.Name.Name, sizeAttr, fld.Type.Name(), sz) } if t.TypeSize < sz { t.TypeSize = sz @@ -371,7 +371,7 @@ func (ctx *structGen) walkUnion(t *prog.UnionType) { func (comp *compiler) genStructDesc(res *prog.StructDesc, n *ast.Struct, varlen bool) { // Leave node for genStructDescs to calculate size/padding. comp.structNodes[res] = n - common := genCommon(n.Name.Name, "", sizeUnassigned, false) + common := genCommon(n.Name.Name, sizeUnassigned, false) common.IsVarlen = varlen *res = prog.StructDesc{ TypeCommon: common, @@ -380,9 +380,10 @@ func (comp *compiler) genStructDesc(res *prog.StructDesc, n *ast.Struct, varlen } func (comp *compiler) layoutStruct(t *prog.StructType, varlen, packed bool) { - var newFields []prog.Type + var newFields []prog.Field var structAlign, byteOffset, bitOffset uint64 - for i, f := range t.Fields { + for i, field := range t.Fields { + f := field.Type fieldAlign := uint64(1) if !packed { fieldAlign = comp.typeAlign(f) @@ -419,7 +420,7 @@ func (comp *compiler) layoutStruct(t *prog.StructType, varlen, packed bool) { pad := fieldOffset - byteOffset byteOffset += pad if i != 0 && t.Fields[i-1].IsBitfield() { - setBitfieldTypeSize(t.Fields[i-1], pad) + setBitfieldTypeSize(t.Fields[i-1].Type, pad) if bitOffset >= 8*pad { // The padding is due to bitfields, so consume the bitOffset. bitOffset -= 8 * pad @@ -428,7 +429,7 @@ func (comp *compiler) layoutStruct(t *prog.StructType, varlen, packed bool) { // But since we don't have any descriptions that trigger this, // let's just guard with the panic. panic(fmt.Sprintf("bad bitOffset: %v.%v pad=%v bitOffset=%v", - t.Name(), f.FieldName(), pad, bitOffset)) + t.Name(), field.Name, pad, bitOffset)) } } else { newFields = append(newFields, genPad(pad)) @@ -440,7 +441,7 @@ func (comp *compiler) layoutStruct(t *prog.StructType, varlen, packed bool) { setBitfieldUnitOffset(f, unitOffset) } } - newFields = append(newFields, f) + newFields = append(newFields, field) if f.IsBitfield() { bitOffset += f.BitfieldLength() } else if !f.Varlen() { @@ -454,7 +455,7 @@ func (comp *compiler) layoutStruct(t *prog.StructType, varlen, packed bool) { byteOffset += pad i := len(t.Fields) if i != 0 && t.Fields[i-1].IsBitfield() { - setBitfieldTypeSize(t.Fields[i-1], pad) + setBitfieldTypeSize(t.Fields[i-1].Type, pad) } else { newFields = append(newFields, genPad(pad)) } @@ -547,7 +548,7 @@ func (comp *compiler) typeAlign(t0 prog.Type) uint64 { } align := uint64(0) for _, f := range t.Fields { - if a := comp.typeAlign(f); align < a { + if a := comp.typeAlign(f.Type); align < a { align = a } } @@ -555,7 +556,7 @@ func (comp *compiler) typeAlign(t0 prog.Type) uint64 { case *prog.UnionType: align := uint64(0) for _, f := range t.Fields { - if a := comp.typeAlign(f); align < a { + if a := comp.typeAlign(f.Type); align < a { align = a } } @@ -565,29 +566,34 @@ func (comp *compiler) typeAlign(t0 prog.Type) uint64 { } } -func genPad(size uint64) prog.Type { - return &prog.ConstType{ - IntTypeCommon: genIntCommon(genCommon("pad", "", size, false), 0, false), - IsPad: true, +func genPad(size uint64) prog.Field { + return prog.Field{ + Type: &prog.ConstType{ + IntTypeCommon: genIntCommon(genCommon("pad", size, false), 0, false), + IsPad: true, + }, } } -func (comp *compiler) genFieldArray(fields []*ast.Field, argSizes []uint64) []prog.Type { - var res []prog.Type +func (comp *compiler) genFieldArray(fields []*ast.Field, argSizes []uint64) []prog.Field { + var res []prog.Field for i, f := range fields { res = append(res, comp.genField(f, argSizes[i])) } return res } -func (comp *compiler) genField(f *ast.Field, argSize uint64) prog.Type { - return comp.genType(f.Type, f.Name.Name, argSize) +func (comp *compiler) genField(f *ast.Field, argSize uint64) prog.Field { + return prog.Field{ + Name: f.Name.Name, + Type: comp.genType(f.Type, argSize), + } } -func (comp *compiler) genType(t *ast.Type, field string, argSize uint64) prog.Type { - desc, args, base := comp.getArgsBase(t, field, argSize != 0) +func (comp *compiler) genType(t *ast.Type, argSize uint64) prog.Type { + desc, args, base := comp.getArgsBase(t, argSize != 0) if desc.Gen == nil { - panic(fmt.Sprintf("no gen for %v %#v", field, t)) + panic(fmt.Sprintf("no gen for %v %#v", t.Ident, t)) } if argSize != 0 { // Now that we know a more precise size, patch the type. @@ -602,11 +608,10 @@ func (comp *compiler) genType(t *ast.Type, field string, argSize uint64) prog.Ty return desc.Gen(comp, t, args, base) } -func genCommon(name, field string, size uint64, opt bool) prog.TypeCommon { +func genCommon(name string, size uint64, opt bool) prog.TypeCommon { return prog.TypeCommon{ TypeName: name, TypeSize: size, - FldName: field, IsOptional: opt, } } diff --git a/pkg/compiler/types.go b/pkg/compiler/types.go index bf542c511..7ed12afd2 100644 --- a/pkg/compiler/types.go +++ b/pkg/compiler/types.go @@ -162,7 +162,7 @@ var typePtr = &typeDesc{ } return &prog.PtrType{ TypeCommon: base.TypeCommon, - Elem: comp.genType(args[1], "", 0), + Elem: comp.genType(args[1], 0), ElemDir: genDir(args[0]), } }, @@ -212,7 +212,7 @@ var typeArray = &typeDesc{ return comp.isZeroSize(args[0]) }, Gen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type { - elemType := comp.genType(args[0], "", 0) + elemType := comp.genType(args[0], 0) kind, begin, end := prog.ArrayRandLen, uint64(0), uint64(0) if len(args) > 1 { kind, begin, end = prog.ArrayRangeLen, args[1].Value, args[1].Value @@ -696,7 +696,7 @@ var typeFmt = &typeDesc{ {Name: "value", Type: typeArgType, IsArg: true}, }, Check: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) { - desc, _, _ := comp.getArgsBase(args[1], "", true) + desc, _, _ := comp.getArgsBase(args[1], true) switch desc { case typeResource, typeInt, typeLen, typeFlags, typeProc: default: @@ -718,7 +718,7 @@ var typeFmt = &typeDesc{ format = prog.FormatStrOct size = 23 } - typ := comp.genType(args[1], "", comp.ptrSize) + typ := comp.genType(args[1], comp.ptrSize) switch t := typ.(type) { case *prog.ResourceType: t.ArgFormat = format @@ -767,7 +767,7 @@ func init() { baseType = r.Base r = comp.resources[r.Base.Ident] } - baseProgType := comp.genType(baseType, "", 0) + baseProgType := comp.genType(baseType, 0) base.TypeSize = baseProgType.Size() return &prog.ResourceType{ TypeCommon: base.TypeCommon, @@ -829,13 +829,11 @@ func init() { if s.IsUnion { return &prog.UnionType{ Key: key, - FldName: base.FldName, StructDesc: desc, } } return &prog.StructType{ Key: key, - FldName: base.FldName, StructDesc: desc, } } diff --git a/pkg/host/syscalls_linux.go b/pkg/host/syscalls_linux.go index 184bf6410..17ea17dd4 100644 --- a/pkg/host/syscalls_linux.go +++ b/pkg/host/syscalls_linux.go @@ -166,11 +166,11 @@ var ( func isSupportedSyzkall(sandbox string, c *prog.Syscall) (bool, string) { switch c.CallName { case "syz_open_dev": - if _, ok := c.Args[0].(*prog.ConstType); ok { + if _, ok := c.Args[0].Type.(*prog.ConstType); ok { // This is for syz_open_dev$char/block. return true, "" } - fname, ok := extractStringConst(c.Args[0]) + fname, ok := extractStringConst(c.Args[0].Type) if !ok { panic("first open arg is not a pointer to string const") } @@ -249,7 +249,7 @@ func isSupportedSyzkall(sandbox string, c *prog.Syscall) (bool, string) { if ok, reason := onlySandboxNone(sandbox); !ok { return ok, reason } - fstype, ok := extractStringConst(c.Args[0]) + fstype, ok := extractStringConst(c.Args[0].Type) if !ok { panic("syz_mount_image arg is not string") } @@ -306,7 +306,7 @@ func onlySandboxNoneOrNamespace(sandbox string) (bool, string) { } func isSupportedSocket(c *prog.Syscall) (bool, string) { - af, ok := c.Args[0].(*prog.ConstType) + af, ok := c.Args[0].Type.(*prog.ConstType) if !ok { panic("socket family is not const") } @@ -320,14 +320,14 @@ func isSupportedSocket(c *prog.Syscall) (bool, string) { if err == syscall.EAFNOSUPPORT { return false, "socket family is not supported (EAFNOSUPPORT)" } - proto, ok := c.Args[2].(*prog.ConstType) + proto, ok := c.Args[2].Type.(*prog.ConstType) if !ok { return true, "" } var typ uint64 - if arg, ok := c.Args[1].(*prog.ConstType); ok { + if arg, ok := c.Args[1].Type.(*prog.ConstType); ok { typ = arg.Val - } else if arg, ok := c.Args[1].(*prog.FlagsType); ok { + } else if arg, ok := c.Args[1].Type.(*prog.FlagsType); ok { typ = arg.Vals[0] } else { return true, "" @@ -344,7 +344,7 @@ func isSupportedOpenAt(c *prog.Syscall) (bool, string) { var fd int var err error - fname, ok := extractStringConst(c.Args[1]) + fname, ok := extractStringConst(c.Args[1].Type) if !ok || len(fname) == 0 || fname[0] != '/' { return true, "" } @@ -352,7 +352,7 @@ func isSupportedOpenAt(c *prog.Syscall) (bool, string) { modes := []int{syscall.O_RDONLY, syscall.O_WRONLY, syscall.O_RDWR} // Attempt to extract flags from the syscall description - if mode, ok := c.Args[2].(*prog.ConstType); ok { + if mode, ok := c.Args[2].Type.(*prog.ConstType); ok { modes = []int{int(mode.Val)} } @@ -370,7 +370,7 @@ func isSupportedOpenAt(c *prog.Syscall) (bool, string) { } func isSupportedMount(c *prog.Syscall, sandbox string) (bool, string) { - fstype, ok := extractStringConst(c.Args[2]) + fstype, ok := extractStringConst(c.Args[2].Type) if !ok { panic(fmt.Sprintf("%v: filesystem is not string const", c.Name)) } -- cgit mrf-deployment