aboutsummaryrefslogtreecommitdiffstats
path: root/tools/syz-check
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2020-01-22 12:17:20 +0100
committerDmitry Vyukov <dvyukov@google.com>2020-01-22 12:19:53 +0100
commit02754a8f9af246f440492295487282e55dc09cc0 (patch)
tree728b59573a119421b2a5e2b89f965edc9ff6f93a /tools/syz-check
parent8eda0b957e5b39c0c525e74f51d6b39ab8c5b1ac (diff)
tools/syz-check: check netlink policy descriptions
Overall idea of netlink checking. Currnetly we check netlink policies for common detectable mistakes. First, we detect what looks like a netlink policy in our descriptions (these are structs/unions only with nlattr/nlnext/nlnetw fields). Then we find corresponding symbols (offset/size) in vmlinux using nm. Then we read elf headers and locate where these symbols are in the rodata section. Then read in the symbol data, which is an array of nla_policy structs. These structs allow to easily figure out type/size of attributes. Finally we compare our descriptions with the kernel policy description. Update #590
Diffstat (limited to 'tools/syz-check')
-rw-r--r--tools/syz-check/check.go314
1 files changed, 297 insertions, 17 deletions
diff --git a/tools/syz-check/check.go b/tools/syz-check/check.go
index f5f0e7454..bd6c76c15 100644
--- a/tools/syz-check/check.go
+++ b/tools/syz-check/check.go
@@ -15,6 +15,7 @@ package main
import (
"bytes"
"debug/dwarf"
+ "debug/elf"
"flag"
"fmt"
"os"
@@ -23,10 +24,12 @@ import (
"runtime/pprof"
"sort"
"strings"
+ "unsafe"
"github.com/google/syzkaller/pkg/ast"
"github.com/google/syzkaller/pkg/compiler"
"github.com/google/syzkaller/pkg/osutil"
+ "github.com/google/syzkaller/pkg/symbolizer"
"github.com/google/syzkaller/prog"
"github.com/google/syzkaller/sys/targets"
)
@@ -86,10 +89,12 @@ func main() {
}
func check(OS, arch, obj string) ([]Warn, error) {
+ var warnings []Warn
structDescs, locs, warnings1, err := parseDescriptions(OS, arch)
if err != nil {
return nil, err
}
+ warnings = append(warnings, warnings1...)
structs, err := parseKernelObject(obj)
if err != nil {
return nil, err
@@ -98,7 +103,12 @@ func check(OS, arch, obj string) ([]Warn, error) {
if err != nil {
return nil, err
}
- warnings := append(warnings1, warnings2...)
+ warnings = append(warnings, warnings2...)
+ warnings3, err := checkNetlink(OS, arch, obj, structDescs, locs)
+ if err != nil {
+ return nil, err
+ }
+ warnings = append(warnings, warnings3...)
for i := range warnings {
warnings[i].arch = arch
}
@@ -106,13 +116,18 @@ func check(OS, arch, obj string) ([]Warn, error) {
}
const (
- WarnCompiler = "compiler"
- WarnNoSuchStruct = "no-such-struct"
- WarnBadStructSize = "bad-struct-size"
- WarnBadFieldNumber = "bad-field-number"
- WarnBadFieldSize = "bad-field-size"
- WarnBadFieldOffset = "bad-field-offset"
- WarnBadBitfield = "bad-bitfield"
+ WarnCompiler = "compiler"
+ WarnNoSuchStruct = "no-such-struct"
+ WarnBadStructSize = "bad-struct-size"
+ WarnBadFieldNumber = "bad-field-number"
+ WarnBadFieldSize = "bad-field-size"
+ WarnBadFieldOffset = "bad-field-offset"
+ WarnBadBitfield = "bad-bitfield"
+ WarnNoNetlinkPolicy = "no-such-netlink-policy"
+ WarnMultipleNetlinkPolicy = "multiple-netlink-policy"
+ WarnNetlinkBadSize = "bad-kernel-netlink-policy-size"
+ WarnNetlinkBadAttrType = "bad-netlink-attr-type"
+ WarnNetlinkBadAttr = "bad-netlink-attr"
)
type Warn struct {
@@ -192,7 +207,7 @@ func checkImpl(structs map[string]*dwarf.StructType, structDescs []*prog.KeyedSt
continue
}
checked[typ.Name()] = true
- name := templateName(typ.Name())
+ name := typ.TemplateName()
astStruct := locs[name]
if astStruct == nil {
continue
@@ -206,19 +221,12 @@ func checkImpl(structs map[string]*dwarf.StructType, structDescs []*prog.KeyedSt
return warnings, nil
}
-func templateName(name string) string {
- if pos := strings.IndexByte(name, '['); pos != -1 {
- name = name[:pos]
- }
- return name
-}
-
func checkStruct(typ *prog.StructDesc, astStruct *ast.Struct, str *dwarf.StructType) ([]Warn, error) {
var warnings []Warn
warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)})
}
- name := templateName(typ.Name())
+ name := typ.TemplateName()
if str == nil {
warn(astStruct.Pos, WarnNoSuchStruct, "%v", name)
return warnings, nil
@@ -325,3 +333,275 @@ func parseDescriptions(OS, arch string) ([]*prog.KeyedStruct, map[string]*ast.St
}
return prg.StructDescs, locs, warnings, nil
}
+
+// Overall idea of netlink checking.
+// Currnetly we check netlink policies for common detectable mistakes.
+// First, we detect what looks like a netlink policy in our descriptions
+// (these are structs/unions only with nlattr/nlnext/nlnetw fields).
+// Then we find corresponding symbols (offset/size) in vmlinux using nm.
+// Then we read elf headers and locate where these symbols are in the rodata section.
+// Then read in the symbol data, which is an array of nla_policy structs.
+// These structs allow to easily figure out type/size of attributes.
+// Finally we compare our descriptions with the kernel policy description.
+func checkNetlink(OS, arch, obj string, structDescs []*prog.KeyedStruct,
+ locs map[string]*ast.Struct) ([]Warn, error) {
+ if arch != "amd64" {
+ // Netlink policies are arch-independent (?),
+ // so no need to check all arches.
+ // Also our definition of nlaPolicy below is 64-bit specific.
+ return nil, nil
+ }
+ ef, err := elf.Open(obj)
+ if err != nil {
+ return nil, err
+ }
+ rodata := ef.Section(".rodata")
+ if rodata == nil {
+ return nil, fmt.Errorf("object file %v does not contain .rodata section", obj)
+ }
+ symbols, err := symbolizer.ReadRodataSymbols(obj)
+ if err != nil {
+ return nil, err
+ }
+ var warnings []Warn
+ warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
+ warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)})
+ }
+ structMap := make(map[string]*prog.StructDesc)
+ for _, str := range structDescs {
+ structMap[str.Desc.Name()] = str.Desc
+ }
+ checked := make(map[string]bool)
+ for _, str := range structDescs {
+ typ := str.Desc
+ if checked[typ.Name()] {
+ continue
+ }
+ checked[typ.Name()] = true
+ name := typ.TemplateName()
+ astStruct := locs[name]
+ if astStruct == nil {
+ continue
+ }
+ if !isNetlinkPolicy(typ) {
+ continue
+ }
+ ss := symbols[name]
+ if len(ss) == 0 {
+ warn(astStruct.Pos, WarnNoNetlinkPolicy, "%v", name)
+ continue
+ }
+ if len(ss) != 1 {
+ warn(astStruct.Pos, WarnMultipleNetlinkPolicy, "%v", name)
+ continue
+ }
+ if ss[0].Size == 0 || ss[0].Size%int(unsafe.Sizeof(nlaPolicy{})) != 0 {
+ warn(astStruct.Pos, WarnNetlinkBadSize, "%v, size %v", name, ss[0].Size)
+ }
+ binary := make([]byte, ss[0].Size)
+ addr := ss[0].Addr - rodata.Addr
+ if _, err := rodata.ReadAt(binary, int64(addr)); err != nil {
+ return nil, fmt.Errorf("failed to read policy %v at %v: %v", name, ss[0].Addr, err)
+ }
+ policy := (*[1e6]nlaPolicy)(unsafe.Pointer(&binary[0]))[:ss[0].Size/int(unsafe.Sizeof(nlaPolicy{}))]
+ warnings1, err := checkNetlinkPolicy(structMap, typ, astStruct, policy)
+ if err != nil {
+ return nil, err
+ }
+ warnings = append(warnings, warnings1...)
+ }
+ return warnings, nil
+}
+
+func isNetlinkPolicy(typ *prog.StructDesc) bool {
+ for _, field := range typ.Fields {
+ if prog.IsPad(field) {
+ continue
+ }
+ name := field.TemplateName()
+ if name != "nlattr_t" && name != "nlattr_tt" {
+ return false
+ }
+ }
+ return true
+}
+
+func checkNetlinkPolicy(structMap map[string]*prog.StructDesc, typ *prog.StructDesc,
+ astStruct *ast.Struct, policy []nlaPolicy) ([]Warn, error) {
+ var warnings []Warn
+ warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
+ warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)})
+ }
+ ai := 0
+ for _, field := range typ.Fields {
+ if prog.IsPad(field) {
+ continue
+ }
+ fld := astStruct.Fields[ai]
+ ai++
+ ft := structMap[field.Name()]
+ attr := int(ft.Fields[1].(*prog.ConstType).Val)
+ if attr >= len(policy) {
+ warn(fld.Pos, WarnNetlinkBadAttrType, "%v.%v: type %v, kernel policy size %v",
+ typ.TemplateName(), field.FieldName(), attr, len(policy))
+ }
+ w := checkNetlinkAttr(ft, policy[attr])
+ if w != "" {
+ warn(fld.Pos, WarnNetlinkBadAttr, "%v.%v: %v",
+ typ.TemplateName(), field.FieldName(), w)
+ }
+ }
+ return warnings, nil
+}
+
+func checkNetlinkAttr(typ *prog.StructDesc, policy nlaPolicy) string {
+ payload := typ.Fields[2]
+ if typ.TemplateName() == "nlattr_tt" {
+ payload = typ.Fields[4]
+ }
+ if warn := checkAttrType(typ, payload, policy); warn != "" {
+ return warn
+ }
+ size, minSize, maxSize := attrSize(policy)
+ payloadSize := -1
+ if !payload.Varlen() {
+ payloadSize = int(payload.Size())
+ }
+ if size != -1 && size != payloadSize {
+ return fmt.Sprintf("bad size %v, expect %v", payloadSize, size)
+ }
+ if minSize != -1 && minSize > payloadSize {
+ return fmt.Sprintf("bad size %v, expect min %v", payloadSize, minSize)
+ }
+ if maxSize != -1 && maxSize < payloadSize {
+ return fmt.Sprintf("bad size %v, expect max %v", payloadSize, maxSize)
+ }
+
+ valMin, valMax, haveVal := typeMinMaxValue(payload)
+ if haveVal {
+ if policy.validation == NLA_VALIDATE_RANGE || policy.validation == NLA_VALIDATE_MIN {
+ if int64(valMin) < int64(policy.minVal) {
+ return fmt.Sprintf("bad min value %v, expect %v",
+ int64(valMin), policy.minVal)
+ }
+ }
+ if policy.validation == NLA_VALIDATE_RANGE || policy.validation == NLA_VALIDATE_MAX {
+ if int64(valMax) > int64(policy.maxVal) {
+ return fmt.Sprintf("bad max value %v, expect %v",
+ int64(valMax), policy.maxVal)
+ }
+ }
+ }
+ return ""
+}
+
+func checkAttrType(typ *prog.StructDesc, payload prog.Type, policy nlaPolicy) string {
+ switch policy.typ {
+ case NLA_STRING, NLA_NUL_STRING:
+ if payload.Name() != "string" {
+ return fmt.Sprintf("expect string")
+ }
+ case NLA_NESTED:
+ if typ.TemplateName() != "nlattr_tt" || typ.Fields[3].(*prog.ConstType).Val != 1 {
+ return fmt.Sprintf("should be nlnest")
+ }
+ case NLA_NESTED_ARRAY, NLA_BITFIELD32, NLA_REJECT:
+ return fmt.Sprintf("unhandled type %v", policy.typ)
+ }
+ return ""
+}
+
+func attrSize(policy nlaPolicy) (int, int, int) {
+ switch policy.typ {
+ case NLA_UNSPEC:
+ if policy.len != 0 {
+ return int(policy.len), -1, -1
+ }
+ case NLA_MIN_LEN:
+ return -1, int(policy.len), -1
+ case NLA_EXACT_LEN, NLA_EXACT_LEN_WARN:
+ return int(policy.len), -1, -1
+ case NLA_U8, NLA_S8:
+ return 1, -1, -1
+ case NLA_U16, NLA_S16:
+ return 2, -1, -1
+ case NLA_U32, NLA_S32:
+ return 4, -1, -1
+ case NLA_U64, NLA_S64, NLA_MSECS:
+ return 8, -1, -1
+ case NLA_FLAG:
+ return 0, -1, -1
+ case NLA_BINARY:
+ if policy.len != 0 {
+ return -1, -1, int(policy.len)
+ }
+ }
+ return -1, -1, -1
+}
+
+func typeMinMaxValue(payload prog.Type) (min, max uint64, ok bool) {
+ switch typ := payload.(type) {
+ case *prog.ConstType:
+ return typ.Val, typ.Val, true
+ case *prog.IntType:
+ if typ.Kind == prog.IntRange {
+ return typ.RangeBegin, typ.RangeEnd, true
+ }
+ return 0, ^uint64(0), true
+ case *prog.FlagsType:
+ min, max := ^uint64(0), uint64(0)
+ for _, v := range typ.Vals {
+ if min > v {
+ min = v
+ }
+ if max < v {
+ max = v
+ }
+ }
+ return min, max, true
+ }
+ return 0, 0, false
+}
+
+type nlaPolicy struct {
+ typ uint8
+ validation uint8
+ len uint16
+ _ uint32
+ minVal int16
+ maxVal int16
+ _ int32
+}
+
+// nolint
+const (
+ NLA_UNSPEC = iota
+ NLA_U8
+ NLA_U16
+ NLA_U32
+ NLA_U64
+ NLA_STRING
+ NLA_FLAG
+ NLA_MSECS
+ NLA_NESTED
+ NLA_NESTED_ARRAY
+ NLA_NUL_STRING
+ NLA_BINARY
+ NLA_S8
+ NLA_S16
+ NLA_S32
+ NLA_S64
+ NLA_BITFIELD32
+ NLA_REJECT
+ NLA_EXACT_LEN
+ NLA_EXACT_LEN_WARN
+ NLA_MIN_LEN
+)
+
+// nolint
+const (
+ _ = iota
+ NLA_VALIDATE_RANGE
+ NLA_VALIDATE_MIN
+ NLA_VALIDATE_MAX
+)