diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2017-08-27 19:55:14 +0200 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2017-08-27 20:19:41 +0200 |
| commit | 4074aed7c0c28afc7d4a3522045196c3f39b5208 (patch) | |
| tree | 8d2c2ce5f6767f8f4355e37e262f85223ee362e3 /pkg | |
| parent | 58579664687b203ff34fad8aa02bf470ef0bc981 (diff) | |
pkg/compiler: more static error checking
Update #217
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/ast/ast.go | 70 | ||||
| -rw-r--r-- | pkg/ast/format.go | 2 | ||||
| -rw-r--r-- | pkg/ast/parser.go | 2 | ||||
| -rw-r--r-- | pkg/ast/parser_test.go | 83 | ||||
| -rw-r--r-- | pkg/ast/scanner.go | 6 | ||||
| -rw-r--r-- | pkg/ast/test_util.go | 99 | ||||
| -rw-r--r-- | pkg/ast/walk.go | 52 | ||||
| -rw-r--r-- | pkg/compiler/compiler.go | 288 | ||||
| -rw-r--r-- | pkg/compiler/compiler_test.go | 69 | ||||
| -rw-r--r-- | pkg/compiler/consts.go | 77 | ||||
| -rw-r--r-- | pkg/compiler/consts_test.go | 61 | ||||
| -rw-r--r-- | pkg/compiler/testdata/consts.txt | 24 | ||||
| -rw-r--r-- | pkg/compiler/testdata/consts_errors.txt | 10 | ||||
| -rw-r--r-- | pkg/compiler/testdata/errors.txt | 63 |
14 files changed, 696 insertions, 210 deletions
diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go index b283ca5f8..4c9101f79 100644 --- a/pkg/ast/ast.go +++ b/pkg/ast/ast.go @@ -18,7 +18,9 @@ type Description struct { } // Node is AST node interface. -type Node interface{} +type Node interface { + Info() (pos Pos, typ string, name string) +} // Top-level AST nodes: @@ -26,34 +28,58 @@ type NewLine struct { Pos Pos } +func (n *NewLine) Info() (Pos, string, string) { + return n.Pos, tok2str[tokNewLine], "" +} + type Comment struct { Pos Pos Text string } +func (n *Comment) Info() (Pos, string, string) { + return n.Pos, tok2str[tokComment], "" +} + type Include struct { Pos Pos File *String } +func (n *Include) Info() (Pos, string, string) { + return n.Pos, tok2str[tokInclude], "" +} + type Incdir struct { Pos Pos Dir *String } +func (n *Incdir) Info() (Pos, string, string) { + return n.Pos, tok2str[tokInclude], "" +} + type Define struct { Pos Pos Name *Ident Value *Int } +func (n *Define) Info() (Pos, string, string) { + return n.Pos, tok2str[tokDefine], n.Name.Name +} + type Resource struct { Pos Pos Name *Ident - Base *Ident + Base *Type Values []*Int } +func (n *Resource) Info() (Pos, string, string) { + return n.Pos, tok2str[tokResource], n.Name.Name +} + type Call struct { Pos Pos Name *Ident @@ -63,6 +89,10 @@ type Call struct { Ret *Type } +func (n *Call) Info() (Pos, string, string) { + return n.Pos, "syscall", n.Name.Name +} + type Struct struct { Pos Pos Name *Ident @@ -72,18 +102,34 @@ type Struct struct { IsUnion bool } +func (n *Struct) Info() (Pos, string, string) { + typ := "struct" + if n.IsUnion { + typ = "union" + } + return n.Pos, typ, n.Name.Name +} + type IntFlags struct { Pos Pos Name *Ident Values []*Int } +func (n *IntFlags) Info() (Pos, string, string) { + return n.Pos, "flags", n.Name.Name +} + type StrFlags struct { Pos Pos Name *Ident Values []*String } +func (n *StrFlags) Info() (Pos, string, string) { + return n.Pos, "string flags", n.Name.Name +} + // Not top-level AST nodes: type Ident struct { @@ -91,11 +137,19 @@ type Ident struct { Name string } +func (n *Ident) Info() (Pos, string, string) { + return n.Pos, tok2str[tokIdent], n.Name +} + type String struct { Pos Pos Value string } +func (n *String) Info() (Pos, string, string) { + return n.Pos, tok2str[tokString], "" +} + type Int struct { Pos Pos // Only one of Value, Ident, CExpr is filled. @@ -105,6 +159,10 @@ type Int struct { CExpr string } +func (n *Int) Info() (Pos, string, string) { + return n.Pos, tok2str[tokInt], "" +} + type Type struct { Pos Pos // Only one of Value, Ident, String is filled. @@ -121,6 +179,10 @@ type Type struct { Args []*Type } +func (n *Type) Info() (Pos, string, string) { + return n.Pos, "type", n.Ident +} + type Field struct { Pos Pos Name *Ident @@ -128,3 +190,7 @@ type Field struct { NewBlock bool // separated from previous fields by a new line Comments []*Comment } + +func (n *Field) Info() (Pos, string, string) { + return n.Pos, "arg/field", n.Name.Name +} diff --git a/pkg/ast/format.go b/pkg/ast/format.go index e7e21dcdd..0f95d7ebf 100644 --- a/pkg/ast/format.go +++ b/pkg/ast/format.go @@ -50,7 +50,7 @@ func (def *Define) serialize(w io.Writer) { } func (res *Resource) serialize(w io.Writer) { - fmt.Fprintf(w, "resource %v[%v]", res.Name.Name, res.Base.Name) + fmt.Fprintf(w, "resource %v[%v]", res.Name.Name, fmtType(res.Base)) for i, v := range res.Values { if i == 0 { fmt.Fprintf(w, ": ") diff --git a/pkg/ast/parser.go b/pkg/ast/parser.go index ca2505e19..fd7b9ad4f 100644 --- a/pkg/ast/parser.go +++ b/pkg/ast/parser.go @@ -228,7 +228,7 @@ func (p *parser) parseResource() *Resource { p.consume(tokResource) name := p.parseIdent() p.consume(tokLBrack) - base := p.parseIdent() + base := p.parseType() p.consume(tokRBrack) var values []*Int if p.tryConsume(tokColon) { diff --git a/pkg/ast/parser_test.go b/pkg/ast/parser_test.go index 84a5fedf0..46ad1e5d3 100644 --- a/pkg/ast/parser_test.go +++ b/pkg/ast/parser_test.go @@ -4,7 +4,6 @@ package ast import ( - "bufio" "bytes" "io/ioutil" "path/filepath" @@ -29,7 +28,7 @@ func TestParseAll(t *testing.T) { } t.Run(file.Name(), func(t *testing.T) { eh := func(pos Pos, msg string) { - t.Fatalf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg) + t.Fatalf("%v: %v", pos, msg) } desc := Parse(data, file.Name(), eh) if desc == nil { @@ -65,7 +64,7 @@ func TestParse(t *testing.T) { for _, test := range parseTests { t.Run(test.name, func(t *testing.T) { errorHandler := func(pos Pos, msg string) { - t.Logf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg) + t.Logf("%v: %v", pos, msg) } Parse([]byte(test.input), "foo", errorHandler) }) @@ -96,13 +95,6 @@ var parseTests = []struct { }, } -type Error struct { - Line int - Col int - Text string - Matched bool -} - func TestErrors(t *testing.T) { files, err := ioutil.ReadDir("testdata") if err != nil { @@ -115,71 +107,18 @@ func TestErrors(t *testing.T) { if !strings.HasSuffix(f.Name(), ".txt") { continue } - t.Run(f.Name(), func(t *testing.T) { - data, err := ioutil.ReadFile(filepath.Join("testdata", f.Name())) - if err != nil { - t.Fatalf("failed to open input file: %v", err) - } - var stripped []byte - var errors []*Error - s := bufio.NewScanner(bytes.NewReader(data)) - for i := 1; s.Scan(); i++ { - ln := s.Bytes() - for { - pos := bytes.LastIndex(ln, []byte("###")) - if pos == -1 { - break - } - errors = append(errors, &Error{ - Line: i, - Text: strings.TrimSpace(string(ln[pos+3:])), - }) - ln = ln[:pos] - } - stripped = append(stripped, ln...) - stripped = append(stripped, '\n') - } - if err := s.Err(); err != nil { - t.Fatalf("failed to scan input file: %v", err) + name := f.Name() + t.Run(name, func(t *testing.T) { + em := NewErrorMatcher(t, filepath.Join("testdata", name)) + desc := Parse(em.Data, name, em.ErrorHandler) + if desc != nil && em.Count() != 0 { + em.DumpErrors(t) + t.Fatalf("parsing succeed, but got errors") } - var got []*Error - desc := Parse(stripped, "test", func(pos Pos, msg string) { - got = append(got, &Error{ - Line: pos.Line, - Col: pos.Col, - Text: msg, - }) - }) - if desc != nil && len(got) != 0 { - t.Fatalf("parsing succeed, but got errors: %v", got) - } - if desc == nil && len(got) == 0 { + if desc == nil && em.Count() == 0 { t.Fatalf("parsing failed, but got no errors") } - nextErr: - for _, gotErr := range got { - for _, wantErr := range errors { - if wantErr.Matched { - continue - } - if wantErr.Line != gotErr.Line { - continue - } - if wantErr.Text != gotErr.Text { - continue - } - wantErr.Matched = true - continue nextErr - } - t.Errorf("unexpected error: %v:%v: %v", - gotErr.Line, gotErr.Col, gotErr.Text) - } - for _, wantErr := range errors { - if wantErr.Matched { - continue - } - t.Errorf("not matched error: %v: %v", wantErr.Line, wantErr.Text) - } + em.Check(t) }) } } diff --git a/pkg/ast/scanner.go b/pkg/ast/scanner.go index 387a58529..f1573350b 100644 --- a/pkg/ast/scanner.go +++ b/pkg/ast/scanner.go @@ -118,7 +118,11 @@ func newScanner(data []byte, filename string, errorHandler ErrorHandler) *scanne type ErrorHandler func(pos Pos, msg string) func LoggingHandler(pos Pos, msg string) { - fmt.Fprintf(os.Stderr, "%v:%v:%v: %v\n", pos.File, pos.Line, pos.Col, msg) + fmt.Fprintf(os.Stderr, "%v: %v\n", pos, msg) +} + +func (pos Pos) String() string { + return fmt.Sprintf("%v:%v:%v", pos.File, pos.Line, pos.Col) } func (s *scanner) Scan() (tok token, lit string, pos Pos) { diff --git a/pkg/ast/test_util.go b/pkg/ast/test_util.go new file mode 100644 index 000000000..0aed0a2dc --- /dev/null +++ b/pkg/ast/test_util.go @@ -0,0 +1,99 @@ +// Copyright 2017 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package ast + +import ( + "bufio" + "bytes" + "io/ioutil" + "strings" + "testing" +) + +type ErrorMatcher struct { + Data []byte + expect []*errorDesc + got []*errorDesc +} + +type errorDesc struct { + file string + line int + col int + text string + matched bool +} + +func NewErrorMatcher(t *testing.T, file string) *ErrorMatcher { + data, err := ioutil.ReadFile(file) + if err != nil { + t.Fatalf("failed to open input file: %v", err) + } + var stripped []byte + var errors []*errorDesc + s := bufio.NewScanner(bytes.NewReader(data)) + for i := 1; s.Scan(); i++ { + ln := s.Bytes() + for { + pos := bytes.LastIndex(ln, []byte("###")) + if pos == -1 { + break + } + errors = append(errors, &errorDesc{ + file: file, + line: i, + text: strings.TrimSpace(string(ln[pos+3:])), + }) + ln = ln[:pos] + } + stripped = append(stripped, ln...) + stripped = append(stripped, '\n') + } + if err := s.Err(); err != nil { + t.Fatalf("failed to scan input file: %v", err) + } + return &ErrorMatcher{ + Data: stripped, + expect: errors, + } +} + +func (em *ErrorMatcher) ErrorHandler(pos Pos, msg string) { + em.got = append(em.got, &errorDesc{ + file: pos.File, + line: pos.Line, + col: pos.Col, + text: msg, + }) +} + +func (em *ErrorMatcher) Count() int { + return len(em.got) +} + +func (em *ErrorMatcher) Check(t *testing.T) { +nextErr: + for _, e := range em.got { + for _, want := range em.expect { + if want.matched || want.line != e.line || want.text != e.text { + continue + } + want.matched = true + continue nextErr + } + t.Errorf("unexpected error: %v:%v:%v: %v", e.file, e.line, e.col, e.text) + } + for _, want := range em.expect { + if want.matched { + continue + } + t.Errorf("unmatched error: %v:%v: %v", want.file, want.line, want.text) + } +} + +func (em *ErrorMatcher) DumpErrors(t *testing.T) { + for _, e := range em.got { + t.Logf("%v:%v:%v: %v", e.file, e.line, e.col, e.text) + } +} diff --git a/pkg/ast/walk.go b/pkg/ast/walk.go index af62884fe..79e0b4bec 100644 --- a/pkg/ast/walk.go +++ b/pkg/ast/walk.go @@ -8,71 +8,71 @@ import ( ) // Walk calls callback cb for every node in AST. -func Walk(desc *Description, cb func(n, parent Node)) { +func Walk(desc *Description, cb func(n Node)) { for _, n := range desc.Nodes { - WalkNode(n, nil, cb) + WalkNode(n, cb) } } -func WalkNode(n0, parent Node, cb func(n, parent Node)) { - cb(n0, parent) +func WalkNode(n0 Node, cb func(n Node)) { + cb(n0) switch n := n0.(type) { case *NewLine: case *Comment: case *Include: - WalkNode(n.File, n, cb) + WalkNode(n.File, cb) case *Incdir: - WalkNode(n.Dir, n, cb) + WalkNode(n.Dir, cb) case *Define: - WalkNode(n.Name, n, cb) - WalkNode(n.Value, n, cb) + WalkNode(n.Name, cb) + WalkNode(n.Value, cb) case *Resource: - WalkNode(n.Name, n, cb) - WalkNode(n.Base, n, cb) + WalkNode(n.Name, cb) + WalkNode(n.Base, cb) for _, v := range n.Values { - WalkNode(v, n, cb) + WalkNode(v, cb) } case *Call: - WalkNode(n.Name, n, cb) + WalkNode(n.Name, cb) for _, f := range n.Args { - WalkNode(f, n, cb) + WalkNode(f, cb) } if n.Ret != nil { - WalkNode(n.Ret, n, cb) + WalkNode(n.Ret, cb) } case *Struct: - WalkNode(n.Name, n, cb) + WalkNode(n.Name, cb) for _, f := range n.Fields { - WalkNode(f, n, cb) + WalkNode(f, cb) } for _, a := range n.Attrs { - WalkNode(a, n, cb) + WalkNode(a, cb) } for _, c := range n.Comments { - WalkNode(c, n, cb) + WalkNode(c, cb) } case *IntFlags: - WalkNode(n.Name, n, cb) + WalkNode(n.Name, cb) for _, v := range n.Values { - WalkNode(v, n, cb) + WalkNode(v, cb) } case *StrFlags: - WalkNode(n.Name, n, cb) + WalkNode(n.Name, cb) for _, v := range n.Values { - WalkNode(v, n, cb) + WalkNode(v, cb) } case *Ident: case *String: case *Int: case *Type: for _, t := range n.Args { - WalkNode(t, n, cb) + WalkNode(t, cb) } case *Field: - WalkNode(n.Name, n, cb) - WalkNode(n.Type, n, cb) + WalkNode(n.Name, cb) + WalkNode(n.Type, cb) for _, c := range n.Comments { - WalkNode(c, n, cb) + WalkNode(c, cb) } default: panic(fmt.Sprintf("unknown AST node: %#v", n)) diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index e85fdf6e2..3afab96d9 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -38,6 +38,14 @@ func Compile(desc *ast.Description, consts map[string]uint64, eh ast.ErrorHandle return nil } + comp.check() + + if comp.errors != 0 { + return nil + } + for _, w := range comp.warnings { + eh(w.pos, w.msg) + } return &Prog{ Desc: comp.desc, Unsupported: comp.unsupported, @@ -45,20 +53,237 @@ func Compile(desc *ast.Description, consts map[string]uint64, eh ast.ErrorHandle } type compiler struct { - desc *ast.Description - eh ast.ErrorHandler - errors int + desc *ast.Description + eh ast.ErrorHandler + errors int + warnings []warn unsupported map[string]bool + + udt map[string]ast.Node // structs, unions and resources + flags map[string]ast.Node // int and string flags +} + +type warn struct { + pos ast.Pos + msg string } func (comp *compiler) error(pos ast.Pos, msg string, args ...interface{}) { comp.errors++ - comp.warning(pos, msg, args...) + comp.eh(pos, fmt.Sprintf(msg, args...)) } func (comp *compiler) warning(pos ast.Pos, msg string, args ...interface{}) { - comp.eh(pos, fmt.Sprintf(msg, args...)) + comp.warnings = append(comp.warnings, warn{pos, fmt.Sprintf(msg, args...)}) +} + +type typeDesc struct { + Names []string + CanBeArg bool + NeedBase bool + AllowColon bool + ResourceBase bool + Args []*typeDesc +} + +var ( + typeDir = &typeDesc{ + Names: []string{"in", "out", "inout"}, + } + + topTypes = []*typeDesc{ + &typeDesc{ + Names: []string{"int8", "int16", "int32", "int64", + "int16be", "int32be", "int64be", "intptr"}, + CanBeArg: true, + AllowColon: true, + ResourceBase: true, + }, + &typeDesc{ + Names: []string{"fileoff"}, + CanBeArg: true, + NeedBase: true, + }, + &typeDesc{ + Names: []string{"buffer"}, + CanBeArg: true, + Args: []*typeDesc{typeDir}, + }, + &typeDesc{ + Names: []string{"string"}, + //Args: []*typeDesc{typeDir}, + }, + } + + builtinTypes = make(map[string]bool) +) + +func init() { + for _, desc := range topTypes { + for _, name := range desc.Names { + if builtinTypes[name] { + panic(fmt.Sprintf("duplicate builtin type %q", name)) + } + builtinTypes[name] = true + } + } +} + +var typeCheck bool + +func (comp *compiler) check() { + // TODO: check len in syscall arguments referring to parent. + // TODO: incorrect name is referenced in len type + // TODO: infinite recursion via struct pointers (e.g. a linked list) + // TODO: no constructor for a resource + // TODO: typo of intour instead of inout + + comp.checkNames() + comp.checkFields() + + if typeCheck { + for _, decl := range comp.desc.Nodes { + switch n := decl.(type) { + case *ast.Resource: + comp.checkType(n.Base, false, true) + case *ast.Struct: + for _, f := range n.Fields { + comp.checkType(f.Type, false, false) + } + case *ast.Call: + for _, a := range n.Args { + comp.checkType(a.Type, true, false) + } + if n.Ret != nil { + comp.checkType(n.Ret, true, false) + } + } + } + } +} + +func (comp *compiler) checkNames() { + comp.udt = make(map[string]ast.Node) + comp.flags = make(map[string]ast.Node) + calls := make(map[string]*ast.Call) + for _, decl := range comp.desc.Nodes { + switch decl.(type) { + case *ast.Resource, *ast.Struct: + pos, typ, name := decl.Info() + if builtinTypes[name] { + comp.error(pos, "%v name %v conflicts with builtin type", typ, name) + continue + } + if prev := comp.udt[name]; prev != nil { + pos1, typ1, _ := prev.Info() + comp.error(pos, "type %v redeclared, previously declared as %v at %v", + name, typ1, pos1) + continue + } + comp.udt[name] = decl + case *ast.IntFlags, *ast.StrFlags: + pos, typ, name := decl.Info() + if prev := comp.flags[name]; prev != nil { + pos1, typ1, _ := prev.Info() + comp.error(pos, "%v %v redeclared, previously declared as %v at %v", + typ, name, typ1, pos1) + continue + } + comp.flags[name] = decl + case *ast.Call: + c := decl.(*ast.Call) + name := c.Name.Name + if prev := calls[name]; prev != nil { + comp.error(c.Pos, "syscall %v redeclared, previously declared at %v", + name, prev.Pos) + } + calls[name] = c + } + } +} + +func (comp *compiler) checkFields() { + const maxArgs = 9 // executor does not support more + for _, decl := range comp.desc.Nodes { + switch n := decl.(type) { + case *ast.Struct: + _, typ, name := n.Info() + fields := make(map[string]bool) + for _, f := range n.Fields { + fn := f.Name.Name + if fn == "parent" { + comp.error(f.Pos, "reserved field name %v in %v %v", fn, typ, name) + } + if fields[fn] { + comp.error(f.Pos, "duplicate field %v in %v %v", fn, typ, name) + } + fields[fn] = true + } + if !n.IsUnion && len(n.Fields) < 1 { + comp.error(n.Pos, "struct %v has no fields, need at least 1 field", name) + } + if n.IsUnion && len(n.Fields) < 2 { + comp.error(n.Pos, "union %v has only %v field, need at least 2 fields", + name, len(n.Fields)) + } + case *ast.Call: + name := n.Name.Name + args := make(map[string]bool) + for _, a := range n.Args { + an := a.Name.Name + if an == "parent" { + comp.error(a.Pos, "reserved argument name %v in syscall %v", + an, name) + } + if args[an] { + comp.error(a.Pos, "duplicate argument %v in syscall %v", + an, name) + } + args[an] = true + } + if len(n.Args) > maxArgs { + comp.error(n.Pos, "syscall %v has %v arguments, allowed maximum is %v", + name, len(n.Args), maxArgs) + } + } + } +} + +func (comp *compiler) checkType(t *ast.Type, isArg, isResourceBase bool) { + if t.String != "" { + comp.error(t.Pos, "unexpected string %q, expecting type", t.String) + return + } + if t.Ident == "" { + comp.error(t.Pos, "unexpected integer %v, expecting type", t.Value) + return + } + var desc *typeDesc + for _, desc1 := range topTypes { + for _, name := range desc1.Names { + if name == t.Ident { + desc = desc1 + break + } + } + } + if desc == nil { + comp.error(t.Pos, "unknown type %q", t.Ident) + return + } + if !desc.AllowColon && t.HasColon { + comp.error(t.Pos2, "unexpected ':'") + return + } + if isArg && !desc.CanBeArg { + comp.error(t.Pos, "%v can't be syscall argument/return", t.Ident) + return + } + if isResourceBase && !desc.ResourceBase { + comp.error(t.Pos, "%v can't be resource base (int types can)", t.Ident) + return + } } // assignSyscallNumbers assigns syscall numbers, discards unsupported syscalls @@ -141,7 +366,7 @@ func (comp *compiler) patchConsts(consts map[string]uint64) { case *ast.Resource, *ast.Struct, *ast.Call: // Walk whole tree and replace consts in Int's and Type's. missing := "" - ast.WalkNode(decl, nil, func(n0, _ ast.Node) { + ast.WalkNode(decl, func(n0 ast.Node) { switch n := n0.(type) { case *ast.Int: comp.patchIntConst(n.Pos, &n.Value, &n.Ident, consts, &missing) @@ -164,18 +389,10 @@ func (comp *compiler) patchConsts(consts map[string]uint64) { // Unsupported syscalls are discarded. // Unsupported resource/struct lead to compilation error. // Fixing that would require removing all uses of the resource/struct. - pos, typ, name := ast.Pos{}, "", "" + pos, typ, name := decl.Info() fn := comp.error - switch n := decl.(type) { - case *ast.Call: - pos, typ, name = n.Pos, "syscall", n.Name.Name + if _, ok := decl.(*ast.Call); ok { fn = comp.warning - case *ast.Resource: - pos, typ, name = n.Pos, "resource", n.Name.Name - case *ast.Struct: - pos, typ, name = n.Pos, "struct", n.Name.Name - default: - panic(fmt.Sprintf("unknown type: %#v", decl)) } if id := typ + " " + name; !comp.unsupported[id] { comp.unsupported[id] = true @@ -187,45 +404,6 @@ func (comp *compiler) patchConsts(consts map[string]uint64) { comp.desc.Nodes = top } -// ExtractConsts returns list of literal constants and other info required const value extraction. -func ExtractConsts(desc *ast.Description) (consts, includes, incdirs []string, defines map[string]string) { - constMap := make(map[string]bool) - defines = make(map[string]string) - - ast.Walk(desc, func(n1, _ ast.Node) { - switch n := n1.(type) { - case *ast.Include: - includes = append(includes, n.File.Value) - case *ast.Incdir: - incdirs = append(incdirs, n.Dir.Value) - case *ast.Define: - v := fmt.Sprint(n.Value.Value) - switch { - case n.Value.CExpr != "": - v = n.Value.CExpr - case n.Value.Ident != "": - v = n.Value.Ident - } - defines[n.Name.Name] = v - constMap[n.Name.Name] = true - case *ast.Call: - if !strings.HasPrefix(n.CallName, "syz_") { - constMap["__NR_"+n.CallName] = true - } - case *ast.Type: - if c := typeConstIdentifier(n); c != nil { - constMap[c.Ident] = true - constMap[c.Ident2] = true - } - case *ast.Int: - constMap[n.Ident] = true - } - }) - - consts = toArray(constMap) - return -} - func (comp *compiler) patchIntConst(pos ast.Pos, val *uint64, id *string, consts map[string]uint64, missing *string) bool { if *id == "" { diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go index 5415234c6..1ef77a19c 100644 --- a/pkg/compiler/compiler_test.go +++ b/pkg/compiler/compiler_test.go @@ -5,7 +5,6 @@ package compiler import ( "path/filepath" - "reflect" "runtime" "testing" @@ -13,8 +12,9 @@ import ( ) func TestCompileAll(t *testing.T) { + t.Skip() eh := func(pos ast.Pos, msg string) { - t.Logf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg) + t.Logf("%v: %v", pos, msg) } desc := ast.ParseGlob(filepath.Join("..", "..", "sys", "*.txt"), eh) if desc == nil { @@ -31,56 +31,21 @@ func TestCompileAll(t *testing.T) { } } -func TestExtractConsts(t *testing.T) { - desc := ast.Parse([]byte(extractConstsInput), "test", nil) - if desc == nil { - t.Fatalf("failed to parse input") - } - consts, includes, incdirs, defines := ExtractConsts(desc) - wantConsts := []string{"CONST1", "CONST10", "CONST11", "CONST12", "CONST13", - "CONST14", "CONST15", "CONST16", - "CONST2", "CONST3", "CONST4", "CONST5", - "CONST6", "CONST7", "CONST8", "CONST9", "__NR_bar", "__NR_foo"} - if !reflect.DeepEqual(consts, wantConsts) { - t.Fatalf("got consts:\n%q\nwant:\n%q", consts, wantConsts) - } - wantIncludes := []string{"foo/bar.h", "bar/foo.h"} - if !reflect.DeepEqual(includes, wantIncludes) { - t.Fatalf("got includes:\n%q\nwant:\n%q", includes, wantIncludes) - } - wantIncdirs := []string{"/foo", "/bar"} - if !reflect.DeepEqual(incdirs, wantIncdirs) { - t.Fatalf("got incdirs:\n%q\nwant:\n%q", incdirs, wantIncdirs) - } - wantDefines := map[string]string{ - "CONST1": "1", - "CONST2": "FOOBAR + 1", - } - if !reflect.DeepEqual(defines, wantDefines) { - t.Fatalf("got defines:\n%q\nwant:\n%q", defines, wantDefines) - } +func init() { + typeCheck = true } -const extractConstsInput = ` -include <foo/bar.h> -incdir </foo> -include <bar/foo.h> -incdir </bar> - -flags = CONST3, CONST2, CONST1 - -define CONST1 1 -define CONST2 FOOBAR + 1 - -foo(x const[CONST4]) ptr[out, array[int32, CONST5]] -bar$BAR() - -str { - f1 const[CONST6, int32] - f2 array[array[int8, CONST7]] +func TestErrors(t *testing.T) { + consts := map[string]uint64{ + "__NR_foo": 1, + } + name := "errors.txt" + em := ast.NewErrorMatcher(t, filepath.Join("testdata", name)) + desc := ast.Parse(em.Data, name, em.ErrorHandler) + if desc == nil { + em.DumpErrors(t) + t.Fatalf("parsing failed") + } + Compile(desc, consts, em.ErrorHandler) + em.Check(t) } - -bar$BAZ(x vma[opt], y vma[CONST8], z vma[CONST9:CONST10]) -bar$QUX(s ptr[in, string["foo", CONST11]], x csum[s, pseudo, CONST12]) -bar$FOO(x int8[8:CONST13], y int16be[CONST14:10], z intptr[CONST15:CONST16]) -` diff --git a/pkg/compiler/consts.go b/pkg/compiler/consts.go index 78f8ada00..dd56d99c7 100644 --- a/pkg/compiler/consts.go +++ b/pkg/compiler/consts.go @@ -16,6 +16,83 @@ import ( "github.com/google/syzkaller/pkg/ast" ) +type ConstInfo struct { + Consts []string + Includes []string + Incdirs []string + Defines map[string]string +} + +// ExtractConsts returns list of literal constants and other info required const value extraction. +func ExtractConsts(desc *ast.Description, eh0 ast.ErrorHandler) *ConstInfo { + errors := 0 + eh := func(pos ast.Pos, msg string, args ...interface{}) { + errors++ + msg = fmt.Sprintf(msg, args...) + if eh0 != nil { + eh0(pos, msg) + } else { + ast.LoggingHandler(pos, msg) + } + } + info := &ConstInfo{ + Defines: make(map[string]string), + } + includeMap := make(map[string]bool) + incdirMap := make(map[string]bool) + constMap := make(map[string]bool) + + ast.Walk(desc, func(n1 ast.Node) { + switch n := n1.(type) { + case *ast.Include: + file := n.File.Value + if includeMap[file] { + eh(n.Pos, "duplicate include %q", file) + } + includeMap[file] = true + info.Includes = append(info.Includes, file) + case *ast.Incdir: + dir := n.Dir.Value + if incdirMap[dir] { + eh(n.Pos, "duplicate incdir %q", dir) + } + incdirMap[dir] = true + info.Incdirs = append(info.Incdirs, dir) + case *ast.Define: + v := fmt.Sprint(n.Value.Value) + switch { + case n.Value.CExpr != "": + v = n.Value.CExpr + case n.Value.Ident != "": + v = n.Value.Ident + } + name := n.Name.Name + if info.Defines[name] != "" { + eh(n.Pos, "duplicate define %v", name) + } + info.Defines[name] = v + constMap[name] = true + case *ast.Call: + if !strings.HasPrefix(n.CallName, "syz_") { + constMap["__NR_"+n.CallName] = true + } + case *ast.Type: + if c := typeConstIdentifier(n); c != nil { + constMap[c.Ident] = true + constMap[c.Ident2] = true + } + case *ast.Int: + constMap[n.Ident] = true + } + }) + + if errors != 0 { + return nil + } + info.Consts = toArray(constMap) + return info +} + func SerializeConsts(consts map[string]uint64) []byte { var nv []nameValuePair for k, v := range consts { diff --git a/pkg/compiler/consts_test.go b/pkg/compiler/consts_test.go new file mode 100644 index 000000000..647b2f1f4 --- /dev/null +++ b/pkg/compiler/consts_test.go @@ -0,0 +1,61 @@ +// Copyright 2017 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package compiler + +import ( + "io/ioutil" + "path/filepath" + "reflect" + "testing" + + "github.com/google/syzkaller/pkg/ast" +) + +func TestExtractConsts(t *testing.T) { + data, err := ioutil.ReadFile(filepath.Join("testdata", "consts.txt")) + if err != nil { + t.Fatalf("failed to read input file: %v", err) + } + desc := ast.Parse(data, "test", nil) + if desc == nil { + t.Fatalf("failed to parse input") + } + info := ExtractConsts(desc, func(pos ast.Pos, msg string) { + t.Fatalf("%v: %v", pos, msg) + }) + wantConsts := []string{"CONST1", "CONST10", "CONST11", "CONST12", "CONST13", + "CONST14", "CONST15", "CONST16", + "CONST2", "CONST3", "CONST4", "CONST5", + "CONST6", "CONST7", "CONST8", "CONST9", "__NR_bar", "__NR_foo"} + if !reflect.DeepEqual(info.Consts, wantConsts) { + t.Fatalf("got consts:\n%q\nwant:\n%q", info.Consts, wantConsts) + } + wantIncludes := []string{"foo/bar.h", "bar/foo.h"} + if !reflect.DeepEqual(info.Includes, wantIncludes) { + t.Fatalf("got includes:\n%q\nwant:\n%q", info.Includes, wantIncludes) + } + wantIncdirs := []string{"/foo", "/bar"} + if !reflect.DeepEqual(info.Incdirs, wantIncdirs) { + t.Fatalf("got incdirs:\n%q\nwant:\n%q", info.Incdirs, wantIncdirs) + } + wantDefines := map[string]string{ + "CONST1": "1", + "CONST2": "FOOBAR + 1", + } + if !reflect.DeepEqual(info.Defines, wantDefines) { + t.Fatalf("got defines:\n%q\nwant:\n%q", info.Defines, wantDefines) + } +} + +func TestConstErrors(t *testing.T) { + name := "consts_errors.txt" + em := ast.NewErrorMatcher(t, filepath.Join("testdata", name)) + desc := ast.Parse(em.Data, name, em.ErrorHandler) + if desc == nil { + em.DumpErrors(t) + t.Fatalf("parsing failed") + } + ExtractConsts(desc, em.ErrorHandler) + em.Check(t) +} diff --git a/pkg/compiler/testdata/consts.txt b/pkg/compiler/testdata/consts.txt new file mode 100644 index 000000000..179efe081 --- /dev/null +++ b/pkg/compiler/testdata/consts.txt @@ -0,0 +1,24 @@ +# Copyright 2017 syzkaller project authors. All rights reserved. +# Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +include <foo/bar.h> +incdir </foo> +include <bar/foo.h> +incdir </bar> + +flags = CONST3, CONST2, CONST1 + +define CONST1 1 +define CONST2 FOOBAR + 1 + +foo(x const[CONST4]) ptr[out, array[int32, CONST5]] +bar$BAR() + +str { + f1 const[CONST6, int32] + f2 array[array[int8, CONST7]] +} + +bar$BAZ(x vma[opt], y vma[CONST8], z vma[CONST9:CONST10]) +bar$QUX(s ptr[in, string["foo", CONST11]], x csum[s, pseudo, CONST12]) +bar$FOO(x int8[8:CONST13], y int16be[CONST14:10], z intptr[CONST15:CONST16]) diff --git a/pkg/compiler/testdata/consts_errors.txt b/pkg/compiler/testdata/consts_errors.txt new file mode 100644 index 000000000..4771777ec --- /dev/null +++ b/pkg/compiler/testdata/consts_errors.txt @@ -0,0 +1,10 @@ +# Copyright 2017 syzkaller project authors. All rights reserved. +# Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +include <foo/bar.h> +incdir </foo> +include <foo/bar.h> ### duplicate include "foo/bar.h" +incdir </foo> ### duplicate incdir "/foo" + +define D0 0 +define D0 1 ### duplicate define D0 diff --git a/pkg/compiler/testdata/errors.txt b/pkg/compiler/testdata/errors.txt new file mode 100644 index 000000000..7af998316 --- /dev/null +++ b/pkg/compiler/testdata/errors.txt @@ -0,0 +1,63 @@ +# Copyright 2017 syzkaller project authors. All rights reserved. +# Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +foo$0(x fileoff, y int8, z buffer[in]) +foo$1(x "bar") ### unexpected string "bar", expecting type +foo$2(x 123, y "bar") ### unexpected integer 123, expecting type ### unexpected string "bar", expecting type +foo$3(x string) ### string can't be syscall argument/return + +resource r0[int32]: 0, 0x1 +resource r1[string["foo"]] ### string can't be resource base (int types can) +resource r1[int32] ### type r1 redeclared, previously declared as resource at errors.txt:10:1 +resource int32[int32] ### resource name int32 conflicts with builtin type +resource fileoff[intptr] ### resource name fileoff conflicts with builtin type + +s1 { + f1 int32 +} + +s1 { ### type s1 redeclared, previously declared as struct at errors.txt:15:1 + f1 int32 + f1 intptr ### duplicate field f1 in struct s1 + parent int8 ### reserved field name parent in struct s1 +} + +s2 { ### struct s2 has no fields, need at least 1 field +} + +int32 { ### struct name int32 conflicts with builtin type + f1 int32 +} + +r0 { ### type r0 redeclared, previously declared as resource at errors.txt:9:1 + f1 int32 +} + +u0 [ + f1 int32 + f2 fileoff +] + +u1 [ ### union u1 has only 1 field, need at least 2 fields + f1 int32 +] + +u2 [ + f1 int8 + f1 int16 ### duplicate field f1 in union u2 + parent int32 ### reserved field name parent in union u2 +] + +foo$4(a int8, a int16) ### duplicate argument a in syscall foo$4 +foo$4() ### syscall foo$4 redeclared, previously declared at errors.txt:51:1 +foo() +foo() ### syscall foo redeclared, previously declared at errors.txt:53:1 +foo$5(a0 int8, a1 int8, a2 int8, a3 int8, a4 int8, a5 int8, a6 int8, a7 int8, a8 int8, a9 int8) ### syscall foo$5 has 10 arguments, allowed maximum is 9 +foo$6(parent int8) ### reserved argument name parent in syscall foo$6 + +#s1 { +# f1 int32:8 +# f2 int32:12 +#} + + |
