aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2017-05-22 05:28:31 +0200
committerDmitry Vyukov <dvyukov@google.com>2017-08-18 11:26:50 +0200
commit127a9c2b65ae07f309e839c3b8e5ab2ee7983e56 (patch)
tree3a4dd2af0a2fc09b2bba1dad738c7657d1b0de1d
parent5809a8e05714bda367f3fd57f9b983a3403f04b0 (diff)
pkg/ast: new parser for sys descriptions
The old parser in sys/sysparser is too hacky, difficult to extend and drops debug info too early, so that we can't produce proper error messages. Add a new parser that is build like a proper language parser and preserves full debug info for every token.
-rw-r--r--pkg/ast/ast.go118
-rw-r--r--pkg/ast/format.go204
-rw-r--r--pkg/ast/parser.go423
-rw-r--r--pkg/ast/parser_test.go180
-rw-r--r--pkg/ast/scanner.go260
-rw-r--r--pkg/ast/testdata/all.txt28
-rw-r--r--tools/syz-fmt/syz-fmt.go72
7 files changed, 1285 insertions, 0 deletions
diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go
new file mode 100644
index 000000000..27497b2a9
--- /dev/null
+++ b/pkg/ast/ast.go
@@ -0,0 +1,118 @@
+// Copyright 2017 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+// Package ast parses and formats sys files.
+package ast
+
+// Pos represents source info for AST nodes.
+type Pos struct {
+ File string
+ Off int // byte offset, starting at 0
+ Line int // line number, starting at 1
+ Col int // column number, starting at 1 (byte count)
+}
+
+// Top-level AST nodes:
+
+type NewLine struct {
+ Pos Pos
+}
+
+type Comment struct {
+ Pos Pos
+ Text string
+}
+
+type Include struct {
+ Pos Pos
+ File *String
+}
+
+type Incdir struct {
+ Pos Pos
+ Dir *String
+}
+
+type Define struct {
+ Pos Pos
+ Name *Ident
+ Value *Int
+}
+
+type Resource struct {
+ Pos Pos
+ Name *Ident
+ Base *Ident
+ Values []*Int
+}
+
+type Call struct {
+ Pos Pos
+ Name *Ident
+ Args []*Field
+ Ret *Type
+}
+
+type Struct struct {
+ Pos Pos
+ Name *Ident
+ Fields []*Field
+ Attrs []*Ident
+ Comments []*Comment
+ IsUnion bool
+}
+
+type IntFlags struct {
+ Pos Pos
+ Name *Ident
+ Values []*Int
+}
+
+type StrFlags struct {
+ Pos Pos
+ Name *Ident
+ Values []*String
+}
+
+// Not top-level AST nodes:
+
+type Ident struct {
+ Pos Pos
+ Name string
+}
+
+type String struct {
+ Pos Pos
+ Value string
+}
+
+type Int struct {
+ Pos Pos
+ // Only one of Value, Ident, CExpr is filled.
+ Value uint64
+ ValueHex bool // says if value was in hex (for formatting)
+ Ident string
+ CExpr string
+}
+
+type Type struct {
+ Pos Pos
+ // Only one of Value, Ident, String is filled.
+ Value uint64
+ ValueHex bool
+ Ident string
+ String string
+ // Part after COLON (for ranges and bitfields).
+ Value2 uint64
+ Value2Hex bool
+ Ident2 string
+ Args []*Type
+}
+
+type Field struct {
+ Pos Pos
+ Name *Ident
+ Type *Type
+ NewBlock bool // separated from previous fields by a new line
+ Comments []*Comment
+}
diff --git a/pkg/ast/format.go b/pkg/ast/format.go
new file mode 100644
index 000000000..0eb9aa957
--- /dev/null
+++ b/pkg/ast/format.go
@@ -0,0 +1,204 @@
+// Copyright 2017 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package ast
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+)
+
+func Format(top []interface{}) []byte {
+ buf := new(bytes.Buffer)
+ FormatWriter(buf, top)
+ return buf.Bytes()
+}
+
+func FormatWriter(w io.Writer, top []interface{}) {
+ for _, decl := range top {
+ s, ok := decl.(serializer)
+ if !ok {
+ panic(fmt.Sprintf("unknown top level decl: %#v", decl))
+ }
+ s.Serialize(w)
+ }
+}
+
+type serializer interface {
+ Serialize(w io.Writer)
+}
+
+func (incl *NewLine) Serialize(w io.Writer) {
+ fmt.Fprintf(w, "\n")
+}
+
+func (com *Comment) Serialize(w io.Writer) {
+ fmt.Fprintf(w, "#%v\n", com.Text)
+}
+
+func (incl *Include) Serialize(w io.Writer) {
+ fmt.Fprintf(w, "include <%v>\n", incl.File.Value)
+}
+
+func (inc *Incdir) Serialize(w io.Writer) {
+ fmt.Fprintf(w, "incdir <%v>\n", inc.Dir.Value)
+}
+
+func (def *Define) Serialize(w io.Writer) {
+ fmt.Fprintf(w, "define %v\t%v\n", def.Name.Name, fmtInt(def.Value))
+}
+
+func (res *Resource) Serialize(w io.Writer) {
+ fmt.Fprintf(w, "resource %v[%v]", res.Name.Name, res.Base.Name)
+ for i, v := range res.Values {
+ if i == 0 {
+ fmt.Fprintf(w, ": ")
+ } else {
+ fmt.Fprintf(w, ", ")
+ }
+ fmt.Fprintf(w, "%v", fmtInt(v))
+ }
+ fmt.Fprintf(w, "\n")
+}
+
+func (c *Call) Serialize(w io.Writer) {
+ fmt.Fprintf(w, "%v(", c.Name.Name)
+ for i, a := range c.Args {
+ if i != 0 {
+ fmt.Fprintf(w, ", ")
+ }
+ fmt.Fprintf(w, "%v", fmtField(a))
+ }
+ fmt.Fprintf(w, ")")
+ if c.Ret != nil {
+ fmt.Fprintf(w, " %v", fmtType(c.Ret))
+ }
+ fmt.Fprintf(w, "\n")
+}
+
+func (str *Struct) Serialize(w io.Writer) {
+ opening, closing := '{', '}'
+ if str.IsUnion {
+ opening, closing = '[', ']'
+ }
+ fmt.Fprintf(w, "%v %c\n", str.Name.Name, opening)
+ // Align all field types to the same column.
+ const tabWidth = 8
+ maxTabs := 0
+ for _, f := range str.Fields {
+ tabs := (len(f.Name.Name) + tabWidth) / tabWidth
+ if maxTabs < tabs {
+ maxTabs = tabs
+ }
+ }
+ for _, f := range str.Fields {
+ if f.NewBlock {
+ fmt.Fprintf(w, "\n")
+ }
+ for _, com := range f.Comments {
+ fmt.Fprintf(w, "#%v\n", com.Text)
+ }
+ fmt.Fprintf(w, "\t%v\t", f.Name.Name)
+ for tabs := len(f.Name.Name)/tabWidth + 1; tabs < maxTabs; tabs++ {
+ fmt.Fprintf(w, "\t")
+ }
+ fmt.Fprintf(w, "%v\n", fmtType(f.Type))
+ }
+ for _, com := range str.Comments {
+ fmt.Fprintf(w, "#%v\n", com.Text)
+ }
+ fmt.Fprintf(w, "%c", closing)
+ if len(str.Attrs) != 0 {
+ fmt.Fprintf(w, " [")
+ for i, attr := range str.Attrs {
+ fmt.Fprintf(w, "%v%v", comma(i), attr.Name)
+ }
+ fmt.Fprintf(w, "]")
+ }
+ fmt.Fprintf(w, "\n")
+}
+
+func (flags *IntFlags) Serialize(w io.Writer) {
+ fmt.Fprintf(w, "%v = ", flags.Name.Name)
+ for i, v := range flags.Values {
+ if i != 0 {
+ fmt.Fprintf(w, ", ")
+ }
+ fmt.Fprintf(w, "%v", fmtInt(v))
+ }
+ fmt.Fprintf(w, "\n")
+}
+
+func (flags *StrFlags) Serialize(w io.Writer) {
+ fmt.Fprintf(w, "%v = ", flags.Name.Name)
+ for i, v := range flags.Values {
+ if i != 0 {
+ fmt.Fprintf(w, ", ")
+ }
+ fmt.Fprintf(w, "\"%v\"", v.Value)
+ }
+ fmt.Fprintf(w, "\n")
+}
+
+func fmtField(f *Field) string {
+ return fmt.Sprintf("%v %v", f.Name.Name, fmtType(f.Type))
+}
+
+func fmtType(t *Type) string {
+ v := ""
+ switch {
+ case t.Ident != "":
+ v = t.Ident
+ case t.String != "":
+ v = fmt.Sprintf("\"%v\"", t.String)
+ default:
+ v = fmtIntValue(t.Value, t.ValueHex)
+ }
+ switch {
+ case t.Ident2 != "":
+ v += fmt.Sprintf(":%v", t.Ident2)
+ case t.Value2 != 0:
+ v += fmt.Sprintf(":%v", fmtIntValue(t.Value2, t.Value2Hex))
+ }
+ v += fmtTypeList(t.Args)
+ return v
+}
+
+func fmtTypeList(args []*Type) string {
+ if len(args) == 0 {
+ return ""
+ }
+ w := new(bytes.Buffer)
+ fmt.Fprintf(w, "[")
+ for i, t := range args {
+ fmt.Fprintf(w, "%v%v", comma(i), fmtType(t))
+ }
+ fmt.Fprintf(w, "]")
+ return w.String()
+}
+
+func fmtInt(i *Int) string {
+ switch {
+ case i.Ident != "":
+ return i.Ident
+ case i.CExpr != "":
+ return fmt.Sprintf("%v", i.CExpr)
+ default:
+ return fmtIntValue(i.Value, i.ValueHex)
+ }
+}
+
+func fmtIntValue(v uint64, hex bool) string {
+ if hex {
+ return fmt.Sprintf("0x%x", v)
+ }
+ return fmt.Sprint(v)
+}
+
+func comma(i int) string {
+ if i == 0 {
+ return ""
+ }
+ return ", "
+}
diff --git a/pkg/ast/parser.go b/pkg/ast/parser.go
new file mode 100644
index 000000000..737d08068
--- /dev/null
+++ b/pkg/ast/parser.go
@@ -0,0 +1,423 @@
+// Copyright 2017 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package ast
+
+import (
+ "errors"
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+// Parse parses sys description into AST and returns top-level nodes.
+func Parse(data []byte, filename string, errorHandler func(pos Pos, msg string)) (top []interface{}, ok bool) {
+ p := &parser{s: newScanner(data, filename, errorHandler)}
+ prevNewLine, prevComment := false, false
+ for p.next(); p.tok != tokEOF; {
+ decl := p.parseTopRecover()
+ if decl == nil {
+ continue
+ }
+ // Add new lines around structs, remove duplicate new lines.
+ if _, ok := decl.(*NewLine); ok && prevNewLine {
+ continue
+ }
+ if str, ok := decl.(*Struct); ok && !prevNewLine && !prevComment {
+ top = append(top, &NewLine{Pos: str.Pos})
+ }
+ top = append(top, decl)
+ if str, ok := decl.(*Struct); ok {
+ decl = &NewLine{Pos: str.Pos}
+ top = append(top, decl)
+ }
+ _, prevNewLine = decl.(*NewLine)
+ _, prevComment = decl.(*Comment)
+ }
+ if prevNewLine {
+ top = top[:len(top)-1]
+ }
+ ok = p.s.Ok()
+ return
+}
+
+type parser struct {
+ s *scanner
+
+ // Current token:
+ tok token
+ lit string
+ pos Pos
+}
+
+// Skip parsing till the next NEWLINE, for error recovery.
+var skipLine = errors.New("")
+
+func (p *parser) parseTopRecover() interface{} {
+ defer func() {
+ switch err := recover(); err {
+ case nil:
+ case skipLine:
+ // Try to recover by consuming everything until next NEWLINE.
+ for p.tok != tokNewLine {
+ p.next()
+ }
+ p.consume(tokNewLine)
+ default:
+ panic(err)
+ }
+ }()
+ decl := p.parseTop()
+ if decl == nil {
+ panic("not reachable")
+ }
+ p.consume(tokNewLine)
+ return decl
+}
+
+func (p *parser) parseTop() interface{} {
+ switch p.tok {
+ case tokNewLine:
+ return &NewLine{Pos: p.pos}
+ case tokComment:
+ return p.parseComment()
+ case tokDefine:
+ return p.parseDefine()
+ case tokInclude:
+ return p.parseInclude()
+ case tokIncdir:
+ return p.parseIncdir()
+ case tokResource:
+ return p.parseResource()
+ case tokIdent:
+ name := p.parseIdent()
+ switch p.tok {
+ case tokLParen:
+ return p.parseCall(name)
+ case tokLBrace, tokLBrack:
+ return p.parseStruct(name)
+ case tokEq:
+ return p.parseFlags(name)
+ default:
+ p.expect(tokLParen, tokLBrace, tokLBrack, tokEq)
+ }
+ case tokIllegal:
+ // Scanner has already producer an error for this one.
+ panic(skipLine)
+ default:
+ p.expect(tokComment, tokDefine, tokInclude, tokResource, tokIdent)
+ }
+ panic("not reachable")
+}
+
+func (p *parser) next() {
+ p.tok, p.lit, p.pos = p.s.Scan()
+}
+
+func (p *parser) consume(tok token) {
+ p.expect(tok)
+ p.next()
+}
+
+func (p *parser) tryConsume(tok token) bool {
+ if p.tok != tok {
+ return false
+ }
+ p.next()
+ return true
+}
+
+func (p *parser) expect(tokens ...token) {
+ for _, tok := range tokens {
+ if p.tok == tok {
+ return
+ }
+ }
+ var str []string
+ for _, tok := range tokens {
+ str = append(str, tok.String())
+ }
+ p.s.Error(p.pos, fmt.Sprintf("unexpected %v, expecting %v", p.tok, strings.Join(str, ", ")))
+ panic(skipLine)
+}
+
+func (p *parser) parseComment() *Comment {
+ c := &Comment{
+ Pos: p.pos,
+ Text: p.lit,
+ }
+ p.consume(tokComment)
+ return c
+}
+
+func (p *parser) parseDefine() *Define {
+ pos0 := p.pos
+ p.consume(tokDefine)
+ name := p.parseIdent()
+ p.expect(tokInt, tokIdent, tokCExpr)
+ var val *Int
+ if p.tok == tokCExpr {
+ val = p.parseCExpr()
+ } else {
+ val = p.parseInt()
+ }
+ return &Define{
+ Pos: pos0,
+ Name: name,
+ Value: val,
+ }
+}
+
+func (p *parser) parseInclude() *Include {
+ pos0 := p.pos
+ p.consume(tokInclude)
+ return &Include{
+ Pos: pos0,
+ File: p.parseString(),
+ }
+}
+
+func (p *parser) parseIncdir() *Incdir {
+ pos0 := p.pos
+ p.consume(tokIncdir)
+ return &Incdir{
+ Pos: pos0,
+ Dir: p.parseString(),
+ }
+}
+
+func (p *parser) parseResource() *Resource {
+ pos0 := p.pos
+ p.consume(tokResource)
+ name := p.parseIdent()
+ p.consume(tokLBrack)
+ base := p.parseIdent()
+ p.consume(tokRBrack)
+ var values []*Int
+ if p.tryConsume(tokColon) {
+ values = append(values, p.parseInt())
+ for p.tryConsume(tokComma) {
+ values = append(values, p.parseInt())
+ }
+ }
+ return &Resource{
+ Pos: pos0,
+ Name: name,
+ Base: base,
+ Values: values,
+ }
+}
+
+func (p *parser) parseCall(name *Ident) *Call {
+ c := &Call{
+ Pos: name.Pos,
+ Name: name,
+ }
+ p.consume(tokLParen)
+ for p.tok != tokRParen {
+ c.Args = append(c.Args, p.parseField())
+ p.expect(tokComma, tokRParen)
+ p.tryConsume(tokComma)
+ }
+ p.consume(tokRParen)
+ if p.tok != tokNewLine {
+ c.Ret = p.parseType()
+ }
+ return c
+}
+
+func (p *parser) parseFlags(name *Ident) interface{} {
+ p.consume(tokEq)
+ switch p.tok {
+ case tokInt, tokIdent:
+ return p.parseIntFlags(name)
+ case tokString:
+ return p.parseStrFlags(name)
+ default:
+ p.expect(tokInt, tokIdent, tokString)
+ return nil
+ }
+}
+
+func (p *parser) parseIntFlags(name *Ident) *IntFlags {
+ values := []*Int{p.parseInt()}
+ for p.tryConsume(tokComma) {
+ values = append(values, p.parseInt())
+ }
+ return &IntFlags{
+ Pos: name.Pos,
+ Name: name,
+ Values: values,
+ }
+}
+
+func (p *parser) parseStrFlags(name *Ident) *StrFlags {
+ values := []*String{p.parseString()}
+ for p.tryConsume(tokComma) {
+ values = append(values, p.parseString())
+ }
+ return &StrFlags{
+ Pos: name.Pos,
+ Name: name,
+ Values: values,
+ }
+}
+
+func (p *parser) parseStruct(name *Ident) *Struct {
+ str := &Struct{
+ Pos: name.Pos,
+ Name: name,
+ }
+ closing := tokRBrace
+ if p.tok == tokLBrack {
+ str.IsUnion = true
+ closing = tokRBrack
+ }
+ p.next()
+ p.consume(tokNewLine)
+ for {
+ newBlock := false
+ for p.tok == tokNewLine {
+ newBlock = true
+ p.next()
+ }
+ comments := p.parseCommentBlock()
+ if p.tryConsume(closing) {
+ str.Comments = comments
+ break
+ }
+ fld := p.parseField()
+ fld.NewBlock = newBlock
+ fld.Comments = comments
+ str.Fields = append(str.Fields, fld)
+ p.consume(tokNewLine)
+ }
+ if p.tryConsume(tokLBrack) {
+ str.Attrs = append(str.Attrs, p.parseIdent())
+ for p.tryConsume(tokComma) {
+ str.Attrs = append(str.Attrs, p.parseIdent())
+ }
+ p.consume(tokRBrack)
+ }
+ return str
+}
+
+func (p *parser) parseCommentBlock() []*Comment {
+ var comments []*Comment
+ for p.tok == tokComment {
+ comments = append(comments, p.parseComment())
+ p.consume(tokNewLine)
+ }
+ return comments
+}
+
+func (p *parser) parseField() *Field {
+ name := p.parseIdent()
+ return &Field{
+ Pos: name.Pos,
+ Name: name,
+ Type: p.parseType(),
+ }
+}
+
+func (p *parser) parseType() *Type {
+ arg := &Type{
+ Pos: p.pos,
+ }
+ allowColon := false
+ switch p.tok {
+ case tokInt:
+ allowColon = true
+ arg.Value, arg.ValueHex = p.parseIntValue()
+ case tokIdent:
+ allowColon = true
+ arg.Ident = p.lit
+ case tokString:
+ arg.String = p.lit
+ default:
+ p.expect(tokInt, tokIdent, tokString)
+ }
+ p.next()
+ if allowColon && p.tryConsume(tokColon) {
+ switch p.tok {
+ case tokInt:
+ arg.Value2, arg.Value2Hex = p.parseIntValue()
+ case tokIdent:
+ arg.Ident2 = p.lit
+ default:
+ p.expect(tokInt, tokIdent)
+ }
+ p.next()
+ }
+ arg.Args = p.parseTypeList()
+ return arg
+}
+
+func (p *parser) parseTypeList() []*Type {
+ var args []*Type
+ if p.tryConsume(tokLBrack) {
+ args = append(args, p.parseType())
+ for p.tryConsume(tokComma) {
+ args = append(args, p.parseType())
+ }
+ p.consume(tokRBrack)
+ }
+ return args
+}
+
+func (p *parser) parseIdent() *Ident {
+ p.expect(tokIdent)
+ ident := &Ident{
+ Pos: p.pos,
+ Name: p.lit,
+ }
+ p.next()
+ return ident
+}
+
+func (p *parser) parseString() *String {
+ p.expect(tokString)
+ str := &String{
+ Pos: p.pos,
+ Value: p.lit,
+ }
+ p.next()
+ return str
+}
+
+func (p *parser) parseInt() *Int {
+ i := &Int{
+ Pos: p.pos,
+ }
+ switch p.tok {
+ case tokInt:
+ i.Value, i.ValueHex = p.parseIntValue()
+ case tokIdent:
+ i.Ident = p.lit
+ default:
+ p.expect(tokInt, tokIdent)
+ }
+ p.next()
+ return i
+}
+
+func (p *parser) parseIntValue() (uint64, bool) {
+ if v, err := strconv.ParseUint(p.lit, 10, 64); err == nil {
+ return v, false
+ }
+ if len(p.lit) > 2 && p.lit[0] == '0' && p.lit[1] == 'x' {
+ if v, err := strconv.ParseUint(p.lit[2:], 16, 64); err == nil {
+ return v, true
+ }
+ }
+ panic(fmt.Sprintf("scanner returned bad integer %q", p.lit))
+}
+
+func (p *parser) parseCExpr() *Int {
+ i := &Int{
+ Pos: p.pos,
+ CExpr: p.lit,
+ }
+ p.consume(tokCExpr)
+ return i
+}
diff --git a/pkg/ast/parser_test.go b/pkg/ast/parser_test.go
new file mode 100644
index 000000000..521078805
--- /dev/null
+++ b/pkg/ast/parser_test.go
@@ -0,0 +1,180 @@
+// Copyright 2017 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package ast
+
+import (
+ "bufio"
+ "bytes"
+ "io/ioutil"
+ "path/filepath"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func TestParseAll(t *testing.T) {
+ dir := filepath.Join("..", "..", "sys")
+ files, err := ioutil.ReadDir(dir)
+ if err != nil {
+ t.Fatalf("failed to read sys dir: %v", err)
+ }
+ for _, file := range files {
+ if file.IsDir() || !strings.HasSuffix(file.Name(), ".txt") {
+ continue
+ }
+ data, err := ioutil.ReadFile(filepath.Join(dir, file.Name()))
+ if err != nil {
+ t.Fatalf("failed to read file: %v", err)
+ }
+ errorHandler := func(pos Pos, msg string) {
+ t.Fatalf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg)
+ }
+ top, ok := Parse(data, file.Name(), errorHandler)
+ if !ok {
+ t.Fatalf("parsing failed, but no error produced")
+ }
+ data2 := Format(top)
+ top2, ok2 := Parse(data2, file.Name(), errorHandler)
+ if !ok2 {
+ t.Fatalf("parsing failed, but no error produced")
+ }
+ if len(top) != len(top2) {
+ t.Fatalf("formatting number of top level decls: %v/%v", len(top), len(top2))
+ }
+ if false {
+ // While sys files are not formatted, formatting in fact changes it.
+ for i := range top {
+ if !reflect.DeepEqual(top[i], top2[i]) {
+ t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", top[i], top2[i])
+ }
+ }
+ }
+ }
+}
+
+func TestParse(t *testing.T) {
+ for _, test := range parseTests {
+ t.Run(test.name, func(t *testing.T) {
+ errorHandler := func(pos Pos, msg string) {
+ t.Logf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg)
+ }
+ toplev, ok := Parse([]byte(test.input), "foo", errorHandler)
+ _, _ = toplev, ok
+ })
+ }
+}
+
+var parseTests = []struct {
+ name string
+ input string
+ result []interface{}
+}{
+ {
+ "empty",
+ ``,
+ []interface{}{},
+ },
+ {
+ "new-line",
+ `
+
+`,
+ []interface{}{},
+ },
+ {
+ "nil",
+ "\x00",
+ []interface{}{},
+ },
+}
+
+type Error struct {
+ Line int
+ Col int
+ Text string
+ Matched bool
+}
+
+func TestErrors(t *testing.T) {
+ files, err := ioutil.ReadDir("testdata")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(files) == 0 {
+ t.Fatal("no input files")
+ }
+ for _, f := range files {
+ if !strings.HasSuffix(f.Name(), ".txt") {
+ continue
+ }
+ t.Run(f.Name(), func(t *testing.T) {
+ data, err := ioutil.ReadFile(filepath.Join("testdata", f.Name()))
+ if err != nil {
+ t.Fatalf("failed to open input file: %v", err)
+ }
+ var stripped []byte
+ var errors []*Error
+ s := bufio.NewScanner(bytes.NewReader(data))
+ for i := 1; s.Scan(); i++ {
+ ln := s.Bytes()
+ for {
+ pos := bytes.LastIndex(ln, []byte("###"))
+ if pos == -1 {
+ break
+ }
+ errors = append(errors, &Error{
+ Line: i,
+ Text: strings.TrimSpace(string(ln[pos+3:])),
+ })
+ ln = ln[:pos]
+ }
+ stripped = append(stripped, ln...)
+ stripped = append(stripped, '\n')
+ }
+ if err := s.Err(); err != nil {
+ t.Fatalf("failed to scan input file: %v", err)
+ }
+ var got []*Error
+ top, ok := Parse(stripped, "test", func(pos Pos, msg string) {
+ got = append(got, &Error{
+ Line: pos.Line,
+ Col: pos.Col,
+ Text: msg,
+ })
+ })
+ if ok && len(got) != 0 {
+ t.Fatalf("parsing succeed, but got errors: %v", got)
+ }
+ if !ok && len(got) == 0 {
+ t.Fatalf("parsing failed, but got no errors")
+ }
+ nextErr:
+ for _, gotErr := range got {
+ for _, wantErr := range errors {
+ if wantErr.Matched {
+ continue
+ }
+ if wantErr.Line != gotErr.Line {
+ continue
+ }
+ if wantErr.Text != gotErr.Text {
+ continue
+ }
+ wantErr.Matched = true
+ continue nextErr
+ }
+ t.Errorf("unexpected error: %v:%v: %v",
+ gotErr.Line, gotErr.Col, gotErr.Text)
+ }
+ for _, wantErr := range errors {
+ if wantErr.Matched {
+ continue
+ }
+ t.Errorf("not matched error: %v: %v", wantErr.Line, wantErr.Text)
+ }
+ // Just to get more code coverage:
+ Format(top)
+ })
+ }
+}
diff --git a/pkg/ast/scanner.go b/pkg/ast/scanner.go
new file mode 100644
index 000000000..ee15cee03
--- /dev/null
+++ b/pkg/ast/scanner.go
@@ -0,0 +1,260 @@
+// Copyright 2017 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package ast
+
+import (
+ "fmt"
+ "strconv"
+)
+
+type token int
+
+const (
+ tokIllegal token = iota
+ tokComment
+ tokIdent
+ tokInclude
+ tokIncdir
+ tokDefine
+ tokResource
+ tokString
+ tokCExpr
+ tokInt
+
+ tokNewLine
+ tokLParen
+ tokRParen
+ tokLBrack
+ tokRBrack
+ tokLBrace
+ tokRBrace
+ tokEq
+ tokComma
+ tokColon
+
+ tokEOF
+)
+
+var punctuation = [256]token{
+ '\n': tokNewLine,
+ '(': tokLParen,
+ ')': tokRParen,
+ '[': tokLBrack,
+ ']': tokRBrack,
+ '{': tokLBrace,
+ '}': tokRBrace,
+ '=': tokEq,
+ ',': tokComma,
+ ':': tokColon,
+}
+
+var tok2str = [...]string{
+ tokIllegal: "ILLEGAL",
+ tokComment: "comment",
+ tokIdent: "identifier",
+ tokInclude: "include",
+ tokIncdir: "incdir",
+ tokDefine: "define",
+ tokResource: "resource",
+ tokString: "string",
+ tokCExpr: "CEXPR",
+ tokInt: "int",
+ tokNewLine: "NEWLINE",
+ tokEOF: "EOF",
+}
+
+func init() {
+ for ch, tok := range punctuation {
+ if tok == tokIllegal {
+ continue
+ }
+ tok2str[tok] = fmt.Sprintf("%q", ch)
+ }
+}
+
+var keywords = map[string]token{
+ "include": tokInclude,
+ "incdir": tokIncdir,
+ "define": tokDefine,
+ "resource": tokResource,
+}
+
+func (tok token) String() string {
+ return tok2str[tok]
+}
+
+type scanner struct {
+ data []byte
+ filename string
+ errorHandler func(pos Pos, msg string)
+
+ ch byte
+ off int
+ line int
+ col int
+
+ prev1 token
+ prev2 token
+
+ errors int
+}
+
+func newScanner(data []byte, filename string, errorHandler func(pos Pos, msg string)) *scanner {
+ s := &scanner{
+ data: data,
+ filename: filename,
+ errorHandler: errorHandler,
+ off: -1,
+ }
+ s.next()
+ return s
+}
+
+func (s *scanner) Scan() (tok token, lit string, pos Pos) {
+ s.skipWhitespace()
+ pos = s.pos()
+ switch {
+ case s.ch == 0:
+ tok = tokEOF
+ s.next()
+ case s.ch == '`':
+ tok = tokCExpr
+ for s.next(); s.ch != '`'; s.next() {
+ if s.ch == 0 || s.ch == '\n' {
+ s.Error(pos, "C expression is not terminated")
+ break
+ }
+ }
+ lit = string(s.data[pos.Off+1 : s.off])
+ s.next()
+ case s.prev2 == tokDefine && s.prev1 == tokIdent:
+ // Note: the old form for C expressions, not really lexable.
+ // TODO(dvyukov): get rid of this eventually.
+ tok = tokCExpr
+ for s.next(); s.ch != '\n'; s.next() {
+ }
+ lit = string(s.data[pos.Off:s.off])
+ case s.ch == '#':
+ tok = tokComment
+ for s.next(); s.ch != '\n'; s.next() {
+ }
+ lit = string(s.data[pos.Off+1 : s.off])
+ case s.ch == '"' || s.ch == '<':
+ // TODO(dvyukov): get rid of <...> strings, that's only includes
+ tok = tokString
+ closing := byte('"')
+ if s.ch == '<' {
+ closing = '>'
+ }
+ for s.next(); s.ch != closing; s.next() {
+ if s.ch == 0 || s.ch == '\n' {
+ s.Error(pos, "string literal is not terminated")
+ return
+ }
+ }
+ lit = string(s.data[pos.Off+1 : s.off])
+ for i := 0; i < len(lit); i++ {
+ if lit[i] < 0x20 || lit[i] >= 0x80 {
+ pos1 := pos
+ pos1.Col += i + 1
+ pos1.Off += i + 1
+ s.Error(pos1, "illegal character %#U in string literal", lit[i])
+ break
+ }
+ }
+ s.next()
+ case s.ch >= '0' && s.ch <= '9':
+ tok = tokInt
+ for s.ch >= '0' && s.ch <= '9' ||
+ s.ch >= 'a' && s.ch <= 'f' ||
+ s.ch >= 'A' && s.ch <= 'F' || s.ch == 'x' {
+ s.next()
+ }
+ lit = string(s.data[pos.Off:s.off])
+ bad := false
+ if _, err := strconv.ParseUint(lit, 10, 64); err != nil {
+ if len(lit) > 2 && lit[0] == '0' && lit[1] == 'x' {
+ if _, err := strconv.ParseUint(lit[2:], 16, 64); err != nil {
+ bad = true
+ }
+ } else {
+ bad = true
+ }
+ }
+ if bad {
+ s.Error(pos, fmt.Sprintf("bad integer %q", lit))
+ lit = "0"
+ }
+ case s.ch == '_' || s.ch >= 'a' && s.ch <= 'z' || s.ch >= 'A' && s.ch <= 'Z':
+ tok = tokIdent
+ for s.ch == '_' || s.ch == '$' ||
+ s.ch >= 'a' && s.ch <= 'z' ||
+ s.ch >= 'A' && s.ch <= 'Z' ||
+ s.ch >= '0' && s.ch <= '9' {
+ s.next()
+ }
+ lit = string(s.data[pos.Off:s.off])
+ if key, ok := keywords[lit]; ok {
+ tok = key
+ }
+ default:
+ tok = punctuation[s.ch]
+ if tok == tokIllegal {
+ s.Error(pos, "illegal character %#U", s.ch)
+ }
+ s.next()
+ }
+ s.prev2 = s.prev1
+ s.prev1 = tok
+ return
+}
+
+func (s *scanner) Error(pos Pos, msg string, args ...interface{}) {
+ s.errors++
+ s.errorHandler(pos, fmt.Sprintf(msg, args...))
+}
+
+func (s *scanner) Ok() bool {
+ return s.errors == 0
+}
+
+func (s *scanner) next() {
+ s.off++
+ if s.off == len(s.data) {
+ // Always emit NEWLINE before EOF.
+ // Makes lots of things simpler as we always
+ // want to treat EOF as NEWLINE as well.
+ s.ch = '\n'
+ s.off++
+ return
+ }
+ if s.off > len(s.data) {
+ s.ch = 0
+ return
+ }
+ if s.off == 0 || s.data[s.off-1] == '\n' {
+ s.line++
+ s.col = 0
+ }
+ s.ch = s.data[s.off]
+ s.col++
+ if s.ch == 0 {
+ s.Error(s.pos(), "illegal character \\x00")
+ }
+}
+
+func (s *scanner) skipWhitespace() {
+ for s.ch == ' ' || s.ch == '\t' {
+ s.next()
+ }
+}
+
+func (s *scanner) pos() Pos {
+ return Pos{
+ File: s.filename,
+ Off: s.off,
+ Line: s.line,
+ Col: s.col,
+ }
+}
diff --git a/pkg/ast/testdata/all.txt b/pkg/ast/testdata/all.txt
new file mode 100644
index 000000000..443f26368
--- /dev/null
+++ b/pkg/ast/testdata/all.txt
@@ -0,0 +1,28 @@
+# Copyright 2017 syzkaller project authors. All rights reserved.
+# Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+0x42 ### unexpected int, expecting comment, define, include, resource, identifier
+foo ### unexpected '\n', expecting '(', '{', '[', '='
+% ### illegal character U+0025 '%'
+
+int_flags0 = 0, 0x1, 0xab
+int_flags1 = 123ab0x ### bad integer "123ab0x"
+int_flags1 == 0, 1 ### unexpected '=', expecting int, identifier, string
+int_flags = 0, "foo" ### unexpected string, expecting int, identifier
+
+str_flags0 = "foo", "bar"
+str_flags1 = "non terminated ### string literal is not terminated
+str_flags2 = "bad chars здесь" ### illegal character U+00D0 'Ð' in string literal
+str_flags3 = "string", not a string ### unexpected identifier, expecting string
+str_flags4 = "string", 42 ### unexpected int, expecting string
+
+call(foo ,int32 , bar int32) ### unexpected ',', expecting int, identifier, string
+call(foo int32:"bar") ### unexpected string, expecting int, identifier
+
+define FOO `bar`
+define FOO `bar ### C expression is not terminated
+
+include <linux/foo.h>
+include "linux/foo.h"
+incdir </foo/bar>
+incdir "/foo/bar"
diff --git a/tools/syz-fmt/syz-fmt.go b/tools/syz-fmt/syz-fmt.go
new file mode 100644
index 000000000..f78e8b8dc
--- /dev/null
+++ b/tools/syz-fmt/syz-fmt.go
@@ -0,0 +1,72 @@
+// Copyright 2017 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+// syz-fmt re-formats sys files into standard form.
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/google/syzkaller/pkg/ast"
+)
+
+func main() {
+ if len(os.Args) < 2 {
+ fmt.Fprintf(os.Stderr, "usage: syz-fmt files... or dirs...\n")
+ os.Exit(1)
+ }
+ for _, arg := range os.Args[1:] {
+ st, err := os.Stat(arg)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "failed to stat %v: %v\n", arg, err)
+ os.Exit(1)
+ }
+ if st.IsDir() {
+ files, err := ioutil.ReadDir(arg)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "failed to read dir %v: %v\n", arg, err)
+ os.Exit(1)
+ }
+ for _, file := range files {
+ if !strings.HasSuffix(file.Name(), ".txt") {
+ continue
+ }
+ processFile(filepath.Join(arg, file.Name()), file.Mode())
+ }
+ } else {
+ processFile(arg, st.Mode())
+ }
+ }
+}
+
+func processFile(file string, mode os.FileMode) {
+ data, err := ioutil.ReadFile(file)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "failed to read file %v: %v\n", file, err)
+ os.Exit(1)
+ }
+ errorHandler := func(pos ast.Pos, msg string) {
+ fmt.Fprintf(os.Stderr, "%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg)
+ }
+ top, ok := ast.Parse(data, filepath.Base(file), errorHandler)
+ if !ok {
+ os.Exit(1)
+ }
+ formatted := ast.Format(top)
+ if bytes.Equal(data, formatted) {
+ return
+ }
+ if err := os.Rename(file, file+"~"); err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+ if err := ioutil.WriteFile(file, formatted, mode); err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+}