From a3857c4e90fa4a3fbe78bd4b53cdc77aa91533cf Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Sat, 26 Aug 2017 21:36:08 +0200 Subject: pkg/compiler, sys/syz-sysgen: move const handling to pkg/compiler Now pkg/compiler deals with consts. --- pkg/ast/ast.go | 11 +++ pkg/ast/clone.go | 196 +++++++++++++++++++++++++++++++++++++++++++++++++ pkg/ast/format.go | 48 ++++++------ pkg/ast/parser.go | 42 ++++++----- pkg/ast/parser_test.go | 41 ++++++----- pkg/ast/scanner.go | 10 ++- pkg/ast/walk.go | 52 ++++++------- 7 files changed, 312 insertions(+), 88 deletions(-) create mode 100644 pkg/ast/clone.go (limited to 'pkg/ast') diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go index e2ddc0224..b283ca5f8 100644 --- a/pkg/ast/ast.go +++ b/pkg/ast/ast.go @@ -12,6 +12,14 @@ type Pos struct { Col int // column number, starting at 1 (byte count) } +// Description contains top-level nodes of a parsed sys description. +type Description struct { + Nodes []Node +} + +// Node is AST node interface. +type Node interface{} + // Top-level AST nodes: type NewLine struct { @@ -50,6 +58,7 @@ type Call struct { Pos Pos Name *Ident CallName string + NR uint64 Args []*Field Ret *Type } @@ -104,6 +113,8 @@ type Type struct { Ident string String string // Part after COLON (for ranges and bitfields). + HasColon bool + Pos2 Pos Value2 uint64 Value2Hex bool Ident2 string diff --git a/pkg/ast/clone.go b/pkg/ast/clone.go new file mode 100644 index 000000000..5c2d773f7 --- /dev/null +++ b/pkg/ast/clone.go @@ -0,0 +1,196 @@ +// Copyright 2017 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package ast + +import ( + "fmt" +) + +func Clone(desc *Description) *Description { + desc1 := &Description{} + for _, n := range desc.Nodes { + c, ok := n.(cloner) + if !ok { + panic(fmt.Sprintf("unknown top level decl: %#v", n)) + } + desc1.Nodes = append(desc1.Nodes, c.clone()) + } + return desc1 +} + +type cloner interface { + clone() Node +} + +func (n *NewLine) clone() Node { + return &NewLine{ + Pos: n.Pos, + } +} + +func (n *Comment) clone() Node { + return &Comment{ + Pos: n.Pos, + Text: n.Text, + } +} + +func (n *Include) clone() Node { + return &Include{ + Pos: n.Pos, + File: n.File.clone(), + } +} + +func (n *Incdir) clone() Node { + return &Incdir{ + Pos: n.Pos, + Dir: n.Dir.clone(), + } +} + +func (n *Define) clone() Node { + return &Define{ + Pos: n.Pos, + Name: n.Name.clone(), + Value: n.Value.clone(), + } +} + +func (n *Resource) clone() Node { + var values []*Int + for _, v := range n.Values { + values = append(values, v.clone()) + } + return &Resource{ + Pos: n.Pos, + Name: n.Name.clone(), + Base: n.Base.clone(), + Values: values, + } +} + +func (n *Call) clone() Node { + var args []*Field + for _, a := range n.Args { + args = append(args, a.clone()) + } + var ret *Type + if n.Ret != nil { + ret = n.Ret.clone() + } + return &Call{ + Pos: n.Pos, + Name: n.Name.clone(), + CallName: n.CallName, + NR: n.NR, + Args: args, + Ret: ret, + } +} + +func (n *Struct) clone() Node { + var fields []*Field + for _, f := range n.Fields { + fields = append(fields, f.clone()) + } + var attrs []*Ident + for _, a := range n.Attrs { + attrs = append(attrs, a.clone()) + } + var comments []*Comment + for _, c := range n.Comments { + comments = append(comments, c.clone().(*Comment)) + } + return &Struct{ + Pos: n.Pos, + Name: n.Name.clone(), + Fields: fields, + Attrs: attrs, + Comments: comments, + IsUnion: n.IsUnion, + } +} + +func (n *IntFlags) clone() Node { + var values []*Int + for _, v := range n.Values { + values = append(values, v.clone()) + } + return &IntFlags{ + Pos: n.Pos, + Name: n.Name.clone(), + Values: values, + } +} + +func (n *StrFlags) clone() Node { + var values []*String + for _, v := range n.Values { + values = append(values, v.clone()) + } + return &StrFlags{ + Pos: n.Pos, + Name: n.Name.clone(), + Values: values, + } +} + +func (n *Ident) clone() *Ident { + return &Ident{ + Pos: n.Pos, + Name: n.Name, + } +} + +func (n *String) clone() *String { + return &String{ + Pos: n.Pos, + Value: n.Value, + } +} + +func (n *Int) clone() *Int { + return &Int{ + Pos: n.Pos, + Value: n.Value, + ValueHex: n.ValueHex, + Ident: n.Ident, + CExpr: n.CExpr, + } +} + +func (n *Type) clone() *Type { + var args []*Type + for _, a := range n.Args { + args = append(args, a.clone()) + } + return &Type{ + Pos: n.Pos, + Value: n.Value, + ValueHex: n.ValueHex, + Ident: n.Ident, + String: n.String, + HasColon: n.HasColon, + Pos2: n.Pos2, + Value2: n.Value2, + Value2Hex: n.Value2Hex, + Ident2: n.Ident2, + Args: args, + } +} + +func (n *Field) clone() *Field { + var comments []*Comment + for _, c := range n.Comments { + comments = append(comments, c.clone().(*Comment)) + } + return &Field{ + Pos: n.Pos, + Name: n.Name.clone(), + Type: n.Type.clone(), + NewBlock: n.NewBlock, + Comments: comments, + } +} diff --git a/pkg/ast/format.go b/pkg/ast/format.go index 0eb9aa957..e7e21dcdd 100644 --- a/pkg/ast/format.go +++ b/pkg/ast/format.go @@ -9,47 +9,47 @@ import ( "io" ) -func Format(top []interface{}) []byte { +func Format(desc *Description) []byte { buf := new(bytes.Buffer) - FormatWriter(buf, top) + FormatWriter(buf, desc) return buf.Bytes() } -func FormatWriter(w io.Writer, top []interface{}) { - for _, decl := range top { - s, ok := decl.(serializer) +func FormatWriter(w io.Writer, desc *Description) { + for _, n := range desc.Nodes { + s, ok := n.(serializer) if !ok { - panic(fmt.Sprintf("unknown top level decl: %#v", decl)) + panic(fmt.Sprintf("unknown top level decl: %#v", n)) } - s.Serialize(w) + s.serialize(w) } } type serializer interface { - Serialize(w io.Writer) + serialize(w io.Writer) } -func (incl *NewLine) Serialize(w io.Writer) { +func (nl *NewLine) serialize(w io.Writer) { fmt.Fprintf(w, "\n") } -func (com *Comment) Serialize(w io.Writer) { +func (com *Comment) serialize(w io.Writer) { fmt.Fprintf(w, "#%v\n", com.Text) } -func (incl *Include) Serialize(w io.Writer) { +func (incl *Include) serialize(w io.Writer) { fmt.Fprintf(w, "include <%v>\n", incl.File.Value) } -func (inc *Incdir) Serialize(w io.Writer) { +func (inc *Incdir) serialize(w io.Writer) { fmt.Fprintf(w, "incdir <%v>\n", inc.Dir.Value) } -func (def *Define) Serialize(w io.Writer) { +func (def *Define) serialize(w io.Writer) { fmt.Fprintf(w, "define %v\t%v\n", def.Name.Name, fmtInt(def.Value)) } -func (res *Resource) Serialize(w io.Writer) { +func (res *Resource) serialize(w io.Writer) { fmt.Fprintf(w, "resource %v[%v]", res.Name.Name, res.Base.Name) for i, v := range res.Values { if i == 0 { @@ -62,7 +62,7 @@ func (res *Resource) Serialize(w io.Writer) { fmt.Fprintf(w, "\n") } -func (c *Call) Serialize(w io.Writer) { +func (c *Call) serialize(w io.Writer) { fmt.Fprintf(w, "%v(", c.Name.Name) for i, a := range c.Args { if i != 0 { @@ -77,7 +77,7 @@ func (c *Call) Serialize(w io.Writer) { fmt.Fprintf(w, "\n") } -func (str *Struct) Serialize(w io.Writer) { +func (str *Struct) serialize(w io.Writer) { opening, closing := '{', '}' if str.IsUnion { opening, closing = '[', ']' @@ -119,7 +119,7 @@ func (str *Struct) Serialize(w io.Writer) { fmt.Fprintf(w, "\n") } -func (flags *IntFlags) Serialize(w io.Writer) { +func (flags *IntFlags) serialize(w io.Writer) { fmt.Fprintf(w, "%v = ", flags.Name.Name) for i, v := range flags.Values { if i != 0 { @@ -130,7 +130,7 @@ func (flags *IntFlags) Serialize(w io.Writer) { fmt.Fprintf(w, "\n") } -func (flags *StrFlags) Serialize(w io.Writer) { +func (flags *StrFlags) serialize(w io.Writer) { fmt.Fprintf(w, "%v = ", flags.Name.Name) for i, v := range flags.Values { if i != 0 { @@ -155,11 +155,13 @@ func fmtType(t *Type) string { default: v = fmtIntValue(t.Value, t.ValueHex) } - switch { - case t.Ident2 != "": - v += fmt.Sprintf(":%v", t.Ident2) - case t.Value2 != 0: - v += fmt.Sprintf(":%v", fmtIntValue(t.Value2, t.Value2Hex)) + if t.HasColon { + switch { + case t.Ident2 != "": + v += fmt.Sprintf(":%v", t.Ident2) + default: + v += fmt.Sprintf(":%v", fmtIntValue(t.Value2, t.Value2Hex)) + } } v += fmtTypeList(t.Args) return v diff --git a/pkg/ast/parser.go b/pkg/ast/parser.go index fc05378b2..ca2505e19 100644 --- a/pkg/ast/parser.go +++ b/pkg/ast/parser.go @@ -13,9 +13,11 @@ import ( ) // Parse parses sys description into AST and returns top-level nodes. -func Parse(data []byte, filename string, errorHandler func(pos Pos, msg string)) (top []interface{}, ok bool) { +// If any errors are encountered, returns nil. +func Parse(data []byte, filename string, errorHandler ErrorHandler) *Description { p := &parser{s: newScanner(data, filename, errorHandler)} prevNewLine, prevComment := false, false + var top []Node for p.next(); p.tok != tokEOF; { decl := p.parseTopRecover() if decl == nil { @@ -39,37 +41,41 @@ func Parse(data []byte, filename string, errorHandler func(pos Pos, msg string)) if prevNewLine { top = top[:len(top)-1] } - ok = p.s.Ok() - return + if !p.s.Ok() { + return nil + } + return &Description{top} } -func ParseGlob(glob string, errorHandler func(pos Pos, msg string)) (top []interface{}, ok bool) { +func ParseGlob(glob string, errorHandler ErrorHandler) *Description { if errorHandler == nil { - errorHandler = loggingHandler + errorHandler = LoggingHandler } files, err := filepath.Glob(glob) if err != nil { errorHandler(Pos{}, fmt.Sprintf("failed to find input files: %v", err)) - return nil, false + return nil } if len(files) == 0 { errorHandler(Pos{}, fmt.Sprintf("no files matched by glob %q", glob)) - return nil, false + return nil } - ok = true + desc := &Description{} for _, f := range files { data, err := ioutil.ReadFile(f) if err != nil { errorHandler(Pos{}, fmt.Sprintf("failed to read input file: %v", err)) - return nil, false + return nil + } + desc1 := Parse(data, filepath.Base(f), errorHandler) + if desc1 == nil { + desc = nil } - top1, ok1 := Parse(data, filepath.Base(f), errorHandler) - if !ok1 { - ok = false + if desc != nil { + desc.Nodes = append(desc.Nodes, desc1.Nodes...) } - top = append(top, top1...) } - return + return desc } type parser struct { @@ -84,7 +90,7 @@ type parser struct { // Skip parsing till the next NEWLINE, for error recovery. var skipLine = errors.New("") -func (p *parser) parseTopRecover() interface{} { +func (p *parser) parseTopRecover() Node { defer func() { switch err := recover(); err { case nil: @@ -106,7 +112,7 @@ func (p *parser) parseTopRecover() interface{} { return decl } -func (p *parser) parseTop() interface{} { +func (p *parser) parseTop() Node { switch p.tok { case tokNewLine: return &NewLine{Pos: p.pos} @@ -266,7 +272,7 @@ func callName(s string) string { return s[:pos] } -func (p *parser) parseFlags(name *Ident) interface{} { +func (p *parser) parseFlags(name *Ident) Node { p.consume(tokEq) switch p.tok { case tokInt, tokIdent: @@ -379,6 +385,8 @@ func (p *parser) parseType() *Type { } p.next() if allowColon && p.tryConsume(tokColon) { + arg.HasColon = true + arg.Pos2 = p.pos switch p.tok { case tokInt: arg.Value2, arg.Value2Hex = p.parseIntValue() diff --git a/pkg/ast/parser_test.go b/pkg/ast/parser_test.go index 43587d7a3..809a0beaf 100644 --- a/pkg/ast/parser_test.go +++ b/pkg/ast/parser_test.go @@ -30,23 +30,31 @@ func TestParseAll(t *testing.T) { errorHandler := func(pos Pos, msg string) { t.Fatalf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg) } - top, ok := Parse(data, file.Name(), errorHandler) - if !ok { + desc := Parse(data, file.Name(), errorHandler) + if desc == nil { t.Fatalf("parsing failed, but no error produced") } - data2 := Format(top) - top2, ok2 := Parse(data2, file.Name(), errorHandler) - if !ok2 { + data2 := Format(desc) + desc2 := Parse(data2, file.Name(), errorHandler) + if desc2 == nil { t.Fatalf("parsing failed, but no error produced") } - if len(top) != len(top2) { - t.Fatalf("formatting number of top level decls: %v/%v", len(top), len(top2)) + if len(desc.Nodes) != len(desc2.Nodes) { + t.Fatalf("formatting number of top level decls: %v/%v", + len(desc.Nodes), len(desc2.Nodes)) } - // While sys files are not formatted, formatting in fact changes it. - for i := range top { - if !reflect.DeepEqual(top[i], top2[i]) { - t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", top[i], top2[i]) + for i := range desc.Nodes { + n1, n2 := desc.Nodes[i], desc2.Nodes[i] + if n1 == nil { + t.Fatalf("got nil node") } + if !reflect.DeepEqual(n1, n2) { + t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", n1, n2) + } + } + data3 := Format(Clone(desc)) + if !bytes.Equal(data, data3) { + t.Fatalf("Clone lost data") } } } @@ -57,8 +65,7 @@ func TestParse(t *testing.T) { errorHandler := func(pos Pos, msg string) { t.Logf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg) } - toplev, ok := Parse([]byte(test.input), "foo", errorHandler) - _, _ = toplev, ok + Parse([]byte(test.input), "foo", errorHandler) }) } } @@ -134,17 +141,17 @@ func TestErrors(t *testing.T) { t.Fatalf("failed to scan input file: %v", err) } var got []*Error - top, ok := Parse(stripped, "test", func(pos Pos, msg string) { + desc := Parse(stripped, "test", func(pos Pos, msg string) { got = append(got, &Error{ Line: pos.Line, Col: pos.Col, Text: msg, }) }) - if ok && len(got) != 0 { + if desc != nil && len(got) != 0 { t.Fatalf("parsing succeed, but got errors: %v", got) } - if !ok && len(got) == 0 { + if desc == nil && len(got) == 0 { t.Fatalf("parsing failed, but got no errors") } nextErr: @@ -171,8 +178,6 @@ func TestErrors(t *testing.T) { } t.Errorf("not matched error: %v: %v", wantErr.Line, wantErr.Text) } - // Just to get more code coverage: - Format(top) }) } } diff --git a/pkg/ast/scanner.go b/pkg/ast/scanner.go index 372d2df3e..387a58529 100644 --- a/pkg/ast/scanner.go +++ b/pkg/ast/scanner.go @@ -88,7 +88,7 @@ func (tok token) String() string { type scanner struct { data []byte filename string - errorHandler func(pos Pos, msg string) + errorHandler ErrorHandler ch byte off int @@ -101,9 +101,9 @@ type scanner struct { errors int } -func newScanner(data []byte, filename string, errorHandler func(pos Pos, msg string)) *scanner { +func newScanner(data []byte, filename string, errorHandler ErrorHandler) *scanner { if errorHandler == nil { - errorHandler = loggingHandler + errorHandler = LoggingHandler } s := &scanner{ data: data, @@ -115,7 +115,9 @@ func newScanner(data []byte, filename string, errorHandler func(pos Pos, msg str return s } -func loggingHandler(pos Pos, msg string) { +type ErrorHandler func(pos Pos, msg string) + +func LoggingHandler(pos Pos, msg string) { fmt.Fprintf(os.Stderr, "%v:%v:%v: %v\n", pos.File, pos.Line, pos.Col, msg) } diff --git a/pkg/ast/walk.go b/pkg/ast/walk.go index f4daeaccd..90a92cf77 100644 --- a/pkg/ast/walk.go +++ b/pkg/ast/walk.go @@ -8,13 +8,13 @@ import ( ) // Walk calls callback cb for every node in AST. -func Walk(top []interface{}, cb func(n interface{})) { - for _, decl := range top { - walkNode(decl, cb) +func Walk(desc *Description, cb func(n Node)) { + for _, n := range desc.Nodes { + WalkNode(n, cb) } } -func walkNode(n0 interface{}, cb func(n interface{})) { +func WalkNode(n0 Node, cb func(n Node)) { switch n := n0.(type) { case *NewLine: cb(n) @@ -22,53 +22,53 @@ func walkNode(n0 interface{}, cb func(n interface{})) { cb(n) case *Include: cb(n) - walkNode(n.File, cb) + WalkNode(n.File, cb) case *Incdir: cb(n) - walkNode(n.Dir, cb) + WalkNode(n.Dir, cb) case *Define: cb(n) - walkNode(n.Name, cb) - walkNode(n.Value, cb) + WalkNode(n.Name, cb) + WalkNode(n.Value, cb) case *Resource: cb(n) - walkNode(n.Name, cb) - walkNode(n.Base, cb) + WalkNode(n.Name, cb) + WalkNode(n.Base, cb) for _, v := range n.Values { - walkNode(v, cb) + WalkNode(v, cb) } case *Call: cb(n) - walkNode(n.Name, cb) + WalkNode(n.Name, cb) for _, f := range n.Args { - walkNode(f, cb) + WalkNode(f, cb) } if n.Ret != nil { - walkNode(n.Ret, cb) + WalkNode(n.Ret, cb) } case *Struct: cb(n) - walkNode(n.Name, cb) + WalkNode(n.Name, cb) for _, f := range n.Fields { - walkNode(f, cb) + WalkNode(f, cb) } for _, a := range n.Attrs { - walkNode(a, cb) + WalkNode(a, cb) } for _, c := range n.Comments { - walkNode(c, cb) + WalkNode(c, cb) } case *IntFlags: cb(n) - walkNode(n.Name, cb) + WalkNode(n.Name, cb) for _, v := range n.Values { - walkNode(v, cb) + WalkNode(v, cb) } case *StrFlags: cb(n) - walkNode(n.Name, cb) + WalkNode(n.Name, cb) for _, v := range n.Values { - walkNode(v, cb) + WalkNode(v, cb) } case *Ident: cb(n) @@ -79,14 +79,14 @@ func walkNode(n0 interface{}, cb func(n interface{})) { case *Type: cb(n) for _, t := range n.Args { - walkNode(t, cb) + WalkNode(t, cb) } case *Field: cb(n) - walkNode(n.Name, cb) - walkNode(n.Type, cb) + WalkNode(n.Name, cb) + WalkNode(n.Type, cb) for _, c := range n.Comments { - walkNode(c, cb) + WalkNode(c, cb) } default: panic(fmt.Sprintf("unknown AST node: %#v", n)) -- cgit mrf-deployment