From 807588d2dc47323949cd5b7266fbf44b18b480a4 Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Tue, 12 Nov 2024 15:47:39 +0100 Subject: tools/syz-declextract: add entry function and access level For now for netlink only. --- tools/syz-declextract/run.go | 14 +++++++- tools/syz-declextract/syz-declextract.cpp | 53 +++++++++++++++++++++++++------ 2 files changed, 56 insertions(+), 11 deletions(-) (limited to 'tools') diff --git a/tools/syz-declextract/run.go b/tools/syz-declextract/run.go index 5794c2a45..2feddfec8 100644 --- a/tools/syz-declextract/run.go +++ b/tools/syz-declextract/run.go @@ -144,7 +144,9 @@ type Interface struct { Type string `json:"type"` Name string `json:"name"` Files []string `json:"files"` - Subsystems []string `json:"subsystems"` + Func string `json:"func,omitempty"` + Access string `json:"access,omitempty"` + Subsystems []string `json:"subsystems,omitempty"` ManualDescriptions bool `json:"has_manual_descriptions"` AutoDescriptions bool `json:"has_auto_descriptions"` @@ -415,11 +417,21 @@ func appendNodes(slice *[]ast.Node, interfaces map[string]Interface, nodes []ast continue } fields := strings.Fields(node.Text) + if len(fields) != 6 { + tool.Failf("%q has wrong number of fields", node.Text) + } + for i := range fields { + if fields[i] == "-" { + fields[i] = "" + } + } iface := Interface{ Type: fields[1], Name: fields[2], Files: []string{file}, identifyingConst: fields[3], + Func: fields[4], + Access: fields[5], } if iface.Type == "SYSCALL" { for _, name := range syscallNames[iface.Name] { diff --git a/tools/syz-declextract/syz-declextract.cpp b/tools/syz-declextract/syz-declextract.cpp index 0a24c6418..392875c3d 100644 --- a/tools/syz-declextract/syz-declextract.cpp +++ b/tools/syz-declextract/syz-declextract.cpp @@ -43,6 +43,11 @@ using namespace clang; using namespace clang::ast_matchers; +const char *const AccessUnknown = "-"; +const char *const AccessUser = "user"; +const char *const AccessNsAdmin = "ns_admin"; +const char *const AccessAdmin = "admin"; + struct Param { std::string type; std::string name; @@ -50,6 +55,8 @@ struct Param { struct NetlinkOps { std::string cmd; + std::string func; + const char *access; std::optional policy; }; @@ -64,8 +71,12 @@ struct StructMember { unsigned int countedBy; }; -void emitInterface(const char *type, std::string_view name, std::string_view identifying_const) { - printf("\n#INTERFACE: %s %s %s\n\n", type, std::string(name).c_str(), std::string(identifying_const).c_str()); +void emitInterface(const char *type, std::string_view name, std::string_view identifying_const, + std::string_view entry_func = "", const char *access = AccessUnknown) { + if (entry_func.empty()) + entry_func = "-"; + printf("\n#INTERFACE: %s %s %s %s %s\n\n", type, std::string(name).c_str(), std::string(identifying_const).c_str(), + std::string(entry_func).c_str(), access); } struct SyzRecordDecl { @@ -96,6 +107,14 @@ struct SyzRecordDecl { } }; +// If expression refers to some identifier, returns the identifier name. +// Otherwise returns an empty string. +// For example, if the expression is `function_name`, returns "function_name" string. +std::string getDeclName(const clang::Expr *expr) { + auto *decl = llvm::dyn_cast(expr->IgnoreCasts()); + return decl ? decl->getDecl()->getNameAsString() : ""; +} + bool endsWith(const std::string_view &str, const std::string_view end) { size_t substrBegin = str.rfind(end); return substrBegin != std::string::npos && str.substr(substrBegin) == end; @@ -758,7 +777,6 @@ private: } std::map genlFamilyMember; - std::map> opsMember; std::optional getPolicyName(const MatchFinder::MatchResult &Result, const ValueDecl *decl) { if (!decl) { @@ -775,10 +793,9 @@ private: // when multiple policies are named the same but have different definitions } - std::vector getOps(const MatchFinder::MatchResult &Result, const std::string opsName, + std::vector getOps(const MatchFinder::MatchResult &Result, const std::string &opsName, const InitListExpr *init) { ASTContext *context = Result.Context; - std::vector ops; const auto n_ops = init->getInit(genlFamilyMember["n_" + opsName])->getIntegerConstantExpr(*context); const auto &opsRef = init->getInit(genlFamilyMember[opsName])->getAsBuiltinConstantDeclRef(*context); if (!n_ops || !opsRef) { @@ -791,21 +808,37 @@ private: return {}; } const auto *opsInit = llvm::dyn_cast(opsDecl->getInit()); + std::map opsMember; for (const auto &field : opsInit->getInit(0)->getType()->getAsRecordDecl()->fields()) { - opsMember[opsName][field->getNameAsString()] = field->getFieldIndex(); + opsMember[field->getNameAsString()] = field->getFieldIndex(); } + std::vector ops; for (int i = 0; i < n_ops; ++i) { const auto &init = llvm::dyn_cast(opsInit->getInit(i)); - const auto &cmdInit = init->getInit(opsMember[opsName]["cmd"])->getEnumConstantDecl(); + const auto &cmdInit = init->getInit(opsMember["cmd"])->getEnumConstantDecl(); if (!cmdInit) { continue; } const auto &cmd = cmdInit->getNameAsString(); const ValueDecl *policyDecl = nullptr; if (opsName != "small_ops") { - policyDecl = init->getInit(opsMember[opsName]["policy"])->getAsBuiltinConstantDeclRef(*context); + policyDecl = init->getInit(opsMember["policy"])->getAsBuiltinConstantDeclRef(*context); } - ops.push_back({std::move(cmd), getPolicyName(Result, policyDecl)}); + std::string func = getDeclName(init->getInit(opsMember["doit"])); + if (func.empty()) + func = getDeclName(init->getInit(opsMember["dumpit"])); + const Expr *flagsDecl = init->getInit(opsMember["flags"]); + Expr::EvalResult flags; + flagsDecl->EvaluateAsConstantExpr(flags, *context); + auto flagsVal = flags.Val.getInt().getExtValue(); + const char *access = AccessUser; + constexpr int GENL_ADMIN_PERM = 0x01; + constexpr int GENL_UNS_ADMIN_PERM = 0x10; + if (flagsVal & GENL_ADMIN_PERM) + access = AccessAdmin; + else if (flagsVal & GENL_UNS_ADMIN_PERM) + access = AccessNsAdmin; + ops.push_back({std::move(cmd), func, access, getPolicyName(Result, policyDecl)}); } return ops; } @@ -847,7 +880,7 @@ private: } else { continue; } - emitInterface("NETLINK", ops.cmd, ops.cmd); + emitInterface("NETLINK", ops.cmd, ops.cmd, ops.func, ops.access); printf("sendmsg$auto_%s(fd sock_nl_generic, msg ptr[in, %s[%s, %s]], f flags[send_flags]) (automatic)\n", ops.cmd.c_str(), msghdr.c_str(), ops.cmd.c_str(), policyName); printedCmds = true; -- cgit mrf-deployment