From b1df3115feca4b02661cb4a2fa45f0d5e37d35c0 Mon Sep 17 00:00:00 2001 From: Pimyn Girgis Date: Wed, 28 Aug 2024 09:43:55 +0000 Subject: tools/syz-declextract: extract genl_family and generate descriptions Extract genl_family and generate descriptions based on the attached policies and commands. Fix Issue with the go tool where the clang tool would fail and remain undetected. --- tools/syz-declextract/run.go | 77 +++++++++++-- tools/syz-declextract/syz-declextract.cpp | 180 ++++++++++++++++++++++++++---- 2 files changed, 229 insertions(+), 28 deletions(-) (limited to 'tools') diff --git a/tools/syz-declextract/run.go b/tools/syz-declextract/run.go index f637eb939..9e06603d7 100644 --- a/tools/syz-declextract/run.go +++ b/tools/syz-declextract/run.go @@ -67,6 +67,8 @@ func main() { var syscalls []*ast.Call var netlinks []*ast.Struct var includes []*ast.Include + var typeDefs []*ast.TypeDef + var resources []*ast.Resource syscallNames := readSyscallNames(filepath.Join(*kernelDir, "arch")) // some syscalls have different names and entry // points and thus need to be renamed. // e.g. SYSCALL_DEFINE1(setuid16, old_uid_t, uid) is referred to in the .tbl file with setuid. @@ -79,6 +81,7 @@ func main() { } parse := ast.Parse([]byte(out.stdout), "", eh) if parse == nil { + fmt.Println(out.stdout) tool.Failf("parsing error") } for _, node := range parse.Nodes { @@ -89,15 +92,22 @@ func main() { netlinks = append(netlinks, node) case *ast.Include: includes = append(includes, node) + case *ast.TypeDef: + typeDefs = append(typeDefs, node) + case *ast.Resource: + resources = append(resources, node) + case *ast.NewLine: + continue } } } close(files) - writeOutput(includes, syscalls, netlinks, *outFile) + writeOutput(includes, syscalls, netlinks, typeDefs, resources, *outFile) } -func writeOutput(includes []*ast.Include, syscalls []*ast.Call, netlinks []*ast.Struct, outFile string) { +func writeOutput(includes []*ast.Include, syscalls []*ast.Call, netlinks []*ast.Struct, types []*ast.TypeDef, + resources []*ast.Resource, outFile string) { slices.SortFunc(includes, func(a, b *ast.Include) int { return strings.Compare(a.File.Value, b.File.Value) }) @@ -117,23 +127,56 @@ func writeOutput(includes []*ast.Include, syscalls []*ast.Call, netlinks []*ast. return strings.Compare(a.Name.Name, b.Name.Name) }) - autoGeneratedNotice := "# Code generated by syz-declextract. DO NOT EDIT." + slices.SortFunc(resources, func(a, b *ast.Resource) int { + return strings.Compare(a.Name.Name, b.Name.Name) + }) + + slices.SortFunc(types, func(a, b *ast.TypeDef) int { + return strings.Compare(a.Name.Name, b.Name.Name) + }) + + autoGeneratedNotice := "# Code generated by syz-declextract. DO NOT EDIT.\n" + commonKernelHeaders := "include \ninclude " + var netlinkNames []string mmap2 := "_ = __NR_mmap2\n" eh := ast.LoggingHandler - desc := ast.Parse([]byte(autoGeneratedNotice), "", eh) + desc := ast.Parse([]byte(autoGeneratedNotice+commonKernelHeaders), "", eh) for _, node := range includes { desc.Nodes = append(desc.Nodes, node) } + for _, node := range resources { + desc.Nodes = append(desc.Nodes, node) + } + for _, node := range types { + desc.Nodes = append(desc.Nodes, node) + } + usedNetlink := make(map[string]bool) for _, node := range syscalls { + if node.CallName == "sendmsg" && len(node.Args[1].Type.Args) == 2 { + policy := node.Args[1].Type.Args[1].Args[1].Ident + usedNetlink[policy] = true + _, isDefined := slices.BinarySearchFunc(netlinks, policy, func(a *ast.Struct, b string) int { + return strings.Compare(a.Name.Name, b) + }) + if !isDefined { + continue + } + } desc.Nodes = append(desc.Nodes, node) } desc.Nodes = append(desc.Nodes, ast.Parse([]byte(mmap2), "", eh).Nodes...) - var netlinkNames []string - for i, node := range netlinks { + for _, node := range netlinks { desc.Nodes = append(desc.Nodes, node) - netlinkNames = append(netlinkNames, fmt.Sprintf("\tpolicy%v msghdr_auto[%v]\n", i, node.Name.Name)) + name := node.Name.Name + if !usedNetlink[name] { + netlinkNames = append(netlinkNames, name) + } + } + for i, netlink := range netlinkNames { + netlinkNames[i] = fmt.Sprintf("\tpolicy%v msghdr_auto[%v]\n", i, netlink) } - netlinkUnion := `type msghdr_auto[POLICY] msghdr_netlink[netlink_msg_t[autogenerated_netlink, genlmsghdr, POLICY]] + netlinkUnion := ` +type msghdr_auto[POLICY] msghdr_netlink[netlink_msg_t[autogenerated_netlink, genlmsghdr, POLICY]] resource autogenerated_netlink[int16] syz_genetlink_get_family_id$auto(name ptr[in, string], fd sock_nl_generic) autogenerated_netlink sendmsg$autorun(fd sock_nl_generic, msg ptr[in, auto_union], f flags[send_flags]) @@ -166,7 +209,11 @@ func worker(outputs chan output, files chan string, binary, compilationDatabase if err != nil { var error *exec.ExitError if errors.As(err, &error) { - stderr = string(error.Stderr) + if len(error.Stderr) != 0 { + stderr = string(error.Stderr) + } else { + stderr = fmt.Sprintf("%v: %v", file, error.String()) + } } else { stderr = err.Error() } @@ -176,6 +223,9 @@ func worker(outputs chan output, files chan string, binary, compilationDatabase } func renameSyscall(syscall *ast.Call, rename map[string][]string) []*ast.Call { + if !shouldRenameSyscall(syscall.CallName) { + return []*ast.Call{syscall} + } var renamed []*ast.Call toReplace := syscall.CallName if rename[toReplace] == nil { @@ -239,6 +289,15 @@ func readSyscallNames(kernelDir string) map[string][]string { return rename } +func shouldRenameSyscall(syscall string) bool { + switch syscall { + case "sendmsg", "syz_genetlink_get_family_id": + return false + default: + return true + } +} + func isProhibited(syscall string) bool { switch syscall { case "reboot", "utimesat": // utimesat is not defined for all arches. diff --git a/tools/syz-declextract/syz-declextract.cpp b/tools/syz-declextract/syz-declextract.cpp index 6a597974b..66ce053a9 100644 --- a/tools/syz-declextract/syz-declextract.cpp +++ b/tools/syz-declextract/syz-declextract.cpp @@ -13,6 +13,7 @@ #include "clang/ASTMatchers/ASTMatchers.h" #include "clang/ASTMatchers/ASTMatchersInternal.h" #include "clang/Basic/LLVM.h" +#include "clang/Lex/Lexer.h" #include "clang/Sema/Ownership.h" #include "clang/Tooling/CommonOptionsParser.h" #include "clang/Tooling/Tooling.h" @@ -20,8 +21,11 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" +#include #include #include +#include +#include #include #include #include @@ -31,7 +35,7 @@ using namespace clang::ast_matchers; struct EnumData { std::string name; - int value; + unsigned long long value; std::string file; }; @@ -40,6 +44,11 @@ struct Param { std::string name; }; +struct NetlinkOps { + std::string cmd; + std::optional policy; +}; + class EnumMatcher : public MatchFinder::MatchCallback { private: std::vector EnumDetails; @@ -48,12 +57,13 @@ public: std::vector getEnumData() { return EnumDetails; } virtual void run(const MatchFinder::MatchResult &Result) override { const auto *enumValue = Result.Nodes.getNodeAs("enum_value"); - if (!enumValue) + if (!enumValue) { return; + } const auto &name = enumValue->getEnumConstantDecl()->getNameAsString(); - const auto value = int(*enumValue->getAPValueResult().getInt().getRawData()); + const auto value = *enumValue->getAPValueResult().getInt().getRawData(); const auto &path = std::filesystem::relative( - Result.SourceManager->getFilename(enumValue->getEnumConstantDecl()->getSourceRange().getBegin()).data()); + Result.SourceManager->getFilename(enumValue->getEnumConstantDecl()->getSourceRange().getBegin()).str()); EnumDetails.push_back({std::move(name), value, std::move(path)}); } }; @@ -122,7 +132,7 @@ private: if (values.empty()) return; - int argc = *values[2]->getIntegerConstantExpr(*context).value().getRawData(); + int argc = *values[2]->getIntegerConstantExpr(*context)->getRawData(); std::vector args(argc); if (argc) { @@ -145,7 +155,7 @@ private: } } - printf("%s$auto(", values[0]->tryEvaluateString(*context).value().c_str() + 4); // name + printf("%s$auto(", values[0]->tryEvaluateString(*context)->c_str() + 4); // name const char *sep = ""; for (const auto &arg : args) { printf("%s%s %s", sep, swapIfReservedKeyword(arg.name).c_str(), getSyzType(arg.type).c_str()); @@ -161,42 +171,170 @@ private: if (!netlinkDecl) { return; } + std::vector> fields; + + const auto *init = netlinkDecl->getInit(); + if (!init) { + return; + } + for (const auto &policy : *llvm::dyn_cast(init)) { + fields.push_back(std::vector()); + for (const auto &member : policy->children()) { + fields.back().push_back(llvm::dyn_cast(member)); + } + } EnumMatcher enumMatcher; MatchFinder enumFinder; enumFinder.addMatcher( - decl(forEachDescendant(designatedInitExpr(has(constantExpr(has(declRefExpr())).bind("enum_value"))))), + decl(forEachDescendant(designatedInitExpr(optionally(has(constantExpr(has(declRefExpr())).bind("enum_value")))) + .bind("designated_init"))), &enumMatcher); enumFinder.match(*netlinkDecl, *context); // get enum details from the current subtree (nla_policy[]) - std::vector> fields; - for (const auto &policy : *llvm::dyn_cast(netlinkDecl->getInit())) { - // The array could have an implicitly initialized policy (i.e. empty) - if (policy->children().empty()) { - continue; - } - fields.push_back(std::vector()); - for (const auto &member : policy->children()) { - fields.back().push_back(llvm::dyn_cast(member)); - } + auto unorderedEnumData = enumMatcher.getEnumData(); + if (unorderedEnumData.empty()) { + return; + } + + std::vector enumData(fields.size()); + for (auto &data : unorderedEnumData) { + enumData.at(data.value) = std::move(data); } - const auto enumData = enumMatcher.getEnumData(); for (const auto &item : enumData) { + if (item.file.empty()) { + continue; + } + if (item.file.back() != 'h') { // only extract from "*.h" files + return; + } printf("include<%s>\n", item.file.c_str()); } - printf("%s$auto[\n", netlinkDecl->getDefinition()->getNameAsString().c_str()); + + printf("%s[\n", getPolicyName(Result, netlinkDecl)->c_str()); for (size_t i = 0; i < fields.size(); ++i) { + // The array could have an implicitly initialized policy (i.e. empty) or an unnamed attribute + if (fields[i].empty() || enumData[i].name.empty()) { + continue; + } printf("\t%s nlattr[%s, %s]\n", enumData[i].name.c_str(), enumData[i].name.c_str(), nlaToSyz(fields[i][0]).c_str()); } puts("] [varlen]"); } + std::map genlFamilyMember; + std::map> opsMember; + + std::optional getPolicyName(const MatchFinder::MatchResult &Result, const ValueDecl *decl) { + if (!decl) { + return std::nullopt; + } + std::string filename = + std::filesystem::path( + Result.SourceManager->getFilename(decl->getCanonicalDecl()->getSourceRange().getBegin()).str()) + .filename() + .stem() + .string(); + std::replace(filename.begin(), filename.end(), '-', '_'); + return decl->getNameAsString() + "$auto_" + filename; // filename is added to address ambiguity + // when multiple policies are named the same but have different definitions + } + + std::vector getOps(const MatchFinder::MatchResult &Result, const std::string opsName, + const InitListExpr *init) { + ASTContext *context = Result.Context; + std::vector ops; + const auto n_ops = init->getInit(genlFamilyMember["n_" + opsName])->getIntegerConstantExpr(*context); + const auto &opsRef = init->getInit(genlFamilyMember[opsName])->getAsBuiltinConstantDeclRef(*context); + if (!n_ops || !opsRef) { + return {}; + } + const auto *opsDecl = llvm::dyn_cast(opsRef); + if (!opsDecl->getInit()) { + // NOTE: This usually happens when the ops is defined as an extern variable + // TODO: Extract extern variables + return {}; + } + const auto *opsInit = llvm::dyn_cast(opsDecl->getInit()); + for (const auto &field : opsInit->getInit(0)->getType()->getAsRecordDecl()->fields()) { + opsMember[opsName][field->getNameAsString()] = field->getFieldIndex(); + } + for (int i = 0; i < n_ops; ++i) { + const auto &init = llvm::dyn_cast(opsInit->getInit(i)); + const auto &cmdInit = init->getInit(opsMember[opsName]["cmd"])->getEnumConstantDecl(); + if (!cmdInit) { + continue; + } + const auto &cmd = cmdInit->getNameAsString(); + const ValueDecl *policyDecl = nullptr; + if (opsName != "small_ops") { + policyDecl = init->getInit(opsMember[opsName]["policy"])->getAsBuiltinConstantDeclRef(*context); + } + ops.push_back({std::move(cmd), getPolicyName(Result, policyDecl)}); + } + return ops; + } + + void genlFamily(const MatchFinder::MatchResult &Result) { + ASTContext *context = Result.Context; + const auto *genlFamily = Result.Nodes.getNodeAs("genl_family"); + if (!genlFamily) { + return; + } + for (const auto &field : genlFamily->fields()) { + genlFamilyMember[field->getNameAsString()] = field->getFieldIndex(); + } + const auto *genlFamilyInit = Result.Nodes.getNodeAs("genl_family_init"); + if (!genlFamilyInit) { + return; + } + + auto name = llvm::dyn_cast(genlFamilyInit->getInit(genlFamilyMember["name"]))->getString().str(); + std::replace(name.begin(), name.end(), '.', '_'); // Illegal character. + std::replace(name.begin(), name.end(), ' ', '_'); // Don't leave space in name. + const auto &globalPolicyName = + genlFamilyInit->getInit(genlFamilyMember["policy"])->getAsBuiltinConstantDeclRef(*context); + + std::string familyPolicyName; + if (globalPolicyName) { + familyPolicyName = *getPolicyName(Result, globalPolicyName); + } + + std::string msghdr = "msghdr_" + name + "_auto"; + bool printedCmds = false; + for (const auto &opsType : {"ops", "small_ops", "split_ops"}) { + for (auto &ops : getOps(Result, opsType, genlFamilyInit)) { + const char *policyName; + if (ops.policy) { + policyName = ops.policy->c_str(); + } else if (globalPolicyName) { + policyName = familyPolicyName.c_str(); + } else { + continue; + } + printf("sendmsg$auto_%s(fd sock_nl_generic, msg ptr[in, %s[%s, %s]], f flags[send_flags]) (automatic)\n", + ops.cmd.c_str(), msghdr.c_str(), ops.cmd.c_str(), policyName); + printedCmds = true; + } + } + if (!printedCmds) { // Do not print resources and types if they're not used in any cmds + return; + } + std::string resourceName = "genl_" + name + "_family_id_auto"; + printf("resource %s[int16]\n", resourceName.c_str()); + printf("type %s[CMD, POLICY] msghdr_netlink[netlink_msg_t[%s, genlmsghdr_t[CMD], POLICY]]\n", msghdr.c_str(), + resourceName.c_str()); + printf("syz_genetlink_get_family_id$auto_%s(name ptr[in, string[\"%s\"]], fd sock_nl_generic) %s (automatic)\n", + name.c_str(), name.c_str(), resourceName.c_str()); + } + public: virtual void run(const MatchFinder::MatchResult &Result) override { syscall(Result); netlink(Result); + genlFamily(Result); }; }; @@ -224,6 +362,10 @@ int main(int argc, const char **argv) { isDefinition()) .bind("netlink"), &Printer); + Finder.addMatcher(varDecl(hasType(recordDecl(hasName("genl_family")).bind("genl_family")), + has(initListExpr().bind("genl_family_init"))) + .bind("genl_family_decl"), + &Printer); return Tool.run(clang::tooling::newFrontendActionFactory(&Finder).get()); } -- cgit mrf-deployment