diff options
| -rw-r--r-- | tools/syz-declextract/run.go | 185 |
1 files changed, 116 insertions, 69 deletions
diff --git a/tools/syz-declextract/run.go b/tools/syz-declextract/run.go index 5e5e546fd..e93723c95 100644 --- a/tools/syz-declextract/run.go +++ b/tools/syz-declextract/run.go @@ -38,8 +38,10 @@ var ( func main() { var ( - flagConfig = flag.String("config", "", "manager config file") - flagBinary = flag.String("binary", "syz-declextract", "path to syz-declextract binary") + flagConfig = flag.String("config", "", "manager config file") + flagBinary = flag.String("binary", "syz-declextract", "path to syz-declextract binary") + flagCacheExtract = flag.Bool("cache-extract", false, "use cached extract results if present"+ + " (cached in manager.workdir/declextract.cache)") ) defer tool.Init()() cfg, err := mgrconfig.LoadFile(*flagConfig) @@ -53,12 +55,20 @@ func main() { tool.Failf("failed to load compile commands: %v", err) } - extractor := subsystem.MakeExtractor(subsystem.GetList(target.OS)) + ctx := &context{ + cfg: cfg, + clangTool: *flagBinary, + compilationDatabase: compilationDatabase, + compileCommands: cmds, + extractor: subsystem.MakeExtractor(subsystem.GetList(target.OS)), + syscallNameMap: readSyscallMap(cfg.KernelSrc), + interfaces: make(map[string]Interface), + } outputs := make(chan *output, len(cmds)) files := make(chan string, len(cmds)) for w := 0; w < runtime.NumCPU(); w++ { - go worker(outputs, files, *flagBinary, compilationDatabase) + go ctx.worker(outputs, files, *flagCacheExtract) } for _, cmd := range cmds { @@ -66,11 +76,6 @@ func main() { } close(files) - syscallNames := readSyscallMap(cfg.KernelSrc) - - var nodes []ast.Node - interfaces := make(map[string]Interface) - eh := ast.LoggingHandler for range cmds { out := <-outputs if out == nil { @@ -83,14 +88,17 @@ func main() { if out.err != nil { tool.Failf("%v: %v", file, out.err) } - parse := ast.Parse(out.output, "", eh) + parse := ast.Parse(out.output, "", nil) if parse == nil { tool.Failf("%v: parsing error:\n%s", file, out.output) } - appendNodes(&nodes, interfaces, parse.Nodes, syscallNames, cfg.KernelSrc, cfg.KernelObj, file) + ctx.appendNodes(parse.Nodes, file) } + ctx.finishDescriptions() - desc := finishDescriptions(nodes) + desc := &ast.Description{ + Nodes: ctx.nodes, + } writeDescriptions(desc) // In order to remove unused bits of the descriptions, we need to write them out first, // and then parse all descriptions back b/c auto descriptions use some types defined @@ -98,13 +106,24 @@ func main() { removeUnused(desc) writeDescriptions(desc) - ifaces := finishInterfaces(interfaces, extractor) + ifaces := ctx.finishInterfaces() ifacesData := serializeInterfaces(ifaces) if err := osutil.WriteFile(autoFile+".info", ifacesData); err != nil { tool.Fail(err) } } +type context struct { + cfg *mgrconfig.Config + clangTool string + compilationDatabase string + compileCommands []compileCommand + extractor *subsystem.Extractor + syscallNameMap map[string][]string + interfaces map[string]Interface + nodes []ast.Node +} + type compileCommand struct { Command string Directory string @@ -179,16 +198,16 @@ func serializeInterfaces(ifaces []Interface) []byte { return w.Bytes() } -func finishInterfaces(m map[string]Interface, extractor *subsystem.Extractor) []Interface { +func (ctx *context) finishInterfaces() []Interface { var interfaces []Interface - for _, iface := range m { + for _, iface := range ctx.interfaces { slices.Sort(iface.Files) iface.Files = slices.Compact(iface.Files) var crashes []*subsystem.Crash for _, file := range iface.Files { crashes = append(crashes, &subsystem.Crash{GuiltyPath: file}) } - for _, s := range extractor.Extract(crashes) { + for _, s := range ctx.extractor.Extract(crashes) { iface.Subsystems = append(iface.Subsystems, s.Name) } slices.Sort(iface.Subsystems) @@ -204,8 +223,8 @@ func finishInterfaces(m map[string]Interface, extractor *subsystem.Extractor) [] return interfaces } -func mergeInterface(interfaces map[string]Interface, iface Interface) { - prev, ok := interfaces[iface.ID()] +func (ctx *context) mergeInterface(iface Interface) { + prev, ok := ctx.interfaces[iface.ID()] if ok { if iface.identifyingConst != prev.identifyingConst { tool.Failf("interface %v has different identifying consts: %v vs %v", @@ -213,7 +232,7 @@ func mergeInterface(interfaces map[string]Interface, iface Interface) { } iface.Files = append(iface.Files, prev.Files...) } - interfaces[iface.ID()] = iface + ctx.interfaces[iface.ID()] = iface } func checkDescriptionPresence(interfaces []Interface, autoFile string) { @@ -253,19 +272,19 @@ func writeDescriptions(desc *ast.Description) { } } -func finishDescriptions(nodes []ast.Node) *ast.Description { - slices.SortFunc(nodes, func(a, b ast.Node) int { +func (ctx *context) finishDescriptions() { + slices.SortFunc(ctx.nodes, func(a, b ast.Node) int { return strings.Compare(ast.SerializeNode(a), ast.SerializeNode(b)) }) - nodes = slices.CompactFunc(nodes, func(a, b ast.Node) bool { + ctx.nodes = slices.CompactFunc(ctx.nodes, func(a, b ast.Node) bool { return ast.SerializeNode(a) == ast.SerializeNode(b) }) - slices.SortStableFunc(nodes, func(a, b ast.Node) int { + slices.SortStableFunc(ctx.nodes, func(a, b ast.Node) int { return getTypeOrder(a) - getTypeOrder(b) }) prevCall, prevCallIndex := "", 0 - for _, node := range nodes { + for _, node := range ctx.nodes { switch n := node.(type) { case *ast.Call: if n.Name.Name == prevCall { @@ -286,8 +305,7 @@ include <include/vdso/bits.h> include <include/linux/types.h> ` desc := ast.Parse([]byte(header), "", nil) - desc.Nodes = append(desc.Nodes, nodes...) - return desc + ctx.nodes = append(desc.Nodes, ctx.nodes...) } func removeUnused(desc *ast.Description) { @@ -311,21 +329,35 @@ func removeUnused(desc *ast.Description) { }) } -func worker(outputs chan *output, files chan string, binary, compilationDatabase string) { +func (ctx *context) worker(outputs chan *output, files chan string, cache bool) { for file := range files { + cacheFile := filepath.Join(ctx.cfg.Workdir, "declextract.cache", + strings.TrimPrefix(strings.TrimPrefix(filepath.Clean(file), + ctx.cfg.KernelSrc), ctx.cfg.KernelObj)) + if cache { + out, err := os.ReadFile(cacheFile) + if err == nil { + outputs <- &output{file, out, nil} + continue + } + } // Suppress warning since we may build the tool on a different clang // version that produces more warnings. - out, err := exec.Command(binary, "-p", compilationDatabase, file, "--extra-arg=-w").Output() + out, err := exec.Command(ctx.clangTool, "-p", ctx.compilationDatabase, file, "--extra-arg=-w").Output() var exitErr *exec.ExitError if err != nil && errors.As(err, &exitErr) && len(exitErr.Stderr) != 0 { err = fmt.Errorf("%s", exitErr.Stderr) } + if err == nil { + osutil.MkdirAll(filepath.Dir(cacheFile)) + osutil.WriteFile(cacheFile, out) + } outputs <- &output{file, out, err} } } -func renameSyscall(syscall *ast.Call, rename map[string][]string) []ast.Node { - names := rename[syscall.CallName] +func (ctx *context) renameSyscall(syscall *ast.Call) []ast.Node { + names := ctx.syscallNameMap[syscall.CallName] if len(names) == 0 { // Syscall has no record in the tables for the architectures we support. return nil @@ -425,74 +457,89 @@ func readSyscallMap(sourceDir string) map[string][]string { return rename } -func appendNodes(slice *[]ast.Node, interfaces map[string]Interface, nodes []ast.Node, - syscallNames map[string][]string, sourceDir, buildDir, file string) { +func (ctx *context) appendNodes(nodes []ast.Node, file string) { for _, node := range nodes { switch node := node.(type) { case *ast.Call: // Some syscalls have different names and entry points and thus need to be renamed. // e.g. SYSCALL_DEFINE1(setuid16, old_uid_t, uid) is referred to in the .tbl file with setuid. - *slice = append(*slice, renameSyscall(node, syscallNames)...) + ctx.nodes = append(ctx.nodes, ctx.renameSyscall(node)...) case *ast.Include: - if file, err := filepath.Rel(sourceDir, filepath.Join(buildDir, node.File.Value)); err == nil { + if file, err := filepath.Rel(ctx.cfg.KernelSrc, filepath.Join(ctx.cfg.KernelObj, node.File.Value)); err == nil { node.File.Value = file } - *slice = append(*slice, node) - case *ast.Comment: - if !strings.HasPrefix(node.Text, "INTERFACE:") { - *slice = append(*slice, node) - continue + if replace := includeReplaces[node.File.Value]; replace != "" { + node.File.Value = replace } - fields := strings.Fields(node.Text) - if len(fields) != 6 { - tool.Failf("%q has wrong number of fields", node.Text) - } - for i := range fields { - if fields[i] == "-" { - fields[i] = "" + ctx.nodes = append(ctx.nodes, node) + case *ast.Comment: + switch { + case strings.HasPrefix(node.Text, "INTERFACE:"): + fields := strings.Fields(node.Text) + if len(fields) != 6 { + tool.Failf("%q has wrong number of fields", node.Text) } - } - iface := Interface{ - Type: fields[1], - Name: fields[2], - Files: []string{file}, - identifyingConst: fields[3], - Func: fields[4], - Access: fields[5], - } - if iface.Type == "SYSCALL" { - for _, name := range syscallNames[iface.Name] { - iface.Name = name - iface.identifyingConst = "__NR_" + name - mergeInterface(interfaces, iface) + for i := range fields { + if fields[i] == "-" { + fields[i] = "" + } } - } else { - mergeInterface(interfaces, iface) + iface := Interface{ + Type: fields[1], + Name: fields[2], + Files: []string{file}, + identifyingConst: fields[3], + Func: fields[4], + Access: fields[5], + } + if iface.Type == "SYSCALL" { + for _, name := range ctx.syscallNameMap[iface.Name] { + iface.Name = name + iface.identifyingConst = "__NR_" + name + ctx.mergeInterface(iface) + } + } else { + ctx.mergeInterface(iface) + } + default: + ctx.nodes = append(ctx.nodes, node) } default: - *slice = append(*slice, node) + ctx.nodes = append(ctx.nodes, node) } } } +// Replace these includes in the tool output. +var includeReplaces = map[string]string{ + // Arches may use some includes from asm-generic and some from arch/arm. + // If the arch used for extract used asm-generic for a header, + // other arches may need arch/asm version of the header. So switch to + // a more generic file name that should resolve correctly for all arches. + "include/uapi/asm-generic/ioctls.h": "asm/ioctls.h", + "include/uapi/asm-generic/sockios.h": "asm/sockios.h", +} + func getTypeOrder(a ast.Node) int { switch a.(type) { case *ast.Comment: return 0 case *ast.Include: return 1 - case *ast.IntFlags: + case *ast.Define: return 2 - case *ast.Resource: + case *ast.IntFlags: return 3 - case *ast.TypeDef: + case *ast.Resource: return 4 - case *ast.Call: + case *ast.TypeDef: return 5 - case *ast.Struct: + case *ast.Call: return 6 - case *ast.NewLine: + case *ast.Struct: return 7 + case *ast.NewLine: + return 8 default: panic(fmt.Sprintf("unhandled type %T", a)) } |
