diff options
Diffstat (limited to 'pkg/compiler')
| -rw-r--r-- | pkg/compiler/compiler.go | 182 | ||||
| -rw-r--r-- | pkg/compiler/compiler_test.go | 16 |
2 files changed, 186 insertions, 12 deletions
diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 7892661fa..593d9e82c 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -9,14 +9,134 @@ import ( "strings" "github.com/google/syzkaller/pkg/ast" + "github.com/google/syzkaller/sys" ) +// Prog is description compilation result. +type Prog struct { + // Processed AST (temporal measure, remove later). + Desc *ast.Description + Resources []*sys.ResourceDesc + // Set of unsupported syscalls/flags. + Unsupported map[string]bool +} + +// Compile compiles sys description. +func Compile(desc0 *ast.Description, consts map[string]uint64, eh ast.ErrorHandler) *Prog { + if eh == nil { + eh = ast.LoggingHandler + } + + desc := ast.Clone(desc0) + unsup, ok := patchConsts(desc, consts, eh) + if !ok { + return nil + } + + return &Prog{ + Desc: desc, + Unsupported: unsup, + } +} + +// patchConsts replaces all symbolic consts with their numeric values taken from consts map. +// Updates desc and returns set of unsupported syscalls and flags. +// After this pass consts are not needed for compilation. +func patchConsts(desc *ast.Description, consts map[string]uint64, eh ast.ErrorHandler) (map[string]bool, bool) { + broken := false + unsup := make(map[string]bool) + var top []ast.Node + for _, decl := range desc.Nodes { + switch decl.(type) { + case *ast.IntFlags: + // Unsupported flag values are dropped. + n := decl.(*ast.IntFlags) + var values []*ast.Int + for _, v := range n.Values { + if patchIntConst(v.Pos, &v.Value, &v.Ident, consts, unsup, nil, eh) { + values = append(values, v) + } + } + n.Values = values + top = append(top, n) + case *ast.Resource, *ast.Struct, *ast.Call: + if c, ok := decl.(*ast.Call); ok { + // Extract syscall NR. + str := "__NR_" + c.CallName + nr, ok := consts[str] + if !ok { + if name := "syscall " + c.CallName; !unsup[name] { + unsup[name] = true + eh(c.Pos, fmt.Sprintf("unsupported syscall: %v due to missing const %v", + c.CallName, str)) + } + continue + } + c.NR = nr + } + // Walk whole tree and replace consts in Int's and Type's. + missing := "" + ast.WalkNode(decl, func(n0 ast.Node) { + switch n := n0.(type) { + case *ast.Int: + patchIntConst(n.Pos, &n.Value, &n.Ident, + consts, unsup, &missing, eh) + case *ast.Type: + if c := typeConstIdentifier(n); c != nil { + patchIntConst(c.Pos, &c.Value, &c.Ident, + consts, unsup, &missing, eh) + if c.HasColon { + patchIntConst(c.Pos2, &c.Value2, &c.Ident2, + consts, unsup, &missing, eh) + } + } + } + }) + if missing == "" { + top = append(top, decl) + } else { + // Produce a warning about unsupported syscall/resource/struct. + // Unsupported syscalls are discarded. + // Unsupported resource/struct lead to compilation error. + // Fixing that would require removing all uses of the resource/struct. + typ, pos, name, fatal := "", ast.Pos{}, "", false + switch n := decl.(type) { + case *ast.Call: + typ, pos, name, fatal = "syscall", n.Pos, n.Name.Name, false + case *ast.Resource: + typ, pos, name, fatal = "resource", n.Pos, n.Name.Name, true + case *ast.Struct: + typ, pos, name, fatal = "struct", n.Pos, n.Name.Name, true + default: + panic(fmt.Sprintf("unknown type: %#v", decl)) + } + if id := typ + " " + name; !unsup[id] { + unsup[id] = true + eh(pos, fmt.Sprintf("unsupported %v: %v due to missing const %v", + typ, name, missing)) + } + if fatal { + broken = true + } + } + case *ast.StrFlags: + top = append(top, decl) + case *ast.NewLine, *ast.Comment, *ast.Include, *ast.Incdir, *ast.Define: + // These are not needed anymore. + default: + panic(fmt.Sprintf("unknown node type: %#v", decl)) + } + } + desc.Nodes = top + return unsup, !broken +} + // ExtractConsts returns list of literal constants and other info required const value extraction. -func ExtractConsts(top []interface{}) (consts, includes, incdirs []string, defines map[string]string) { +func ExtractConsts(desc *ast.Description) (consts, includes, incdirs []string, defines map[string]string) { constMap := make(map[string]bool) defines = make(map[string]string) - ast.Walk(top, func(n1 interface{}) { + ast.Walk(desc, func(n1 ast.Node) { switch n := n1.(type) { case *ast.Include: includes = append(includes, n.File.Value) @@ -37,11 +157,9 @@ func ExtractConsts(top []interface{}) (consts, includes, incdirs []string, defin constMap["__NR_"+n.CallName] = true } case *ast.Type: - if n.Ident == "const" && len(n.Args) > 0 { - constMap[n.Args[0].Ident] = true - } - if n.Ident == "array" && len(n.Args) > 1 { - constMap[n.Args[1].Ident] = true + if c := typeConstIdentifier(n); c != nil { + constMap[c.Ident] = true + constMap[c.Ident2] = true } case *ast.Int: constMap[n.Ident] = true @@ -52,6 +170,56 @@ func ExtractConsts(top []interface{}) (consts, includes, incdirs []string, defin return } +func patchIntConst(pos ast.Pos, val *uint64, id *string, + consts map[string]uint64, unsup map[string]bool, missing *string, eh ast.ErrorHandler) bool { + if *id == "" { + return true + } + v, ok := consts[*id] + if !ok { + name := "const " + *id + if !unsup[name] { + unsup[name] = true + eh(pos, fmt.Sprintf("unsupported const: %v", *id)) + } + if missing != nil && *missing == "" { + *missing = *id + } + } + *val = v + *id = "" + return ok +} + +// typeConstIdentifier returns type arg that is an integer constant (subject for const patching), if any. +func typeConstIdentifier(n *ast.Type) *ast.Type { + if n.Ident == "const" && len(n.Args) > 0 { + return n.Args[0] + } + if n.Ident == "array" && len(n.Args) > 1 && n.Args[1].Ident != "opt" { + return n.Args[1] + } + if n.Ident == "vma" && len(n.Args) > 0 && n.Args[0].Ident != "opt" { + return n.Args[0] + } + if n.Ident == "vma" && len(n.Args) > 0 && n.Args[0].Ident != "opt" { + return n.Args[0] + } + if n.Ident == "string" && len(n.Args) > 1 && n.Args[1].Ident != "opt" { + return n.Args[1] + } + if n.Ident == "csum" && len(n.Args) > 2 && n.Args[1].Ident == "pseudo" { + return n.Args[2] + } + switch n.Ident { + case "int8", "int16", "int16be", "int32", "int32be", "int64", "int64be", "intptr": + if len(n.Args) > 0 && n.Args[0].Ident != "opt" { + return n.Args[0] + } + } + return nil +} + func toArray(m map[string]bool) []string { delete(m, "") var res []string diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go index 13ed97b4c..e68d97d3e 100644 --- a/pkg/compiler/compiler_test.go +++ b/pkg/compiler/compiler_test.go @@ -11,13 +11,15 @@ import ( ) func TestExtractConsts(t *testing.T) { - top, ok := ast.Parse([]byte(extractConstsInput), "test", nil) - if !ok { + desc := ast.Parse([]byte(extractConstsInput), "test", nil) + if desc == nil { t.Fatalf("failed to parse input") } - consts, includes, incdirs, defines := ExtractConsts(top) - wantConsts := []string{"CONST1", "CONST2", "CONST3", "CONST4", "CONST5", - "CONST6", "CONST7", "__NR_bar", "__NR_foo"} + consts, includes, incdirs, defines := ExtractConsts(desc) + wantConsts := []string{"CONST1", "CONST10", "CONST11", "CONST12", "CONST13", + "CONST14", "CONST15", "CONST16", + "CONST2", "CONST3", "CONST4", "CONST5", + "CONST6", "CONST7", "CONST8", "CONST9", "__NR_bar", "__NR_foo"} if !reflect.DeepEqual(consts, wantConsts) { t.Fatalf("got consts:\n%q\nwant:\n%q", consts, wantConsts) } @@ -56,4 +58,8 @@ str { f1 const[CONST6, int32] f2 array[array[int8, CONST7]] } + +bar$BAZ(x vma[opt], y vma[CONST8], z vma[CONST9:CONST10]) +bar$QUX(s ptr[in, string["foo", CONST11]], x csum[s, pseudo, CONST12]) +bar$FOO(x int8[8:CONST13], y int16be[CONST14:10], z intptr[CONST15:CONST16]) ` |
