aboutsummaryrefslogtreecommitdiffstats
path: root/tools
diff options
context:
space:
mode:
authorPimyn Girgis <bemenboshra2001@gmail.com>2024-08-28 09:43:55 +0000
committerAleksandr Nogikh <nogikh@google.com>2024-09-03 08:45:44 +0000
commitb1df3115feca4b02661cb4a2fa45f0d5e37d35c0 (patch)
treecd088bf2641910d0c3a0ced490c8cf6a829d4664 /tools
parent512328ba4c7c2d32f3ce13ad568a44b25d7e8df5 (diff)
tools/syz-declextract: extract genl_family and generate descriptions
Extract genl_family and generate descriptions based on the attached policies and commands. Fix Issue with the go tool where the clang tool would fail and remain undetected.
Diffstat (limited to 'tools')
-rw-r--r--tools/syz-declextract/run.go77
-rw-r--r--tools/syz-declextract/syz-declextract.cpp180
2 files changed, 229 insertions, 28 deletions
diff --git a/tools/syz-declextract/run.go b/tools/syz-declextract/run.go
index f637eb939..9e06603d7 100644
--- a/tools/syz-declextract/run.go
+++ b/tools/syz-declextract/run.go
@@ -67,6 +67,8 @@ func main() {
var syscalls []*ast.Call
var netlinks []*ast.Struct
var includes []*ast.Include
+ var typeDefs []*ast.TypeDef
+ var resources []*ast.Resource
syscallNames := readSyscallNames(filepath.Join(*kernelDir, "arch")) // some syscalls have different names and entry
// points and thus need to be renamed.
// e.g. SYSCALL_DEFINE1(setuid16, old_uid_t, uid) is referred to in the .tbl file with setuid.
@@ -79,6 +81,7 @@ func main() {
}
parse := ast.Parse([]byte(out.stdout), "", eh)
if parse == nil {
+ fmt.Println(out.stdout)
tool.Failf("parsing error")
}
for _, node := range parse.Nodes {
@@ -89,15 +92,22 @@ func main() {
netlinks = append(netlinks, node)
case *ast.Include:
includes = append(includes, node)
+ case *ast.TypeDef:
+ typeDefs = append(typeDefs, node)
+ case *ast.Resource:
+ resources = append(resources, node)
+ case *ast.NewLine:
+ continue
}
}
}
close(files)
- writeOutput(includes, syscalls, netlinks, *outFile)
+ writeOutput(includes, syscalls, netlinks, typeDefs, resources, *outFile)
}
-func writeOutput(includes []*ast.Include, syscalls []*ast.Call, netlinks []*ast.Struct, outFile string) {
+func writeOutput(includes []*ast.Include, syscalls []*ast.Call, netlinks []*ast.Struct, types []*ast.TypeDef,
+ resources []*ast.Resource, outFile string) {
slices.SortFunc(includes, func(a, b *ast.Include) int {
return strings.Compare(a.File.Value, b.File.Value)
})
@@ -117,23 +127,56 @@ func writeOutput(includes []*ast.Include, syscalls []*ast.Call, netlinks []*ast.
return strings.Compare(a.Name.Name, b.Name.Name)
})
- autoGeneratedNotice := "# Code generated by syz-declextract. DO NOT EDIT."
+ slices.SortFunc(resources, func(a, b *ast.Resource) int {
+ return strings.Compare(a.Name.Name, b.Name.Name)
+ })
+
+ slices.SortFunc(types, func(a, b *ast.TypeDef) int {
+ return strings.Compare(a.Name.Name, b.Name.Name)
+ })
+
+ autoGeneratedNotice := "# Code generated by syz-declextract. DO NOT EDIT.\n"
+ commonKernelHeaders := "include <include/vdso/bits.h>\ninclude <include/linux/types.h>"
+ var netlinkNames []string
mmap2 := "_ = __NR_mmap2\n"
eh := ast.LoggingHandler
- desc := ast.Parse([]byte(autoGeneratedNotice), "", eh)
+ desc := ast.Parse([]byte(autoGeneratedNotice+commonKernelHeaders), "", eh)
for _, node := range includes {
desc.Nodes = append(desc.Nodes, node)
}
+ for _, node := range resources {
+ desc.Nodes = append(desc.Nodes, node)
+ }
+ for _, node := range types {
+ desc.Nodes = append(desc.Nodes, node)
+ }
+ usedNetlink := make(map[string]bool)
for _, node := range syscalls {
+ if node.CallName == "sendmsg" && len(node.Args[1].Type.Args) == 2 {
+ policy := node.Args[1].Type.Args[1].Args[1].Ident
+ usedNetlink[policy] = true
+ _, isDefined := slices.BinarySearchFunc(netlinks, policy, func(a *ast.Struct, b string) int {
+ return strings.Compare(a.Name.Name, b)
+ })
+ if !isDefined {
+ continue
+ }
+ }
desc.Nodes = append(desc.Nodes, node)
}
desc.Nodes = append(desc.Nodes, ast.Parse([]byte(mmap2), "", eh).Nodes...)
- var netlinkNames []string
- for i, node := range netlinks {
+ for _, node := range netlinks {
desc.Nodes = append(desc.Nodes, node)
- netlinkNames = append(netlinkNames, fmt.Sprintf("\tpolicy%v msghdr_auto[%v]\n", i, node.Name.Name))
+ name := node.Name.Name
+ if !usedNetlink[name] {
+ netlinkNames = append(netlinkNames, name)
+ }
+ }
+ for i, netlink := range netlinkNames {
+ netlinkNames[i] = fmt.Sprintf("\tpolicy%v msghdr_auto[%v]\n", i, netlink)
}
- netlinkUnion := `type msghdr_auto[POLICY] msghdr_netlink[netlink_msg_t[autogenerated_netlink, genlmsghdr, POLICY]]
+ netlinkUnion := `
+type msghdr_auto[POLICY] msghdr_netlink[netlink_msg_t[autogenerated_netlink, genlmsghdr, POLICY]]
resource autogenerated_netlink[int16]
syz_genetlink_get_family_id$auto(name ptr[in, string], fd sock_nl_generic) autogenerated_netlink
sendmsg$autorun(fd sock_nl_generic, msg ptr[in, auto_union], f flags[send_flags])
@@ -166,7 +209,11 @@ func worker(outputs chan output, files chan string, binary, compilationDatabase
if err != nil {
var error *exec.ExitError
if errors.As(err, &error) {
- stderr = string(error.Stderr)
+ if len(error.Stderr) != 0 {
+ stderr = string(error.Stderr)
+ } else {
+ stderr = fmt.Sprintf("%v: %v", file, error.String())
+ }
} else {
stderr = err.Error()
}
@@ -176,6 +223,9 @@ func worker(outputs chan output, files chan string, binary, compilationDatabase
}
func renameSyscall(syscall *ast.Call, rename map[string][]string) []*ast.Call {
+ if !shouldRenameSyscall(syscall.CallName) {
+ return []*ast.Call{syscall}
+ }
var renamed []*ast.Call
toReplace := syscall.CallName
if rename[toReplace] == nil {
@@ -239,6 +289,15 @@ func readSyscallNames(kernelDir string) map[string][]string {
return rename
}
+func shouldRenameSyscall(syscall string) bool {
+ switch syscall {
+ case "sendmsg", "syz_genetlink_get_family_id":
+ return false
+ default:
+ return true
+ }
+}
+
func isProhibited(syscall string) bool {
switch syscall {
case "reboot", "utimesat": // utimesat is not defined for all arches.
diff --git a/tools/syz-declextract/syz-declextract.cpp b/tools/syz-declextract/syz-declextract.cpp
index 6a597974b..66ce053a9 100644
--- a/tools/syz-declextract/syz-declextract.cpp
+++ b/tools/syz-declextract/syz-declextract.cpp
@@ -13,6 +13,7 @@
#include "clang/ASTMatchers/ASTMatchers.h"
#include "clang/ASTMatchers/ASTMatchersInternal.h"
#include "clang/Basic/LLVM.h"
+#include "clang/Lex/Lexer.h"
#include "clang/Sema/Ownership.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Tooling.h"
@@ -20,8 +21,11 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
#include <cstddef>
#include <filesystem>
+#include <map>
+#include <optional>
#include <stdio.h>
#include <string>
#include <vector>
@@ -31,7 +35,7 @@ using namespace clang::ast_matchers;
struct EnumData {
std::string name;
- int value;
+ unsigned long long value;
std::string file;
};
@@ -40,6 +44,11 @@ struct Param {
std::string name;
};
+struct NetlinkOps {
+ std::string cmd;
+ std::optional<std::string> policy;
+};
+
class EnumMatcher : public MatchFinder::MatchCallback {
private:
std::vector<EnumData> EnumDetails;
@@ -48,12 +57,13 @@ public:
std::vector<EnumData> getEnumData() { return EnumDetails; }
virtual void run(const MatchFinder::MatchResult &Result) override {
const auto *enumValue = Result.Nodes.getNodeAs<ConstantExpr>("enum_value");
- if (!enumValue)
+ if (!enumValue) {
return;
+ }
const auto &name = enumValue->getEnumConstantDecl()->getNameAsString();
- const auto value = int(*enumValue->getAPValueResult().getInt().getRawData());
+ const auto value = *enumValue->getAPValueResult().getInt().getRawData();
const auto &path = std::filesystem::relative(
- Result.SourceManager->getFilename(enumValue->getEnumConstantDecl()->getSourceRange().getBegin()).data());
+ Result.SourceManager->getFilename(enumValue->getEnumConstantDecl()->getSourceRange().getBegin()).str());
EnumDetails.push_back({std::move(name), value, std::move(path)});
}
};
@@ -122,7 +132,7 @@ private:
if (values.empty())
return;
- int argc = *values[2]->getIntegerConstantExpr(*context).value().getRawData();
+ int argc = *values[2]->getIntegerConstantExpr(*context)->getRawData();
std::vector<Param> args(argc);
if (argc) {
@@ -145,7 +155,7 @@ private:
}
}
- printf("%s$auto(", values[0]->tryEvaluateString(*context).value().c_str() + 4); // name
+ printf("%s$auto(", values[0]->tryEvaluateString(*context)->c_str() + 4); // name
const char *sep = "";
for (const auto &arg : args) {
printf("%s%s %s", sep, swapIfReservedKeyword(arg.name).c_str(), getSyzType(arg.type).c_str());
@@ -161,42 +171,170 @@ private:
if (!netlinkDecl) {
return;
}
+ std::vector<std::vector<Expr *>> fields;
+
+ const auto *init = netlinkDecl->getInit();
+ if (!init) {
+ return;
+ }
+ for (const auto &policy : *llvm::dyn_cast<InitListExpr>(init)) {
+ fields.push_back(std::vector<Expr *>());
+ for (const auto &member : policy->children()) {
+ fields.back().push_back(llvm::dyn_cast<Expr>(member));
+ }
+ }
EnumMatcher enumMatcher;
MatchFinder enumFinder;
enumFinder.addMatcher(
- decl(forEachDescendant(designatedInitExpr(has(constantExpr(has(declRefExpr())).bind("enum_value"))))),
+ decl(forEachDescendant(designatedInitExpr(optionally(has(constantExpr(has(declRefExpr())).bind("enum_value"))))
+ .bind("designated_init"))),
&enumMatcher);
enumFinder.match(*netlinkDecl, *context); // get enum details from the current subtree (nla_policy[])
- std::vector<std::vector<Expr *>> fields;
- for (const auto &policy : *llvm::dyn_cast<InitListExpr>(netlinkDecl->getInit())) {
- // The array could have an implicitly initialized policy (i.e. empty)
- if (policy->children().empty()) {
- continue;
- }
- fields.push_back(std::vector<Expr *>());
- for (const auto &member : policy->children()) {
- fields.back().push_back(llvm::dyn_cast<Expr>(member));
- }
+ auto unorderedEnumData = enumMatcher.getEnumData();
+ if (unorderedEnumData.empty()) {
+ return;
+ }
+
+ std::vector<EnumData> enumData(fields.size());
+ for (auto &data : unorderedEnumData) {
+ enumData.at(data.value) = std::move(data);
}
- const auto enumData = enumMatcher.getEnumData();
for (const auto &item : enumData) {
+ if (item.file.empty()) {
+ continue;
+ }
+ if (item.file.back() != 'h') { // only extract from "*.h" files
+ return;
+ }
printf("include<%s>\n", item.file.c_str());
}
- printf("%s$auto[\n", netlinkDecl->getDefinition()->getNameAsString().c_str());
+
+ printf("%s[\n", getPolicyName(Result, netlinkDecl)->c_str());
for (size_t i = 0; i < fields.size(); ++i) {
+ // The array could have an implicitly initialized policy (i.e. empty) or an unnamed attribute
+ if (fields[i].empty() || enumData[i].name.empty()) {
+ continue;
+ }
printf("\t%s nlattr[%s, %s]\n", enumData[i].name.c_str(), enumData[i].name.c_str(),
nlaToSyz(fields[i][0]).c_str());
}
puts("] [varlen]");
}
+ std::map<std::string, unsigned> genlFamilyMember;
+ std::map<std::string, std::map<std::string, unsigned>> opsMember;
+
+ std::optional<std::string> getPolicyName(const MatchFinder::MatchResult &Result, const ValueDecl *decl) {
+ if (!decl) {
+ return std::nullopt;
+ }
+ std::string filename =
+ std::filesystem::path(
+ Result.SourceManager->getFilename(decl->getCanonicalDecl()->getSourceRange().getBegin()).str())
+ .filename()
+ .stem()
+ .string();
+ std::replace(filename.begin(), filename.end(), '-', '_');
+ return decl->getNameAsString() + "$auto_" + filename; // filename is added to address ambiguity
+ // when multiple policies are named the same but have different definitions
+ }
+
+ std::vector<NetlinkOps> getOps(const MatchFinder::MatchResult &Result, const std::string opsName,
+ const InitListExpr *init) {
+ ASTContext *context = Result.Context;
+ std::vector<NetlinkOps> ops;
+ const auto n_ops = init->getInit(genlFamilyMember["n_" + opsName])->getIntegerConstantExpr(*context);
+ const auto &opsRef = init->getInit(genlFamilyMember[opsName])->getAsBuiltinConstantDeclRef(*context);
+ if (!n_ops || !opsRef) {
+ return {};
+ }
+ const auto *opsDecl = llvm::dyn_cast<VarDecl>(opsRef);
+ if (!opsDecl->getInit()) {
+ // NOTE: This usually happens when the ops is defined as an extern variable
+ // TODO: Extract extern variables
+ return {};
+ }
+ const auto *opsInit = llvm::dyn_cast<InitListExpr>(opsDecl->getInit());
+ for (const auto &field : opsInit->getInit(0)->getType()->getAsRecordDecl()->fields()) {
+ opsMember[opsName][field->getNameAsString()] = field->getFieldIndex();
+ }
+ for (int i = 0; i < n_ops; ++i) {
+ const auto &init = llvm::dyn_cast<InitListExpr>(opsInit->getInit(i));
+ const auto &cmdInit = init->getInit(opsMember[opsName]["cmd"])->getEnumConstantDecl();
+ if (!cmdInit) {
+ continue;
+ }
+ const auto &cmd = cmdInit->getNameAsString();
+ const ValueDecl *policyDecl = nullptr;
+ if (opsName != "small_ops") {
+ policyDecl = init->getInit(opsMember[opsName]["policy"])->getAsBuiltinConstantDeclRef(*context);
+ }
+ ops.push_back({std::move(cmd), getPolicyName(Result, policyDecl)});
+ }
+ return ops;
+ }
+
+ void genlFamily(const MatchFinder::MatchResult &Result) {
+ ASTContext *context = Result.Context;
+ const auto *genlFamily = Result.Nodes.getNodeAs<RecordDecl>("genl_family");
+ if (!genlFamily) {
+ return;
+ }
+ for (const auto &field : genlFamily->fields()) {
+ genlFamilyMember[field->getNameAsString()] = field->getFieldIndex();
+ }
+ const auto *genlFamilyInit = Result.Nodes.getNodeAs<InitListExpr>("genl_family_init");
+ if (!genlFamilyInit) {
+ return;
+ }
+
+ auto name = llvm::dyn_cast<StringLiteral>(genlFamilyInit->getInit(genlFamilyMember["name"]))->getString().str();
+ std::replace(name.begin(), name.end(), '.', '_'); // Illegal character.
+ std::replace(name.begin(), name.end(), ' ', '_'); // Don't leave space in name.
+ const auto &globalPolicyName =
+ genlFamilyInit->getInit(genlFamilyMember["policy"])->getAsBuiltinConstantDeclRef(*context);
+
+ std::string familyPolicyName;
+ if (globalPolicyName) {
+ familyPolicyName = *getPolicyName(Result, globalPolicyName);
+ }
+
+ std::string msghdr = "msghdr_" + name + "_auto";
+ bool printedCmds = false;
+ for (const auto &opsType : {"ops", "small_ops", "split_ops"}) {
+ for (auto &ops : getOps(Result, opsType, genlFamilyInit)) {
+ const char *policyName;
+ if (ops.policy) {
+ policyName = ops.policy->c_str();
+ } else if (globalPolicyName) {
+ policyName = familyPolicyName.c_str();
+ } else {
+ continue;
+ }
+ printf("sendmsg$auto_%s(fd sock_nl_generic, msg ptr[in, %s[%s, %s]], f flags[send_flags]) (automatic)\n",
+ ops.cmd.c_str(), msghdr.c_str(), ops.cmd.c_str(), policyName);
+ printedCmds = true;
+ }
+ }
+ if (!printedCmds) { // Do not print resources and types if they're not used in any cmds
+ return;
+ }
+ std::string resourceName = "genl_" + name + "_family_id_auto";
+ printf("resource %s[int16]\n", resourceName.c_str());
+ printf("type %s[CMD, POLICY] msghdr_netlink[netlink_msg_t[%s, genlmsghdr_t[CMD], POLICY]]\n", msghdr.c_str(),
+ resourceName.c_str());
+ printf("syz_genetlink_get_family_id$auto_%s(name ptr[in, string[\"%s\"]], fd sock_nl_generic) %s (automatic)\n",
+ name.c_str(), name.c_str(), resourceName.c_str());
+ }
+
public:
virtual void run(const MatchFinder::MatchResult &Result) override {
syscall(Result);
netlink(Result);
+ genlFamily(Result);
};
};
@@ -224,6 +362,10 @@ int main(int argc, const char **argv) {
isDefinition())
.bind("netlink"),
&Printer);
+ Finder.addMatcher(varDecl(hasType(recordDecl(hasName("genl_family")).bind("genl_family")),
+ has(initListExpr().bind("genl_family_init")))
+ .bind("genl_family_decl"),
+ &Printer);
return Tool.run(clang::tooling::newFrontendActionFactory(&Finder).get());
}