diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2017-08-26 21:36:08 +0200 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2017-08-27 11:51:40 +0200 |
| commit | a3857c4e90fa4a3fbe78bd4b53cdc77aa91533cf (patch) | |
| tree | 8bc28379a29112de7bc11c57f3d91d0baba84594 /pkg | |
| parent | 9ec49e082f811482ecdccc837c27961d68247d25 (diff) | |
pkg/compiler, sys/syz-sysgen: move const handling to pkg/compiler
Now pkg/compiler deals with consts.
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/ast/ast.go | 11 | ||||
| -rw-r--r-- | pkg/ast/clone.go | 196 | ||||
| -rw-r--r-- | pkg/ast/format.go | 48 | ||||
| -rw-r--r-- | pkg/ast/parser.go | 42 | ||||
| -rw-r--r-- | pkg/ast/parser_test.go | 41 | ||||
| -rw-r--r-- | pkg/ast/scanner.go | 10 | ||||
| -rw-r--r-- | pkg/ast/walk.go | 52 | ||||
| -rw-r--r-- | pkg/compiler/compiler.go | 182 | ||||
| -rw-r--r-- | pkg/compiler/compiler_test.go | 16 |
9 files changed, 498 insertions, 100 deletions
diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go index e2ddc0224..b283ca5f8 100644 --- a/pkg/ast/ast.go +++ b/pkg/ast/ast.go @@ -12,6 +12,14 @@ type Pos struct { Col int // column number, starting at 1 (byte count) } +// Description contains top-level nodes of a parsed sys description. +type Description struct { + Nodes []Node +} + +// Node is AST node interface. +type Node interface{} + // Top-level AST nodes: type NewLine struct { @@ -50,6 +58,7 @@ type Call struct { Pos Pos Name *Ident CallName string + NR uint64 Args []*Field Ret *Type } @@ -104,6 +113,8 @@ type Type struct { Ident string String string // Part after COLON (for ranges and bitfields). + HasColon bool + Pos2 Pos Value2 uint64 Value2Hex bool Ident2 string diff --git a/pkg/ast/clone.go b/pkg/ast/clone.go new file mode 100644 index 000000000..5c2d773f7 --- /dev/null +++ b/pkg/ast/clone.go @@ -0,0 +1,196 @@ +// Copyright 2017 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package ast + +import ( + "fmt" +) + +func Clone(desc *Description) *Description { + desc1 := &Description{} + for _, n := range desc.Nodes { + c, ok := n.(cloner) + if !ok { + panic(fmt.Sprintf("unknown top level decl: %#v", n)) + } + desc1.Nodes = append(desc1.Nodes, c.clone()) + } + return desc1 +} + +type cloner interface { + clone() Node +} + +func (n *NewLine) clone() Node { + return &NewLine{ + Pos: n.Pos, + } +} + +func (n *Comment) clone() Node { + return &Comment{ + Pos: n.Pos, + Text: n.Text, + } +} + +func (n *Include) clone() Node { + return &Include{ + Pos: n.Pos, + File: n.File.clone(), + } +} + +func (n *Incdir) clone() Node { + return &Incdir{ + Pos: n.Pos, + Dir: n.Dir.clone(), + } +} + +func (n *Define) clone() Node { + return &Define{ + Pos: n.Pos, + Name: n.Name.clone(), + Value: n.Value.clone(), + } +} + +func (n *Resource) clone() Node { + var values []*Int + for _, v := range n.Values { + values = append(values, v.clone()) + } + return &Resource{ + Pos: n.Pos, + Name: n.Name.clone(), + Base: n.Base.clone(), + Values: values, + } +} + +func (n *Call) clone() Node { + var args []*Field + for _, a := range n.Args { + args = append(args, a.clone()) + } + var ret *Type + if n.Ret != nil { + ret = n.Ret.clone() + } + return &Call{ + Pos: n.Pos, + Name: n.Name.clone(), + CallName: n.CallName, + NR: n.NR, + Args: args, + Ret: ret, + } +} + +func (n *Struct) clone() Node { + var fields []*Field + for _, f := range n.Fields { + fields = append(fields, f.clone()) + } + var attrs []*Ident + for _, a := range n.Attrs { + attrs = append(attrs, a.clone()) + } + var comments []*Comment + for _, c := range n.Comments { + comments = append(comments, c.clone().(*Comment)) + } + return &Struct{ + Pos: n.Pos, + Name: n.Name.clone(), + Fields: fields, + Attrs: attrs, + Comments: comments, + IsUnion: n.IsUnion, + } +} + +func (n *IntFlags) clone() Node { + var values []*Int + for _, v := range n.Values { + values = append(values, v.clone()) + } + return &IntFlags{ + Pos: n.Pos, + Name: n.Name.clone(), + Values: values, + } +} + +func (n *StrFlags) clone() Node { + var values []*String + for _, v := range n.Values { + values = append(values, v.clone()) + } + return &StrFlags{ + Pos: n.Pos, + Name: n.Name.clone(), + Values: values, + } +} + +func (n *Ident) clone() *Ident { + return &Ident{ + Pos: n.Pos, + Name: n.Name, + } +} + +func (n *String) clone() *String { + return &String{ + Pos: n.Pos, + Value: n.Value, + } +} + +func (n *Int) clone() *Int { + return &Int{ + Pos: n.Pos, + Value: n.Value, + ValueHex: n.ValueHex, + Ident: n.Ident, + CExpr: n.CExpr, + } +} + +func (n *Type) clone() *Type { + var args []*Type + for _, a := range n.Args { + args = append(args, a.clone()) + } + return &Type{ + Pos: n.Pos, + Value: n.Value, + ValueHex: n.ValueHex, + Ident: n.Ident, + String: n.String, + HasColon: n.HasColon, + Pos2: n.Pos2, + Value2: n.Value2, + Value2Hex: n.Value2Hex, + Ident2: n.Ident2, + Args: args, + } +} + +func (n *Field) clone() *Field { + var comments []*Comment + for _, c := range n.Comments { + comments = append(comments, c.clone().(*Comment)) + } + return &Field{ + Pos: n.Pos, + Name: n.Name.clone(), + Type: n.Type.clone(), + NewBlock: n.NewBlock, + Comments: comments, + } +} diff --git a/pkg/ast/format.go b/pkg/ast/format.go index 0eb9aa957..e7e21dcdd 100644 --- a/pkg/ast/format.go +++ b/pkg/ast/format.go @@ -9,47 +9,47 @@ import ( "io" ) -func Format(top []interface{}) []byte { +func Format(desc *Description) []byte { buf := new(bytes.Buffer) - FormatWriter(buf, top) + FormatWriter(buf, desc) return buf.Bytes() } -func FormatWriter(w io.Writer, top []interface{}) { - for _, decl := range top { - s, ok := decl.(serializer) +func FormatWriter(w io.Writer, desc *Description) { + for _, n := range desc.Nodes { + s, ok := n.(serializer) if !ok { - panic(fmt.Sprintf("unknown top level decl: %#v", decl)) + panic(fmt.Sprintf("unknown top level decl: %#v", n)) } - s.Serialize(w) + s.serialize(w) } } type serializer interface { - Serialize(w io.Writer) + serialize(w io.Writer) } -func (incl *NewLine) Serialize(w io.Writer) { +func (nl *NewLine) serialize(w io.Writer) { fmt.Fprintf(w, "\n") } -func (com *Comment) Serialize(w io.Writer) { +func (com *Comment) serialize(w io.Writer) { fmt.Fprintf(w, "#%v\n", com.Text) } -func (incl *Include) Serialize(w io.Writer) { +func (incl *Include) serialize(w io.Writer) { fmt.Fprintf(w, "include <%v>\n", incl.File.Value) } -func (inc *Incdir) Serialize(w io.Writer) { +func (inc *Incdir) serialize(w io.Writer) { fmt.Fprintf(w, "incdir <%v>\n", inc.Dir.Value) } -func (def *Define) Serialize(w io.Writer) { +func (def *Define) serialize(w io.Writer) { fmt.Fprintf(w, "define %v\t%v\n", def.Name.Name, fmtInt(def.Value)) } -func (res *Resource) Serialize(w io.Writer) { +func (res *Resource) serialize(w io.Writer) { fmt.Fprintf(w, "resource %v[%v]", res.Name.Name, res.Base.Name) for i, v := range res.Values { if i == 0 { @@ -62,7 +62,7 @@ func (res *Resource) Serialize(w io.Writer) { fmt.Fprintf(w, "\n") } -func (c *Call) Serialize(w io.Writer) { +func (c *Call) serialize(w io.Writer) { fmt.Fprintf(w, "%v(", c.Name.Name) for i, a := range c.Args { if i != 0 { @@ -77,7 +77,7 @@ func (c *Call) Serialize(w io.Writer) { fmt.Fprintf(w, "\n") } -func (str *Struct) Serialize(w io.Writer) { +func (str *Struct) serialize(w io.Writer) { opening, closing := '{', '}' if str.IsUnion { opening, closing = '[', ']' @@ -119,7 +119,7 @@ func (str *Struct) Serialize(w io.Writer) { fmt.Fprintf(w, "\n") } -func (flags *IntFlags) Serialize(w io.Writer) { +func (flags *IntFlags) serialize(w io.Writer) { fmt.Fprintf(w, "%v = ", flags.Name.Name) for i, v := range flags.Values { if i != 0 { @@ -130,7 +130,7 @@ func (flags *IntFlags) Serialize(w io.Writer) { fmt.Fprintf(w, "\n") } -func (flags *StrFlags) Serialize(w io.Writer) { +func (flags *StrFlags) serialize(w io.Writer) { fmt.Fprintf(w, "%v = ", flags.Name.Name) for i, v := range flags.Values { if i != 0 { @@ -155,11 +155,13 @@ func fmtType(t *Type) string { default: v = fmtIntValue(t.Value, t.ValueHex) } - switch { - case t.Ident2 != "": - v += fmt.Sprintf(":%v", t.Ident2) - case t.Value2 != 0: - v += fmt.Sprintf(":%v", fmtIntValue(t.Value2, t.Value2Hex)) + if t.HasColon { + switch { + case t.Ident2 != "": + v += fmt.Sprintf(":%v", t.Ident2) + default: + v += fmt.Sprintf(":%v", fmtIntValue(t.Value2, t.Value2Hex)) + } } v += fmtTypeList(t.Args) return v diff --git a/pkg/ast/parser.go b/pkg/ast/parser.go index fc05378b2..ca2505e19 100644 --- a/pkg/ast/parser.go +++ b/pkg/ast/parser.go @@ -13,9 +13,11 @@ import ( ) // Parse parses sys description into AST and returns top-level nodes. -func Parse(data []byte, filename string, errorHandler func(pos Pos, msg string)) (top []interface{}, ok bool) { +// If any errors are encountered, returns nil. +func Parse(data []byte, filename string, errorHandler ErrorHandler) *Description { p := &parser{s: newScanner(data, filename, errorHandler)} prevNewLine, prevComment := false, false + var top []Node for p.next(); p.tok != tokEOF; { decl := p.parseTopRecover() if decl == nil { @@ -39,37 +41,41 @@ func Parse(data []byte, filename string, errorHandler func(pos Pos, msg string)) if prevNewLine { top = top[:len(top)-1] } - ok = p.s.Ok() - return + if !p.s.Ok() { + return nil + } + return &Description{top} } -func ParseGlob(glob string, errorHandler func(pos Pos, msg string)) (top []interface{}, ok bool) { +func ParseGlob(glob string, errorHandler ErrorHandler) *Description { if errorHandler == nil { - errorHandler = loggingHandler + errorHandler = LoggingHandler } files, err := filepath.Glob(glob) if err != nil { errorHandler(Pos{}, fmt.Sprintf("failed to find input files: %v", err)) - return nil, false + return nil } if len(files) == 0 { errorHandler(Pos{}, fmt.Sprintf("no files matched by glob %q", glob)) - return nil, false + return nil } - ok = true + desc := &Description{} for _, f := range files { data, err := ioutil.ReadFile(f) if err != nil { errorHandler(Pos{}, fmt.Sprintf("failed to read input file: %v", err)) - return nil, false + return nil + } + desc1 := Parse(data, filepath.Base(f), errorHandler) + if desc1 == nil { + desc = nil } - top1, ok1 := Parse(data, filepath.Base(f), errorHandler) - if !ok1 { - ok = false + if desc != nil { + desc.Nodes = append(desc.Nodes, desc1.Nodes...) } - top = append(top, top1...) } - return + return desc } type parser struct { @@ -84,7 +90,7 @@ type parser struct { // Skip parsing till the next NEWLINE, for error recovery. var skipLine = errors.New("") -func (p *parser) parseTopRecover() interface{} { +func (p *parser) parseTopRecover() Node { defer func() { switch err := recover(); err { case nil: @@ -106,7 +112,7 @@ func (p *parser) parseTopRecover() interface{} { return decl } -func (p *parser) parseTop() interface{} { +func (p *parser) parseTop() Node { switch p.tok { case tokNewLine: return &NewLine{Pos: p.pos} @@ -266,7 +272,7 @@ func callName(s string) string { return s[:pos] } -func (p *parser) parseFlags(name *Ident) interface{} { +func (p *parser) parseFlags(name *Ident) Node { p.consume(tokEq) switch p.tok { case tokInt, tokIdent: @@ -379,6 +385,8 @@ func (p *parser) parseType() *Type { } p.next() if allowColon && p.tryConsume(tokColon) { + arg.HasColon = true + arg.Pos2 = p.pos switch p.tok { case tokInt: arg.Value2, arg.Value2Hex = p.parseIntValue() diff --git a/pkg/ast/parser_test.go b/pkg/ast/parser_test.go index 43587d7a3..809a0beaf 100644 --- a/pkg/ast/parser_test.go +++ b/pkg/ast/parser_test.go @@ -30,23 +30,31 @@ func TestParseAll(t *testing.T) { errorHandler := func(pos Pos, msg string) { t.Fatalf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg) } - top, ok := Parse(data, file.Name(), errorHandler) - if !ok { + desc := Parse(data, file.Name(), errorHandler) + if desc == nil { t.Fatalf("parsing failed, but no error produced") } - data2 := Format(top) - top2, ok2 := Parse(data2, file.Name(), errorHandler) - if !ok2 { + data2 := Format(desc) + desc2 := Parse(data2, file.Name(), errorHandler) + if desc2 == nil { t.Fatalf("parsing failed, but no error produced") } - if len(top) != len(top2) { - t.Fatalf("formatting number of top level decls: %v/%v", len(top), len(top2)) + if len(desc.Nodes) != len(desc2.Nodes) { + t.Fatalf("formatting number of top level decls: %v/%v", + len(desc.Nodes), len(desc2.Nodes)) } - // While sys files are not formatted, formatting in fact changes it. - for i := range top { - if !reflect.DeepEqual(top[i], top2[i]) { - t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", top[i], top2[i]) + for i := range desc.Nodes { + n1, n2 := desc.Nodes[i], desc2.Nodes[i] + if n1 == nil { + t.Fatalf("got nil node") } + if !reflect.DeepEqual(n1, n2) { + t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", n1, n2) + } + } + data3 := Format(Clone(desc)) + if !bytes.Equal(data, data3) { + t.Fatalf("Clone lost data") } } } @@ -57,8 +65,7 @@ func TestParse(t *testing.T) { errorHandler := func(pos Pos, msg string) { t.Logf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg) } - toplev, ok := Parse([]byte(test.input), "foo", errorHandler) - _, _ = toplev, ok + Parse([]byte(test.input), "foo", errorHandler) }) } } @@ -134,17 +141,17 @@ func TestErrors(t *testing.T) { t.Fatalf("failed to scan input file: %v", err) } var got []*Error - top, ok := Parse(stripped, "test", func(pos Pos, msg string) { + desc := Parse(stripped, "test", func(pos Pos, msg string) { got = append(got, &Error{ Line: pos.Line, Col: pos.Col, Text: msg, }) }) - if ok && len(got) != 0 { + if desc != nil && len(got) != 0 { t.Fatalf("parsing succeed, but got errors: %v", got) } - if !ok && len(got) == 0 { + if desc == nil && len(got) == 0 { t.Fatalf("parsing failed, but got no errors") } nextErr: @@ -171,8 +178,6 @@ func TestErrors(t *testing.T) { } t.Errorf("not matched error: %v: %v", wantErr.Line, wantErr.Text) } - // Just to get more code coverage: - Format(top) }) } } diff --git a/pkg/ast/scanner.go b/pkg/ast/scanner.go index 372d2df3e..387a58529 100644 --- a/pkg/ast/scanner.go +++ b/pkg/ast/scanner.go @@ -88,7 +88,7 @@ func (tok token) String() string { type scanner struct { data []byte filename string - errorHandler func(pos Pos, msg string) + errorHandler ErrorHandler ch byte off int @@ -101,9 +101,9 @@ type scanner struct { errors int } -func newScanner(data []byte, filename string, errorHandler func(pos Pos, msg string)) *scanner { +func newScanner(data []byte, filename string, errorHandler ErrorHandler) *scanner { if errorHandler == nil { - errorHandler = loggingHandler + errorHandler = LoggingHandler } s := &scanner{ data: data, @@ -115,7 +115,9 @@ func newScanner(data []byte, filename string, errorHandler func(pos Pos, msg str return s } -func loggingHandler(pos Pos, msg string) { +type ErrorHandler func(pos Pos, msg string) + +func LoggingHandler(pos Pos, msg string) { fmt.Fprintf(os.Stderr, "%v:%v:%v: %v\n", pos.File, pos.Line, pos.Col, msg) } diff --git a/pkg/ast/walk.go b/pkg/ast/walk.go index f4daeaccd..90a92cf77 100644 --- a/pkg/ast/walk.go +++ b/pkg/ast/walk.go @@ -8,13 +8,13 @@ import ( ) // Walk calls callback cb for every node in AST. -func Walk(top []interface{}, cb func(n interface{})) { - for _, decl := range top { - walkNode(decl, cb) +func Walk(desc *Description, cb func(n Node)) { + for _, n := range desc.Nodes { + WalkNode(n, cb) } } -func walkNode(n0 interface{}, cb func(n interface{})) { +func WalkNode(n0 Node, cb func(n Node)) { switch n := n0.(type) { case *NewLine: cb(n) @@ -22,53 +22,53 @@ func walkNode(n0 interface{}, cb func(n interface{})) { cb(n) case *Include: cb(n) - walkNode(n.File, cb) + WalkNode(n.File, cb) case *Incdir: cb(n) - walkNode(n.Dir, cb) + WalkNode(n.Dir, cb) case *Define: cb(n) - walkNode(n.Name, cb) - walkNode(n.Value, cb) + WalkNode(n.Name, cb) + WalkNode(n.Value, cb) case *Resource: cb(n) - walkNode(n.Name, cb) - walkNode(n.Base, cb) + WalkNode(n.Name, cb) + WalkNode(n.Base, cb) for _, v := range n.Values { - walkNode(v, cb) + WalkNode(v, cb) } case *Call: cb(n) - walkNode(n.Name, cb) + WalkNode(n.Name, cb) for _, f := range n.Args { - walkNode(f, cb) + WalkNode(f, cb) } if n.Ret != nil { - walkNode(n.Ret, cb) + WalkNode(n.Ret, cb) } case *Struct: cb(n) - walkNode(n.Name, cb) + WalkNode(n.Name, cb) for _, f := range n.Fields { - walkNode(f, cb) + WalkNode(f, cb) } for _, a := range n.Attrs { - walkNode(a, cb) + WalkNode(a, cb) } for _, c := range n.Comments { - walkNode(c, cb) + WalkNode(c, cb) } case *IntFlags: cb(n) - walkNode(n.Name, cb) + WalkNode(n.Name, cb) for _, v := range n.Values { - walkNode(v, cb) + WalkNode(v, cb) } case *StrFlags: cb(n) - walkNode(n.Name, cb) + WalkNode(n.Name, cb) for _, v := range n.Values { - walkNode(v, cb) + WalkNode(v, cb) } case *Ident: cb(n) @@ -79,14 +79,14 @@ func walkNode(n0 interface{}, cb func(n interface{})) { case *Type: cb(n) for _, t := range n.Args { - walkNode(t, cb) + WalkNode(t, cb) } case *Field: cb(n) - walkNode(n.Name, cb) - walkNode(n.Type, cb) + WalkNode(n.Name, cb) + WalkNode(n.Type, cb) for _, c := range n.Comments { - walkNode(c, cb) + WalkNode(c, cb) } default: panic(fmt.Sprintf("unknown AST node: %#v", n)) diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 7892661fa..593d9e82c 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -9,14 +9,134 @@ import ( "strings" "github.com/google/syzkaller/pkg/ast" + "github.com/google/syzkaller/sys" ) +// Prog is description compilation result. +type Prog struct { + // Processed AST (temporal measure, remove later). + Desc *ast.Description + Resources []*sys.ResourceDesc + // Set of unsupported syscalls/flags. + Unsupported map[string]bool +} + +// Compile compiles sys description. +func Compile(desc0 *ast.Description, consts map[string]uint64, eh ast.ErrorHandler) *Prog { + if eh == nil { + eh = ast.LoggingHandler + } + + desc := ast.Clone(desc0) + unsup, ok := patchConsts(desc, consts, eh) + if !ok { + return nil + } + + return &Prog{ + Desc: desc, + Unsupported: unsup, + } +} + +// patchConsts replaces all symbolic consts with their numeric values taken from consts map. +// Updates desc and returns set of unsupported syscalls and flags. +// After this pass consts are not needed for compilation. +func patchConsts(desc *ast.Description, consts map[string]uint64, eh ast.ErrorHandler) (map[string]bool, bool) { + broken := false + unsup := make(map[string]bool) + var top []ast.Node + for _, decl := range desc.Nodes { + switch decl.(type) { + case *ast.IntFlags: + // Unsupported flag values are dropped. + n := decl.(*ast.IntFlags) + var values []*ast.Int + for _, v := range n.Values { + if patchIntConst(v.Pos, &v.Value, &v.Ident, consts, unsup, nil, eh) { + values = append(values, v) + } + } + n.Values = values + top = append(top, n) + case *ast.Resource, *ast.Struct, *ast.Call: + if c, ok := decl.(*ast.Call); ok { + // Extract syscall NR. + str := "__NR_" + c.CallName + nr, ok := consts[str] + if !ok { + if name := "syscall " + c.CallName; !unsup[name] { + unsup[name] = true + eh(c.Pos, fmt.Sprintf("unsupported syscall: %v due to missing const %v", + c.CallName, str)) + } + continue + } + c.NR = nr + } + // Walk whole tree and replace consts in Int's and Type's. + missing := "" + ast.WalkNode(decl, func(n0 ast.Node) { + switch n := n0.(type) { + case *ast.Int: + patchIntConst(n.Pos, &n.Value, &n.Ident, + consts, unsup, &missing, eh) + case *ast.Type: + if c := typeConstIdentifier(n); c != nil { + patchIntConst(c.Pos, &c.Value, &c.Ident, + consts, unsup, &missing, eh) + if c.HasColon { + patchIntConst(c.Pos2, &c.Value2, &c.Ident2, + consts, unsup, &missing, eh) + } + } + } + }) + if missing == "" { + top = append(top, decl) + } else { + // Produce a warning about unsupported syscall/resource/struct. + // Unsupported syscalls are discarded. + // Unsupported resource/struct lead to compilation error. + // Fixing that would require removing all uses of the resource/struct. + typ, pos, name, fatal := "", ast.Pos{}, "", false + switch n := decl.(type) { + case *ast.Call: + typ, pos, name, fatal = "syscall", n.Pos, n.Name.Name, false + case *ast.Resource: + typ, pos, name, fatal = "resource", n.Pos, n.Name.Name, true + case *ast.Struct: + typ, pos, name, fatal = "struct", n.Pos, n.Name.Name, true + default: + panic(fmt.Sprintf("unknown type: %#v", decl)) + } + if id := typ + " " + name; !unsup[id] { + unsup[id] = true + eh(pos, fmt.Sprintf("unsupported %v: %v due to missing const %v", + typ, name, missing)) + } + if fatal { + broken = true + } + } + case *ast.StrFlags: + top = append(top, decl) + case *ast.NewLine, *ast.Comment, *ast.Include, *ast.Incdir, *ast.Define: + // These are not needed anymore. + default: + panic(fmt.Sprintf("unknown node type: %#v", decl)) + } + } + desc.Nodes = top + return unsup, !broken +} + // ExtractConsts returns list of literal constants and other info required const value extraction. -func ExtractConsts(top []interface{}) (consts, includes, incdirs []string, defines map[string]string) { +func ExtractConsts(desc *ast.Description) (consts, includes, incdirs []string, defines map[string]string) { constMap := make(map[string]bool) defines = make(map[string]string) - ast.Walk(top, func(n1 interface{}) { + ast.Walk(desc, func(n1 ast.Node) { switch n := n1.(type) { case *ast.Include: includes = append(includes, n.File.Value) @@ -37,11 +157,9 @@ func ExtractConsts(top []interface{}) (consts, includes, incdirs []string, defin constMap["__NR_"+n.CallName] = true } case *ast.Type: - if n.Ident == "const" && len(n.Args) > 0 { - constMap[n.Args[0].Ident] = true - } - if n.Ident == "array" && len(n.Args) > 1 { - constMap[n.Args[1].Ident] = true + if c := typeConstIdentifier(n); c != nil { + constMap[c.Ident] = true + constMap[c.Ident2] = true } case *ast.Int: constMap[n.Ident] = true @@ -52,6 +170,56 @@ func ExtractConsts(top []interface{}) (consts, includes, incdirs []string, defin return } +func patchIntConst(pos ast.Pos, val *uint64, id *string, + consts map[string]uint64, unsup map[string]bool, missing *string, eh ast.ErrorHandler) bool { + if *id == "" { + return true + } + v, ok := consts[*id] + if !ok { + name := "const " + *id + if !unsup[name] { + unsup[name] = true + eh(pos, fmt.Sprintf("unsupported const: %v", *id)) + } + if missing != nil && *missing == "" { + *missing = *id + } + } + *val = v + *id = "" + return ok +} + +// typeConstIdentifier returns type arg that is an integer constant (subject for const patching), if any. +func typeConstIdentifier(n *ast.Type) *ast.Type { + if n.Ident == "const" && len(n.Args) > 0 { + return n.Args[0] + } + if n.Ident == "array" && len(n.Args) > 1 && n.Args[1].Ident != "opt" { + return n.Args[1] + } + if n.Ident == "vma" && len(n.Args) > 0 && n.Args[0].Ident != "opt" { + return n.Args[0] + } + if n.Ident == "vma" && len(n.Args) > 0 && n.Args[0].Ident != "opt" { + return n.Args[0] + } + if n.Ident == "string" && len(n.Args) > 1 && n.Args[1].Ident != "opt" { + return n.Args[1] + } + if n.Ident == "csum" && len(n.Args) > 2 && n.Args[1].Ident == "pseudo" { + return n.Args[2] + } + switch n.Ident { + case "int8", "int16", "int16be", "int32", "int32be", "int64", "int64be", "intptr": + if len(n.Args) > 0 && n.Args[0].Ident != "opt" { + return n.Args[0] + } + } + return nil +} + func toArray(m map[string]bool) []string { delete(m, "") var res []string diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go index 13ed97b4c..e68d97d3e 100644 --- a/pkg/compiler/compiler_test.go +++ b/pkg/compiler/compiler_test.go @@ -11,13 +11,15 @@ import ( ) func TestExtractConsts(t *testing.T) { - top, ok := ast.Parse([]byte(extractConstsInput), "test", nil) - if !ok { + desc := ast.Parse([]byte(extractConstsInput), "test", nil) + if desc == nil { t.Fatalf("failed to parse input") } - consts, includes, incdirs, defines := ExtractConsts(top) - wantConsts := []string{"CONST1", "CONST2", "CONST3", "CONST4", "CONST5", - "CONST6", "CONST7", "__NR_bar", "__NR_foo"} + consts, includes, incdirs, defines := ExtractConsts(desc) + wantConsts := []string{"CONST1", "CONST10", "CONST11", "CONST12", "CONST13", + "CONST14", "CONST15", "CONST16", + "CONST2", "CONST3", "CONST4", "CONST5", + "CONST6", "CONST7", "CONST8", "CONST9", "__NR_bar", "__NR_foo"} if !reflect.DeepEqual(consts, wantConsts) { t.Fatalf("got consts:\n%q\nwant:\n%q", consts, wantConsts) } @@ -56,4 +58,8 @@ str { f1 const[CONST6, int32] f2 array[array[int8, CONST7]] } + +bar$BAZ(x vma[opt], y vma[CONST8], z vma[CONST9:CONST10]) +bar$QUX(s ptr[in, string["foo", CONST11]], x csum[s, pseudo, CONST12]) +bar$FOO(x int8[8:CONST13], y int16be[CONST14:10], z intptr[CONST15:CONST16]) ` |
