diff options
| -rw-r--r-- | pkg/cover/report.go | 2 | ||||
| -rw-r--r-- | pkg/report/linux.go | 2 | ||||
| -rw-r--r-- | pkg/report/netbsd.go | 2 | ||||
| -rw-r--r-- | pkg/report/openbsd.go | 2 | ||||
| -rw-r--r-- | pkg/symbolizer/nm.go | 28 | ||||
| -rw-r--r-- | pkg/symbolizer/nm_test.go | 2 | ||||
| -rw-r--r-- | tools/syz-check/check.go | 314 |
7 files changed, 324 insertions, 28 deletions
diff --git a/pkg/cover/report.go b/pkg/cover/report.go index b92c5e982..d66d730b3 100644 --- a/pkg/cover/report.go +++ b/pkg/cover/report.go @@ -313,7 +313,7 @@ func (rg *ReportGenerator) findSymbol(pc uint64) uint64 { } func readSymbols(obj string) ([]symbol, error) { - raw, err := symbolizer.ReadSymbols(obj) + raw, err := symbolizer.ReadTextSymbols(obj) if err != nil { return nil, fmt.Errorf("failed to run nm on %v: %v", obj, err) } diff --git a/pkg/report/linux.go b/pkg/report/linux.go index ac9a78ee5..968bc1692 100644 --- a/pkg/report/linux.go +++ b/pkg/report/linux.go @@ -38,7 +38,7 @@ func ctorLinux(cfg *config) (Reporter, []string, error) { if cfg.kernelObj != "" { vmlinux = filepath.Join(cfg.kernelObj, cfg.target.KernelObject) var err error - symbols, err = symbolizer.ReadSymbols(vmlinux) + symbols, err = symbolizer.ReadTextSymbols(vmlinux) if err != nil { return nil, nil, err } diff --git a/pkg/report/netbsd.go b/pkg/report/netbsd.go index 036e9dc9c..69287ed58 100644 --- a/pkg/report/netbsd.go +++ b/pkg/report/netbsd.go @@ -37,7 +37,7 @@ func ctorNetbsd(cfg *config) (Reporter, []string, error) { if cfg.kernelObj != "" { kernelObject = filepath.Join(cfg.kernelObj, cfg.target.KernelObject) var err error - symbols, err = symbolizer.ReadSymbols(kernelObject) + symbols, err = symbolizer.ReadTextSymbols(kernelObject) if err != nil { return nil, nil, err } diff --git a/pkg/report/openbsd.go b/pkg/report/openbsd.go index 9cf1a593b..99e49c81c 100644 --- a/pkg/report/openbsd.go +++ b/pkg/report/openbsd.go @@ -36,7 +36,7 @@ func ctorOpenbsd(cfg *config) (Reporter, []string, error) { if cfg.kernelObj != "" { kernelObject = filepath.Join(cfg.kernelObj, cfg.target.KernelObject) var err error - symbols, err = symbolizer.ReadSymbols(kernelObject) + symbols, err = symbolizer.ReadTextSymbols(kernelObject) if err != nil { return nil, nil, err } diff --git a/pkg/symbolizer/nm.go b/pkg/symbolizer/nm.go index caa246629..09baf8119 100644 --- a/pkg/symbolizer/nm.go +++ b/pkg/symbolizer/nm.go @@ -16,8 +16,21 @@ type Symbol struct { Size int } -// ReadSymbols returns list of text symbols in the binary bin. -func ReadSymbols(bin string) (map[string][]Symbol, error) { +// ReadTextSymbols returns list of text symbols in the binary bin. +func ReadTextSymbols(bin string) (map[string][]Symbol, error) { + return read(bin, "t", "T") +} + +// ReadRodataSymbols returns list of rodata symbols in the binary bin. +func ReadRodataSymbols(bin string) (map[string][]Symbol, error) { + return read(bin, "r", "R") +} + +func read(bin string, types ...string) (map[string][]Symbol, error) { + if len(types) != 2 || len(types[0]) != 1 || len(types[1]) != 1 { + // We assume these things below. + panic("bad types") + } cmd := osutil.Command("nm", "-Ptx", bin) stdout, err := cmd.StdoutPipe() if err != nil { @@ -30,11 +43,14 @@ func ReadSymbols(bin string) (map[string][]Symbol, error) { defer cmd.Wait() symbols := make(map[string][]Symbol) s := bufio.NewScanner(stdout) - text := [][]byte{[]byte(" t "), []byte(" T ")} + var tt [][]byte + for _, typ := range types { + tt = append(tt, []byte(" "+typ+" ")) + } for s.Scan() { // A line looks as: "snb_uncore_msr_enable_box t ffffffff8104db90 0000000000000059" ln := s.Bytes() - if !bytes.Contains(ln, text[0]) && !bytes.Contains(ln, text[1]) { + if !bytes.Contains(ln, tt[0]) && !bytes.Contains(ln, tt[1]) { continue } @@ -42,11 +58,11 @@ func ReadSymbols(bin string) (map[string][]Symbol, error) { if sp1 == -1 { continue } - if !bytes.HasPrefix(ln[sp1:], text[0]) && !bytes.HasPrefix(ln[sp1:], text[1]) { + if !bytes.HasPrefix(ln[sp1:], tt[0]) && !bytes.HasPrefix(ln[sp1:], tt[1]) { continue } - sp2 := sp1 + len(text[0]) + sp2 := sp1 + len(tt[0]) sp3 := bytes.IndexByte(ln[sp2:], ' ') if sp3 == -1 { continue diff --git a/pkg/symbolizer/nm_test.go b/pkg/symbolizer/nm_test.go index 44a02f3cc..6061608ec 100644 --- a/pkg/symbolizer/nm_test.go +++ b/pkg/symbolizer/nm_test.go @@ -8,7 +8,7 @@ import ( ) func TestSymbols(t *testing.T) { - symbols, err := ReadSymbols("testdata/nm.test.out") + symbols, err := ReadTextSymbols("testdata/nm.test.out") if err != nil { t.Fatalf("failed to read symbols: %v", err) } diff --git a/tools/syz-check/check.go b/tools/syz-check/check.go index f5f0e7454..bd6c76c15 100644 --- a/tools/syz-check/check.go +++ b/tools/syz-check/check.go @@ -15,6 +15,7 @@ package main import ( "bytes" "debug/dwarf" + "debug/elf" "flag" "fmt" "os" @@ -23,10 +24,12 @@ import ( "runtime/pprof" "sort" "strings" + "unsafe" "github.com/google/syzkaller/pkg/ast" "github.com/google/syzkaller/pkg/compiler" "github.com/google/syzkaller/pkg/osutil" + "github.com/google/syzkaller/pkg/symbolizer" "github.com/google/syzkaller/prog" "github.com/google/syzkaller/sys/targets" ) @@ -86,10 +89,12 @@ func main() { } func check(OS, arch, obj string) ([]Warn, error) { + var warnings []Warn structDescs, locs, warnings1, err := parseDescriptions(OS, arch) if err != nil { return nil, err } + warnings = append(warnings, warnings1...) structs, err := parseKernelObject(obj) if err != nil { return nil, err @@ -98,7 +103,12 @@ func check(OS, arch, obj string) ([]Warn, error) { if err != nil { return nil, err } - warnings := append(warnings1, warnings2...) + warnings = append(warnings, warnings2...) + warnings3, err := checkNetlink(OS, arch, obj, structDescs, locs) + if err != nil { + return nil, err + } + warnings = append(warnings, warnings3...) for i := range warnings { warnings[i].arch = arch } @@ -106,13 +116,18 @@ func check(OS, arch, obj string) ([]Warn, error) { } const ( - WarnCompiler = "compiler" - WarnNoSuchStruct = "no-such-struct" - WarnBadStructSize = "bad-struct-size" - WarnBadFieldNumber = "bad-field-number" - WarnBadFieldSize = "bad-field-size" - WarnBadFieldOffset = "bad-field-offset" - WarnBadBitfield = "bad-bitfield" + WarnCompiler = "compiler" + WarnNoSuchStruct = "no-such-struct" + WarnBadStructSize = "bad-struct-size" + WarnBadFieldNumber = "bad-field-number" + WarnBadFieldSize = "bad-field-size" + WarnBadFieldOffset = "bad-field-offset" + WarnBadBitfield = "bad-bitfield" + WarnNoNetlinkPolicy = "no-such-netlink-policy" + WarnMultipleNetlinkPolicy = "multiple-netlink-policy" + WarnNetlinkBadSize = "bad-kernel-netlink-policy-size" + WarnNetlinkBadAttrType = "bad-netlink-attr-type" + WarnNetlinkBadAttr = "bad-netlink-attr" ) type Warn struct { @@ -192,7 +207,7 @@ func checkImpl(structs map[string]*dwarf.StructType, structDescs []*prog.KeyedSt continue } checked[typ.Name()] = true - name := templateName(typ.Name()) + name := typ.TemplateName() astStruct := locs[name] if astStruct == nil { continue @@ -206,19 +221,12 @@ func checkImpl(structs map[string]*dwarf.StructType, structDescs []*prog.KeyedSt return warnings, nil } -func templateName(name string) string { - if pos := strings.IndexByte(name, '['); pos != -1 { - name = name[:pos] - } - return name -} - func checkStruct(typ *prog.StructDesc, astStruct *ast.Struct, str *dwarf.StructType) ([]Warn, error) { var warnings []Warn warn := func(pos ast.Pos, typ, msg string, args ...interface{}) { warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)}) } - name := templateName(typ.Name()) + name := typ.TemplateName() if str == nil { warn(astStruct.Pos, WarnNoSuchStruct, "%v", name) return warnings, nil @@ -325,3 +333,275 @@ func parseDescriptions(OS, arch string) ([]*prog.KeyedStruct, map[string]*ast.St } return prg.StructDescs, locs, warnings, nil } + +// Overall idea of netlink checking. +// Currnetly we check netlink policies for common detectable mistakes. +// First, we detect what looks like a netlink policy in our descriptions +// (these are structs/unions only with nlattr/nlnext/nlnetw fields). +// Then we find corresponding symbols (offset/size) in vmlinux using nm. +// Then we read elf headers and locate where these symbols are in the rodata section. +// Then read in the symbol data, which is an array of nla_policy structs. +// These structs allow to easily figure out type/size of attributes. +// Finally we compare our descriptions with the kernel policy description. +func checkNetlink(OS, arch, obj string, structDescs []*prog.KeyedStruct, + locs map[string]*ast.Struct) ([]Warn, error) { + if arch != "amd64" { + // Netlink policies are arch-independent (?), + // so no need to check all arches. + // Also our definition of nlaPolicy below is 64-bit specific. + return nil, nil + } + ef, err := elf.Open(obj) + if err != nil { + return nil, err + } + rodata := ef.Section(".rodata") + if rodata == nil { + return nil, fmt.Errorf("object file %v does not contain .rodata section", obj) + } + symbols, err := symbolizer.ReadRodataSymbols(obj) + if err != nil { + return nil, err + } + var warnings []Warn + warn := func(pos ast.Pos, typ, msg string, args ...interface{}) { + warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)}) + } + structMap := make(map[string]*prog.StructDesc) + for _, str := range structDescs { + structMap[str.Desc.Name()] = str.Desc + } + checked := make(map[string]bool) + for _, str := range structDescs { + typ := str.Desc + if checked[typ.Name()] { + continue + } + checked[typ.Name()] = true + name := typ.TemplateName() + astStruct := locs[name] + if astStruct == nil { + continue + } + if !isNetlinkPolicy(typ) { + continue + } + ss := symbols[name] + if len(ss) == 0 { + warn(astStruct.Pos, WarnNoNetlinkPolicy, "%v", name) + continue + } + if len(ss) != 1 { + warn(astStruct.Pos, WarnMultipleNetlinkPolicy, "%v", name) + continue + } + if ss[0].Size == 0 || ss[0].Size%int(unsafe.Sizeof(nlaPolicy{})) != 0 { + warn(astStruct.Pos, WarnNetlinkBadSize, "%v, size %v", name, ss[0].Size) + } + binary := make([]byte, ss[0].Size) + addr := ss[0].Addr - rodata.Addr + if _, err := rodata.ReadAt(binary, int64(addr)); err != nil { + return nil, fmt.Errorf("failed to read policy %v at %v: %v", name, ss[0].Addr, err) + } + policy := (*[1e6]nlaPolicy)(unsafe.Pointer(&binary[0]))[:ss[0].Size/int(unsafe.Sizeof(nlaPolicy{}))] + warnings1, err := checkNetlinkPolicy(structMap, typ, astStruct, policy) + if err != nil { + return nil, err + } + warnings = append(warnings, warnings1...) + } + return warnings, nil +} + +func isNetlinkPolicy(typ *prog.StructDesc) bool { + for _, field := range typ.Fields { + if prog.IsPad(field) { + continue + } + name := field.TemplateName() + if name != "nlattr_t" && name != "nlattr_tt" { + return false + } + } + return true +} + +func checkNetlinkPolicy(structMap map[string]*prog.StructDesc, typ *prog.StructDesc, + astStruct *ast.Struct, policy []nlaPolicy) ([]Warn, error) { + var warnings []Warn + warn := func(pos ast.Pos, typ, msg string, args ...interface{}) { + warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)}) + } + ai := 0 + for _, field := range typ.Fields { + if prog.IsPad(field) { + continue + } + fld := astStruct.Fields[ai] + ai++ + ft := structMap[field.Name()] + attr := int(ft.Fields[1].(*prog.ConstType).Val) + if attr >= len(policy) { + warn(fld.Pos, WarnNetlinkBadAttrType, "%v.%v: type %v, kernel policy size %v", + typ.TemplateName(), field.FieldName(), attr, len(policy)) + } + w := checkNetlinkAttr(ft, policy[attr]) + if w != "" { + warn(fld.Pos, WarnNetlinkBadAttr, "%v.%v: %v", + typ.TemplateName(), field.FieldName(), w) + } + } + return warnings, nil +} + +func checkNetlinkAttr(typ *prog.StructDesc, policy nlaPolicy) string { + payload := typ.Fields[2] + if typ.TemplateName() == "nlattr_tt" { + payload = typ.Fields[4] + } + if warn := checkAttrType(typ, payload, policy); warn != "" { + return warn + } + size, minSize, maxSize := attrSize(policy) + payloadSize := -1 + if !payload.Varlen() { + payloadSize = int(payload.Size()) + } + if size != -1 && size != payloadSize { + return fmt.Sprintf("bad size %v, expect %v", payloadSize, size) + } + if minSize != -1 && minSize > payloadSize { + return fmt.Sprintf("bad size %v, expect min %v", payloadSize, minSize) + } + if maxSize != -1 && maxSize < payloadSize { + return fmt.Sprintf("bad size %v, expect max %v", payloadSize, maxSize) + } + + valMin, valMax, haveVal := typeMinMaxValue(payload) + if haveVal { + if policy.validation == NLA_VALIDATE_RANGE || policy.validation == NLA_VALIDATE_MIN { + if int64(valMin) < int64(policy.minVal) { + return fmt.Sprintf("bad min value %v, expect %v", + int64(valMin), policy.minVal) + } + } + if policy.validation == NLA_VALIDATE_RANGE || policy.validation == NLA_VALIDATE_MAX { + if int64(valMax) > int64(policy.maxVal) { + return fmt.Sprintf("bad max value %v, expect %v", + int64(valMax), policy.maxVal) + } + } + } + return "" +} + +func checkAttrType(typ *prog.StructDesc, payload prog.Type, policy nlaPolicy) string { + switch policy.typ { + case NLA_STRING, NLA_NUL_STRING: + if payload.Name() != "string" { + return fmt.Sprintf("expect string") + } + case NLA_NESTED: + if typ.TemplateName() != "nlattr_tt" || typ.Fields[3].(*prog.ConstType).Val != 1 { + return fmt.Sprintf("should be nlnest") + } + case NLA_NESTED_ARRAY, NLA_BITFIELD32, NLA_REJECT: + return fmt.Sprintf("unhandled type %v", policy.typ) + } + return "" +} + +func attrSize(policy nlaPolicy) (int, int, int) { + switch policy.typ { + case NLA_UNSPEC: + if policy.len != 0 { + return int(policy.len), -1, -1 + } + case NLA_MIN_LEN: + return -1, int(policy.len), -1 + case NLA_EXACT_LEN, NLA_EXACT_LEN_WARN: + return int(policy.len), -1, -1 + case NLA_U8, NLA_S8: + return 1, -1, -1 + case NLA_U16, NLA_S16: + return 2, -1, -1 + case NLA_U32, NLA_S32: + return 4, -1, -1 + case NLA_U64, NLA_S64, NLA_MSECS: + return 8, -1, -1 + case NLA_FLAG: + return 0, -1, -1 + case NLA_BINARY: + if policy.len != 0 { + return -1, -1, int(policy.len) + } + } + return -1, -1, -1 +} + +func typeMinMaxValue(payload prog.Type) (min, max uint64, ok bool) { + switch typ := payload.(type) { + case *prog.ConstType: + return typ.Val, typ.Val, true + case *prog.IntType: + if typ.Kind == prog.IntRange { + return typ.RangeBegin, typ.RangeEnd, true + } + return 0, ^uint64(0), true + case *prog.FlagsType: + min, max := ^uint64(0), uint64(0) + for _, v := range typ.Vals { + if min > v { + min = v + } + if max < v { + max = v + } + } + return min, max, true + } + return 0, 0, false +} + +type nlaPolicy struct { + typ uint8 + validation uint8 + len uint16 + _ uint32 + minVal int16 + maxVal int16 + _ int32 +} + +// nolint +const ( + NLA_UNSPEC = iota + NLA_U8 + NLA_U16 + NLA_U32 + NLA_U64 + NLA_STRING + NLA_FLAG + NLA_MSECS + NLA_NESTED + NLA_NESTED_ARRAY + NLA_NUL_STRING + NLA_BINARY + NLA_S8 + NLA_S16 + NLA_S32 + NLA_S64 + NLA_BITFIELD32 + NLA_REJECT + NLA_EXACT_LEN + NLA_EXACT_LEN_WARN + NLA_MIN_LEN +) + +// nolint +const ( + _ = iota + NLA_VALIDATE_RANGE + NLA_VALIDATE_MIN + NLA_VALIDATE_MAX +) |
