From e59ec59b027f921a6bfbe5014b15c2a802445ada Mon Sep 17 00:00:00 2001 From: Aleksandr Nogikh Date: Wed, 29 Nov 2023 16:01:41 +0100 Subject: pkg/ast: support expressions with ast.Type So far they have the following grammar: OP = "==", "!=", "&" value-expr = value-expr OP value-expr | factor factor = "(" and-expr ")" | integer | identifier | string Operators are left associative, e.g. A & B & C is the same as (A & B) & C. Further restrictions will be imposed in pkg/compiler. This will help implement conditionally included fields. --- pkg/ast/ast.go | 34 +++++++++++++++++++++++------- pkg/ast/clone.go | 15 ++++++++++++- pkg/ast/format.go | 45 +++++++++++++++++++++++++++++++++++++++ pkg/ast/parser.go | 51 +++++++++++++++++++++++++++++++++++++++++++++ pkg/ast/parser_test.go | 15 +++++-------- pkg/ast/scanner.go | 21 +++++++++++++++++++ pkg/ast/testdata/all.txt | 19 +++++++++++++++++ pkg/ast/testdata/errors.txt | 10 ++++++++- pkg/ast/walk.go | 15 ++++++++++--- 9 files changed, 203 insertions(+), 22 deletions(-) (limited to 'pkg') diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go index 2458c1245..2726af602 100644 --- a/pkg/ast/ast.go +++ b/pkg/ast/ast.go @@ -252,15 +252,35 @@ func (n *Int) GetName() string { return n.Ident } +type Operator int + +const ( + OperatorCompareEq = iota + 1 + OperatorCompareNeq + OperatorBinaryAnd +) + +type BinaryExpression struct { + Pos Pos + Operator Operator + Left *Type + Right *Type +} + +func (n *BinaryExpression) Info() (Pos, string, string) { + return n.Pos, "binary-expression", "" +} + type Type struct { Pos Pos - // Only one of Value, Ident, String is filled. - Value uint64 - ValueFmt IntFmt - Ident string - String string - StringFmt StrFmt - HasString bool + // Only one of Value, Ident, String, Expression is filled. + Value uint64 + ValueFmt IntFmt + Ident string + String string + StringFmt StrFmt + HasString bool + Expression *BinaryExpression // Parts after COLON (for ranges and bitfields). Colon []*Type // Sub-types in []. diff --git a/pkg/ast/clone.go b/pkg/ast/clone.go index 0c9c831f0..e23a9eb70 100644 --- a/pkg/ast/clone.go +++ b/pkg/ast/clone.go @@ -157,7 +157,7 @@ func (n *Int) Clone() Node { } func (n *Type) Clone() Node { - return &Type{ + ret := &Type{ Pos: n.Pos, Value: n.Value, ValueFmt: n.ValueFmt, @@ -168,6 +168,10 @@ func (n *Type) Clone() Node { Colon: cloneTypes(n.Colon), Args: cloneTypes(n.Args), } + if n.Expression != nil { + ret.Expression = n.Expression.Clone().(*BinaryExpression) + } + return ret } func (n *Field) Clone() Node { @@ -181,6 +185,15 @@ func (n *Field) Clone() Node { } } +func (n *BinaryExpression) Clone() Node { + return &BinaryExpression{ + Pos: n.Pos, + Operator: n.Operator, + Left: n.Left.Clone().(*Type), + Right: n.Right.Clone().(*Type), + } +} + func cloneFields(list []*Field) (res []*Field) { for _, n := range list { res = append(res, n.Clone().(*Field)) diff --git a/pkg/ast/format.go b/pkg/ast/format.go index 420eac916..4b54466c2 100644 --- a/pkg/ast/format.go +++ b/pkg/ast/format.go @@ -194,6 +194,12 @@ func (n *Type) serialize(w io.Writer) { } func fmtType(t *Type) string { + var sb strings.Builder + fmtExpressionRec(&sb, t, -1) + return sb.String() +} + +func fmtEndType(t *Type) string { v := "" switch { case t.Ident != "": @@ -247,6 +253,45 @@ func fmtInt(i *Int) string { } } +func fmtExpressionRec(sb *strings.Builder, t *Type, parentPrio int) { + if t.Expression == nil { + sb.WriteString(fmtEndType(t)) + return + } + be := t.Expression + myPrio := operatorPrio(be.Operator) + parentheses := myPrio < parentPrio + if parentheses { + sb.WriteByte('(') + } + fmtExpressionRec(sb, be.Left, myPrio) + sb.WriteByte(' ') + switch be.Operator { + case OperatorCompareEq: + sb.WriteString("==") + case OperatorCompareNeq: + sb.WriteString("!=") + case OperatorBinaryAnd: + sb.WriteString("&") + default: + panic(fmt.Sprintf("unknown operator %q", be.Operator)) + } + sb.WriteByte(' ') + fmtExpressionRec(sb, be.Right, myPrio) + if parentheses { + sb.WriteByte(')') + } +} + +func operatorPrio(op Operator) int { + for _, info := range binaryOperators { + if info.op == op { + return info.prio + } + } + panic(fmt.Sprintf("unknown operator %q", op)) +} + func comma(i int, or string) string { if i == 0 { return or diff --git a/pkg/ast/parser.go b/pkg/ast/parser.go index c87d43f5e..2f2e62055 100644 --- a/pkg/ast/parser.go +++ b/pkg/ast/parser.go @@ -433,7 +433,58 @@ func (p *parser) parseField(parseAttrs bool) *Field { return field } +type operatorInfo struct { + op Operator + prio int +} + +const maxOperatorPrio = 1 + +// The highest priority is 0. +var binaryOperators = map[token]operatorInfo{ + tokCmpEq: {op: OperatorCompareEq, prio: 0}, + tokCmpNeq: {op: OperatorCompareNeq, prio: 0}, + tokBinAnd: {op: OperatorBinaryAnd, prio: 1}, +} + +// Parse out a single Type object, which can either be a plain object or an expression. +// For now, only expressions constructed via '(', ')', "==", "!=", '&' are supported. func (p *parser) parseType() *Type { + return p.parseBinaryExpr(0) +} + +func (p *parser) parseBinaryExpr(expectPrio int) *Type { + if expectPrio > maxOperatorPrio { + return p.parseExprFactor() + } + lastPos := p.pos + curr := p.parseBinaryExpr(expectPrio + 1) + for { + info, ok := binaryOperators[p.tok] + if !ok || info.prio != expectPrio { + return curr + } + p.consume(p.tok) + curr = &Type{ + Pos: lastPos, + Expression: &BinaryExpression{ + Pos: p.pos, + Operator: info.op, + Left: curr, + Right: p.parseBinaryExpr(expectPrio + 1), + }, + } + lastPos = p.pos + } +} + +func (p *parser) parseExprFactor() *Type { + if p.tok == tokLParen { + p.consume(tokLParen) + ret := p.parseBinaryExpr(0) + p.consume(tokRParen) + return ret + } arg := &Type{ Pos: p.pos, } diff --git a/pkg/ast/parser_test.go b/pkg/ast/parser_test.go index cecdcc9b2..5e25157f5 100644 --- a/pkg/ast/parser_test.go +++ b/pkg/ast/parser_test.go @@ -4,14 +4,13 @@ package ast import ( - "bytes" "os" "path/filepath" - "reflect" "strings" "testing" "github.com/google/syzkaller/sys/targets" + "github.com/stretchr/testify/assert" ) func TestParseAll(t *testing.T) { @@ -47,14 +46,9 @@ func TestParseAll(t *testing.T) { if n1 == nil { t.Fatalf("got nil node") } - if !reflect.DeepEqual(n1, n2) { - t.Fatalf("formatting changed code:\n%#v\nvs:\n%#v", n1, n2) - } - } - data3 := Format(desc.Clone()) - if !bytes.Equal(data, data3) { - t.Fatalf("Clone lost data") + assert.Equal(t, n1, n2, "formating changed code") } + assert.Equal(t, string(data), string(Format(desc.Clone()))) nodes0 := 0 desc.Walk(func(n Node) { nodes0++ @@ -63,12 +57,13 @@ func TestParseAll(t *testing.T) { } }) nodes1 := 0 - desc.Walk(Recursive(func(n Node) { + desc.Walk(Recursive(func(n Node) bool { nodes1++ pos, typ, _ := n.Info() if typ == "" { t.Fatalf("%v: node has empty typ=%q: %#v", pos, typ, n) } + return true })) nodes2 := 0 desc.Walk(PostRecursive(func(n Node) { diff --git a/pkg/ast/scanner.go b/pkg/ast/scanner.go index 5607e3082..594b9ea8d 100644 --- a/pkg/ast/scanner.go +++ b/pkg/ast/scanner.go @@ -4,6 +4,7 @@ package ast import ( + "bytes" "encoding/hex" "fmt" "os" @@ -35,6 +36,9 @@ const ( tokEq tokComma tokColon + tokBinAnd + tokCmpEq + tokCmpNeq tokEOF ) @@ -50,6 +54,7 @@ var punctuation = [256]token{ '=': tokEq, ',': tokComma, ':': tokColon, + '&': tokBinAnd, } var tok2str = [...]string{ @@ -66,6 +71,8 @@ var tok2str = [...]string{ tokInt: "int", tokNewLine: "NEWLINE", tokEOF: "EOF", + tokCmpEq: "==", + tokCmpNeq: "!=", } func init() { @@ -181,6 +188,10 @@ func (s *scanner) Scan() (tok token, lit string, pos Pos) { lit = s.scanChar(pos) case s.ch == '_' || s.ch >= 'a' && s.ch <= 'z' || s.ch >= 'A' && s.ch <= 'Z': tok, lit = s.scanIdent(pos) + case s.tryConsume("=="): + tok = tokCmpEq + case s.tryConsume("!="): + tok = tokCmpNeq default: tok = punctuation[s.ch] if tok == tokIllegal { @@ -313,6 +324,16 @@ func (s *scanner) next() { } } +func (s *scanner) tryConsume(str string) bool { + if !bytes.HasPrefix(s.data[s.off:], []byte(str)) { + return false + } + for i := 0; i < len(str); i++ { + s.next() + } + return true +} + func (s *scanner) skipWhitespace() { for s.ch == ' ' || s.ch == '\t' { s.next() diff --git a/pkg/ast/testdata/all.txt b/pkg/ast/testdata/all.txt index c9569840c..d04fd4992 100644 --- a/pkg/ast/testdata/all.txt +++ b/pkg/ast/testdata/all.txt @@ -8,3 +8,22 @@ incdir strflags0 = "foo", strflags1 strflags1 = "bar" + +expressions { + f0 int8 (if[value[X] & Y]) + f1 int8 (if[X & Y == Z]) + f2 int8 (if[X & Y & Z == value[X] & A]) + f3 int8 (if[X & (A == B) & Z != C]) +} + +condFields { + mask int8 +# Simple expressions work. + f0 int16 (if[val[mask] == SOME_CONST]) +# Conditions and other attributes work together. + f1 int16 (out, if[val[mask] == SOME_CONST]) +# Test some more complex expressions. + f2 int16 (out, if[val[mask] & SOME_CONST == OTHER_CONST]) + f3 int16 (out, if[val[mask] & SOME_CONST & OTHER_CONST == val[mask] & CONST_X]) + f4 int16 (out, if[val[mask] & SOME_CONST]) +} diff --git a/pkg/ast/testdata/errors.txt b/pkg/ast/testdata/errors.txt index 266babf8f..b3d9e7f52 100644 --- a/pkg/ast/testdata/errors.txt +++ b/pkg/ast/testdata/errors.txt @@ -11,7 +11,7 @@ meta foo, bar ### unexpected ',', expecting '\n' int_flags0 = 0, 0x1, 0xab int_flags1 = 123ab0x ### bad integer "123ab0x" -int_flags1 == 0, 1 ### unexpected '=', expecting int, identifier, string +int_flags1 == 0, 1 ### unexpected ==, expecting '(', '{', '[', '=' int_flags = 0, "foo" ### unexpected string, expecting int, identifier int_flags2 = ' ### char literal is not terminated int_flags3 = 'a ### char literal is not terminated @@ -67,6 +67,14 @@ s3 { f1 int8 } [attribute[1, "foo"], another[and[another]]] +sCondFieldsError1 { + f0 int16 (out, if[val[mask] SOME_CONST == val[mask]]) ### unexpected identifier, expecting ']' +} ### unexpected '}', expecting comment, define, include, resource, identifier + +sCondFieldsError2 { + f5 int16 (out, if[val[mask] & == val[mask]]) ### unexpected ==, expecting int, identifier, string +} ### unexpected '}', expecting comment, define, include, resource, identifier + type mybool8 int8 type net_port proc[1, 2, int16be] type mybool16 ### unexpected '\n', expecting '[', identifier diff --git a/pkg/ast/walk.go b/pkg/ast/walk.go index b187769f3..2322d32e0 100644 --- a/pkg/ast/walk.go +++ b/pkg/ast/walk.go @@ -11,11 +11,12 @@ func (desc *Description) Walk(cb func(Node)) { } } -func Recursive(cb func(Node)) func(Node) { +func Recursive(cb func(Node) bool) func(Node) { var rec func(Node) rec = func(n Node) { - cb(n) - n.walk(rec) + if cb(n) { + n.walk(rec) + } } return rec } @@ -117,6 +118,9 @@ func (n *Type) walk(cb func(Node)) { for _, t := range n.Args { cb(t) } + if n.Expression != nil { + cb(n.Expression) + } } func (n *Field) walk(cb func(Node)) { @@ -129,3 +133,8 @@ func (n *Field) walk(cb func(Node)) { cb(c) } } + +func (n *BinaryExpression) walk(cb func(Node)) { + cb(n.Left) + cb(n.Right) +} -- cgit mrf-deployment