aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/ast
diff options
context:
space:
mode:
authorAleksandr Nogikh <nogikh@google.com>2023-11-29 16:01:41 +0100
committerAleksandr Nogikh <nogikh@google.com>2024-02-19 11:54:01 +0000
commite59ec59b027f921a6bfbe5014b15c2a802445ada (patch)
tree1c69d1db1b34e0b53d620f9fe272cf9f374b1400 /pkg/ast
parent164800ebad7f26d05eefb0095d190462ed97bee0 (diff)
pkg/ast: support expressions with ast.Type
So far they have the following grammar: OP = "==", "!=", "&" value-expr = value-expr OP value-expr | factor factor = "(" and-expr ")" | integer | identifier | string Operators are left associative, e.g. A & B & C is the same as (A & B) & C. Further restrictions will be imposed in pkg/compiler. This will help implement conditionally included fields.
Diffstat (limited to 'pkg/ast')
-rw-r--r--pkg/ast/ast.go34
-rw-r--r--pkg/ast/clone.go15
-rw-r--r--pkg/ast/format.go45
-rw-r--r--pkg/ast/parser.go51
-rw-r--r--pkg/ast/parser_test.go15
-rw-r--r--pkg/ast/scanner.go21
-rw-r--r--pkg/ast/testdata/all.txt19
-rw-r--r--pkg/ast/testdata/errors.txt10
-rw-r--r--pkg/ast/walk.go15
9 files changed, 203 insertions, 22 deletions
diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go
index 2458c1245..2726af602 100644
--- a/pkg/ast/ast.go
+++ b/pkg/ast/ast.go
@@ -252,15 +252,35 @@ func (n *Int) GetName() string {
return n.Ident
}
+type Operator int
+
+const (
+ OperatorCompareEq = iota + 1
+ OperatorCompareNeq
+ OperatorBinaryAnd
+)
+
+type BinaryExpression struct {
+ Pos Pos
+ Operator Operator
+ Left *Type
+ Right *Type
+}
+
+func (n *BinaryExpression) Info() (Pos, string, string) {
+ return n.Pos, "binary-expression", ""
+}
+
type Type struct {
Pos Pos
- // Only one of Value, Ident, String is filled.
- Value uint64
- ValueFmt IntFmt
- Ident string
- String string
- StringFmt StrFmt
- HasString bool
+ // Only one of Value, Ident, String, Expression is filled.
+ Value uint64
+ ValueFmt IntFmt
+ Ident string
+ String string
+ StringFmt StrFmt
+ HasString bool
+ Expression *BinaryExpression
// Parts after COLON (for ranges and bitfields).
Colon []*Type
// Sub-types in [].
diff --git a/pkg/ast/clone.go b/pkg/ast/clone.go
index 0c9c831f0..e23a9eb70 100644
--- a/pkg/ast/clone.go
+++ b/pkg/ast/clone.go
@@ -157,7 +157,7 @@ func (n *Int) Clone() Node {
}
func (n *Type) Clone() Node {
- return &Type{
+ ret := &Type{
Pos: n.Pos,
Value: n.Value,
ValueFmt: n.ValueFmt,
@@ -168,6 +168,10 @@ func (n *Type) Clone() Node {
Colon: cloneTypes(n.Colon),
Args: cloneTypes(n.Args),
}
+ if n.Expression != nil {
+ ret.Expression = n.Expression.Clone().(*BinaryExpression)
+ }
+ return ret
}
func (n *Field) Clone() Node {
@@ -181,6 +185,15 @@ func (n *Field) Clone() Node {
}
}
+func (n *BinaryExpression) Clone() Node {
+ return &BinaryExpression{
+ Pos: n.Pos,
+ Operator: n.Operator,
+ Left: n.Left.Clone().(*Type),
+ Right: n.Right.Clone().(*Type),
+ }
+}
+
func cloneFields(list []*Field) (res []*Field) {
for _, n := range list {
res = append(res, n.Clone().(*Field))
diff --git a/pkg/ast/format.go b/pkg/ast/format.go
index 420eac916..4b54466c2 100644
--- a/pkg/ast/format.go
+++ b/pkg/ast/format.go
@@ -194,6 +194,12 @@ func (n *Type) serialize(w io.Writer) {
}
func fmtType(t *Type) string {
+ var sb strings.Builder
+ fmtExpressionRec(&sb, t, -1)
+ return sb.String()
+}
+
+func fmtEndType(t *Type) string {
v := ""
switch {
case t.Ident != "":
@@ -247,6 +253,45 @@ func fmtInt(i *Int) string {
}
}
+func fmtExpressionRec(sb *strings.Builder, t *Type, parentPrio int) {
+ if t.Expression == nil {
+ sb.WriteString(fmtEndType(t))
+ return
+ }
+ be := t.Expression
+ myPrio := operatorPrio(be.Operator)
+ parentheses := myPrio < parentPrio
+ if parentheses {
+ sb.WriteByte('(')
+ }
+ fmtExpressionRec(sb, be.Left, myPrio)
+ sb.WriteByte(' ')
+ switch be.Operator {
+ case OperatorCompareEq:
+ sb.WriteString("==")
+ case OperatorCompareNeq:
+ sb.WriteString("!=")
+ case OperatorBinaryAnd:
+ sb.WriteString("&")
+ default:
+ panic(fmt.Sprintf("unknown operator %q", be.Operator))
+ }
+ sb.WriteByte(' ')
+ fmtExpressionRec(sb, be.Right, myPrio)
+ if parentheses {
+ sb.WriteByte(')')
+ }
+}
+
+func operatorPrio(op Operator) int {
+ for _, info := range binaryOperators {
+ if info.op == op {
+ return info.prio
+ }
+ }
+ panic(fmt.Sprintf("unknown operator %q", op))
+}
+
func comma(i int, or string) string {
if i == 0 {
return or
diff --git a/pkg/ast/parser.go b/pkg/ast/parser.go
index c87d43f5e..2f2e62055 100644
--- a/pkg/ast/parser.go
+++ b/pkg/ast/parser.go
@@ -433,7 +433,58 @@ func (p *parser) parseField(parseAttrs bool) *Field {
return field
}
+type operatorInfo struct {
+ op Operator
+ prio int
+}
+
+const maxOperatorPrio = 1
+
+// The highest priority is 0.
+var binaryOperators = map[token]operatorInfo{
+ tokCmpEq: {op: OperatorCompareEq, prio: 0},
+ tokCmpNeq: {op: OperatorCompareNeq, prio: 0},
+ tokBinAnd: {op: OperatorBinaryAnd, prio: 1},
+}
+
+// Parse out a single Type object, which can either be a plain object or an expression.
+// For now, only expressions constructed via '(', ')', "==", "!=", '&' are supported.
func (p *parser) parseType() *Type {
+ return p.parseBinaryExpr(0)
+}
+
+func (p *parser) parseBinaryExpr(expectPrio int) *Type {
+ if expectPrio > maxOperatorPrio {
+ return p.parseExprFactor()
+ }
+ lastPos := p.pos
+ curr := p.parseBinaryExpr(expectPrio + 1)
+ for {
+ info, ok := binaryOperators[p.tok]
+ if !ok || info.prio != expectPrio {
+ return curr
+ }
+ p.consume(p.tok)
+ curr = &Type{
+ Pos: lastPos,
+ Expression: &BinaryExpression{
+ Pos: p.pos,
+ Operator: info.op,
+ Left: curr,
+ Right: p.parseBinaryExpr(expectPrio + 1),
+ },
+ }
+ lastPos = p.pos
+ }
+}
+
+func (p *parser) parseExprFactor() *Type {
+ if p.tok == tokLParen {
+ p.consume(tokLParen)
+ ret := p.parseBinaryExpr(0)
+ p.consume(tokRParen)
+ return ret
+ }
arg := &Type{
Pos: p.pos,
}
diff --git a/pkg/ast/parser_test.go b/pkg/ast/parser_test.go
index cecdcc9b2..5e25157f5 100644
--- a/pkg/ast/parser_test.go
+++ b/pkg/ast/parser_test.go
@@ -4,14 +4,13 @@
package ast
import (
- "bytes"
"os"
"path/filepath"
- "reflect"
"strings"
"testing"
"github.com/google/syzkaller/sys/targets"
+ "github.com/stretchr/testify/assert"
)
func TestParseAll(t *testing.T) {
@@ -47,14 +46,9 @@ func TestParseAll(t *testing.T) {
if n1 == nil {
t.Fatalf("got nil node")
}
- if !reflect.DeepEqual(n1, n2) {
- t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", n1, n2)
- }
- }
- data3 := Format(desc.Clone())
- if !bytes.Equal(data, data3) {
- t.Fatalf("Clone lost data")
+ assert.Equal(t, n1, n2, "formating changed code")
}
+ assert.Equal(t, string(data), string(Format(desc.Clone())))
nodes0 := 0
desc.Walk(func(n Node) {
nodes0++
@@ -63,12 +57,13 @@ func TestParseAll(t *testing.T) {
}
})
nodes1 := 0
- desc.Walk(Recursive(func(n Node) {
+ desc.Walk(Recursive(func(n Node) bool {
nodes1++
pos, typ, _ := n.Info()
if typ == "" {
t.Fatalf("%v: node has empty typ=%q: %#v", pos, typ, n)
}
+ return true
}))
nodes2 := 0
desc.Walk(PostRecursive(func(n Node) {
diff --git a/pkg/ast/scanner.go b/pkg/ast/scanner.go
index 5607e3082..594b9ea8d 100644
--- a/pkg/ast/scanner.go
+++ b/pkg/ast/scanner.go
@@ -4,6 +4,7 @@
package ast
import (
+ "bytes"
"encoding/hex"
"fmt"
"os"
@@ -35,6 +36,9 @@ const (
tokEq
tokComma
tokColon
+ tokBinAnd
+ tokCmpEq
+ tokCmpNeq
tokEOF
)
@@ -50,6 +54,7 @@ var punctuation = [256]token{
'=': tokEq,
',': tokComma,
':': tokColon,
+ '&': tokBinAnd,
}
var tok2str = [...]string{
@@ -66,6 +71,8 @@ var tok2str = [...]string{
tokInt: "int",
tokNewLine: "NEWLINE",
tokEOF: "EOF",
+ tokCmpEq: "==",
+ tokCmpNeq: "!=",
}
func init() {
@@ -181,6 +188,10 @@ func (s *scanner) Scan() (tok token, lit string, pos Pos) {
lit = s.scanChar(pos)
case s.ch == '_' || s.ch >= 'a' && s.ch <= 'z' || s.ch >= 'A' && s.ch <= 'Z':
tok, lit = s.scanIdent(pos)
+ case s.tryConsume("=="):
+ tok = tokCmpEq
+ case s.tryConsume("!="):
+ tok = tokCmpNeq
default:
tok = punctuation[s.ch]
if tok == tokIllegal {
@@ -313,6 +324,16 @@ func (s *scanner) next() {
}
}
+func (s *scanner) tryConsume(str string) bool {
+ if !bytes.HasPrefix(s.data[s.off:], []byte(str)) {
+ return false
+ }
+ for i := 0; i < len(str); i++ {
+ s.next()
+ }
+ return true
+}
+
func (s *scanner) skipWhitespace() {
for s.ch == ' ' || s.ch == '\t' {
s.next()
diff --git a/pkg/ast/testdata/all.txt b/pkg/ast/testdata/all.txt
index c9569840c..d04fd4992 100644
--- a/pkg/ast/testdata/all.txt
+++ b/pkg/ast/testdata/all.txt
@@ -8,3 +8,22 @@ incdir <some/path>
strflags0 = "foo", strflags1
strflags1 = "bar"
+
+expressions {
+ f0 int8 (if[value[X] & Y])
+ f1 int8 (if[X & Y == Z])
+ f2 int8 (if[X & Y & Z == value[X] & A])
+ f3 int8 (if[X & (A == B) & Z != C])
+}
+
+condFields {
+ mask int8
+# Simple expressions work.
+ f0 int16 (if[val[mask] == SOME_CONST])
+# Conditions and other attributes work together.
+ f1 int16 (out, if[val[mask] == SOME_CONST])
+# Test some more complex expressions.
+ f2 int16 (out, if[val[mask] & SOME_CONST == OTHER_CONST])
+ f3 int16 (out, if[val[mask] & SOME_CONST & OTHER_CONST == val[mask] & CONST_X])
+ f4 int16 (out, if[val[mask] & SOME_CONST])
+}
diff --git a/pkg/ast/testdata/errors.txt b/pkg/ast/testdata/errors.txt
index 266babf8f..b3d9e7f52 100644
--- a/pkg/ast/testdata/errors.txt
+++ b/pkg/ast/testdata/errors.txt
@@ -11,7 +11,7 @@ meta foo, bar ### unexpected ',', expecting '\n'
int_flags0 = 0, 0x1, 0xab
int_flags1 = 123ab0x ### bad integer "123ab0x"
-int_flags1 == 0, 1 ### unexpected '=', expecting int, identifier, string
+int_flags1 == 0, 1 ### unexpected ==, expecting '(', '{', '[', '='
int_flags = 0, "foo" ### unexpected string, expecting int, identifier
int_flags2 = ' ### char literal is not terminated
int_flags3 = 'a ### char literal is not terminated
@@ -67,6 +67,14 @@ s3 {
f1 int8
} [attribute[1, "foo"], another[and[another]]]
+sCondFieldsError1 {
+ f0 int16 (out, if[val[mask] SOME_CONST == val[mask]]) ### unexpected identifier, expecting ']'
+} ### unexpected '}', expecting comment, define, include, resource, identifier
+
+sCondFieldsError2 {
+ f5 int16 (out, if[val[mask] & == val[mask]]) ### unexpected ==, expecting int, identifier, string
+} ### unexpected '}', expecting comment, define, include, resource, identifier
+
type mybool8 int8
type net_port proc[1, 2, int16be]
type mybool16 ### unexpected '\n', expecting '[', identifier
diff --git a/pkg/ast/walk.go b/pkg/ast/walk.go
index b187769f3..2322d32e0 100644
--- a/pkg/ast/walk.go
+++ b/pkg/ast/walk.go
@@ -11,11 +11,12 @@ func (desc *Description) Walk(cb func(Node)) {
}
}
-func Recursive(cb func(Node)) func(Node) {
+func Recursive(cb func(Node) bool) func(Node) {
var rec func(Node)
rec = func(n Node) {
- cb(n)
- n.walk(rec)
+ if cb(n) {
+ n.walk(rec)
+ }
}
return rec
}
@@ -117,6 +118,9 @@ func (n *Type) walk(cb func(Node)) {
for _, t := range n.Args {
cb(t)
}
+ if n.Expression != nil {
+ cb(n.Expression)
+ }
}
func (n *Field) walk(cb func(Node)) {
@@ -129,3 +133,8 @@ func (n *Field) walk(cb func(Node)) {
cb(c)
}
}
+
+func (n *BinaryExpression) walk(cb func(Node)) {
+ cb(n.Left)
+ cb(n.Right)
+}