aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/declextract
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2024-12-13 15:15:49 +0100
committerDmitry Vyukov <dvyukov@google.com>2024-12-17 13:44:24 +0000
commitc8c15bb214509bafc8fe1a1e3abb8ccf90b3306e (patch)
treeca722a71aff5a1566389f178d9c95d7d7e8caeed /pkg/declextract
parentbc1a1b50f942408a9139887b914f745d9fa02adc (diff)
tools/syz-declextract: infer argument/field types
Use data flow analysis to infer syscall argument, return value, and struct field types. See the comment in pkg/declextract/typing.go for more details.
Diffstat (limited to 'pkg/declextract')
-rw-r--r--pkg/declextract/declextract.go25
-rw-r--r--pkg/declextract/entity.go49
-rw-r--r--pkg/declextract/serialization.go7
-rw-r--r--pkg/declextract/typing.go268
4 files changed, 342 insertions, 7 deletions
diff --git a/pkg/declextract/declextract.go b/pkg/declextract/declextract.go
index 4edb6c867..1f4592523 100644
--- a/pkg/declextract/declextract.go
+++ b/pkg/declextract/declextract.go
@@ -7,6 +7,7 @@ import (
"bytes"
"errors"
"fmt"
+ "io"
"os"
"slices"
"strings"
@@ -14,16 +15,20 @@ import (
"github.com/google/syzkaller/pkg/ifaceprobe"
)
-func Run(out *Output, probe *ifaceprobe.Info, syscallRename map[string][]string) ([]byte, []*Interface, error) {
+func Run(out *Output, probe *ifaceprobe.Info, syscallRename map[string][]string, trace io.Writer) (
+ []byte, []*Interface, error) {
ctx := &context{
Output: out,
probe: probe,
syscallRename: syscallRename,
structs: make(map[string]*Struct),
funcs: make(map[string]*Function),
+ facts: make(map[string]*typingNode),
uniqualizer: make(map[string]int),
+ debugTrace: trace,
}
ctx.processFunctions()
+ ctx.processTypingFacts()
ctx.processIncludes()
ctx.processEnums()
ctx.processStructs()
@@ -41,9 +46,11 @@ type context struct {
syscallRename map[string][]string // syscall function -> syscall names
structs map[string]*Struct
funcs map[string]*Function
+ facts map[string]*typingNode
uniqualizer map[string]int
interfaces []*Interface
descriptions *bytes.Buffer
+ debugTrace io.Writer
errs []error
}
@@ -55,6 +62,12 @@ func (ctx *context) warn(msg string, args ...any) {
fmt.Fprintf(os.Stderr, msg+"\n", args...)
}
+func (ctx *context) trace(msg string, args ...any) {
+ if ctx.debugTrace != nil {
+ fmt.Fprintf(ctx.debugTrace, msg+"\n", args...)
+ }
+}
+
func (ctx *context) processIncludes() {
// These additional includes must be at the top, because other kernel headers
// are broken and won't compile without these additional ones included first.
@@ -88,6 +101,11 @@ func (ctx *context) processSyscalls() {
var syscalls []*Syscall
for _, call := range ctx.Syscalls {
ctx.processFields(call.Args, "", false)
+ call.returnType = ctx.inferReturnType(call.Func, call.SourceFile)
+ for i, arg := range call.Args {
+ typ := ctx.inferArgType(call.Func, call.SourceFile, i)
+ refineFieldType(arg, typ, false)
+ }
fn := strings.TrimPrefix(call.Func, "__do_sys_")
for _, name := range ctx.syscallRename[fn] {
ctx.noteInterface(&Interface{
@@ -129,6 +147,11 @@ func (ctx *context) processStructs() {
})
for _, str := range ctx.Structs {
ctx.processFields(str.Fields, str.Name, true)
+ name := strings.TrimSuffix(str.Name, autoSuffix)
+ for _, f := range str.Fields {
+ typ := ctx.inferFieldType(name, f.Name)
+ refineFieldType(f, typ, true)
+ }
}
}
diff --git a/pkg/declextract/entity.go b/pkg/declextract/entity.go
index ba45cc51c..266647ed8 100644
--- a/pkg/declextract/entity.go
+++ b/pkg/declextract/entity.go
@@ -24,14 +24,16 @@ type Output struct {
}
type Function struct {
- Name string `json:"name,omitempty"`
- File string `json:"file,omitempty"`
- IsStatic bool `json:"is_static,omitempty"`
- LOC int `json:"loc,omitempty"`
- Calls []string `json:"calls,omitempty"`
+ Name string `json:"name,omitempty"`
+ File string `json:"file,omitempty"`
+ IsStatic bool `json:"is_static,omitempty"`
+ LOC int `json:"loc,omitempty"`
+ Calls []string `json:"calls,omitempty"`
+ Facts []*TypingFact `json:"facts,omitempty"`
callers int
calls []*Function
+ facts map[string]*typingNode
}
type Define struct {
@@ -53,6 +55,8 @@ type Syscall struct {
Func string `json:"func,omitempty"`
Args []*Field `json:"args,omitempty"`
SourceFile string `json:"source_file,omitempty"`
+
+ returnType string
}
// FileOps describes one file_operations variable.
@@ -158,6 +162,41 @@ type BufferType struct {
IsNonTerminated bool `json:"is_non_terminated,omitempty"`
}
+type TypingFact struct {
+ Src *TypingEntity `json:"src,omitempty"`
+ Dst *TypingEntity `json:"dst,omitempty"`
+}
+
+type TypingEntity struct {
+ Return *EntityReturn `json:"return,omitempty"`
+ Argument *EntityArgument `json:"argument,omitempty"`
+ Field *EntityField `json:"field,omitempty"`
+ Local *EntityLocal `json:"local,omitempty"`
+ GlobalAddr *EntityGlobalAddr `json:"global_addr,omitempty"`
+}
+
+type EntityReturn struct {
+ Func string `json:"func,omitempty"`
+}
+
+type EntityArgument struct {
+ Func string `json:"func,omitempty"`
+ Arg int `json:"arg"`
+}
+
+type EntityField struct {
+ Struct string `json:"struct"`
+ Field string `json:"field"`
+}
+
+type EntityLocal struct {
+ Name string `json:"name"`
+}
+
+type EntityGlobalAddr struct {
+ Name string
+}
+
func (out *Output) Merge(other *Output) {
out.Functions = append(out.Functions, other.Functions...)
out.Includes = append(out.Includes, other.Includes...)
diff --git a/pkg/declextract/serialization.go b/pkg/declextract/serialization.go
index 6d27d2a13..d69358679 100644
--- a/pkg/declextract/serialization.go
+++ b/pkg/declextract/serialization.go
@@ -27,6 +27,11 @@ meta automatic
type auto_todo int8
+type auto_union[INFERRED, RAW] [
+ inferred INFERRED
+ raw RAW
+]
+
`
func (ctx *context) fmt(msg string, args ...any) {
@@ -53,7 +58,7 @@ func (ctx *context) serializeSyscalls() {
for i, arg := range call.Args {
ctx.fmt("%v%v %v", comma(i), arg.Name, arg.syzType)
}
- ctx.fmt(")\n")
+ ctx.fmt(") %v\n", call.returnType)
}
ctx.fmt("\n")
}
diff --git a/pkg/declextract/typing.go b/pkg/declextract/typing.go
new file mode 100644
index 000000000..7de22474d
--- /dev/null
+++ b/pkg/declextract/typing.go
@@ -0,0 +1,268 @@
+// Copyright 2024 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package declextract
+
+import (
+ "bytes"
+ "fmt"
+ "slices"
+ "strings"
+)
+
+// Argument/field type inference based on data flow analysis.
+//
+// First, the clang tool produces data flow summary for each function.
+// The summary describes how data flows between function arguments, return values, local variables, and struct fields.
+// Then, the logic in this file tracks global data flow in the kernel to infer types for syscall arguments,
+// return values, and struct fields.
+// If data transitively flows from an argument to a known function that accepts a resource of a particular type
+// (e.g. __fget_light for file descriptors), then we infer that the original argument is an fd.
+// Similarly, if data flows from a known function that creates a resource (e.g. alloc_fd for file descriptors)
+// to a syscall return value, then we infer that the syscall returns an fd.
+// For struct fields we track data flow in both direction (to/from) to infer their types.
+//
+// If the inference produces multiple resources, currently we pick the one with the shortest flow path
+// (and then additionally pick lexicographically first among them for determinism). Potentially we could
+// use a more elaborate strategy that would somehow rank candidates and/or produce multiple candidates
+// (that we will then use as a union).
+//
+// Other potential improvements:
+// - Add more functions that consume/produce resources.
+// - Refine enum types. If we see an argument is used in bitops with an enum, it has that enum type.
+// - Infer pointer types when they flow to copy_from_user (sometimes they are declared as uint64).
+// - Infer that pointers are file names (they should flow to some known function for path resolution).
+// - Use SSA analysis to track flow via local variables better. Potentiall we can just rename on every next use
+// and ignore backwards edges (it's unlikely that backwards edges are required for type inference).
+// - Infer ioctl commands in transitively called functions using data flow.
+// - Infer file_operations associated with an fd by tracking flow to alloc_file_pseudo and friends.
+// - Add context-sensitivity at least on switched arguments (ioctl commands).
+// - Infer other switched arguments besides ioctl commands.
+// - Infer netlink arg types by tracking flow from genl_info::attrs[ATTR_FOO].
+// - Infer simple constraints on arguments, e.g. "if (arg != 0) return -EINVAL".
+// - Use kernel typedefs for typing (e.g. pid_t). We can use them for uapi structs, but also for kernel
+// structs and function arguments during dataflow tracking (e.g. if int flows to a pid_t argument, it's a pid).
+// - Track side flows. E.g. dup2 argument newfd flows to the return value, and newfd can be inferred to be an fd,
+// but currently we don't infer that the return value is an fd. Potentially we could infer that.
+// - Detect cases where returned value is actually an error rather than a resource.
+// For example, these cases lead to false inference of fd type for returned value:
+// https://elixir.bootlin.com/linux/v6.13-rc2/source/net/core/sock.c#L1870
+// https://elixir.bootlin.com/linux/v6.13-rc2/source/net/socket.c#L1742
+
+var (
+ // Refines types based on data flows...
+ flowResources = [2]map[string]string{
+ // ...to function arguments.
+ {
+ "__fget_light:arg0": "fd",
+ "__fget_files_rcu:arg1": "fd",
+ "make_kuid:arg1": "uid",
+ "make_kgid:arg1": "gid",
+ "find_pid_ns:arg0": "pid",
+ "pidfd_get_pid:arg0": "fd_pidfd",
+ "__dev_get_by_index:arg1": "ifindex",
+ },
+ // ...from function return value.
+ {
+ "alloc_fd:ret": "fd",
+ "pid_nr_ns:ret": "pid",
+ "from_kuid:ret": "uid",
+ "from_kgid:ret": "gid",
+ },
+ }
+ // These functions/structs/files provide very high false connectivity between unrelated nodes.
+ flowIgnoreFuncs = map[string]bool{
+ "ptr_to_compat": true,
+ "compat_ptr": true,
+ }
+ flowIgnoreStructs = map[string]bool{
+ "pt_regs": true,
+ "io_cqe": true,
+ "inode": true,
+ }
+ flowIgnoreFiles = map[string]bool{
+ "include/linux/err.h": true, // PTR_ERR/ERR_PTR/ERR_CAST
+ "include/linux/byteorder": true, // ntohl/etc
+ "include/linux/uaccess.h": true, // copy_to/from_user
+ "fs/befs/endian.h": true, // cpu_to_fs32/etc
+ "fs/ufs/swab.h": true,
+ }
+)
+
+// Limit on the flow graph traversal depth to avoid false positives due to false weird connections.
+const maxTraversalDepth = 18
+
+type typingNode struct {
+ id string
+ flows [2]map[*typingNode]bool
+}
+
+const (
+ flowTo = iota
+ flowFrom
+)
+
+func (ctx *context) processTypingFacts() {
+ for _, fn := range ctx.Functions {
+ for _, fact := range fn.Facts {
+ src := ctx.canonicalNode(fn, fact.Src)
+ dst := ctx.canonicalNode(fn, fact.Dst)
+ if src == nil || dst == nil {
+ continue
+ }
+ src.flows[flowTo][dst] = true
+ dst.flows[flowFrom][src] = true
+ }
+ }
+}
+
+func (ctx *context) canonicalNode(fn *Function, ent *TypingEntity) *typingNode {
+ scope, id := ent.ID(fn)
+ fullID := id
+ facts := ctx.facts
+ if scope != "" {
+ if scope != fn.Name {
+ fn = ctx.findFunc(scope, fn.File)
+ if fn == nil {
+ return nil
+ }
+ }
+ if flowIgnoreFuncs[fn.Name] || flowIgnoreFiles[fn.File] {
+ return nil
+ }
+ if fn.facts == nil {
+ fn.facts = make(map[string]*typingNode)
+ }
+ facts = fn.facts
+ fullID = fmt.Sprintf("%v:%v", scope, id)
+ } else if ent.Field != nil && flowIgnoreStructs[ent.Field.Struct] {
+ return nil
+ }
+ n := facts[id]
+ if n != nil {
+ return n
+ }
+ n = &typingNode{
+ id: fullID,
+ }
+ for i := range n.flows {
+ n.flows[i] = make(map[*typingNode]bool)
+ }
+ facts[id] = n
+ return n
+}
+
+func (ent *TypingEntity) ID(fn *Function) (string, string) {
+ switch {
+ case ent.Return != nil:
+ return ent.Return.Func, "ret"
+ case ent.Argument != nil:
+ return ent.Argument.Func, fmt.Sprintf("arg%v", ent.Argument.Arg)
+ case ent.Local != nil:
+ return fn.Name, fmt.Sprintf("loc.%v", ent.Local.Name)
+ case ent.Field != nil:
+ return "", fmt.Sprintf("%v.%v", ent.Field.Struct, ent.Field.Field)
+ case ent.GlobalAddr != nil:
+ return "", ent.GlobalAddr.Name
+ default:
+ panic("unhandled type")
+ }
+}
+
+func (ctx *context) inferReturnType(name, file string) string {
+ return ctx.inferFuncNode(name, file, "ret")
+}
+
+func (ctx *context) inferArgType(name, file string, arg int) string {
+ return ctx.inferFuncNode(name, file, fmt.Sprintf("arg%v", arg))
+}
+
+func (ctx *context) inferFuncNode(name, file, node string) string {
+ fn := ctx.findFunc(name, file)
+ if fn == nil {
+ return ""
+ }
+ return ctx.inferNodeType(fn.facts[node], fmt.Sprintf("%v %v", name, node))
+}
+
+func (ctx *context) inferFieldType(structName, field string) string {
+ name := fmt.Sprintf("%v.%v", structName, field)
+ return ctx.inferNodeType(ctx.facts[name], name)
+}
+
+func (ctx *context) inferNodeType(n *typingNode, what string) string {
+ if n == nil {
+ return ""
+ }
+ ic := &inferContext{
+ visited: make(map[*typingNode]bool),
+ flowType: flowFrom,
+ maxDepth: maxTraversalDepth,
+ }
+ ic.walk(n)
+ ic.flowType = flowTo
+ ic.visited = make(map[*typingNode]bool)
+ ic.walk(n)
+ if ic.result != "" {
+ ctx.trace("inferred %v\n %v%v", what, ic.result, flowString(ic.resultPath, ic.resultFlow))
+ }
+ return ic.result
+}
+
+type inferContext struct {
+ path []*typingNode
+ visited map[*typingNode]bool
+ result string
+ resultPath []*typingNode
+ resultFlow int
+ flowType int
+ maxDepth int
+}
+
+func (ic *inferContext) walk(n *typingNode) {
+ if ic.visited[n] {
+ return
+ }
+ ic.visited[n] = true
+ ic.path = append(ic.path, n)
+ if result, ok := flowResources[ic.flowType][n.id]; ok {
+ // Use lexicographical order just to make the result stable.
+ if ic.result == "" || len(ic.path) < ic.maxDepth ||
+ len(ic.path) == ic.maxDepth && strings.Compare(result, ic.result) < 0 {
+ ic.result = result
+ ic.resultPath = slices.Clone(ic.path)
+ ic.resultFlow = ic.flowType
+ ic.maxDepth = len(ic.path)
+ }
+ }
+ if len(ic.path) < ic.maxDepth {
+ for e := range n.flows[ic.flowType] {
+ ic.walk(e)
+ }
+ }
+ ic.path = ic.path[:len(ic.path)-1]
+}
+
+func refineFieldType(f *Field, typ string, preserveSize bool) {
+ // If our manual heuristics have figured out a more precise fd subtype,
+ // don't replace it with generic fd.
+ if typ == "" || typ == f.syzType ||
+ typ == "fd" && (strings.HasPrefix(f.syzType, "fd_") || strings.HasPrefix(f.syzType, "sock")) {
+ return
+ }
+ // For struct fields we need to keep the original size.
+ // Sometimes fd is passed as uint64.
+ if preserveSize {
+ typ = fmt.Sprintf("auto_union[%v, %v]", typ, f.syzType)
+ }
+ f.syzType = typ
+}
+
+func flowString(path []*typingNode, flowType int) string {
+ w := new(bytes.Buffer)
+ dir := [2]string{"->", "<-"}[flowType]
+ for _, e := range path {
+ fmt.Fprintf(w, " %v %v", dir, e.id)
+ }
+ return w.String()
+}