aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2017-08-27 15:08:11 +0200
committerDmitry Vyukov <dvyukov@google.com>2017-08-27 15:28:49 +0200
commite2ffb4fc9111e28f1d8e0e987cb06172cbbd4e84 (patch)
tree8dc4d76063e3430321622cd81763a95a972f2e3a /pkg
parente71c87fbf52c83d8e514e4685d40da4d4d0f1a1c (diff)
pkg/compiler: move more const-processing code to compiler
Diffstat (limited to 'pkg')
-rw-r--r--pkg/ast/parser_test.go56
-rw-r--r--pkg/ast/testdata/all.txt2
-rw-r--r--pkg/ast/walk.go66
-rw-r--r--pkg/compiler/compiler.go196
-rw-r--r--pkg/compiler/compiler_test.go21
-rw-r--r--pkg/compiler/consts.go124
6 files changed, 328 insertions, 137 deletions
diff --git a/pkg/ast/parser_test.go b/pkg/ast/parser_test.go
index 809a0beaf..84a5fedf0 100644
--- a/pkg/ast/parser_test.go
+++ b/pkg/ast/parser_test.go
@@ -27,35 +27,37 @@ func TestParseAll(t *testing.T) {
if err != nil {
t.Fatalf("failed to read file: %v", err)
}
- errorHandler := func(pos Pos, msg string) {
- t.Fatalf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg)
- }
- desc := Parse(data, file.Name(), errorHandler)
- if desc == nil {
- t.Fatalf("parsing failed, but no error produced")
- }
- data2 := Format(desc)
- desc2 := Parse(data2, file.Name(), errorHandler)
- if desc2 == nil {
- t.Fatalf("parsing failed, but no error produced")
- }
- if len(desc.Nodes) != len(desc2.Nodes) {
- t.Fatalf("formatting number of top level decls: %v/%v",
- len(desc.Nodes), len(desc2.Nodes))
- }
- for i := range desc.Nodes {
- n1, n2 := desc.Nodes[i], desc2.Nodes[i]
- if n1 == nil {
- t.Fatalf("got nil node")
+ t.Run(file.Name(), func(t *testing.T) {
+ eh := func(pos Pos, msg string) {
+ t.Fatalf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg)
}
- if !reflect.DeepEqual(n1, n2) {
- t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", n1, n2)
+ desc := Parse(data, file.Name(), eh)
+ if desc == nil {
+ t.Fatalf("parsing failed, but no error produced")
}
- }
- data3 := Format(Clone(desc))
- if !bytes.Equal(data, data3) {
- t.Fatalf("Clone lost data")
- }
+ data2 := Format(desc)
+ desc2 := Parse(data2, file.Name(), eh)
+ if desc2 == nil {
+ t.Fatalf("parsing failed, but no error produced")
+ }
+ if len(desc.Nodes) != len(desc2.Nodes) {
+ t.Fatalf("formatting number of top level decls: %v/%v",
+ len(desc.Nodes), len(desc2.Nodes))
+ }
+ for i := range desc.Nodes {
+ n1, n2 := desc.Nodes[i], desc2.Nodes[i]
+ if n1 == nil {
+ t.Fatalf("got nil node")
+ }
+ if !reflect.DeepEqual(n1, n2) {
+ t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", n1, n2)
+ }
+ }
+ data3 := Format(Clone(desc))
+ if !bytes.Equal(data, data3) {
+ t.Fatalf("Clone lost data")
+ }
+ })
}
}
diff --git a/pkg/ast/testdata/all.txt b/pkg/ast/testdata/all.txt
index 443f26368..9ddf67844 100644
--- a/pkg/ast/testdata/all.txt
+++ b/pkg/ast/testdata/all.txt
@@ -22,6 +22,8 @@ call(foo int32:"bar") ### unexpected string, expecting int, identifier
define FOO `bar`
define FOO `bar ### C expression is not terminated
+foo(x int32[1:2:3, opt]) ### unexpected ':', expecting ']'
+
include <linux/foo.h>
include "linux/foo.h"
incdir </foo/bar>
diff --git a/pkg/ast/walk.go b/pkg/ast/walk.go
index 90a92cf77..af62884fe 100644
--- a/pkg/ast/walk.go
+++ b/pkg/ast/walk.go
@@ -8,85 +8,71 @@ import (
)
// Walk calls callback cb for every node in AST.
-func Walk(desc *Description, cb func(n Node)) {
+func Walk(desc *Description, cb func(n, parent Node)) {
for _, n := range desc.Nodes {
- WalkNode(n, cb)
+ WalkNode(n, nil, cb)
}
}
-func WalkNode(n0 Node, cb func(n Node)) {
+func WalkNode(n0, parent Node, cb func(n, parent Node)) {
+ cb(n0, parent)
switch n := n0.(type) {
case *NewLine:
- cb(n)
case *Comment:
- cb(n)
case *Include:
- cb(n)
- WalkNode(n.File, cb)
+ WalkNode(n.File, n, cb)
case *Incdir:
- cb(n)
- WalkNode(n.Dir, cb)
+ WalkNode(n.Dir, n, cb)
case *Define:
- cb(n)
- WalkNode(n.Name, cb)
- WalkNode(n.Value, cb)
+ WalkNode(n.Name, n, cb)
+ WalkNode(n.Value, n, cb)
case *Resource:
- cb(n)
- WalkNode(n.Name, cb)
- WalkNode(n.Base, cb)
+ WalkNode(n.Name, n, cb)
+ WalkNode(n.Base, n, cb)
for _, v := range n.Values {
- WalkNode(v, cb)
+ WalkNode(v, n, cb)
}
case *Call:
- cb(n)
- WalkNode(n.Name, cb)
+ WalkNode(n.Name, n, cb)
for _, f := range n.Args {
- WalkNode(f, cb)
+ WalkNode(f, n, cb)
}
if n.Ret != nil {
- WalkNode(n.Ret, cb)
+ WalkNode(n.Ret, n, cb)
}
case *Struct:
- cb(n)
- WalkNode(n.Name, cb)
+ WalkNode(n.Name, n, cb)
for _, f := range n.Fields {
- WalkNode(f, cb)
+ WalkNode(f, n, cb)
}
for _, a := range n.Attrs {
- WalkNode(a, cb)
+ WalkNode(a, n, cb)
}
for _, c := range n.Comments {
- WalkNode(c, cb)
+ WalkNode(c, n, cb)
}
case *IntFlags:
- cb(n)
- WalkNode(n.Name, cb)
+ WalkNode(n.Name, n, cb)
for _, v := range n.Values {
- WalkNode(v, cb)
+ WalkNode(v, n, cb)
}
case *StrFlags:
- cb(n)
- WalkNode(n.Name, cb)
+ WalkNode(n.Name, n, cb)
for _, v := range n.Values {
- WalkNode(v, cb)
+ WalkNode(v, n, cb)
}
case *Ident:
- cb(n)
case *String:
- cb(n)
case *Int:
- cb(n)
case *Type:
- cb(n)
for _, t := range n.Args {
- WalkNode(t, cb)
+ WalkNode(t, n, cb)
}
case *Field:
- cb(n)
- WalkNode(n.Name, cb)
- WalkNode(n.Type, cb)
+ WalkNode(n.Name, n, cb)
+ WalkNode(n.Type, n, cb)
for _, c := range n.Comments {
- WalkNode(c, cb)
+ WalkNode(c, n, cb)
}
default:
panic(fmt.Sprintf("unknown AST node: %#v", n))
diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go
index 593d9e82c..e85fdf6e2 100644
--- a/pkg/compiler/compiler.go
+++ b/pkg/compiler/compiler.go
@@ -22,113 +22,169 @@ type Prog struct {
}
// Compile compiles sys description.
-func Compile(desc0 *ast.Description, consts map[string]uint64, eh ast.ErrorHandler) *Prog {
+func Compile(desc *ast.Description, consts map[string]uint64, eh ast.ErrorHandler) *Prog {
if eh == nil {
eh = ast.LoggingHandler
}
+ comp := &compiler{
+ desc: ast.Clone(desc),
+ eh: eh,
+ unsupported: make(map[string]bool),
+ }
- desc := ast.Clone(desc0)
- unsup, ok := patchConsts(desc, consts, eh)
- if !ok {
+ comp.assignSyscallNumbers(consts)
+ comp.patchConsts(consts)
+ if comp.errors != 0 {
return nil
}
return &Prog{
- Desc: desc,
- Unsupported: unsup,
+ Desc: comp.desc,
+ Unsupported: comp.unsupported,
}
}
+type compiler struct {
+ desc *ast.Description
+ eh ast.ErrorHandler
+ errors int
+
+ unsupported map[string]bool
+}
+
+func (comp *compiler) error(pos ast.Pos, msg string, args ...interface{}) {
+ comp.errors++
+ comp.warning(pos, msg, args...)
+}
+
+func (comp *compiler) warning(pos ast.Pos, msg string, args ...interface{}) {
+ comp.eh(pos, fmt.Sprintf(msg, args...))
+}
+
+// assignSyscallNumbers assigns syscall numbers, discards unsupported syscalls
+// and removes no longer irrelevant nodes from the tree (comments, new lines, etc).
+func (comp *compiler) assignSyscallNumbers(consts map[string]uint64) {
+ // Pseudo syscalls starting from syz_ are assigned numbers starting from syzbase.
+ // Note: the numbers must be stable (not depend on file reading order, etc),
+ // so we have to do it in 2 passes.
+ const syzbase = 1000000
+ syzcalls := make(map[string]bool)
+ for _, decl := range comp.desc.Nodes {
+ c, ok := decl.(*ast.Call)
+ if !ok {
+ continue
+ }
+ if strings.HasPrefix(c.CallName, "syz_") {
+ syzcalls[c.CallName] = true
+ }
+ }
+ syznr := make(map[string]uint64)
+ for i, name := range toArray(syzcalls) {
+ syznr[name] = syzbase + uint64(i)
+ }
+
+ var top []ast.Node
+ for _, decl := range comp.desc.Nodes {
+ switch decl.(type) {
+ case *ast.Call:
+ c := decl.(*ast.Call)
+ if strings.HasPrefix(c.CallName, "syz_") {
+ c.NR = syznr[c.CallName]
+ top = append(top, decl)
+ continue
+ }
+ // Lookup in consts.
+ str := "__NR_" + c.CallName
+ nr, ok := consts[str]
+ if ok {
+ c.NR = nr
+ top = append(top, decl)
+ continue
+ }
+ name := "syscall " + c.CallName
+ if !comp.unsupported[name] {
+ comp.unsupported[name] = true
+ comp.warning(c.Pos, "unsupported syscall: %v due to missing const %v",
+ c.CallName, str)
+ }
+ case *ast.IntFlags, *ast.Resource, *ast.Struct, *ast.StrFlags:
+ top = append(top, decl)
+ case *ast.NewLine, *ast.Comment, *ast.Include, *ast.Incdir, *ast.Define:
+ // These are not needed anymore.
+ default:
+ panic(fmt.Sprintf("unknown node type: %#v", decl))
+ }
+ }
+ comp.desc.Nodes = top
+}
+
// patchConsts replaces all symbolic consts with their numeric values taken from consts map.
// Updates desc and returns set of unsupported syscalls and flags.
// After this pass consts are not needed for compilation.
-func patchConsts(desc *ast.Description, consts map[string]uint64, eh ast.ErrorHandler) (map[string]bool, bool) {
- broken := false
- unsup := make(map[string]bool)
+func (comp *compiler) patchConsts(consts map[string]uint64) {
var top []ast.Node
- for _, decl := range desc.Nodes {
+ for _, decl := range comp.desc.Nodes {
switch decl.(type) {
case *ast.IntFlags:
// Unsupported flag values are dropped.
n := decl.(*ast.IntFlags)
var values []*ast.Int
for _, v := range n.Values {
- if patchIntConst(v.Pos, &v.Value, &v.Ident, consts, unsup, nil, eh) {
+ if comp.patchIntConst(v.Pos, &v.Value, &v.Ident, consts, nil) {
values = append(values, v)
}
}
n.Values = values
top = append(top, n)
+ case *ast.StrFlags:
+ top = append(top, decl)
case *ast.Resource, *ast.Struct, *ast.Call:
- if c, ok := decl.(*ast.Call); ok {
- // Extract syscall NR.
- str := "__NR_" + c.CallName
- nr, ok := consts[str]
- if !ok {
- if name := "syscall " + c.CallName; !unsup[name] {
- unsup[name] = true
- eh(c.Pos, fmt.Sprintf("unsupported syscall: %v due to missing const %v",
- c.CallName, str))
- }
- continue
- }
- c.NR = nr
- }
// Walk whole tree and replace consts in Int's and Type's.
missing := ""
- ast.WalkNode(decl, func(n0 ast.Node) {
+ ast.WalkNode(decl, nil, func(n0, _ ast.Node) {
switch n := n0.(type) {
case *ast.Int:
- patchIntConst(n.Pos, &n.Value, &n.Ident,
- consts, unsup, &missing, eh)
+ comp.patchIntConst(n.Pos, &n.Value, &n.Ident, consts, &missing)
case *ast.Type:
if c := typeConstIdentifier(n); c != nil {
- patchIntConst(c.Pos, &c.Value, &c.Ident,
- consts, unsup, &missing, eh)
+ comp.patchIntConst(c.Pos, &c.Value, &c.Ident,
+ consts, &missing)
if c.HasColon {
- patchIntConst(c.Pos2, &c.Value2, &c.Ident2,
- consts, unsup, &missing, eh)
+ comp.patchIntConst(c.Pos2, &c.Value2, &c.Ident2,
+ consts, &missing)
}
}
}
})
if missing == "" {
top = append(top, decl)
- } else {
- // Produce a warning about unsupported syscall/resource/struct.
- // Unsupported syscalls are discarded.
- // Unsupported resource/struct lead to compilation error.
- // Fixing that would require removing all uses of the resource/struct.
- typ, pos, name, fatal := "", ast.Pos{}, "", false
- switch n := decl.(type) {
- case *ast.Call:
- typ, pos, name, fatal = "syscall", n.Pos, n.Name.Name, false
- case *ast.Resource:
- typ, pos, name, fatal = "resource", n.Pos, n.Name.Name, true
- case *ast.Struct:
- typ, pos, name, fatal = "struct", n.Pos, n.Name.Name, true
- default:
- panic(fmt.Sprintf("unknown type: %#v", decl))
- }
- if id := typ + " " + name; !unsup[id] {
- unsup[id] = true
- eh(pos, fmt.Sprintf("unsupported %v: %v due to missing const %v",
- typ, name, missing))
- }
- if fatal {
- broken = true
- }
+ continue
+ }
+ // Produce a warning about unsupported syscall/resource/struct.
+ // Unsupported syscalls are discarded.
+ // Unsupported resource/struct lead to compilation error.
+ // Fixing that would require removing all uses of the resource/struct.
+ pos, typ, name := ast.Pos{}, "", ""
+ fn := comp.error
+ switch n := decl.(type) {
+ case *ast.Call:
+ pos, typ, name = n.Pos, "syscall", n.Name.Name
+ fn = comp.warning
+ case *ast.Resource:
+ pos, typ, name = n.Pos, "resource", n.Name.Name
+ case *ast.Struct:
+ pos, typ, name = n.Pos, "struct", n.Name.Name
+ default:
+ panic(fmt.Sprintf("unknown type: %#v", decl))
+ }
+ if id := typ + " " + name; !comp.unsupported[id] {
+ comp.unsupported[id] = true
+ fn(pos, "unsupported %v: %v due to missing const %v",
+ typ, name, missing)
}
- case *ast.StrFlags:
- top = append(top, decl)
- case *ast.NewLine, *ast.Comment, *ast.Include, *ast.Incdir, *ast.Define:
- // These are not needed anymore.
- default:
- panic(fmt.Sprintf("unknown node type: %#v", decl))
}
}
- desc.Nodes = top
- return unsup, !broken
+ comp.desc.Nodes = top
}
// ExtractConsts returns list of literal constants and other info required const value extraction.
@@ -136,7 +192,7 @@ func ExtractConsts(desc *ast.Description) (consts, includes, incdirs []string, d
constMap := make(map[string]bool)
defines = make(map[string]string)
- ast.Walk(desc, func(n1 ast.Node) {
+ ast.Walk(desc, func(n1, _ ast.Node) {
switch n := n1.(type) {
case *ast.Include:
includes = append(includes, n.File.Value)
@@ -170,17 +226,17 @@ func ExtractConsts(desc *ast.Description) (consts, includes, incdirs []string, d
return
}
-func patchIntConst(pos ast.Pos, val *uint64, id *string,
- consts map[string]uint64, unsup map[string]bool, missing *string, eh ast.ErrorHandler) bool {
+func (comp *compiler) patchIntConst(pos ast.Pos, val *uint64, id *string,
+ consts map[string]uint64, missing *string) bool {
if *id == "" {
return true
}
v, ok := consts[*id]
if !ok {
name := "const " + *id
- if !unsup[name] {
- unsup[name] = true
- eh(pos, fmt.Sprintf("unsupported const: %v", *id))
+ if !comp.unsupported[name] {
+ comp.unsupported[name] = true
+ comp.warning(pos, "unsupported const: %v", *id)
}
if missing != nil && *missing == "" {
*missing = *id
diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go
index e68d97d3e..5415234c6 100644
--- a/pkg/compiler/compiler_test.go
+++ b/pkg/compiler/compiler_test.go
@@ -4,12 +4,33 @@
package compiler
import (
+ "path/filepath"
"reflect"
+ "runtime"
"testing"
"github.com/google/syzkaller/pkg/ast"
)
+func TestCompileAll(t *testing.T) {
+ eh := func(pos ast.Pos, msg string) {
+ t.Logf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg)
+ }
+ desc := ast.ParseGlob(filepath.Join("..", "..", "sys", "*.txt"), eh)
+ if desc == nil {
+ t.Fatalf("parsing failed")
+ }
+ glob := filepath.Join("..", "..", "sys", "*_"+runtime.GOARCH+".const")
+ consts := DeserializeConstsGlob(glob, eh)
+ if consts == nil {
+ t.Fatalf("reading consts failed")
+ }
+ prog := Compile(desc, consts, eh)
+ if prog == nil {
+ t.Fatalf("compilation failed")
+ }
+}
+
func TestExtractConsts(t *testing.T) {
desc := ast.Parse([]byte(extractConstsInput), "test", nil)
if desc == nil {
diff --git a/pkg/compiler/consts.go b/pkg/compiler/consts.go
new file mode 100644
index 000000000..78f8ada00
--- /dev/null
+++ b/pkg/compiler/consts.go
@@ -0,0 +1,124 @@
+// Copyright 2017 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package compiler
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "path/filepath"
+ "sort"
+ "strconv"
+ "strings"
+
+ "github.com/google/syzkaller/pkg/ast"
+)
+
+func SerializeConsts(consts map[string]uint64) []byte {
+ var nv []nameValuePair
+ for k, v := range consts {
+ nv = append(nv, nameValuePair{k, v})
+ }
+ sort.Sort(nameValueArray(nv))
+
+ buf := new(bytes.Buffer)
+ fmt.Fprintf(buf, "# AUTOGENERATED FILE\n")
+ for _, x := range nv {
+ fmt.Fprintf(buf, "%v = %v\n", x.name, x.val)
+ }
+ return buf.Bytes()
+}
+
+func DeserializeConsts(data []byte, file string, eh ast.ErrorHandler) map[string]uint64 {
+ consts := make(map[string]uint64)
+ pos := ast.Pos{
+ File: file,
+ Line: 1,
+ }
+ ok := true
+ s := bufio.NewScanner(bytes.NewReader(data))
+ for ; s.Scan(); pos.Line++ {
+ line := s.Text()
+ if line == "" || line[0] == '#' {
+ continue
+ }
+ eq := strings.IndexByte(line, '=')
+ if eq == -1 {
+ eh(pos, "expect '='")
+ ok = false
+ continue
+ }
+ name := strings.TrimSpace(line[:eq])
+ val, err := strconv.ParseUint(strings.TrimSpace(line[eq+1:]), 0, 64)
+ if err != nil {
+ eh(pos, fmt.Sprintf("failed to parse int: %v", err))
+ ok = false
+ continue
+ }
+ if _, ok := consts[name]; ok {
+ eh(pos, fmt.Sprintf("duplicate const %q", name))
+ ok = false
+ continue
+ }
+ consts[name] = val
+ }
+ if err := s.Err(); err != nil {
+ eh(pos, fmt.Sprintf("failed to parse: %v", err))
+ ok = false
+ }
+ if !ok {
+ return nil
+ }
+ return consts
+}
+
+func DeserializeConstsGlob(glob string, eh ast.ErrorHandler) map[string]uint64 {
+ if eh == nil {
+ eh = ast.LoggingHandler
+ }
+ files, err := filepath.Glob(glob)
+ if err != nil {
+ eh(ast.Pos{}, fmt.Sprintf("failed to find const files: %v", err))
+ return nil
+ }
+ if len(files) == 0 {
+ eh(ast.Pos{}, fmt.Sprintf("no const files matched by glob %q", glob))
+ return nil
+ }
+ consts := make(map[string]uint64)
+ for _, f := range files {
+ data, err := ioutil.ReadFile(f)
+ if err != nil {
+ eh(ast.Pos{}, fmt.Sprintf("failed to read const file: %v", err))
+ return nil
+ }
+ consts1 := DeserializeConsts(data, filepath.Base(f), eh)
+ if consts1 == nil {
+ consts = nil
+ }
+ if consts != nil {
+ for n, v := range consts1 {
+ if old, ok := consts[n]; ok && old != v {
+ eh(ast.Pos{}, fmt.Sprintf(
+ "different values for const %q: %v vs %v", n, v, old))
+ return nil
+ }
+ consts[n] = v
+ }
+ }
+ }
+ return consts
+}
+
+type nameValuePair struct {
+ name string
+ val uint64
+}
+
+type nameValueArray []nameValuePair
+
+func (a nameValueArray) Len() int { return len(a) }
+func (a nameValueArray) Less(i, j int) bool { return a[i].name < a[j].name }
+func (a nameValueArray) Swap(i, j int) { a[i], a[j] = a[j], a[i] }