From d81ee624aa8da9b9d6360798bab72156e1abe4ee Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Mon, 11 Nov 2024 10:26:41 +0100 Subject: tools/syz-declextract: properly remove unused declarations Currently we have a number of hacks to deal with unused bits: - remove some netlink syscalls if no policy union present - generate the huge union for unused policies - manually remove 1 struct But we still got more unused structs after recent changes. Properly remove all unused bits using the compiler knowledge. This is both simpler and more reliable. --- tools/syz-declextract/run.go | 152 ++++++++++++------------------ tools/syz-declextract/syz-declextract.cpp | 8 +- 2 files changed, 66 insertions(+), 94 deletions(-) (limited to 'tools') diff --git a/tools/syz-declextract/run.go b/tools/syz-declextract/run.go index 8b37cc4a6..5794c2a45 100644 --- a/tools/syz-declextract/run.go +++ b/tools/syz-declextract/run.go @@ -29,10 +29,14 @@ import ( "github.com/google/syzkaller/sys/targets" ) +var ( + autoFile = filepath.FromSlash("sys/linux/auto.txt") + target = targets.Get(targets.Linux, targets.AMD64) +) + func main() { var ( binary = flag.String("binary", "syz-declextract", "path to binary") - outFile = flag.String("output", "sys/linux/auto.txt", "output file") sourceDir = flag.String("sourcedir", "", "kernel source directory") buildDir = flag.String("builddir", "", "kernel build directory (defaults to source directory)") ) @@ -52,7 +56,7 @@ func main() { tool.Fail(err) } - extractor := subsystem.MakeExtractor(subsystem.GetList(targets.Linux)) + extractor := subsystem.MakeExtractor(subsystem.GetList(target.OS)) var cmds []compileCommand if err := json.Unmarshal(fileData, &cmds); err != nil { @@ -106,20 +110,20 @@ func main() { appendNodes(&nodes, interfaces, parse.Nodes, syscallNames, *sourceDir, *buildDir, file) } - // New lines are added in the parsing step. This is why we need to Format (serialize the description), - // Parse, then Format again. desc := finishDescriptions(nodes) - output := ast.Format(ast.Parse(ast.Format(desc), "", ast.LoggingHandler)) - if err := osutil.WriteFile(*outFile, output); err != nil { - tool.Fail(err) - } - - ifaces := finishInterfaces(interfaces, extractor, *outFile) + writeDescriptions(desc) + // In order to remove unused bits of the descriptions, we need to write them out first, + // and then parse all descriptions back b/c auto descriptions use some types defined + // by manual descriptions (compiler.CollectUnused requires complete descriptions). + removeUnused(desc) + writeDescriptions(desc) + + ifaces := finishInterfaces(interfaces, extractor) data, err := json.MarshalIndent(ifaces, "", "\t") if err != nil { tool.Failf("failed to marshal interfaces: %v", err) } - if err := osutil.WriteFile(*outFile+".json", data); err != nil { + if err := osutil.WriteFile(autoFile+".json", data); err != nil { tool.Fail(err) } } @@ -151,7 +155,7 @@ func (iface *Interface) ID() string { return fmt.Sprintf("%v/%v", iface.Type, iface.Name) } -func finishInterfaces(m map[string]Interface, extractor *subsystem.Extractor, autoFile string) []Interface { +func finishInterfaces(m map[string]Interface, extractor *subsystem.Extractor) []Interface { var interfaces []Interface for _, iface := range m { slices.Sort(iface.Files) @@ -186,11 +190,11 @@ func mergeInterface(interfaces map[string]Interface, iface Interface) { } func checkDescriptionPresence(interfaces []Interface, autoFile string) { - desc := ast.ParseGlob(filepath.Join("sys", targets.Linux, "*.txt"), nil) + desc := ast.ParseGlob(filepath.Join("sys", target.OS, "*.txt"), nil) if desc == nil { tool.Failf("failed to parse descriptions") } - consts := compiler.ExtractConsts(desc, targets.List[targets.Linux][targets.AMD64], nil) + consts := compiler.ExtractConsts(desc, target, nil) auto := make(map[string]bool) manual := make(map[string]bool) for file, desc := range consts { @@ -213,7 +217,14 @@ func checkDescriptionPresence(interfaces []Interface, autoFile string) { } } -const sendmsg = "sendmsg" +func writeDescriptions(desc *ast.Description) { + // New lines are added in the parsing step. This is why we need to Format (serialize the description), + // Parse, then Format again. + output := ast.Format(ast.Parse(ast.Format(desc), "", ast.LoggingHandler)) + if err := osutil.WriteFile(autoFile, output); err != nil { + tool.Fail(err) + } +} func finishDescriptions(nodes []ast.Node) *ast.Description { slices.SortFunc(nodes, func(a, b ast.Node) int { @@ -226,80 +237,19 @@ func finishDescriptions(nodes []ast.Node) *ast.Description { return getTypeOrder(a) - getTypeOrder(b) }) - var syscalls []*ast.Call - var structs []*ast.Struct + prevCall, prevCallIndex := "", 0 for _, node := range nodes { - switch node := node.(type) { + switch n := node.(type) { case *ast.Call: - syscalls = append(syscalls, node) - case *ast.Struct: - // Special case for unsued struct. TODO: handle unused structs. - if node.Name.Name == "old_utimbuf32$auto_record" { - continue - } - structs = append(structs, node) - case *ast.Include, *ast.TypeDef, *ast.Resource, *ast.IntFlags, *ast.NewLine, *ast.Comment: - continue - default: - _, typ, _ := node.Info() - tool.Failf("unhandled Node type: %v", typ) - } - } - // NOTE: The -2 at the end is to account for one unused struct and one newline - nodes = nodes[:len(nodes)-len(structs)-len(syscalls)-2] - - sendmsgNo := 0 - // Some commands are executed for multiple policies. Ensure that they don't get deleted by the following compact call. - for i := 1; i < len(syscalls); i++ { - if syscalls[i].CallName == sendmsg && syscalls[i].Name.Name == syscalls[i-1].Name.Name { - syscalls[i].Name.Name += strconv.Itoa(sendmsgNo) - sendmsgNo++ - } - } - syscalls = slices.CompactFunc(syscalls, func(a, b *ast.Call) bool { - // We only compare the the system call names for cases where the same system call has different parameter names, - // but share the same syzkaller type. NOTE:Change when we have better type extraction. - return a.Name.Name == b.Name.Name - }) - - usedNetlink := make(map[string]bool) - for _, node := range syscalls { - if node.CallName == sendmsg && len(node.Args[1].Type.Args) == 2 && len(node.Args[1].Type.Args[1].Args) > 1 { - policy := node.Args[1].Type.Args[1].Args[1].Ident - usedNetlink[policy] = true - _, isDefined := slices.BinarySearchFunc(structs, policy, func(a *ast.Struct, b string) int { - return strings.Compare(a.Name.Name, b) - }) - if !isDefined { - continue + if n.Name.Name == prevCall { + n.Name.Name += strconv.Itoa(prevCallIndex) + prevCallIndex++ + } else { + prevCall = n.Name.Name + prevCallIndex = 0 } } - nodes = append(nodes, node) - } - var netlinkNames []string - for _, node := range structs { - nodes = append(nodes, node) - name := node.Name.Name - if !usedNetlink[name] && !strings.HasSuffix(name, "$auto_record") { - netlinkNames = append(netlinkNames, name) - } } - for i, structName := range netlinkNames { - netlinkNames[i] = fmt.Sprintf("\tpolicy%v msghdr_auto[%v]\n", i, structName) - } - netlinkUnion := ` -type msghdr_auto[POLICY] msghdr_netlink[netlink_msg_t[autogenerated_netlink, genlmsghdr, POLICY]] -resource autogenerated_netlink[int16] -syz_genetlink_get_family_id$auto(name ptr[in, string], fd sock_nl_generic) autogenerated_netlink -sendmsg$autorun(fd sock_nl_generic, msg ptr[in, auto_union], f flags[send_flags]) -auto_union [ -` + strings.Join(netlinkNames, "") + "]" - eh := ast.LoggingHandler - netlinkUnionParsed := ast.Parse([]byte(netlinkUnion), "", eh) - if netlinkUnionParsed == nil { - tool.Failf("parsing error") - } - nodes = append(nodes, netlinkUnionParsed.Nodes...) // These additional includes must be at the top (added after sorting), because other kernel headers // are broken and won't compile without these additional ones included first. @@ -308,9 +258,30 @@ auto_union [ include include ` - desc := ast.Parse([]byte(header), "", eh) - nodes = append(desc.Nodes, nodes...) - return &ast.Description{Nodes: nodes} + desc := ast.Parse([]byte(header), "", nil) + desc.Nodes = append(desc.Nodes, nodes...) + return desc +} + +func removeUnused(desc *ast.Description) { + all := ast.ParseGlob(filepath.Join("sys", target.OS, "*.txt"), nil) + if all == nil { + tool.Failf("failed to parse descriptions") + } + unusedNodes, err := compiler.CollectUnused(all, target, nil) + if err != nil { + tool.Failf("failed to typecheck descriptions: %v", err) + } + unused := make(map[string]bool) + for _, n := range unusedNodes { + if pos, typ, name := n.Info(); pos.File == autoFile { + unused[fmt.Sprintf("%v/%v", typ, name)] = true + } + } + desc.Nodes = slices.DeleteFunc(desc.Nodes, func(n ast.Node) bool { + _, typ, name := n.Info() + return unused[fmt.Sprintf("%v/%v", typ, name)] + }) } func worker(outputs chan *output, files chan string, binary, compilationDatabase string) { @@ -357,7 +328,7 @@ func readSyscallMap(sourceDir string) map[string][]string { is64bit bool } syscalls := make(map[string][]desc) - for _, arch := range targets.List[targets.Linux] { + for _, arch := range targets.List[target.OS] { filepath.Walk(filepath.Join(sourceDir, "arch", arch.KernelHeaderArch), func(path string, info fs.FileInfo, err error) error { if err != nil || !strings.HasSuffix(path, ".tbl") { @@ -403,11 +374,10 @@ func readSyscallMap(sourceDir string) map[string][]string { rename := map[string][]string{ "syz_genetlink_get_family_id": {"syz_genetlink_get_family_id"}, } - const mainArch = targets.AMD64 for syscall, descs := range syscalls { slices.SortFunc(descs, func(a, b desc) int { - if (a.arch == mainArch) != (b.arch == mainArch) { - if a.arch == mainArch { + if (a.arch == target.Arch) != (b.arch == target.Arch) { + if a.arch == target.Arch { return -1 } return 1 diff --git a/tools/syz-declextract/syz-declextract.cpp b/tools/syz-declextract/syz-declextract.cpp index f9b08eef5..0a24c6418 100644 --- a/tools/syz-declextract/syz-declextract.cpp +++ b/tools/syz-declextract/syz-declextract.cpp @@ -704,12 +704,12 @@ private: if (!netlinkDecl) { return; } - std::vector> fields; const auto *init = netlinkDecl->getInit(); if (!init) { return; } + std::vector> fields; for (const auto &policy : *llvm::dyn_cast(init)) { fields.push_back(std::vector()); for (const auto &member : policy->children()) { @@ -718,8 +718,10 @@ private: } auto enumData = extractDesignatedInitConsts(*context, *netlinkDecl); - // TODO: generate an empty message for these or something. if (enumData.empty()) { + // We need to emit at least some type for it. + // Ideally it should be void, but typedef to void currently does not work. + printf("type %s auto_todo\n", getPolicyName(Result, netlinkDecl)->c_str()); return; } for (const auto &[_, item] : enumData) { @@ -730,7 +732,7 @@ private: } RecordExtractor recordExtractor(Result.SourceManager); - printf("%s[\n", getPolicyName(Result, netlinkDecl)->c_str()); + printf("%s [\n", getPolicyName(Result, netlinkDecl)->c_str()); for (size_t i = 0; i < fields.size(); ++i) { // The array could have an implicitly initialized policy (i.e. empty) or an unnamed attribute if (fields[i].empty() || enumData[i].name.empty()) { -- cgit mrf-deployment