diff options
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/ast/parser_test.go | 56 | ||||
| -rw-r--r-- | pkg/ast/testdata/all.txt | 2 | ||||
| -rw-r--r-- | pkg/ast/walk.go | 66 | ||||
| -rw-r--r-- | pkg/compiler/compiler.go | 196 | ||||
| -rw-r--r-- | pkg/compiler/compiler_test.go | 21 | ||||
| -rw-r--r-- | pkg/compiler/consts.go | 124 |
6 files changed, 328 insertions, 137 deletions
diff --git a/pkg/ast/parser_test.go b/pkg/ast/parser_test.go index 809a0beaf..84a5fedf0 100644 --- a/pkg/ast/parser_test.go +++ b/pkg/ast/parser_test.go @@ -27,35 +27,37 @@ func TestParseAll(t *testing.T) { if err != nil { t.Fatalf("failed to read file: %v", err) } - errorHandler := func(pos Pos, msg string) { - t.Fatalf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg) - } - desc := Parse(data, file.Name(), errorHandler) - if desc == nil { - t.Fatalf("parsing failed, but no error produced") - } - data2 := Format(desc) - desc2 := Parse(data2, file.Name(), errorHandler) - if desc2 == nil { - t.Fatalf("parsing failed, but no error produced") - } - if len(desc.Nodes) != len(desc2.Nodes) { - t.Fatalf("formatting number of top level decls: %v/%v", - len(desc.Nodes), len(desc2.Nodes)) - } - for i := range desc.Nodes { - n1, n2 := desc.Nodes[i], desc2.Nodes[i] - if n1 == nil { - t.Fatalf("got nil node") + t.Run(file.Name(), func(t *testing.T) { + eh := func(pos Pos, msg string) { + t.Fatalf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg) } - if !reflect.DeepEqual(n1, n2) { - t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", n1, n2) + desc := Parse(data, file.Name(), eh) + if desc == nil { + t.Fatalf("parsing failed, but no error produced") } - } - data3 := Format(Clone(desc)) - if !bytes.Equal(data, data3) { - t.Fatalf("Clone lost data") - } + data2 := Format(desc) + desc2 := Parse(data2, file.Name(), eh) + if desc2 == nil { + t.Fatalf("parsing failed, but no error produced") + } + if len(desc.Nodes) != len(desc2.Nodes) { + t.Fatalf("formatting number of top level decls: %v/%v", + len(desc.Nodes), len(desc2.Nodes)) + } + for i := range desc.Nodes { + n1, n2 := desc.Nodes[i], desc2.Nodes[i] + if n1 == nil { + t.Fatalf("got nil node") + } + if !reflect.DeepEqual(n1, n2) { + t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", n1, n2) + } + } + data3 := Format(Clone(desc)) + if !bytes.Equal(data, data3) { + t.Fatalf("Clone lost data") + } + }) } } diff --git a/pkg/ast/testdata/all.txt b/pkg/ast/testdata/all.txt index 443f26368..9ddf67844 100644 --- a/pkg/ast/testdata/all.txt +++ b/pkg/ast/testdata/all.txt @@ -22,6 +22,8 @@ call(foo int32:"bar") ### unexpected string, expecting int, identifier define FOO `bar` define FOO `bar ### C expression is not terminated +foo(x int32[1:2:3, opt]) ### unexpected ':', expecting ']' + include <linux/foo.h> include "linux/foo.h" incdir </foo/bar> diff --git a/pkg/ast/walk.go b/pkg/ast/walk.go index 90a92cf77..af62884fe 100644 --- a/pkg/ast/walk.go +++ b/pkg/ast/walk.go @@ -8,85 +8,71 @@ import ( ) // Walk calls callback cb for every node in AST. -func Walk(desc *Description, cb func(n Node)) { +func Walk(desc *Description, cb func(n, parent Node)) { for _, n := range desc.Nodes { - WalkNode(n, cb) + WalkNode(n, nil, cb) } } -func WalkNode(n0 Node, cb func(n Node)) { +func WalkNode(n0, parent Node, cb func(n, parent Node)) { + cb(n0, parent) switch n := n0.(type) { case *NewLine: - cb(n) case *Comment: - cb(n) case *Include: - cb(n) - WalkNode(n.File, cb) + WalkNode(n.File, n, cb) case *Incdir: - cb(n) - WalkNode(n.Dir, cb) + WalkNode(n.Dir, n, cb) case *Define: - cb(n) - WalkNode(n.Name, cb) - WalkNode(n.Value, cb) + WalkNode(n.Name, n, cb) + WalkNode(n.Value, n, cb) case *Resource: - cb(n) - WalkNode(n.Name, cb) - WalkNode(n.Base, cb) + WalkNode(n.Name, n, cb) + WalkNode(n.Base, n, cb) for _, v := range n.Values { - WalkNode(v, cb) + WalkNode(v, n, cb) } case *Call: - cb(n) - WalkNode(n.Name, cb) + WalkNode(n.Name, n, cb) for _, f := range n.Args { - WalkNode(f, cb) + WalkNode(f, n, cb) } if n.Ret != nil { - WalkNode(n.Ret, cb) + WalkNode(n.Ret, n, cb) } case *Struct: - cb(n) - WalkNode(n.Name, cb) + WalkNode(n.Name, n, cb) for _, f := range n.Fields { - WalkNode(f, cb) + WalkNode(f, n, cb) } for _, a := range n.Attrs { - WalkNode(a, cb) + WalkNode(a, n, cb) } for _, c := range n.Comments { - WalkNode(c, cb) + WalkNode(c, n, cb) } case *IntFlags: - cb(n) - WalkNode(n.Name, cb) + WalkNode(n.Name, n, cb) for _, v := range n.Values { - WalkNode(v, cb) + WalkNode(v, n, cb) } case *StrFlags: - cb(n) - WalkNode(n.Name, cb) + WalkNode(n.Name, n, cb) for _, v := range n.Values { - WalkNode(v, cb) + WalkNode(v, n, cb) } case *Ident: - cb(n) case *String: - cb(n) case *Int: - cb(n) case *Type: - cb(n) for _, t := range n.Args { - WalkNode(t, cb) + WalkNode(t, n, cb) } case *Field: - cb(n) - WalkNode(n.Name, cb) - WalkNode(n.Type, cb) + WalkNode(n.Name, n, cb) + WalkNode(n.Type, n, cb) for _, c := range n.Comments { - WalkNode(c, cb) + WalkNode(c, n, cb) } default: panic(fmt.Sprintf("unknown AST node: %#v", n)) diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 593d9e82c..e85fdf6e2 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -22,113 +22,169 @@ type Prog struct { } // Compile compiles sys description. -func Compile(desc0 *ast.Description, consts map[string]uint64, eh ast.ErrorHandler) *Prog { +func Compile(desc *ast.Description, consts map[string]uint64, eh ast.ErrorHandler) *Prog { if eh == nil { eh = ast.LoggingHandler } + comp := &compiler{ + desc: ast.Clone(desc), + eh: eh, + unsupported: make(map[string]bool), + } - desc := ast.Clone(desc0) - unsup, ok := patchConsts(desc, consts, eh) - if !ok { + comp.assignSyscallNumbers(consts) + comp.patchConsts(consts) + if comp.errors != 0 { return nil } return &Prog{ - Desc: desc, - Unsupported: unsup, + Desc: comp.desc, + Unsupported: comp.unsupported, } } +type compiler struct { + desc *ast.Description + eh ast.ErrorHandler + errors int + + unsupported map[string]bool +} + +func (comp *compiler) error(pos ast.Pos, msg string, args ...interface{}) { + comp.errors++ + comp.warning(pos, msg, args...) +} + +func (comp *compiler) warning(pos ast.Pos, msg string, args ...interface{}) { + comp.eh(pos, fmt.Sprintf(msg, args...)) +} + +// assignSyscallNumbers assigns syscall numbers, discards unsupported syscalls +// and removes no longer irrelevant nodes from the tree (comments, new lines, etc). +func (comp *compiler) assignSyscallNumbers(consts map[string]uint64) { + // Pseudo syscalls starting from syz_ are assigned numbers starting from syzbase. + // Note: the numbers must be stable (not depend on file reading order, etc), + // so we have to do it in 2 passes. + const syzbase = 1000000 + syzcalls := make(map[string]bool) + for _, decl := range comp.desc.Nodes { + c, ok := decl.(*ast.Call) + if !ok { + continue + } + if strings.HasPrefix(c.CallName, "syz_") { + syzcalls[c.CallName] = true + } + } + syznr := make(map[string]uint64) + for i, name := range toArray(syzcalls) { + syznr[name] = syzbase + uint64(i) + } + + var top []ast.Node + for _, decl := range comp.desc.Nodes { + switch decl.(type) { + case *ast.Call: + c := decl.(*ast.Call) + if strings.HasPrefix(c.CallName, "syz_") { + c.NR = syznr[c.CallName] + top = append(top, decl) + continue + } + // Lookup in consts. + str := "__NR_" + c.CallName + nr, ok := consts[str] + if ok { + c.NR = nr + top = append(top, decl) + continue + } + name := "syscall " + c.CallName + if !comp.unsupported[name] { + comp.unsupported[name] = true + comp.warning(c.Pos, "unsupported syscall: %v due to missing const %v", + c.CallName, str) + } + case *ast.IntFlags, *ast.Resource, *ast.Struct, *ast.StrFlags: + top = append(top, decl) + case *ast.NewLine, *ast.Comment, *ast.Include, *ast.Incdir, *ast.Define: + // These are not needed anymore. + default: + panic(fmt.Sprintf("unknown node type: %#v", decl)) + } + } + comp.desc.Nodes = top +} + // patchConsts replaces all symbolic consts with their numeric values taken from consts map. // Updates desc and returns set of unsupported syscalls and flags. // After this pass consts are not needed for compilation. -func patchConsts(desc *ast.Description, consts map[string]uint64, eh ast.ErrorHandler) (map[string]bool, bool) { - broken := false - unsup := make(map[string]bool) +func (comp *compiler) patchConsts(consts map[string]uint64) { var top []ast.Node - for _, decl := range desc.Nodes { + for _, decl := range comp.desc.Nodes { switch decl.(type) { case *ast.IntFlags: // Unsupported flag values are dropped. n := decl.(*ast.IntFlags) var values []*ast.Int for _, v := range n.Values { - if patchIntConst(v.Pos, &v.Value, &v.Ident, consts, unsup, nil, eh) { + if comp.patchIntConst(v.Pos, &v.Value, &v.Ident, consts, nil) { values = append(values, v) } } n.Values = values top = append(top, n) + case *ast.StrFlags: + top = append(top, decl) case *ast.Resource, *ast.Struct, *ast.Call: - if c, ok := decl.(*ast.Call); ok { - // Extract syscall NR. - str := "__NR_" + c.CallName - nr, ok := consts[str] - if !ok { - if name := "syscall " + c.CallName; !unsup[name] { - unsup[name] = true - eh(c.Pos, fmt.Sprintf("unsupported syscall: %v due to missing const %v", - c.CallName, str)) - } - continue - } - c.NR = nr - } // Walk whole tree and replace consts in Int's and Type's. missing := "" - ast.WalkNode(decl, func(n0 ast.Node) { + ast.WalkNode(decl, nil, func(n0, _ ast.Node) { switch n := n0.(type) { case *ast.Int: - patchIntConst(n.Pos, &n.Value, &n.Ident, - consts, unsup, &missing, eh) + comp.patchIntConst(n.Pos, &n.Value, &n.Ident, consts, &missing) case *ast.Type: if c := typeConstIdentifier(n); c != nil { - patchIntConst(c.Pos, &c.Value, &c.Ident, - consts, unsup, &missing, eh) + comp.patchIntConst(c.Pos, &c.Value, &c.Ident, + consts, &missing) if c.HasColon { - patchIntConst(c.Pos2, &c.Value2, &c.Ident2, - consts, unsup, &missing, eh) + comp.patchIntConst(c.Pos2, &c.Value2, &c.Ident2, + consts, &missing) } } } }) if missing == "" { top = append(top, decl) - } else { - // Produce a warning about unsupported syscall/resource/struct. - // Unsupported syscalls are discarded. - // Unsupported resource/struct lead to compilation error. - // Fixing that would require removing all uses of the resource/struct. - typ, pos, name, fatal := "", ast.Pos{}, "", false - switch n := decl.(type) { - case *ast.Call: - typ, pos, name, fatal = "syscall", n.Pos, n.Name.Name, false - case *ast.Resource: - typ, pos, name, fatal = "resource", n.Pos, n.Name.Name, true - case *ast.Struct: - typ, pos, name, fatal = "struct", n.Pos, n.Name.Name, true - default: - panic(fmt.Sprintf("unknown type: %#v", decl)) - } - if id := typ + " " + name; !unsup[id] { - unsup[id] = true - eh(pos, fmt.Sprintf("unsupported %v: %v due to missing const %v", - typ, name, missing)) - } - if fatal { - broken = true - } + continue + } + // Produce a warning about unsupported syscall/resource/struct. + // Unsupported syscalls are discarded. + // Unsupported resource/struct lead to compilation error. + // Fixing that would require removing all uses of the resource/struct. + pos, typ, name := ast.Pos{}, "", "" + fn := comp.error + switch n := decl.(type) { + case *ast.Call: + pos, typ, name = n.Pos, "syscall", n.Name.Name + fn = comp.warning + case *ast.Resource: + pos, typ, name = n.Pos, "resource", n.Name.Name + case *ast.Struct: + pos, typ, name = n.Pos, "struct", n.Name.Name + default: + panic(fmt.Sprintf("unknown type: %#v", decl)) + } + if id := typ + " " + name; !comp.unsupported[id] { + comp.unsupported[id] = true + fn(pos, "unsupported %v: %v due to missing const %v", + typ, name, missing) } - case *ast.StrFlags: - top = append(top, decl) - case *ast.NewLine, *ast.Comment, *ast.Include, *ast.Incdir, *ast.Define: - // These are not needed anymore. - default: - panic(fmt.Sprintf("unknown node type: %#v", decl)) } } - desc.Nodes = top - return unsup, !broken + comp.desc.Nodes = top } // ExtractConsts returns list of literal constants and other info required const value extraction. @@ -136,7 +192,7 @@ func ExtractConsts(desc *ast.Description) (consts, includes, incdirs []string, d constMap := make(map[string]bool) defines = make(map[string]string) - ast.Walk(desc, func(n1 ast.Node) { + ast.Walk(desc, func(n1, _ ast.Node) { switch n := n1.(type) { case *ast.Include: includes = append(includes, n.File.Value) @@ -170,17 +226,17 @@ func ExtractConsts(desc *ast.Description) (consts, includes, incdirs []string, d return } -func patchIntConst(pos ast.Pos, val *uint64, id *string, - consts map[string]uint64, unsup map[string]bool, missing *string, eh ast.ErrorHandler) bool { +func (comp *compiler) patchIntConst(pos ast.Pos, val *uint64, id *string, + consts map[string]uint64, missing *string) bool { if *id == "" { return true } v, ok := consts[*id] if !ok { name := "const " + *id - if !unsup[name] { - unsup[name] = true - eh(pos, fmt.Sprintf("unsupported const: %v", *id)) + if !comp.unsupported[name] { + comp.unsupported[name] = true + comp.warning(pos, "unsupported const: %v", *id) } if missing != nil && *missing == "" { *missing = *id diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go index e68d97d3e..5415234c6 100644 --- a/pkg/compiler/compiler_test.go +++ b/pkg/compiler/compiler_test.go @@ -4,12 +4,33 @@ package compiler import ( + "path/filepath" "reflect" + "runtime" "testing" "github.com/google/syzkaller/pkg/ast" ) +func TestCompileAll(t *testing.T) { + eh := func(pos ast.Pos, msg string) { + t.Logf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg) + } + desc := ast.ParseGlob(filepath.Join("..", "..", "sys", "*.txt"), eh) + if desc == nil { + t.Fatalf("parsing failed") + } + glob := filepath.Join("..", "..", "sys", "*_"+runtime.GOARCH+".const") + consts := DeserializeConstsGlob(glob, eh) + if consts == nil { + t.Fatalf("reading consts failed") + } + prog := Compile(desc, consts, eh) + if prog == nil { + t.Fatalf("compilation failed") + } +} + func TestExtractConsts(t *testing.T) { desc := ast.Parse([]byte(extractConstsInput), "test", nil) if desc == nil { diff --git a/pkg/compiler/consts.go b/pkg/compiler/consts.go new file mode 100644 index 000000000..78f8ada00 --- /dev/null +++ b/pkg/compiler/consts.go @@ -0,0 +1,124 @@ +// Copyright 2017 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package compiler + +import ( + "bufio" + "bytes" + "fmt" + "io/ioutil" + "path/filepath" + "sort" + "strconv" + "strings" + + "github.com/google/syzkaller/pkg/ast" +) + +func SerializeConsts(consts map[string]uint64) []byte { + var nv []nameValuePair + for k, v := range consts { + nv = append(nv, nameValuePair{k, v}) + } + sort.Sort(nameValueArray(nv)) + + buf := new(bytes.Buffer) + fmt.Fprintf(buf, "# AUTOGENERATED FILE\n") + for _, x := range nv { + fmt.Fprintf(buf, "%v = %v\n", x.name, x.val) + } + return buf.Bytes() +} + +func DeserializeConsts(data []byte, file string, eh ast.ErrorHandler) map[string]uint64 { + consts := make(map[string]uint64) + pos := ast.Pos{ + File: file, + Line: 1, + } + ok := true + s := bufio.NewScanner(bytes.NewReader(data)) + for ; s.Scan(); pos.Line++ { + line := s.Text() + if line == "" || line[0] == '#' { + continue + } + eq := strings.IndexByte(line, '=') + if eq == -1 { + eh(pos, "expect '='") + ok = false + continue + } + name := strings.TrimSpace(line[:eq]) + val, err := strconv.ParseUint(strings.TrimSpace(line[eq+1:]), 0, 64) + if err != nil { + eh(pos, fmt.Sprintf("failed to parse int: %v", err)) + ok = false + continue + } + if _, ok := consts[name]; ok { + eh(pos, fmt.Sprintf("duplicate const %q", name)) + ok = false + continue + } + consts[name] = val + } + if err := s.Err(); err != nil { + eh(pos, fmt.Sprintf("failed to parse: %v", err)) + ok = false + } + if !ok { + return nil + } + return consts +} + +func DeserializeConstsGlob(glob string, eh ast.ErrorHandler) map[string]uint64 { + if eh == nil { + eh = ast.LoggingHandler + } + files, err := filepath.Glob(glob) + if err != nil { + eh(ast.Pos{}, fmt.Sprintf("failed to find const files: %v", err)) + return nil + } + if len(files) == 0 { + eh(ast.Pos{}, fmt.Sprintf("no const files matched by glob %q", glob)) + return nil + } + consts := make(map[string]uint64) + for _, f := range files { + data, err := ioutil.ReadFile(f) + if err != nil { + eh(ast.Pos{}, fmt.Sprintf("failed to read const file: %v", err)) + return nil + } + consts1 := DeserializeConsts(data, filepath.Base(f), eh) + if consts1 == nil { + consts = nil + } + if consts != nil { + for n, v := range consts1 { + if old, ok := consts[n]; ok && old != v { + eh(ast.Pos{}, fmt.Sprintf( + "different values for const %q: %v vs %v", n, v, old)) + return nil + } + consts[n] = v + } + } + } + return consts +} + +type nameValuePair struct { + name string + val uint64 +} + +type nameValueArray []nameValuePair + +func (a nameValueArray) Len() int { return len(a) } +func (a nameValueArray) Less(i, j int) bool { return a[i].name < a[j].name } +func (a nameValueArray) Swap(i, j int) { a[i], a[j] = a[j], a[i] } |
