diff options
| -rw-r--r-- | tools/syz-declextract/run.go | 38 |
1 files changed, 32 insertions, 6 deletions
diff --git a/tools/syz-declextract/run.go b/tools/syz-declextract/run.go index 9e06603d7..9a21f1e7b 100644 --- a/tools/syz-declextract/run.go +++ b/tools/syz-declextract/run.go @@ -15,6 +15,7 @@ import ( "path/filepath" "runtime" "slices" + "strconv" "strings" "github.com/google/syzkaller/pkg/ast" @@ -22,6 +23,8 @@ import ( "github.com/google/syzkaller/sys/targets" ) +const sendmsg = "sendmsg" + type compileCommand struct { Arguments []string Directory string @@ -69,8 +72,8 @@ func main() { var includes []*ast.Include var typeDefs []*ast.TypeDef var resources []*ast.Resource - syscallNames := readSyscallNames(filepath.Join(*kernelDir, "arch")) // some syscalls have different names and entry - // points and thus need to be renamed. + syscallNames := readSyscallNames(filepath.Join(*kernelDir, "arch")) + // Some syscalls have different names and entry points and thus need to be renamed. // e.g. SYSCALL_DEFINE1(setuid16, old_uid_t, uid) is referred to in the .tbl file with setuid. eh := ast.LoggingHandler @@ -116,13 +119,36 @@ func writeOutput(includes []*ast.Include, syscalls []*ast.Call, netlinks []*ast. }) slices.SortFunc(syscalls, func(a, b *ast.Call) int { - return strings.Compare(a.Name.Name, b.Name.Name) + nameCmp := strings.Compare(a.Name.Name, b.Name.Name) + if nameCmp != 0 { + return nameCmp + } + if a.CallName == sendmsg { + // For sendmsg, compare by the policy name: sendmsg(_, msg ptr[_, msghdr_macsec_auto[_, PolicyName]], _). + return strings.Compare(a.Args[1].Type.Args[1].Args[1].Ident, b.Args[1].Type.Args[1].Args[1].Ident) + } + return slices.CompareFunc(a.Args, b.Args, func(a, b *ast.Field) int { + // Ensure deterministic output. Some system calls have the same name but share different parameter names; this + // guarantees that the compact function will always keep the same one. + return strings.Compare(a.Name.Name, b.Name.Name) + }) }) + + sendmsgNo := 0 + // Some commands are executed for multiple policies. Ensure that they don't get deleted by the following compact call. + for _, node := range syscalls { + if node.CallName == sendmsg { + node.Name.Name += strconv.Itoa(sendmsgNo) + sendmsgNo++ + } + } + syscalls = slices.CompactFunc(syscalls, func(a, b *ast.Call) bool { // We only compare the the system call names for cases where the same system call has different parameter names, // but share the same syzkaller type. NOTE:Change when we have better type extraction. return a.Name.Name == b.Name.Name }) + slices.SortFunc(netlinks, func(a, b *ast.Struct) int { return strings.Compare(a.Name.Name, b.Name.Name) }) @@ -152,7 +178,7 @@ func writeOutput(includes []*ast.Include, syscalls []*ast.Call, netlinks []*ast. } usedNetlink := make(map[string]bool) for _, node := range syscalls { - if node.CallName == "sendmsg" && len(node.Args[1].Type.Args) == 2 { + if node.CallName == sendmsg && len(node.Args[1].Type.Args) == 2 { policy := node.Args[1].Type.Args[1].Args[1].Ident usedNetlink[policy] = true _, isDefined := slices.BinarySearchFunc(netlinks, policy, func(a *ast.Struct, b string) int { @@ -291,7 +317,7 @@ func readSyscallNames(kernelDir string) map[string][]string { func shouldRenameSyscall(syscall string) bool { switch syscall { - case "sendmsg", "syz_genetlink_get_family_id": + case sendmsg, "syz_genetlink_get_family_id": return false default: return true @@ -300,7 +326,7 @@ func shouldRenameSyscall(syscall string) bool { func isProhibited(syscall string) bool { switch syscall { - case "reboot", "utimesat": // utimesat is not defined for all arches. + case "reboot", "utimesat": // `utimesat` is not defined for all arches. return true default: return false |
