aboutsummaryrefslogtreecommitdiffstats
path: root/tools/syz-check
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2020-05-25 15:33:49 +0200
committerDmitry Vyukov <dvyukov@google.com>2020-05-25 18:06:29 +0200
commit0f54349fc6f9c0c5507604dca2df9aadbc660a8b (patch)
treeae58ad60ecdef7c9678aa26c6adb62214cad5ffc /tools/syz-check
parent82f3c7881f572ba32d304fbff6bc29cab9625174 (diff)
tools/syz-check: restore handling of unions
Unions were dropped accidentially during removal StructDesc.
Diffstat (limited to 'tools/syz-check')
-rw-r--r--tools/syz-check/check.go177
1 files changed, 96 insertions, 81 deletions
diff --git a/tools/syz-check/check.go b/tools/syz-check/check.go
index aeab35cd0..3c12ba3da 100644
--- a/tools/syz-check/check.go
+++ b/tools/syz-check/check.go
@@ -201,7 +201,7 @@ func writeWarnings(OS string, narches int, warnings []Warn) error {
return nil
}
-func checkImpl(structs map[string]*dwarf.StructType, structTypes []*prog.StructType,
+func checkImpl(structs map[string]*dwarf.StructType, structTypes []prog.Type,
locs map[string]*ast.Struct) ([]Warn, error) {
var warnings []Warn
for _, typ := range structTypes {
@@ -219,7 +219,7 @@ func checkImpl(structs map[string]*dwarf.StructType, structTypes []*prog.StructT
return warnings, nil
}
-func checkStruct(typ *prog.StructType, astStruct *ast.Struct, str *dwarf.StructType) ([]Warn, error) {
+func checkStruct(typ prog.Type, astStruct *ast.Struct, str *dwarf.StructType) ([]Warn, error) {
var warnings []Warn
warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)})
@@ -236,7 +236,7 @@ func checkStruct(typ *prog.StructType, astStruct *ast.Struct, str *dwarf.StructT
warn(astStruct.Pos, WarnBadStructSize, "%v: syz=%v kernel=%v", name, typ.Size(), str.ByteSize)
}
// TODO: handle unions, currently we should report some false errors.
- if str.Kind == "union" || astStruct.IsUnion {
+ if _, ok := typ.(*prog.UnionType); ok || str.Kind == "union" {
return warnings, nil
}
// TODO: we could also check enums (elements match corresponding flags in syzkaller).
@@ -255,7 +255,7 @@ func checkStruct(typ *prog.StructType, astStruct *ast.Struct, str *dwarf.StructT
// e.g. if a name contains filedes/uid/pid/gid that may be the corresponding resource.
ai := 0
offset := uint64(0)
- for _, field := range typ.Fields {
+ for _, field := range typ.(*prog.StructType).Fields {
if field.Type.Varlen() {
ai = len(str.Field)
break
@@ -305,7 +305,7 @@ func checkStruct(typ *prog.StructType, astStruct *ast.Struct, str *dwarf.StructT
return warnings, nil
}
-func parseDescriptions(OS, arch string) ([]*prog.StructType, map[string]*ast.Struct, []Warn, error) {
+func parseDescriptions(OS, arch string) ([]prog.Type, map[string]*ast.Struct, []Warn, error) {
errorBuf := new(bytes.Buffer)
var warnings []Warn
eh := func(pos ast.Pos, msg string) {
@@ -336,10 +336,11 @@ func parseDescriptions(OS, arch string) ([]*prog.StructType, map[string]*ast.Str
}
}
}
- var structs []*prog.StructType
+ var structs []prog.Type
for _, typ := range prg.Types {
- if t, ok := typ.(*prog.StructType); ok {
- structs = append(structs, t)
+ switch typ.(type) {
+ case *prog.StructType, *prog.UnionType:
+ structs = append(structs, typ)
}
}
return structs, locs, warnings, nil
@@ -354,7 +355,7 @@ func parseDescriptions(OS, arch string) ([]*prog.StructType, map[string]*ast.Str
// Then read in the symbol data, which is an array of nla_policy structs.
// These structs allow to easily figure out type/size of attributes.
// Finally we compare our descriptions with the kernel policy description.
-func checkNetlink(OS, arch, obj string, structTypes []*prog.StructType,
+func checkNetlink(OS, arch, obj string, structTypes []prog.Type,
locs map[string]*ast.Struct) ([]Warn, error) {
if arch != "amd64" {
// Netlink policies are arch-independent (?),
@@ -375,89 +376,103 @@ func checkNetlink(OS, arch, obj string, structTypes []*prog.StructType,
return nil, err
}
var warnings []Warn
- warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
- warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)})
- }
- structMap := make(map[string]*prog.StructType)
+ structMap := make(map[string]prog.Type)
for _, typ := range structTypes {
structMap[typ.Name()] = typ
}
checkedAttrs := make(map[string]*checkAttr)
for _, typ := range structTypes {
- name := typ.TemplateName()
- astStruct := locs[name]
- if astStruct == nil {
- continue
+ warnings1, err := checkNetlinkStruct(locs, symbols, rodata, structMap, checkedAttrs, typ)
+ if err != nil {
+ return nil, err
}
- if !isNetlinkPolicy(typ.Fields) {
- continue
+ warnings = append(warnings, warnings1...)
+ }
+ warnings = append(warnings, checkMissingAttrs(checkedAttrs)...)
+ return warnings, nil
+}
+
+func checkNetlinkStruct(locs map[string]*ast.Struct, symbols map[string][]symbolizer.Symbol, rodata *elf.Section,
+ structMap map[string]prog.Type, checkedAttrs map[string]*checkAttr, typ prog.Type) ([]Warn, error) {
+ name := typ.TemplateName()
+ astStruct := locs[name]
+ if astStruct == nil {
+ return nil, nil
+ }
+ var fields []prog.Field
+ switch t := typ.(type) {
+ case *prog.StructType:
+ fields = t.Fields
+ case *prog.UnionType:
+ fields = t.Fields
+ }
+ if !isNetlinkPolicy(fields) {
+ return nil, nil
+ }
+ kernelName := name
+ var ss []symbolizer.Symbol
+ // In some cases we split a single policy into multiple ones
+ // (more precise description), so try to match our foo_bar_baz
+ // with kernel foo_bar and foo as well.
+ for kernelName != "" {
+ ss = symbols[kernelName]
+ if len(ss) != 0 {
+ break
}
- kernelName := name
- var ss []symbolizer.Symbol
- // In some cases we split a single policy into multiple ones
- // (more precise description), so try to match our foo_bar_baz
- // with kernel foo_bar and foo as well.
- for kernelName != "" {
- ss = symbols[kernelName]
- if len(ss) != 0 {
- break
- }
- underscore := strings.LastIndexByte(kernelName, '_')
- if underscore == -1 {
- break
- }
- kernelName = kernelName[:underscore]
+ underscore := strings.LastIndexByte(kernelName, '_')
+ if underscore == -1 {
+ break
}
- if len(ss) == 0 {
- warn(astStruct.Pos, WarnNoNetlinkPolicy, "%v", name)
+ kernelName = kernelName[:underscore]
+ }
+ if len(ss) == 0 {
+ return []Warn{{pos: astStruct.Pos, typ: WarnNoNetlinkPolicy, msg: name}}, nil
+ }
+ var warnings []Warn
+ var warnings1 *[]Warn
+ var policy1 []nlaPolicy
+ var attrs1 map[int]bool
+ // We may have several symbols with the same name (they frequently have internal linking),
+ // in such case we choose the one that produces fewer warnings.
+ for _, symb := range ss {
+ if symb.Size == 0 || symb.Size%int(unsafe.Sizeof(nlaPolicy{})) != 0 {
+ warnings = append(warnings, Warn{pos: astStruct.Pos, typ: WarnNetlinkBadSize,
+ msg: fmt.Sprintf("%v (%v), size %v", kernelName, name, ss[0].Size)})
continue
}
- var warnings1 *[]Warn
- var policy1 []nlaPolicy
- var attrs1 map[int]bool
- // We may have several symbols with the same name (they frequently have internal linking),
- // in such case we choose the one that produces fewer warnings.
- for _, symb := range ss {
- if symb.Size == 0 || symb.Size%int(unsafe.Sizeof(nlaPolicy{})) != 0 {
- warn(astStruct.Pos, WarnNetlinkBadSize, "%v (%v), size %v",
- kernelName, name, ss[0].Size)
- continue
- }
- binary := make([]byte, symb.Size)
- addr := symb.Addr - rodata.Addr
- if _, err := rodata.ReadAt(binary, int64(addr)); err != nil {
- return nil, fmt.Errorf("failed to read policy %v (%v) at %v: %v",
- kernelName, name, symb.Addr, err)
- }
- policy := (*[1e6]nlaPolicy)(unsafe.Pointer(&binary[0]))[:symb.Size/int(unsafe.Sizeof(nlaPolicy{}))]
- warnings2, attrs2, err := checkNetlinkPolicy(structMap, typ, astStruct, policy)
- if err != nil {
- return nil, err
- }
- if warnings1 == nil || len(*warnings1) > len(warnings2) {
- warnings1 = &warnings2
- policy1 = policy
- attrs1 = attrs2
- }
+ binary := make([]byte, symb.Size)
+ addr := symb.Addr - rodata.Addr
+ if _, err := rodata.ReadAt(binary, int64(addr)); err != nil {
+ return nil, fmt.Errorf("failed to read policy %v (%v) at %v: %v",
+ kernelName, name, symb.Addr, err)
}
- if warnings1 != nil {
- warnings = append(warnings, *warnings1...)
- ca := checkedAttrs[kernelName]
- if ca == nil {
- ca = &checkAttr{
- pos: astStruct.Pos,
- name: name,
- policy: policy1,
- attrs: make(map[int]bool),
- }
- checkedAttrs[kernelName] = ca
- }
- for attr := range attrs1 {
- ca.attrs[attr] = true
+ policy := (*[1e6]nlaPolicy)(unsafe.Pointer(&binary[0]))[:symb.Size/int(unsafe.Sizeof(nlaPolicy{}))]
+ warnings2, attrs2, err := checkNetlinkPolicy(structMap, typ, fields, astStruct, policy)
+ if err != nil {
+ return nil, err
+ }
+ if warnings1 == nil || len(*warnings1) > len(warnings2) {
+ warnings1 = &warnings2
+ policy1 = policy
+ attrs1 = attrs2
+ }
+ }
+ if warnings1 != nil {
+ warnings = append(warnings, *warnings1...)
+ ca := checkedAttrs[kernelName]
+ if ca == nil {
+ ca = &checkAttr{
+ pos: astStruct.Pos,
+ name: name,
+ policy: policy1,
+ attrs: make(map[int]bool),
}
+ checkedAttrs[kernelName] = ca
+ }
+ for attr := range attrs1 {
+ ca.attrs[attr] = true
}
}
- warnings = append(warnings, checkMissingAttrs(checkedAttrs)...)
return warnings, nil
}
@@ -531,7 +546,7 @@ func isNlattr(typ prog.Type) bool {
return name == "nlattr_t" || name == "nlattr_tt"
}
-func checkNetlinkPolicy(structMap map[string]*prog.StructType, typ *prog.StructType,
+func checkNetlinkPolicy(structMap map[string]prog.Type, typ prog.Type, fields []prog.Field,
astStruct *ast.Struct, policy []nlaPolicy) ([]Warn, map[int]bool, error) {
var warnings []Warn
warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
@@ -539,7 +554,7 @@ func checkNetlinkPolicy(structMap map[string]*prog.StructType, typ *prog.StructT
}
checked := make(map[int]bool)
ai := 0
- for _, field := range typ.Fields {
+ for _, field := range fields {
if prog.IsPad(field.Type) {
continue
}
@@ -548,7 +563,7 @@ func checkNetlinkPolicy(structMap map[string]*prog.StructType, typ *prog.StructT
if !isNlattr(field.Type) {
continue
}
- ft := structMap[field.Type.Name()]
+ ft := field.Type.(*prog.StructType)
attr := int(ft.Fields[1].Type.(*prog.ConstType).Val)
if attr >= len(policy) {
warn(fld.Pos, WarnNetlinkBadAttrType, "%v.%v: type %v, kernel policy size %v",