From 0f54349fc6f9c0c5507604dca2df9aadbc660a8b Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Mon, 25 May 2020 15:33:49 +0200 Subject: tools/syz-check: restore handling of unions Unions were dropped accidentially during removal StructDesc. --- tools/syz-check/check.go | 177 +++++++++++++++++++++++++---------------------- 1 file changed, 96 insertions(+), 81 deletions(-) (limited to 'tools') diff --git a/tools/syz-check/check.go b/tools/syz-check/check.go index aeab35cd0..3c12ba3da 100644 --- a/tools/syz-check/check.go +++ b/tools/syz-check/check.go @@ -201,7 +201,7 @@ func writeWarnings(OS string, narches int, warnings []Warn) error { return nil } -func checkImpl(structs map[string]*dwarf.StructType, structTypes []*prog.StructType, +func checkImpl(structs map[string]*dwarf.StructType, structTypes []prog.Type, locs map[string]*ast.Struct) ([]Warn, error) { var warnings []Warn for _, typ := range structTypes { @@ -219,7 +219,7 @@ func checkImpl(structs map[string]*dwarf.StructType, structTypes []*prog.StructT return warnings, nil } -func checkStruct(typ *prog.StructType, astStruct *ast.Struct, str *dwarf.StructType) ([]Warn, error) { +func checkStruct(typ prog.Type, astStruct *ast.Struct, str *dwarf.StructType) ([]Warn, error) { var warnings []Warn warn := func(pos ast.Pos, typ, msg string, args ...interface{}) { warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)}) @@ -236,7 +236,7 @@ func checkStruct(typ *prog.StructType, astStruct *ast.Struct, str *dwarf.StructT warn(astStruct.Pos, WarnBadStructSize, "%v: syz=%v kernel=%v", name, typ.Size(), str.ByteSize) } // TODO: handle unions, currently we should report some false errors. - if str.Kind == "union" || astStruct.IsUnion { + if _, ok := typ.(*prog.UnionType); ok || str.Kind == "union" { return warnings, nil } // TODO: we could also check enums (elements match corresponding flags in syzkaller). @@ -255,7 +255,7 @@ func checkStruct(typ *prog.StructType, astStruct *ast.Struct, str *dwarf.StructT // e.g. if a name contains filedes/uid/pid/gid that may be the corresponding resource. ai := 0 offset := uint64(0) - for _, field := range typ.Fields { + for _, field := range typ.(*prog.StructType).Fields { if field.Type.Varlen() { ai = len(str.Field) break @@ -305,7 +305,7 @@ func checkStruct(typ *prog.StructType, astStruct *ast.Struct, str *dwarf.StructT return warnings, nil } -func parseDescriptions(OS, arch string) ([]*prog.StructType, map[string]*ast.Struct, []Warn, error) { +func parseDescriptions(OS, arch string) ([]prog.Type, map[string]*ast.Struct, []Warn, error) { errorBuf := new(bytes.Buffer) var warnings []Warn eh := func(pos ast.Pos, msg string) { @@ -336,10 +336,11 @@ func parseDescriptions(OS, arch string) ([]*prog.StructType, map[string]*ast.Str } } } - var structs []*prog.StructType + var structs []prog.Type for _, typ := range prg.Types { - if t, ok := typ.(*prog.StructType); ok { - structs = append(structs, t) + switch typ.(type) { + case *prog.StructType, *prog.UnionType: + structs = append(structs, typ) } } return structs, locs, warnings, nil @@ -354,7 +355,7 @@ func parseDescriptions(OS, arch string) ([]*prog.StructType, map[string]*ast.Str // Then read in the symbol data, which is an array of nla_policy structs. // These structs allow to easily figure out type/size of attributes. // Finally we compare our descriptions with the kernel policy description. -func checkNetlink(OS, arch, obj string, structTypes []*prog.StructType, +func checkNetlink(OS, arch, obj string, structTypes []prog.Type, locs map[string]*ast.Struct) ([]Warn, error) { if arch != "amd64" { // Netlink policies are arch-independent (?), @@ -375,89 +376,103 @@ func checkNetlink(OS, arch, obj string, structTypes []*prog.StructType, return nil, err } var warnings []Warn - warn := func(pos ast.Pos, typ, msg string, args ...interface{}) { - warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)}) - } - structMap := make(map[string]*prog.StructType) + structMap := make(map[string]prog.Type) for _, typ := range structTypes { structMap[typ.Name()] = typ } checkedAttrs := make(map[string]*checkAttr) for _, typ := range structTypes { - name := typ.TemplateName() - astStruct := locs[name] - if astStruct == nil { - continue + warnings1, err := checkNetlinkStruct(locs, symbols, rodata, structMap, checkedAttrs, typ) + if err != nil { + return nil, err } - if !isNetlinkPolicy(typ.Fields) { - continue + warnings = append(warnings, warnings1...) + } + warnings = append(warnings, checkMissingAttrs(checkedAttrs)...) + return warnings, nil +} + +func checkNetlinkStruct(locs map[string]*ast.Struct, symbols map[string][]symbolizer.Symbol, rodata *elf.Section, + structMap map[string]prog.Type, checkedAttrs map[string]*checkAttr, typ prog.Type) ([]Warn, error) { + name := typ.TemplateName() + astStruct := locs[name] + if astStruct == nil { + return nil, nil + } + var fields []prog.Field + switch t := typ.(type) { + case *prog.StructType: + fields = t.Fields + case *prog.UnionType: + fields = t.Fields + } + if !isNetlinkPolicy(fields) { + return nil, nil + } + kernelName := name + var ss []symbolizer.Symbol + // In some cases we split a single policy into multiple ones + // (more precise description), so try to match our foo_bar_baz + // with kernel foo_bar and foo as well. + for kernelName != "" { + ss = symbols[kernelName] + if len(ss) != 0 { + break } - kernelName := name - var ss []symbolizer.Symbol - // In some cases we split a single policy into multiple ones - // (more precise description), so try to match our foo_bar_baz - // with kernel foo_bar and foo as well. - for kernelName != "" { - ss = symbols[kernelName] - if len(ss) != 0 { - break - } - underscore := strings.LastIndexByte(kernelName, '_') - if underscore == -1 { - break - } - kernelName = kernelName[:underscore] + underscore := strings.LastIndexByte(kernelName, '_') + if underscore == -1 { + break } - if len(ss) == 0 { - warn(astStruct.Pos, WarnNoNetlinkPolicy, "%v", name) + kernelName = kernelName[:underscore] + } + if len(ss) == 0 { + return []Warn{{pos: astStruct.Pos, typ: WarnNoNetlinkPolicy, msg: name}}, nil + } + var warnings []Warn + var warnings1 *[]Warn + var policy1 []nlaPolicy + var attrs1 map[int]bool + // We may have several symbols with the same name (they frequently have internal linking), + // in such case we choose the one that produces fewer warnings. + for _, symb := range ss { + if symb.Size == 0 || symb.Size%int(unsafe.Sizeof(nlaPolicy{})) != 0 { + warnings = append(warnings, Warn{pos: astStruct.Pos, typ: WarnNetlinkBadSize, + msg: fmt.Sprintf("%v (%v), size %v", kernelName, name, ss[0].Size)}) continue } - var warnings1 *[]Warn - var policy1 []nlaPolicy - var attrs1 map[int]bool - // We may have several symbols with the same name (they frequently have internal linking), - // in such case we choose the one that produces fewer warnings. - for _, symb := range ss { - if symb.Size == 0 || symb.Size%int(unsafe.Sizeof(nlaPolicy{})) != 0 { - warn(astStruct.Pos, WarnNetlinkBadSize, "%v (%v), size %v", - kernelName, name, ss[0].Size) - continue - } - binary := make([]byte, symb.Size) - addr := symb.Addr - rodata.Addr - if _, err := rodata.ReadAt(binary, int64(addr)); err != nil { - return nil, fmt.Errorf("failed to read policy %v (%v) at %v: %v", - kernelName, name, symb.Addr, err) - } - policy := (*[1e6]nlaPolicy)(unsafe.Pointer(&binary[0]))[:symb.Size/int(unsafe.Sizeof(nlaPolicy{}))] - warnings2, attrs2, err := checkNetlinkPolicy(structMap, typ, astStruct, policy) - if err != nil { - return nil, err - } - if warnings1 == nil || len(*warnings1) > len(warnings2) { - warnings1 = &warnings2 - policy1 = policy - attrs1 = attrs2 - } + binary := make([]byte, symb.Size) + addr := symb.Addr - rodata.Addr + if _, err := rodata.ReadAt(binary, int64(addr)); err != nil { + return nil, fmt.Errorf("failed to read policy %v (%v) at %v: %v", + kernelName, name, symb.Addr, err) } - if warnings1 != nil { - warnings = append(warnings, *warnings1...) - ca := checkedAttrs[kernelName] - if ca == nil { - ca = &checkAttr{ - pos: astStruct.Pos, - name: name, - policy: policy1, - attrs: make(map[int]bool), - } - checkedAttrs[kernelName] = ca - } - for attr := range attrs1 { - ca.attrs[attr] = true + policy := (*[1e6]nlaPolicy)(unsafe.Pointer(&binary[0]))[:symb.Size/int(unsafe.Sizeof(nlaPolicy{}))] + warnings2, attrs2, err := checkNetlinkPolicy(structMap, typ, fields, astStruct, policy) + if err != nil { + return nil, err + } + if warnings1 == nil || len(*warnings1) > len(warnings2) { + warnings1 = &warnings2 + policy1 = policy + attrs1 = attrs2 + } + } + if warnings1 != nil { + warnings = append(warnings, *warnings1...) + ca := checkedAttrs[kernelName] + if ca == nil { + ca = &checkAttr{ + pos: astStruct.Pos, + name: name, + policy: policy1, + attrs: make(map[int]bool), } + checkedAttrs[kernelName] = ca + } + for attr := range attrs1 { + ca.attrs[attr] = true } } - warnings = append(warnings, checkMissingAttrs(checkedAttrs)...) return warnings, nil } @@ -531,7 +546,7 @@ func isNlattr(typ prog.Type) bool { return name == "nlattr_t" || name == "nlattr_tt" } -func checkNetlinkPolicy(structMap map[string]*prog.StructType, typ *prog.StructType, +func checkNetlinkPolicy(structMap map[string]prog.Type, typ prog.Type, fields []prog.Field, astStruct *ast.Struct, policy []nlaPolicy) ([]Warn, map[int]bool, error) { var warnings []Warn warn := func(pos ast.Pos, typ, msg string, args ...interface{}) { @@ -539,7 +554,7 @@ func checkNetlinkPolicy(structMap map[string]*prog.StructType, typ *prog.StructT } checked := make(map[int]bool) ai := 0 - for _, field := range typ.Fields { + for _, field := range fields { if prog.IsPad(field.Type) { continue } @@ -548,7 +563,7 @@ func checkNetlinkPolicy(structMap map[string]*prog.StructType, typ *prog.StructT if !isNlattr(field.Type) { continue } - ft := structMap[field.Type.Name()] + ft := field.Type.(*prog.StructType) attr := int(ft.Fields[1].Type.(*prog.ConstType).Val) if attr >= len(policy) { warn(fld.Pos, WarnNetlinkBadAttrType, "%v.%v: type %v, kernel policy size %v", -- cgit mrf-deployment