aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pkg/cover/report.go2
-rw-r--r--pkg/report/linux.go2
-rw-r--r--pkg/report/netbsd.go2
-rw-r--r--pkg/report/openbsd.go2
-rw-r--r--pkg/symbolizer/nm.go28
-rw-r--r--pkg/symbolizer/nm_test.go2
-rw-r--r--tools/syz-check/check.go314
7 files changed, 324 insertions, 28 deletions
diff --git a/pkg/cover/report.go b/pkg/cover/report.go
index b92c5e982..d66d730b3 100644
--- a/pkg/cover/report.go
+++ b/pkg/cover/report.go
@@ -313,7 +313,7 @@ func (rg *ReportGenerator) findSymbol(pc uint64) uint64 {
}
func readSymbols(obj string) ([]symbol, error) {
- raw, err := symbolizer.ReadSymbols(obj)
+ raw, err := symbolizer.ReadTextSymbols(obj)
if err != nil {
return nil, fmt.Errorf("failed to run nm on %v: %v", obj, err)
}
diff --git a/pkg/report/linux.go b/pkg/report/linux.go
index ac9a78ee5..968bc1692 100644
--- a/pkg/report/linux.go
+++ b/pkg/report/linux.go
@@ -38,7 +38,7 @@ func ctorLinux(cfg *config) (Reporter, []string, error) {
if cfg.kernelObj != "" {
vmlinux = filepath.Join(cfg.kernelObj, cfg.target.KernelObject)
var err error
- symbols, err = symbolizer.ReadSymbols(vmlinux)
+ symbols, err = symbolizer.ReadTextSymbols(vmlinux)
if err != nil {
return nil, nil, err
}
diff --git a/pkg/report/netbsd.go b/pkg/report/netbsd.go
index 036e9dc9c..69287ed58 100644
--- a/pkg/report/netbsd.go
+++ b/pkg/report/netbsd.go
@@ -37,7 +37,7 @@ func ctorNetbsd(cfg *config) (Reporter, []string, error) {
if cfg.kernelObj != "" {
kernelObject = filepath.Join(cfg.kernelObj, cfg.target.KernelObject)
var err error
- symbols, err = symbolizer.ReadSymbols(kernelObject)
+ symbols, err = symbolizer.ReadTextSymbols(kernelObject)
if err != nil {
return nil, nil, err
}
diff --git a/pkg/report/openbsd.go b/pkg/report/openbsd.go
index 9cf1a593b..99e49c81c 100644
--- a/pkg/report/openbsd.go
+++ b/pkg/report/openbsd.go
@@ -36,7 +36,7 @@ func ctorOpenbsd(cfg *config) (Reporter, []string, error) {
if cfg.kernelObj != "" {
kernelObject = filepath.Join(cfg.kernelObj, cfg.target.KernelObject)
var err error
- symbols, err = symbolizer.ReadSymbols(kernelObject)
+ symbols, err = symbolizer.ReadTextSymbols(kernelObject)
if err != nil {
return nil, nil, err
}
diff --git a/pkg/symbolizer/nm.go b/pkg/symbolizer/nm.go
index caa246629..09baf8119 100644
--- a/pkg/symbolizer/nm.go
+++ b/pkg/symbolizer/nm.go
@@ -16,8 +16,21 @@ type Symbol struct {
Size int
}
-// ReadSymbols returns list of text symbols in the binary bin.
-func ReadSymbols(bin string) (map[string][]Symbol, error) {
+// ReadTextSymbols returns list of text symbols in the binary bin.
+func ReadTextSymbols(bin string) (map[string][]Symbol, error) {
+ return read(bin, "t", "T")
+}
+
+// ReadRodataSymbols returns list of rodata symbols in the binary bin.
+func ReadRodataSymbols(bin string) (map[string][]Symbol, error) {
+ return read(bin, "r", "R")
+}
+
+func read(bin string, types ...string) (map[string][]Symbol, error) {
+ if len(types) != 2 || len(types[0]) != 1 || len(types[1]) != 1 {
+ // We assume these things below.
+ panic("bad types")
+ }
cmd := osutil.Command("nm", "-Ptx", bin)
stdout, err := cmd.StdoutPipe()
if err != nil {
@@ -30,11 +43,14 @@ func ReadSymbols(bin string) (map[string][]Symbol, error) {
defer cmd.Wait()
symbols := make(map[string][]Symbol)
s := bufio.NewScanner(stdout)
- text := [][]byte{[]byte(" t "), []byte(" T ")}
+ var tt [][]byte
+ for _, typ := range types {
+ tt = append(tt, []byte(" "+typ+" "))
+ }
for s.Scan() {
// A line looks as: "snb_uncore_msr_enable_box t ffffffff8104db90 0000000000000059"
ln := s.Bytes()
- if !bytes.Contains(ln, text[0]) && !bytes.Contains(ln, text[1]) {
+ if !bytes.Contains(ln, tt[0]) && !bytes.Contains(ln, tt[1]) {
continue
}
@@ -42,11 +58,11 @@ func ReadSymbols(bin string) (map[string][]Symbol, error) {
if sp1 == -1 {
continue
}
- if !bytes.HasPrefix(ln[sp1:], text[0]) && !bytes.HasPrefix(ln[sp1:], text[1]) {
+ if !bytes.HasPrefix(ln[sp1:], tt[0]) && !bytes.HasPrefix(ln[sp1:], tt[1]) {
continue
}
- sp2 := sp1 + len(text[0])
+ sp2 := sp1 + len(tt[0])
sp3 := bytes.IndexByte(ln[sp2:], ' ')
if sp3 == -1 {
continue
diff --git a/pkg/symbolizer/nm_test.go b/pkg/symbolizer/nm_test.go
index 44a02f3cc..6061608ec 100644
--- a/pkg/symbolizer/nm_test.go
+++ b/pkg/symbolizer/nm_test.go
@@ -8,7 +8,7 @@ import (
)
func TestSymbols(t *testing.T) {
- symbols, err := ReadSymbols("testdata/nm.test.out")
+ symbols, err := ReadTextSymbols("testdata/nm.test.out")
if err != nil {
t.Fatalf("failed to read symbols: %v", err)
}
diff --git a/tools/syz-check/check.go b/tools/syz-check/check.go
index f5f0e7454..bd6c76c15 100644
--- a/tools/syz-check/check.go
+++ b/tools/syz-check/check.go
@@ -15,6 +15,7 @@ package main
import (
"bytes"
"debug/dwarf"
+ "debug/elf"
"flag"
"fmt"
"os"
@@ -23,10 +24,12 @@ import (
"runtime/pprof"
"sort"
"strings"
+ "unsafe"
"github.com/google/syzkaller/pkg/ast"
"github.com/google/syzkaller/pkg/compiler"
"github.com/google/syzkaller/pkg/osutil"
+ "github.com/google/syzkaller/pkg/symbolizer"
"github.com/google/syzkaller/prog"
"github.com/google/syzkaller/sys/targets"
)
@@ -86,10 +89,12 @@ func main() {
}
func check(OS, arch, obj string) ([]Warn, error) {
+ var warnings []Warn
structDescs, locs, warnings1, err := parseDescriptions(OS, arch)
if err != nil {
return nil, err
}
+ warnings = append(warnings, warnings1...)
structs, err := parseKernelObject(obj)
if err != nil {
return nil, err
@@ -98,7 +103,12 @@ func check(OS, arch, obj string) ([]Warn, error) {
if err != nil {
return nil, err
}
- warnings := append(warnings1, warnings2...)
+ warnings = append(warnings, warnings2...)
+ warnings3, err := checkNetlink(OS, arch, obj, structDescs, locs)
+ if err != nil {
+ return nil, err
+ }
+ warnings = append(warnings, warnings3...)
for i := range warnings {
warnings[i].arch = arch
}
@@ -106,13 +116,18 @@ func check(OS, arch, obj string) ([]Warn, error) {
}
const (
- WarnCompiler = "compiler"
- WarnNoSuchStruct = "no-such-struct"
- WarnBadStructSize = "bad-struct-size"
- WarnBadFieldNumber = "bad-field-number"
- WarnBadFieldSize = "bad-field-size"
- WarnBadFieldOffset = "bad-field-offset"
- WarnBadBitfield = "bad-bitfield"
+ WarnCompiler = "compiler"
+ WarnNoSuchStruct = "no-such-struct"
+ WarnBadStructSize = "bad-struct-size"
+ WarnBadFieldNumber = "bad-field-number"
+ WarnBadFieldSize = "bad-field-size"
+ WarnBadFieldOffset = "bad-field-offset"
+ WarnBadBitfield = "bad-bitfield"
+ WarnNoNetlinkPolicy = "no-such-netlink-policy"
+ WarnMultipleNetlinkPolicy = "multiple-netlink-policy"
+ WarnNetlinkBadSize = "bad-kernel-netlink-policy-size"
+ WarnNetlinkBadAttrType = "bad-netlink-attr-type"
+ WarnNetlinkBadAttr = "bad-netlink-attr"
)
type Warn struct {
@@ -192,7 +207,7 @@ func checkImpl(structs map[string]*dwarf.StructType, structDescs []*prog.KeyedSt
continue
}
checked[typ.Name()] = true
- name := templateName(typ.Name())
+ name := typ.TemplateName()
astStruct := locs[name]
if astStruct == nil {
continue
@@ -206,19 +221,12 @@ func checkImpl(structs map[string]*dwarf.StructType, structDescs []*prog.KeyedSt
return warnings, nil
}
-func templateName(name string) string {
- if pos := strings.IndexByte(name, '['); pos != -1 {
- name = name[:pos]
- }
- return name
-}
-
func checkStruct(typ *prog.StructDesc, astStruct *ast.Struct, str *dwarf.StructType) ([]Warn, error) {
var warnings []Warn
warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)})
}
- name := templateName(typ.Name())
+ name := typ.TemplateName()
if str == nil {
warn(astStruct.Pos, WarnNoSuchStruct, "%v", name)
return warnings, nil
@@ -325,3 +333,275 @@ func parseDescriptions(OS, arch string) ([]*prog.KeyedStruct, map[string]*ast.St
}
return prg.StructDescs, locs, warnings, nil
}
+
+// Overall idea of netlink checking.
+// Currnetly we check netlink policies for common detectable mistakes.
+// First, we detect what looks like a netlink policy in our descriptions
+// (these are structs/unions only with nlattr/nlnext/nlnetw fields).
+// Then we find corresponding symbols (offset/size) in vmlinux using nm.
+// Then we read elf headers and locate where these symbols are in the rodata section.
+// Then read in the symbol data, which is an array of nla_policy structs.
+// These structs allow to easily figure out type/size of attributes.
+// Finally we compare our descriptions with the kernel policy description.
+func checkNetlink(OS, arch, obj string, structDescs []*prog.KeyedStruct,
+ locs map[string]*ast.Struct) ([]Warn, error) {
+ if arch != "amd64" {
+ // Netlink policies are arch-independent (?),
+ // so no need to check all arches.
+ // Also our definition of nlaPolicy below is 64-bit specific.
+ return nil, nil
+ }
+ ef, err := elf.Open(obj)
+ if err != nil {
+ return nil, err
+ }
+ rodata := ef.Section(".rodata")
+ if rodata == nil {
+ return nil, fmt.Errorf("object file %v does not contain .rodata section", obj)
+ }
+ symbols, err := symbolizer.ReadRodataSymbols(obj)
+ if err != nil {
+ return nil, err
+ }
+ var warnings []Warn
+ warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
+ warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)})
+ }
+ structMap := make(map[string]*prog.StructDesc)
+ for _, str := range structDescs {
+ structMap[str.Desc.Name()] = str.Desc
+ }
+ checked := make(map[string]bool)
+ for _, str := range structDescs {
+ typ := str.Desc
+ if checked[typ.Name()] {
+ continue
+ }
+ checked[typ.Name()] = true
+ name := typ.TemplateName()
+ astStruct := locs[name]
+ if astStruct == nil {
+ continue
+ }
+ if !isNetlinkPolicy(typ) {
+ continue
+ }
+ ss := symbols[name]
+ if len(ss) == 0 {
+ warn(astStruct.Pos, WarnNoNetlinkPolicy, "%v", name)
+ continue
+ }
+ if len(ss) != 1 {
+ warn(astStruct.Pos, WarnMultipleNetlinkPolicy, "%v", name)
+ continue
+ }
+ if ss[0].Size == 0 || ss[0].Size%int(unsafe.Sizeof(nlaPolicy{})) != 0 {
+ warn(astStruct.Pos, WarnNetlinkBadSize, "%v, size %v", name, ss[0].Size)
+ }
+ binary := make([]byte, ss[0].Size)
+ addr := ss[0].Addr - rodata.Addr
+ if _, err := rodata.ReadAt(binary, int64(addr)); err != nil {
+ return nil, fmt.Errorf("failed to read policy %v at %v: %v", name, ss[0].Addr, err)
+ }
+ policy := (*[1e6]nlaPolicy)(unsafe.Pointer(&binary[0]))[:ss[0].Size/int(unsafe.Sizeof(nlaPolicy{}))]
+ warnings1, err := checkNetlinkPolicy(structMap, typ, astStruct, policy)
+ if err != nil {
+ return nil, err
+ }
+ warnings = append(warnings, warnings1...)
+ }
+ return warnings, nil
+}
+
+func isNetlinkPolicy(typ *prog.StructDesc) bool {
+ for _, field := range typ.Fields {
+ if prog.IsPad(field) {
+ continue
+ }
+ name := field.TemplateName()
+ if name != "nlattr_t" && name != "nlattr_tt" {
+ return false
+ }
+ }
+ return true
+}
+
+func checkNetlinkPolicy(structMap map[string]*prog.StructDesc, typ *prog.StructDesc,
+ astStruct *ast.Struct, policy []nlaPolicy) ([]Warn, error) {
+ var warnings []Warn
+ warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
+ warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)})
+ }
+ ai := 0
+ for _, field := range typ.Fields {
+ if prog.IsPad(field) {
+ continue
+ }
+ fld := astStruct.Fields[ai]
+ ai++
+ ft := structMap[field.Name()]
+ attr := int(ft.Fields[1].(*prog.ConstType).Val)
+ if attr >= len(policy) {
+ warn(fld.Pos, WarnNetlinkBadAttrType, "%v.%v: type %v, kernel policy size %v",
+ typ.TemplateName(), field.FieldName(), attr, len(policy))
+ }
+ w := checkNetlinkAttr(ft, policy[attr])
+ if w != "" {
+ warn(fld.Pos, WarnNetlinkBadAttr, "%v.%v: %v",
+ typ.TemplateName(), field.FieldName(), w)
+ }
+ }
+ return warnings, nil
+}
+
+func checkNetlinkAttr(typ *prog.StructDesc, policy nlaPolicy) string {
+ payload := typ.Fields[2]
+ if typ.TemplateName() == "nlattr_tt" {
+ payload = typ.Fields[4]
+ }
+ if warn := checkAttrType(typ, payload, policy); warn != "" {
+ return warn
+ }
+ size, minSize, maxSize := attrSize(policy)
+ payloadSize := -1
+ if !payload.Varlen() {
+ payloadSize = int(payload.Size())
+ }
+ if size != -1 && size != payloadSize {
+ return fmt.Sprintf("bad size %v, expect %v", payloadSize, size)
+ }
+ if minSize != -1 && minSize > payloadSize {
+ return fmt.Sprintf("bad size %v, expect min %v", payloadSize, minSize)
+ }
+ if maxSize != -1 && maxSize < payloadSize {
+ return fmt.Sprintf("bad size %v, expect max %v", payloadSize, maxSize)
+ }
+
+ valMin, valMax, haveVal := typeMinMaxValue(payload)
+ if haveVal {
+ if policy.validation == NLA_VALIDATE_RANGE || policy.validation == NLA_VALIDATE_MIN {
+ if int64(valMin) < int64(policy.minVal) {
+ return fmt.Sprintf("bad min value %v, expect %v",
+ int64(valMin), policy.minVal)
+ }
+ }
+ if policy.validation == NLA_VALIDATE_RANGE || policy.validation == NLA_VALIDATE_MAX {
+ if int64(valMax) > int64(policy.maxVal) {
+ return fmt.Sprintf("bad max value %v, expect %v",
+ int64(valMax), policy.maxVal)
+ }
+ }
+ }
+ return ""
+}
+
+func checkAttrType(typ *prog.StructDesc, payload prog.Type, policy nlaPolicy) string {
+ switch policy.typ {
+ case NLA_STRING, NLA_NUL_STRING:
+ if payload.Name() != "string" {
+ return fmt.Sprintf("expect string")
+ }
+ case NLA_NESTED:
+ if typ.TemplateName() != "nlattr_tt" || typ.Fields[3].(*prog.ConstType).Val != 1 {
+ return fmt.Sprintf("should be nlnest")
+ }
+ case NLA_NESTED_ARRAY, NLA_BITFIELD32, NLA_REJECT:
+ return fmt.Sprintf("unhandled type %v", policy.typ)
+ }
+ return ""
+}
+
+func attrSize(policy nlaPolicy) (int, int, int) {
+ switch policy.typ {
+ case NLA_UNSPEC:
+ if policy.len != 0 {
+ return int(policy.len), -1, -1
+ }
+ case NLA_MIN_LEN:
+ return -1, int(policy.len), -1
+ case NLA_EXACT_LEN, NLA_EXACT_LEN_WARN:
+ return int(policy.len), -1, -1
+ case NLA_U8, NLA_S8:
+ return 1, -1, -1
+ case NLA_U16, NLA_S16:
+ return 2, -1, -1
+ case NLA_U32, NLA_S32:
+ return 4, -1, -1
+ case NLA_U64, NLA_S64, NLA_MSECS:
+ return 8, -1, -1
+ case NLA_FLAG:
+ return 0, -1, -1
+ case NLA_BINARY:
+ if policy.len != 0 {
+ return -1, -1, int(policy.len)
+ }
+ }
+ return -1, -1, -1
+}
+
+func typeMinMaxValue(payload prog.Type) (min, max uint64, ok bool) {
+ switch typ := payload.(type) {
+ case *prog.ConstType:
+ return typ.Val, typ.Val, true
+ case *prog.IntType:
+ if typ.Kind == prog.IntRange {
+ return typ.RangeBegin, typ.RangeEnd, true
+ }
+ return 0, ^uint64(0), true
+ case *prog.FlagsType:
+ min, max := ^uint64(0), uint64(0)
+ for _, v := range typ.Vals {
+ if min > v {
+ min = v
+ }
+ if max < v {
+ max = v
+ }
+ }
+ return min, max, true
+ }
+ return 0, 0, false
+}
+
+type nlaPolicy struct {
+ typ uint8
+ validation uint8
+ len uint16
+ _ uint32
+ minVal int16
+ maxVal int16
+ _ int32
+}
+
+// nolint
+const (
+ NLA_UNSPEC = iota
+ NLA_U8
+ NLA_U16
+ NLA_U32
+ NLA_U64
+ NLA_STRING
+ NLA_FLAG
+ NLA_MSECS
+ NLA_NESTED
+ NLA_NESTED_ARRAY
+ NLA_NUL_STRING
+ NLA_BINARY
+ NLA_S8
+ NLA_S16
+ NLA_S32
+ NLA_S64
+ NLA_BITFIELD32
+ NLA_REJECT
+ NLA_EXACT_LEN
+ NLA_EXACT_LEN_WARN
+ NLA_MIN_LEN
+)
+
+// nolint
+const (
+ _ = iota
+ NLA_VALIDATE_RANGE
+ NLA_VALIDATE_MIN
+ NLA_VALIDATE_MAX
+)