aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2017-08-26 21:36:08 +0200
committerDmitry Vyukov <dvyukov@google.com>2017-08-27 11:51:40 +0200
commita3857c4e90fa4a3fbe78bd4b53cdc77aa91533cf (patch)
tree8bc28379a29112de7bc11c57f3d91d0baba84594 /pkg
parent9ec49e082f811482ecdccc837c27961d68247d25 (diff)
pkg/compiler, sys/syz-sysgen: move const handling to pkg/compiler
Now pkg/compiler deals with consts.
Diffstat (limited to 'pkg')
-rw-r--r--pkg/ast/ast.go11
-rw-r--r--pkg/ast/clone.go196
-rw-r--r--pkg/ast/format.go48
-rw-r--r--pkg/ast/parser.go42
-rw-r--r--pkg/ast/parser_test.go41
-rw-r--r--pkg/ast/scanner.go10
-rw-r--r--pkg/ast/walk.go52
-rw-r--r--pkg/compiler/compiler.go182
-rw-r--r--pkg/compiler/compiler_test.go16
9 files changed, 498 insertions, 100 deletions
diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go
index e2ddc0224..b283ca5f8 100644
--- a/pkg/ast/ast.go
+++ b/pkg/ast/ast.go
@@ -12,6 +12,14 @@ type Pos struct {
Col int // column number, starting at 1 (byte count)
}
+// Description contains top-level nodes of a parsed sys description.
+type Description struct {
+ Nodes []Node
+}
+
+// Node is AST node interface.
+type Node interface{}
+
// Top-level AST nodes:
type NewLine struct {
@@ -50,6 +58,7 @@ type Call struct {
Pos Pos
Name *Ident
CallName string
+ NR uint64
Args []*Field
Ret *Type
}
@@ -104,6 +113,8 @@ type Type struct {
Ident string
String string
// Part after COLON (for ranges and bitfields).
+ HasColon bool
+ Pos2 Pos
Value2 uint64
Value2Hex bool
Ident2 string
diff --git a/pkg/ast/clone.go b/pkg/ast/clone.go
new file mode 100644
index 000000000..5c2d773f7
--- /dev/null
+++ b/pkg/ast/clone.go
@@ -0,0 +1,196 @@
+// Copyright 2017 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package ast
+
+import (
+ "fmt"
+)
+
+func Clone(desc *Description) *Description {
+ desc1 := &Description{}
+ for _, n := range desc.Nodes {
+ c, ok := n.(cloner)
+ if !ok {
+ panic(fmt.Sprintf("unknown top level decl: %#v", n))
+ }
+ desc1.Nodes = append(desc1.Nodes, c.clone())
+ }
+ return desc1
+}
+
+type cloner interface {
+ clone() Node
+}
+
+func (n *NewLine) clone() Node {
+ return &NewLine{
+ Pos: n.Pos,
+ }
+}
+
+func (n *Comment) clone() Node {
+ return &Comment{
+ Pos: n.Pos,
+ Text: n.Text,
+ }
+}
+
+func (n *Include) clone() Node {
+ return &Include{
+ Pos: n.Pos,
+ File: n.File.clone(),
+ }
+}
+
+func (n *Incdir) clone() Node {
+ return &Incdir{
+ Pos: n.Pos,
+ Dir: n.Dir.clone(),
+ }
+}
+
+func (n *Define) clone() Node {
+ return &Define{
+ Pos: n.Pos,
+ Name: n.Name.clone(),
+ Value: n.Value.clone(),
+ }
+}
+
+func (n *Resource) clone() Node {
+ var values []*Int
+ for _, v := range n.Values {
+ values = append(values, v.clone())
+ }
+ return &Resource{
+ Pos: n.Pos,
+ Name: n.Name.clone(),
+ Base: n.Base.clone(),
+ Values: values,
+ }
+}
+
+func (n *Call) clone() Node {
+ var args []*Field
+ for _, a := range n.Args {
+ args = append(args, a.clone())
+ }
+ var ret *Type
+ if n.Ret != nil {
+ ret = n.Ret.clone()
+ }
+ return &Call{
+ Pos: n.Pos,
+ Name: n.Name.clone(),
+ CallName: n.CallName,
+ NR: n.NR,
+ Args: args,
+ Ret: ret,
+ }
+}
+
+func (n *Struct) clone() Node {
+ var fields []*Field
+ for _, f := range n.Fields {
+ fields = append(fields, f.clone())
+ }
+ var attrs []*Ident
+ for _, a := range n.Attrs {
+ attrs = append(attrs, a.clone())
+ }
+ var comments []*Comment
+ for _, c := range n.Comments {
+ comments = append(comments, c.clone().(*Comment))
+ }
+ return &Struct{
+ Pos: n.Pos,
+ Name: n.Name.clone(),
+ Fields: fields,
+ Attrs: attrs,
+ Comments: comments,
+ IsUnion: n.IsUnion,
+ }
+}
+
+func (n *IntFlags) clone() Node {
+ var values []*Int
+ for _, v := range n.Values {
+ values = append(values, v.clone())
+ }
+ return &IntFlags{
+ Pos: n.Pos,
+ Name: n.Name.clone(),
+ Values: values,
+ }
+}
+
+func (n *StrFlags) clone() Node {
+ var values []*String
+ for _, v := range n.Values {
+ values = append(values, v.clone())
+ }
+ return &StrFlags{
+ Pos: n.Pos,
+ Name: n.Name.clone(),
+ Values: values,
+ }
+}
+
+func (n *Ident) clone() *Ident {
+ return &Ident{
+ Pos: n.Pos,
+ Name: n.Name,
+ }
+}
+
+func (n *String) clone() *String {
+ return &String{
+ Pos: n.Pos,
+ Value: n.Value,
+ }
+}
+
+func (n *Int) clone() *Int {
+ return &Int{
+ Pos: n.Pos,
+ Value: n.Value,
+ ValueHex: n.ValueHex,
+ Ident: n.Ident,
+ CExpr: n.CExpr,
+ }
+}
+
+func (n *Type) clone() *Type {
+ var args []*Type
+ for _, a := range n.Args {
+ args = append(args, a.clone())
+ }
+ return &Type{
+ Pos: n.Pos,
+ Value: n.Value,
+ ValueHex: n.ValueHex,
+ Ident: n.Ident,
+ String: n.String,
+ HasColon: n.HasColon,
+ Pos2: n.Pos2,
+ Value2: n.Value2,
+ Value2Hex: n.Value2Hex,
+ Ident2: n.Ident2,
+ Args: args,
+ }
+}
+
+func (n *Field) clone() *Field {
+ var comments []*Comment
+ for _, c := range n.Comments {
+ comments = append(comments, c.clone().(*Comment))
+ }
+ return &Field{
+ Pos: n.Pos,
+ Name: n.Name.clone(),
+ Type: n.Type.clone(),
+ NewBlock: n.NewBlock,
+ Comments: comments,
+ }
+}
diff --git a/pkg/ast/format.go b/pkg/ast/format.go
index 0eb9aa957..e7e21dcdd 100644
--- a/pkg/ast/format.go
+++ b/pkg/ast/format.go
@@ -9,47 +9,47 @@ import (
"io"
)
-func Format(top []interface{}) []byte {
+func Format(desc *Description) []byte {
buf := new(bytes.Buffer)
- FormatWriter(buf, top)
+ FormatWriter(buf, desc)
return buf.Bytes()
}
-func FormatWriter(w io.Writer, top []interface{}) {
- for _, decl := range top {
- s, ok := decl.(serializer)
+func FormatWriter(w io.Writer, desc *Description) {
+ for _, n := range desc.Nodes {
+ s, ok := n.(serializer)
if !ok {
- panic(fmt.Sprintf("unknown top level decl: %#v", decl))
+ panic(fmt.Sprintf("unknown top level decl: %#v", n))
}
- s.Serialize(w)
+ s.serialize(w)
}
}
type serializer interface {
- Serialize(w io.Writer)
+ serialize(w io.Writer)
}
-func (incl *NewLine) Serialize(w io.Writer) {
+func (nl *NewLine) serialize(w io.Writer) {
fmt.Fprintf(w, "\n")
}
-func (com *Comment) Serialize(w io.Writer) {
+func (com *Comment) serialize(w io.Writer) {
fmt.Fprintf(w, "#%v\n", com.Text)
}
-func (incl *Include) Serialize(w io.Writer) {
+func (incl *Include) serialize(w io.Writer) {
fmt.Fprintf(w, "include <%v>\n", incl.File.Value)
}
-func (inc *Incdir) Serialize(w io.Writer) {
+func (inc *Incdir) serialize(w io.Writer) {
fmt.Fprintf(w, "incdir <%v>\n", inc.Dir.Value)
}
-func (def *Define) Serialize(w io.Writer) {
+func (def *Define) serialize(w io.Writer) {
fmt.Fprintf(w, "define %v\t%v\n", def.Name.Name, fmtInt(def.Value))
}
-func (res *Resource) Serialize(w io.Writer) {
+func (res *Resource) serialize(w io.Writer) {
fmt.Fprintf(w, "resource %v[%v]", res.Name.Name, res.Base.Name)
for i, v := range res.Values {
if i == 0 {
@@ -62,7 +62,7 @@ func (res *Resource) Serialize(w io.Writer) {
fmt.Fprintf(w, "\n")
}
-func (c *Call) Serialize(w io.Writer) {
+func (c *Call) serialize(w io.Writer) {
fmt.Fprintf(w, "%v(", c.Name.Name)
for i, a := range c.Args {
if i != 0 {
@@ -77,7 +77,7 @@ func (c *Call) Serialize(w io.Writer) {
fmt.Fprintf(w, "\n")
}
-func (str *Struct) Serialize(w io.Writer) {
+func (str *Struct) serialize(w io.Writer) {
opening, closing := '{', '}'
if str.IsUnion {
opening, closing = '[', ']'
@@ -119,7 +119,7 @@ func (str *Struct) Serialize(w io.Writer) {
fmt.Fprintf(w, "\n")
}
-func (flags *IntFlags) Serialize(w io.Writer) {
+func (flags *IntFlags) serialize(w io.Writer) {
fmt.Fprintf(w, "%v = ", flags.Name.Name)
for i, v := range flags.Values {
if i != 0 {
@@ -130,7 +130,7 @@ func (flags *IntFlags) Serialize(w io.Writer) {
fmt.Fprintf(w, "\n")
}
-func (flags *StrFlags) Serialize(w io.Writer) {
+func (flags *StrFlags) serialize(w io.Writer) {
fmt.Fprintf(w, "%v = ", flags.Name.Name)
for i, v := range flags.Values {
if i != 0 {
@@ -155,11 +155,13 @@ func fmtType(t *Type) string {
default:
v = fmtIntValue(t.Value, t.ValueHex)
}
- switch {
- case t.Ident2 != "":
- v += fmt.Sprintf(":%v", t.Ident2)
- case t.Value2 != 0:
- v += fmt.Sprintf(":%v", fmtIntValue(t.Value2, t.Value2Hex))
+ if t.HasColon {
+ switch {
+ case t.Ident2 != "":
+ v += fmt.Sprintf(":%v", t.Ident2)
+ default:
+ v += fmt.Sprintf(":%v", fmtIntValue(t.Value2, t.Value2Hex))
+ }
}
v += fmtTypeList(t.Args)
return v
diff --git a/pkg/ast/parser.go b/pkg/ast/parser.go
index fc05378b2..ca2505e19 100644
--- a/pkg/ast/parser.go
+++ b/pkg/ast/parser.go
@@ -13,9 +13,11 @@ import (
)
// Parse parses sys description into AST and returns top-level nodes.
-func Parse(data []byte, filename string, errorHandler func(pos Pos, msg string)) (top []interface{}, ok bool) {
+// If any errors are encountered, returns nil.
+func Parse(data []byte, filename string, errorHandler ErrorHandler) *Description {
p := &parser{s: newScanner(data, filename, errorHandler)}
prevNewLine, prevComment := false, false
+ var top []Node
for p.next(); p.tok != tokEOF; {
decl := p.parseTopRecover()
if decl == nil {
@@ -39,37 +41,41 @@ func Parse(data []byte, filename string, errorHandler func(pos Pos, msg string))
if prevNewLine {
top = top[:len(top)-1]
}
- ok = p.s.Ok()
- return
+ if !p.s.Ok() {
+ return nil
+ }
+ return &Description{top}
}
-func ParseGlob(glob string, errorHandler func(pos Pos, msg string)) (top []interface{}, ok bool) {
+func ParseGlob(glob string, errorHandler ErrorHandler) *Description {
if errorHandler == nil {
- errorHandler = loggingHandler
+ errorHandler = LoggingHandler
}
files, err := filepath.Glob(glob)
if err != nil {
errorHandler(Pos{}, fmt.Sprintf("failed to find input files: %v", err))
- return nil, false
+ return nil
}
if len(files) == 0 {
errorHandler(Pos{}, fmt.Sprintf("no files matched by glob %q", glob))
- return nil, false
+ return nil
}
- ok = true
+ desc := &Description{}
for _, f := range files {
data, err := ioutil.ReadFile(f)
if err != nil {
errorHandler(Pos{}, fmt.Sprintf("failed to read input file: %v", err))
- return nil, false
+ return nil
+ }
+ desc1 := Parse(data, filepath.Base(f), errorHandler)
+ if desc1 == nil {
+ desc = nil
}
- top1, ok1 := Parse(data, filepath.Base(f), errorHandler)
- if !ok1 {
- ok = false
+ if desc != nil {
+ desc.Nodes = append(desc.Nodes, desc1.Nodes...)
}
- top = append(top, top1...)
}
- return
+ return desc
}
type parser struct {
@@ -84,7 +90,7 @@ type parser struct {
// Skip parsing till the next NEWLINE, for error recovery.
var skipLine = errors.New("")
-func (p *parser) parseTopRecover() interface{} {
+func (p *parser) parseTopRecover() Node {
defer func() {
switch err := recover(); err {
case nil:
@@ -106,7 +112,7 @@ func (p *parser) parseTopRecover() interface{} {
return decl
}
-func (p *parser) parseTop() interface{} {
+func (p *parser) parseTop() Node {
switch p.tok {
case tokNewLine:
return &NewLine{Pos: p.pos}
@@ -266,7 +272,7 @@ func callName(s string) string {
return s[:pos]
}
-func (p *parser) parseFlags(name *Ident) interface{} {
+func (p *parser) parseFlags(name *Ident) Node {
p.consume(tokEq)
switch p.tok {
case tokInt, tokIdent:
@@ -379,6 +385,8 @@ func (p *parser) parseType() *Type {
}
p.next()
if allowColon && p.tryConsume(tokColon) {
+ arg.HasColon = true
+ arg.Pos2 = p.pos
switch p.tok {
case tokInt:
arg.Value2, arg.Value2Hex = p.parseIntValue()
diff --git a/pkg/ast/parser_test.go b/pkg/ast/parser_test.go
index 43587d7a3..809a0beaf 100644
--- a/pkg/ast/parser_test.go
+++ b/pkg/ast/parser_test.go
@@ -30,23 +30,31 @@ func TestParseAll(t *testing.T) {
errorHandler := func(pos Pos, msg string) {
t.Fatalf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg)
}
- top, ok := Parse(data, file.Name(), errorHandler)
- if !ok {
+ desc := Parse(data, file.Name(), errorHandler)
+ if desc == nil {
t.Fatalf("parsing failed, but no error produced")
}
- data2 := Format(top)
- top2, ok2 := Parse(data2, file.Name(), errorHandler)
- if !ok2 {
+ data2 := Format(desc)
+ desc2 := Parse(data2, file.Name(), errorHandler)
+ if desc2 == nil {
t.Fatalf("parsing failed, but no error produced")
}
- if len(top) != len(top2) {
- t.Fatalf("formatting number of top level decls: %v/%v", len(top), len(top2))
+ if len(desc.Nodes) != len(desc2.Nodes) {
+ t.Fatalf("formatting number of top level decls: %v/%v",
+ len(desc.Nodes), len(desc2.Nodes))
}
- // While sys files are not formatted, formatting in fact changes it.
- for i := range top {
- if !reflect.DeepEqual(top[i], top2[i]) {
- t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", top[i], top2[i])
+ for i := range desc.Nodes {
+ n1, n2 := desc.Nodes[i], desc2.Nodes[i]
+ if n1 == nil {
+ t.Fatalf("got nil node")
}
+ if !reflect.DeepEqual(n1, n2) {
+ t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", n1, n2)
+ }
+ }
+ data3 := Format(Clone(desc))
+ if !bytes.Equal(data, data3) {
+ t.Fatalf("Clone lost data")
}
}
}
@@ -57,8 +65,7 @@ func TestParse(t *testing.T) {
errorHandler := func(pos Pos, msg string) {
t.Logf("%v:%v:%v: %v", pos.File, pos.Line, pos.Col, msg)
}
- toplev, ok := Parse([]byte(test.input), "foo", errorHandler)
- _, _ = toplev, ok
+ Parse([]byte(test.input), "foo", errorHandler)
})
}
}
@@ -134,17 +141,17 @@ func TestErrors(t *testing.T) {
t.Fatalf("failed to scan input file: %v", err)
}
var got []*Error
- top, ok := Parse(stripped, "test", func(pos Pos, msg string) {
+ desc := Parse(stripped, "test", func(pos Pos, msg string) {
got = append(got, &Error{
Line: pos.Line,
Col: pos.Col,
Text: msg,
})
})
- if ok && len(got) != 0 {
+ if desc != nil && len(got) != 0 {
t.Fatalf("parsing succeed, but got errors: %v", got)
}
- if !ok && len(got) == 0 {
+ if desc == nil && len(got) == 0 {
t.Fatalf("parsing failed, but got no errors")
}
nextErr:
@@ -171,8 +178,6 @@ func TestErrors(t *testing.T) {
}
t.Errorf("not matched error: %v: %v", wantErr.Line, wantErr.Text)
}
- // Just to get more code coverage:
- Format(top)
})
}
}
diff --git a/pkg/ast/scanner.go b/pkg/ast/scanner.go
index 372d2df3e..387a58529 100644
--- a/pkg/ast/scanner.go
+++ b/pkg/ast/scanner.go
@@ -88,7 +88,7 @@ func (tok token) String() string {
type scanner struct {
data []byte
filename string
- errorHandler func(pos Pos, msg string)
+ errorHandler ErrorHandler
ch byte
off int
@@ -101,9 +101,9 @@ type scanner struct {
errors int
}
-func newScanner(data []byte, filename string, errorHandler func(pos Pos, msg string)) *scanner {
+func newScanner(data []byte, filename string, errorHandler ErrorHandler) *scanner {
if errorHandler == nil {
- errorHandler = loggingHandler
+ errorHandler = LoggingHandler
}
s := &scanner{
data: data,
@@ -115,7 +115,9 @@ func newScanner(data []byte, filename string, errorHandler func(pos Pos, msg str
return s
}
-func loggingHandler(pos Pos, msg string) {
+type ErrorHandler func(pos Pos, msg string)
+
+func LoggingHandler(pos Pos, msg string) {
fmt.Fprintf(os.Stderr, "%v:%v:%v: %v\n", pos.File, pos.Line, pos.Col, msg)
}
diff --git a/pkg/ast/walk.go b/pkg/ast/walk.go
index f4daeaccd..90a92cf77 100644
--- a/pkg/ast/walk.go
+++ b/pkg/ast/walk.go
@@ -8,13 +8,13 @@ import (
)
// Walk calls callback cb for every node in AST.
-func Walk(top []interface{}, cb func(n interface{})) {
- for _, decl := range top {
- walkNode(decl, cb)
+func Walk(desc *Description, cb func(n Node)) {
+ for _, n := range desc.Nodes {
+ WalkNode(n, cb)
}
}
-func walkNode(n0 interface{}, cb func(n interface{})) {
+func WalkNode(n0 Node, cb func(n Node)) {
switch n := n0.(type) {
case *NewLine:
cb(n)
@@ -22,53 +22,53 @@ func walkNode(n0 interface{}, cb func(n interface{})) {
cb(n)
case *Include:
cb(n)
- walkNode(n.File, cb)
+ WalkNode(n.File, cb)
case *Incdir:
cb(n)
- walkNode(n.Dir, cb)
+ WalkNode(n.Dir, cb)
case *Define:
cb(n)
- walkNode(n.Name, cb)
- walkNode(n.Value, cb)
+ WalkNode(n.Name, cb)
+ WalkNode(n.Value, cb)
case *Resource:
cb(n)
- walkNode(n.Name, cb)
- walkNode(n.Base, cb)
+ WalkNode(n.Name, cb)
+ WalkNode(n.Base, cb)
for _, v := range n.Values {
- walkNode(v, cb)
+ WalkNode(v, cb)
}
case *Call:
cb(n)
- walkNode(n.Name, cb)
+ WalkNode(n.Name, cb)
for _, f := range n.Args {
- walkNode(f, cb)
+ WalkNode(f, cb)
}
if n.Ret != nil {
- walkNode(n.Ret, cb)
+ WalkNode(n.Ret, cb)
}
case *Struct:
cb(n)
- walkNode(n.Name, cb)
+ WalkNode(n.Name, cb)
for _, f := range n.Fields {
- walkNode(f, cb)
+ WalkNode(f, cb)
}
for _, a := range n.Attrs {
- walkNode(a, cb)
+ WalkNode(a, cb)
}
for _, c := range n.Comments {
- walkNode(c, cb)
+ WalkNode(c, cb)
}
case *IntFlags:
cb(n)
- walkNode(n.Name, cb)
+ WalkNode(n.Name, cb)
for _, v := range n.Values {
- walkNode(v, cb)
+ WalkNode(v, cb)
}
case *StrFlags:
cb(n)
- walkNode(n.Name, cb)
+ WalkNode(n.Name, cb)
for _, v := range n.Values {
- walkNode(v, cb)
+ WalkNode(v, cb)
}
case *Ident:
cb(n)
@@ -79,14 +79,14 @@ func walkNode(n0 interface{}, cb func(n interface{})) {
case *Type:
cb(n)
for _, t := range n.Args {
- walkNode(t, cb)
+ WalkNode(t, cb)
}
case *Field:
cb(n)
- walkNode(n.Name, cb)
- walkNode(n.Type, cb)
+ WalkNode(n.Name, cb)
+ WalkNode(n.Type, cb)
for _, c := range n.Comments {
- walkNode(c, cb)
+ WalkNode(c, cb)
}
default:
panic(fmt.Sprintf("unknown AST node: %#v", n))
diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go
index 7892661fa..593d9e82c 100644
--- a/pkg/compiler/compiler.go
+++ b/pkg/compiler/compiler.go
@@ -9,14 +9,134 @@ import (
"strings"
"github.com/google/syzkaller/pkg/ast"
+ "github.com/google/syzkaller/sys"
)
+// Prog is description compilation result.
+type Prog struct {
+ // Processed AST (temporal measure, remove later).
+ Desc *ast.Description
+ Resources []*sys.ResourceDesc
+ // Set of unsupported syscalls/flags.
+ Unsupported map[string]bool
+}
+
+// Compile compiles sys description.
+func Compile(desc0 *ast.Description, consts map[string]uint64, eh ast.ErrorHandler) *Prog {
+ if eh == nil {
+ eh = ast.LoggingHandler
+ }
+
+ desc := ast.Clone(desc0)
+ unsup, ok := patchConsts(desc, consts, eh)
+ if !ok {
+ return nil
+ }
+
+ return &Prog{
+ Desc: desc,
+ Unsupported: unsup,
+ }
+}
+
+// patchConsts replaces all symbolic consts with their numeric values taken from consts map.
+// Updates desc and returns set of unsupported syscalls and flags.
+// After this pass consts are not needed for compilation.
+func patchConsts(desc *ast.Description, consts map[string]uint64, eh ast.ErrorHandler) (map[string]bool, bool) {
+ broken := false
+ unsup := make(map[string]bool)
+ var top []ast.Node
+ for _, decl := range desc.Nodes {
+ switch decl.(type) {
+ case *ast.IntFlags:
+ // Unsupported flag values are dropped.
+ n := decl.(*ast.IntFlags)
+ var values []*ast.Int
+ for _, v := range n.Values {
+ if patchIntConst(v.Pos, &v.Value, &v.Ident, consts, unsup, nil, eh) {
+ values = append(values, v)
+ }
+ }
+ n.Values = values
+ top = append(top, n)
+ case *ast.Resource, *ast.Struct, *ast.Call:
+ if c, ok := decl.(*ast.Call); ok {
+ // Extract syscall NR.
+ str := "__NR_" + c.CallName
+ nr, ok := consts[str]
+ if !ok {
+ if name := "syscall " + c.CallName; !unsup[name] {
+ unsup[name] = true
+ eh(c.Pos, fmt.Sprintf("unsupported syscall: %v due to missing const %v",
+ c.CallName, str))
+ }
+ continue
+ }
+ c.NR = nr
+ }
+ // Walk whole tree and replace consts in Int's and Type's.
+ missing := ""
+ ast.WalkNode(decl, func(n0 ast.Node) {
+ switch n := n0.(type) {
+ case *ast.Int:
+ patchIntConst(n.Pos, &n.Value, &n.Ident,
+ consts, unsup, &missing, eh)
+ case *ast.Type:
+ if c := typeConstIdentifier(n); c != nil {
+ patchIntConst(c.Pos, &c.Value, &c.Ident,
+ consts, unsup, &missing, eh)
+ if c.HasColon {
+ patchIntConst(c.Pos2, &c.Value2, &c.Ident2,
+ consts, unsup, &missing, eh)
+ }
+ }
+ }
+ })
+ if missing == "" {
+ top = append(top, decl)
+ } else {
+ // Produce a warning about unsupported syscall/resource/struct.
+ // Unsupported syscalls are discarded.
+ // Unsupported resource/struct lead to compilation error.
+ // Fixing that would require removing all uses of the resource/struct.
+ typ, pos, name, fatal := "", ast.Pos{}, "", false
+ switch n := decl.(type) {
+ case *ast.Call:
+ typ, pos, name, fatal = "syscall", n.Pos, n.Name.Name, false
+ case *ast.Resource:
+ typ, pos, name, fatal = "resource", n.Pos, n.Name.Name, true
+ case *ast.Struct:
+ typ, pos, name, fatal = "struct", n.Pos, n.Name.Name, true
+ default:
+ panic(fmt.Sprintf("unknown type: %#v", decl))
+ }
+ if id := typ + " " + name; !unsup[id] {
+ unsup[id] = true
+ eh(pos, fmt.Sprintf("unsupported %v: %v due to missing const %v",
+ typ, name, missing))
+ }
+ if fatal {
+ broken = true
+ }
+ }
+ case *ast.StrFlags:
+ top = append(top, decl)
+ case *ast.NewLine, *ast.Comment, *ast.Include, *ast.Incdir, *ast.Define:
+ // These are not needed anymore.
+ default:
+ panic(fmt.Sprintf("unknown node type: %#v", decl))
+ }
+ }
+ desc.Nodes = top
+ return unsup, !broken
+}
+
// ExtractConsts returns list of literal constants and other info required const value extraction.
-func ExtractConsts(top []interface{}) (consts, includes, incdirs []string, defines map[string]string) {
+func ExtractConsts(desc *ast.Description) (consts, includes, incdirs []string, defines map[string]string) {
constMap := make(map[string]bool)
defines = make(map[string]string)
- ast.Walk(top, func(n1 interface{}) {
+ ast.Walk(desc, func(n1 ast.Node) {
switch n := n1.(type) {
case *ast.Include:
includes = append(includes, n.File.Value)
@@ -37,11 +157,9 @@ func ExtractConsts(top []interface{}) (consts, includes, incdirs []string, defin
constMap["__NR_"+n.CallName] = true
}
case *ast.Type:
- if n.Ident == "const" && len(n.Args) > 0 {
- constMap[n.Args[0].Ident] = true
- }
- if n.Ident == "array" && len(n.Args) > 1 {
- constMap[n.Args[1].Ident] = true
+ if c := typeConstIdentifier(n); c != nil {
+ constMap[c.Ident] = true
+ constMap[c.Ident2] = true
}
case *ast.Int:
constMap[n.Ident] = true
@@ -52,6 +170,56 @@ func ExtractConsts(top []interface{}) (consts, includes, incdirs []string, defin
return
}
+func patchIntConst(pos ast.Pos, val *uint64, id *string,
+ consts map[string]uint64, unsup map[string]bool, missing *string, eh ast.ErrorHandler) bool {
+ if *id == "" {
+ return true
+ }
+ v, ok := consts[*id]
+ if !ok {
+ name := "const " + *id
+ if !unsup[name] {
+ unsup[name] = true
+ eh(pos, fmt.Sprintf("unsupported const: %v", *id))
+ }
+ if missing != nil && *missing == "" {
+ *missing = *id
+ }
+ }
+ *val = v
+ *id = ""
+ return ok
+}
+
+// typeConstIdentifier returns type arg that is an integer constant (subject for const patching), if any.
+func typeConstIdentifier(n *ast.Type) *ast.Type {
+ if n.Ident == "const" && len(n.Args) > 0 {
+ return n.Args[0]
+ }
+ if n.Ident == "array" && len(n.Args) > 1 && n.Args[1].Ident != "opt" {
+ return n.Args[1]
+ }
+ if n.Ident == "vma" && len(n.Args) > 0 && n.Args[0].Ident != "opt" {
+ return n.Args[0]
+ }
+ if n.Ident == "vma" && len(n.Args) > 0 && n.Args[0].Ident != "opt" {
+ return n.Args[0]
+ }
+ if n.Ident == "string" && len(n.Args) > 1 && n.Args[1].Ident != "opt" {
+ return n.Args[1]
+ }
+ if n.Ident == "csum" && len(n.Args) > 2 && n.Args[1].Ident == "pseudo" {
+ return n.Args[2]
+ }
+ switch n.Ident {
+ case "int8", "int16", "int16be", "int32", "int32be", "int64", "int64be", "intptr":
+ if len(n.Args) > 0 && n.Args[0].Ident != "opt" {
+ return n.Args[0]
+ }
+ }
+ return nil
+}
+
func toArray(m map[string]bool) []string {
delete(m, "")
var res []string
diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go
index 13ed97b4c..e68d97d3e 100644
--- a/pkg/compiler/compiler_test.go
+++ b/pkg/compiler/compiler_test.go
@@ -11,13 +11,15 @@ import (
)
func TestExtractConsts(t *testing.T) {
- top, ok := ast.Parse([]byte(extractConstsInput), "test", nil)
- if !ok {
+ desc := ast.Parse([]byte(extractConstsInput), "test", nil)
+ if desc == nil {
t.Fatalf("failed to parse input")
}
- consts, includes, incdirs, defines := ExtractConsts(top)
- wantConsts := []string{"CONST1", "CONST2", "CONST3", "CONST4", "CONST5",
- "CONST6", "CONST7", "__NR_bar", "__NR_foo"}
+ consts, includes, incdirs, defines := ExtractConsts(desc)
+ wantConsts := []string{"CONST1", "CONST10", "CONST11", "CONST12", "CONST13",
+ "CONST14", "CONST15", "CONST16",
+ "CONST2", "CONST3", "CONST4", "CONST5",
+ "CONST6", "CONST7", "CONST8", "CONST9", "__NR_bar", "__NR_foo"}
if !reflect.DeepEqual(consts, wantConsts) {
t.Fatalf("got consts:\n%q\nwant:\n%q", consts, wantConsts)
}
@@ -56,4 +58,8 @@ str {
f1 const[CONST6, int32]
f2 array[array[int8, CONST7]]
}
+
+bar$BAZ(x vma[opt], y vma[CONST8], z vma[CONST9:CONST10])
+bar$QUX(s ptr[in, string["foo", CONST11]], x csum[s, pseudo, CONST12])
+bar$FOO(x int8[8:CONST13], y int16be[CONST14:10], z intptr[CONST15:CONST16])
`