aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/ast
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2017-08-27 19:55:14 +0200
committerDmitry Vyukov <dvyukov@google.com>2017-08-27 20:19:41 +0200
commit4074aed7c0c28afc7d4a3522045196c3f39b5208 (patch)
tree8d2c2ce5f6767f8f4355e37e262f85223ee362e3 /pkg/ast
parent58579664687b203ff34fad8aa02bf470ef0bc981 (diff)
pkg/compiler: more static error checking
Update #217
Diffstat (limited to 'pkg/ast')
-rw-r--r--pkg/ast/ast.go70
-rw-r--r--pkg/ast/format.go2
-rw-r--r--pkg/ast/parser.go2
-rw-r--r--pkg/ast/parser_test.go83
-rw-r--r--pkg/ast/scanner.go6
-rw-r--r--pkg/ast/test_util.go99
-rw-r--r--pkg/ast/walk.go52
7 files changed, 211 insertions, 103 deletions
diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go
index b283ca5f8..4c9101f79 100644
--- a/pkg/ast/ast.go
+++ b/pkg/ast/ast.go
@@ -18,7 +18,9 @@ type Description struct {
}
// Node is AST node interface.
-type Node interface{}
+type Node interface {
+ Info() (pos Pos, typ string, name string)
+}
// Top-level AST nodes:
@@ -26,34 +28,58 @@ type NewLine struct {
Pos Pos
}
+func (n *NewLine) Info() (Pos, string, string) {
+ return n.Pos, tok2str[tokNewLine], ""
+}
+
type Comment struct {
Pos Pos
Text string
}
+func (n *Comment) Info() (Pos, string, string) {
+ return n.Pos, tok2str[tokComment], ""
+}
+
type Include struct {
Pos Pos
File *String
}
+func (n *Include) Info() (Pos, string, string) {
+ return n.Pos, tok2str[tokInclude], ""
+}
+
type Incdir struct {
Pos Pos
Dir *String
}
+func (n *Incdir) Info() (Pos, string, string) {
+ return n.Pos, tok2str[tokInclude], ""
+}
+
type Define struct {
Pos Pos
Name *Ident
Value *Int
}
+func (n *Define) Info() (Pos, string, string) {
+ return n.Pos, tok2str[tokDefine], n.Name.Name
+}
+
type Resource struct {
Pos Pos
Name *Ident
- Base *Ident
+ Base *Type
Values []*Int
}
+func (n *Resource) Info() (Pos, string, string) {
+ return n.Pos, tok2str[tokResource], n.Name.Name
+}
+
type Call struct {
Pos Pos
Name *Ident
@@ -63,6 +89,10 @@ type Call struct {
Ret *Type
}
+func (n *Call) Info() (Pos, string, string) {
+ return n.Pos, "syscall", n.Name.Name
+}
+
type Struct struct {
Pos Pos
Name *Ident
@@ -72,18 +102,34 @@ type Struct struct {
IsUnion bool
}
+func (n *Struct) Info() (Pos, string, string) {
+ typ := "struct"
+ if n.IsUnion {
+ typ = "union"
+ }
+ return n.Pos, typ, n.Name.Name
+}
+
type IntFlags struct {
Pos Pos
Name *Ident
Values []*Int
}
+func (n *IntFlags) Info() (Pos, string, string) {
+ return n.Pos, "flags", n.Name.Name
+}
+
type StrFlags struct {
Pos Pos
Name *Ident
Values []*String
}
+func (n *StrFlags) Info() (Pos, string, string) {
+ return n.Pos, "string flags", n.Name.Name
+}
+
// Not top-level AST nodes:
type Ident struct {
@@ -91,11 +137,19 @@ type Ident struct {
Name string
}
+func (n *Ident) Info() (Pos, string, string) {
+ return n.Pos, tok2str[tokIdent], n.Name
+}
+
type String struct {
Pos Pos
Value string
}
+func (n *String) Info() (Pos, string, string) {
+ return n.Pos, tok2str[tokString], ""
+}
+
type Int struct {
Pos Pos
// Only one of Value, Ident, CExpr is filled.
@@ -105,6 +159,10 @@ type Int struct {
CExpr string
}
+func (n *Int) Info() (Pos, string, string) {
+ return n.Pos, tok2str[tokInt], ""
+}
+
type Type struct {
Pos Pos
// Only one of Value, Ident, String is filled.
@@ -121,6 +179,10 @@ type Type struct {
Args []*Type
}
+func (n *Type) Info() (Pos, string, string) {
+ return n.Pos, "type", n.Ident
+}
+
type Field struct {
Pos Pos
Name *Ident
@@ -128,3 +190,7 @@ type Field struct {
NewBlock bool // separated from previous fields by a new line
Comments []*Comment
}
+
+func (n *Field) Info() (Pos, string, string) {
+ return n.Pos, "arg/field", n.Name.Name
+}
diff --git a/pkg/ast/format.go b/pkg/ast/format.go
index e7e21dcdd..0f95d7ebf 100644
--- a/pkg/ast/format.go
+++ b/pkg/ast/format.go
@@ -50,7 +50,7 @@ func (def *Define) serialize(w io.Writer) {
}
func (res *Resource) serialize(w io.Writer) {
- fmt.Fprintf(w, "resource %v[%v]", res.Name.Name, res.Base.Name)
+ fmt.Fprintf(w, "resource %v[%v]", res.Name.Name, fmtType(res.Base))
for i, v := range res.Values {
if i == 0 {
fmt.Fprintf(w, ": ")
diff --git a/pkg/ast/parser.go b/pkg/ast/parser.go
index ca2505e19..fd7b9ad4f 100644
--- a/pkg/ast/parser.go
+++ b/pkg/ast/parser.go
@@ -228,7 +228,7 @@ func (p *parser) parseResource() *Resource {
p.consume(tokResource)
name := p.parseIdent()
p.consume(tokLBrack)
- base := p.parseIdent()
+ base := p.parseType()
p.consume(tokRBrack)
var values []*Int
if p.tryConsume(tokColon) {
diff --git a/pkg/ast/parser_test.go b/pkg/ast/parser_test.go
index 84a5fedf0..46ad1e5d3 100644
--- a/pkg/ast/parser_test.go
+++ b/pkg/ast/parser_test.go
@@ -4,7 +4,6 @@
package ast
import (
- "bufio"
"bytes"
"io/ioutil"
"path/filepath"
@@ -29,7 +28,7 @@ func TestParseAll(t *testing.T) {
}
t.Run(file.Name(), func(t *testing.T) {
eh := func(pos Pos, msg string) {
- t.Fatalf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg)
+ t.Fatalf("%v: %v", pos, msg)
}
desc := Parse(data, file.Name(), eh)
if desc == nil {
@@ -65,7 +64,7 @@ func TestParse(t *testing.T) {
for _, test := range parseTests {
t.Run(test.name, func(t *testing.T) {
errorHandler := func(pos Pos, msg string) {
- t.Logf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg)
+ t.Logf("%v: %v", pos, msg)
}
Parse([]byte(test.input), "foo", errorHandler)
})
@@ -96,13 +95,6 @@ var parseTests = []struct {
},
}
-type Error struct {
- Line int
- Col int
- Text string
- Matched bool
-}
-
func TestErrors(t *testing.T) {
files, err := ioutil.ReadDir("testdata")
if err != nil {
@@ -115,71 +107,18 @@ func TestErrors(t *testing.T) {
if !strings.HasSuffix(f.Name(), ".txt") {
continue
}
- t.Run(f.Name(), func(t *testing.T) {
- data, err := ioutil.ReadFile(filepath.Join("testdata", f.Name()))
- if err != nil {
- t.Fatalf("failed to open input file: %v", err)
- }
- var stripped []byte
- var errors []*Error
- s := bufio.NewScanner(bytes.NewReader(data))
- for i := 1; s.Scan(); i++ {
- ln := s.Bytes()
- for {
- pos := bytes.LastIndex(ln, []byte("###"))
- if pos == -1 {
- break
- }
- errors = append(errors, &Error{
- Line: i,
- Text: strings.TrimSpace(string(ln[pos+3:])),
- })
- ln = ln[:pos]
- }
- stripped = append(stripped, ln...)
- stripped = append(stripped, '\n')
- }
- if err := s.Err(); err != nil {
- t.Fatalf("failed to scan input file: %v", err)
+ name := f.Name()
+ t.Run(name, func(t *testing.T) {
+ em := NewErrorMatcher(t, filepath.Join("testdata", name))
+ desc := Parse(em.Data, name, em.ErrorHandler)
+ if desc != nil && em.Count() != 0 {
+ em.DumpErrors(t)
+ t.Fatalf("parsing succeed, but got errors")
}
- var got []*Error
- desc := Parse(stripped, "test", func(pos Pos, msg string) {
- got = append(got, &Error{
- Line: pos.Line,
- Col: pos.Col,
- Text: msg,
- })
- })
- if desc != nil && len(got) != 0 {
- t.Fatalf("parsing succeed, but got errors: %v", got)
- }
- if desc == nil && len(got) == 0 {
+ if desc == nil && em.Count() == 0 {
t.Fatalf("parsing failed, but got no errors")
}
- nextErr:
- for _, gotErr := range got {
- for _, wantErr := range errors {
- if wantErr.Matched {
- continue
- }
- if wantErr.Line != gotErr.Line {
- continue
- }
- if wantErr.Text != gotErr.Text {
- continue
- }
- wantErr.Matched = true
- continue nextErr
- }
- t.Errorf("unexpected error: %v:%v: %v",
- gotErr.Line, gotErr.Col, gotErr.Text)
- }
- for _, wantErr := range errors {
- if wantErr.Matched {
- continue
- }
- t.Errorf("not matched error: %v: %v", wantErr.Line, wantErr.Text)
- }
+ em.Check(t)
})
}
}
diff --git a/pkg/ast/scanner.go b/pkg/ast/scanner.go
index 387a58529..f1573350b 100644
--- a/pkg/ast/scanner.go
+++ b/pkg/ast/scanner.go
@@ -118,7 +118,11 @@ func newScanner(data []byte, filename string, errorHandler ErrorHandler) *scanne
type ErrorHandler func(pos Pos, msg string)
func LoggingHandler(pos Pos, msg string) {
- fmt.Fprintf(os.Stderr, "%v:%v:%v: %v\n", pos.File, pos.Line, pos.Col, msg)
+ fmt.Fprintf(os.Stderr, "%v: %v\n", pos, msg)
+}
+
+func (pos Pos) String() string {
+ return fmt.Sprintf("%v:%v:%v", pos.File, pos.Line, pos.Col)
}
func (s *scanner) Scan() (tok token, lit string, pos Pos) {
diff --git a/pkg/ast/test_util.go b/pkg/ast/test_util.go
new file mode 100644
index 000000000..0aed0a2dc
--- /dev/null
+++ b/pkg/ast/test_util.go
@@ -0,0 +1,99 @@
+// Copyright 2017 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package ast
+
+import (
+ "bufio"
+ "bytes"
+ "io/ioutil"
+ "strings"
+ "testing"
+)
+
+type ErrorMatcher struct {
+ Data []byte
+ expect []*errorDesc
+ got []*errorDesc
+}
+
+type errorDesc struct {
+ file string
+ line int
+ col int
+ text string
+ matched bool
+}
+
+func NewErrorMatcher(t *testing.T, file string) *ErrorMatcher {
+ data, err := ioutil.ReadFile(file)
+ if err != nil {
+ t.Fatalf("failed to open input file: %v", err)
+ }
+ var stripped []byte
+ var errors []*errorDesc
+ s := bufio.NewScanner(bytes.NewReader(data))
+ for i := 1; s.Scan(); i++ {
+ ln := s.Bytes()
+ for {
+ pos := bytes.LastIndex(ln, []byte("###"))
+ if pos == -1 {
+ break
+ }
+ errors = append(errors, &errorDesc{
+ file: file,
+ line: i,
+ text: strings.TrimSpace(string(ln[pos+3:])),
+ })
+ ln = ln[:pos]
+ }
+ stripped = append(stripped, ln...)
+ stripped = append(stripped, '\n')
+ }
+ if err := s.Err(); err != nil {
+ t.Fatalf("failed to scan input file: %v", err)
+ }
+ return &ErrorMatcher{
+ Data: stripped,
+ expect: errors,
+ }
+}
+
+func (em *ErrorMatcher) ErrorHandler(pos Pos, msg string) {
+ em.got = append(em.got, &errorDesc{
+ file: pos.File,
+ line: pos.Line,
+ col: pos.Col,
+ text: msg,
+ })
+}
+
+func (em *ErrorMatcher) Count() int {
+ return len(em.got)
+}
+
+func (em *ErrorMatcher) Check(t *testing.T) {
+nextErr:
+ for _, e := range em.got {
+ for _, want := range em.expect {
+ if want.matched || want.line != e.line || want.text != e.text {
+ continue
+ }
+ want.matched = true
+ continue nextErr
+ }
+ t.Errorf("unexpected error: %v:%v:%v: %v", e.file, e.line, e.col, e.text)
+ }
+ for _, want := range em.expect {
+ if want.matched {
+ continue
+ }
+ t.Errorf("unmatched error: %v:%v: %v", want.file, want.line, want.text)
+ }
+}
+
+func (em *ErrorMatcher) DumpErrors(t *testing.T) {
+ for _, e := range em.got {
+ t.Logf("%v:%v:%v: %v", e.file, e.line, e.col, e.text)
+ }
+}
diff --git a/pkg/ast/walk.go b/pkg/ast/walk.go
index af62884fe..79e0b4bec 100644
--- a/pkg/ast/walk.go
+++ b/pkg/ast/walk.go
@@ -8,71 +8,71 @@ import (
)
// Walk calls callback cb for every node in AST.
-func Walk(desc *Description, cb func(n, parent Node)) {
+func Walk(desc *Description, cb func(n Node)) {
for _, n := range desc.Nodes {
- WalkNode(n, nil, cb)
+ WalkNode(n, cb)
}
}
-func WalkNode(n0, parent Node, cb func(n, parent Node)) {
- cb(n0, parent)
+func WalkNode(n0 Node, cb func(n Node)) {
+ cb(n0)
switch n := n0.(type) {
case *NewLine:
case *Comment:
case *Include:
- WalkNode(n.File, n, cb)
+ WalkNode(n.File, cb)
case *Incdir:
- WalkNode(n.Dir, n, cb)
+ WalkNode(n.Dir, cb)
case *Define:
- WalkNode(n.Name, n, cb)
- WalkNode(n.Value, n, cb)
+ WalkNode(n.Name, cb)
+ WalkNode(n.Value, cb)
case *Resource:
- WalkNode(n.Name, n, cb)
- WalkNode(n.Base, n, cb)
+ WalkNode(n.Name, cb)
+ WalkNode(n.Base, cb)
for _, v := range n.Values {
- WalkNode(v, n, cb)
+ WalkNode(v, cb)
}
case *Call:
- WalkNode(n.Name, n, cb)
+ WalkNode(n.Name, cb)
for _, f := range n.Args {
- WalkNode(f, n, cb)
+ WalkNode(f, cb)
}
if n.Ret != nil {
- WalkNode(n.Ret, n, cb)
+ WalkNode(n.Ret, cb)
}
case *Struct:
- WalkNode(n.Name, n, cb)
+ WalkNode(n.Name, cb)
for _, f := range n.Fields {
- WalkNode(f, n, cb)
+ WalkNode(f, cb)
}
for _, a := range n.Attrs {
- WalkNode(a, n, cb)
+ WalkNode(a, cb)
}
for _, c := range n.Comments {
- WalkNode(c, n, cb)
+ WalkNode(c, cb)
}
case *IntFlags:
- WalkNode(n.Name, n, cb)
+ WalkNode(n.Name, cb)
for _, v := range n.Values {
- WalkNode(v, n, cb)
+ WalkNode(v, cb)
}
case *StrFlags:
- WalkNode(n.Name, n, cb)
+ WalkNode(n.Name, cb)
for _, v := range n.Values {
- WalkNode(v, n, cb)
+ WalkNode(v, cb)
}
case *Ident:
case *String:
case *Int:
case *Type:
for _, t := range n.Args {
- WalkNode(t, n, cb)
+ WalkNode(t, cb)
}
case *Field:
- WalkNode(n.Name, n, cb)
- WalkNode(n.Type, n, cb)
+ WalkNode(n.Name, cb)
+ WalkNode(n.Type, cb)
for _, c := range n.Comments {
- WalkNode(c, n, cb)
+ WalkNode(c, cb)
}
default:
panic(fmt.Sprintf("unknown AST node: %#v", n))