From 6b52293f4defa6b45b564d037fd641be5d6d0e0e Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Wed, 10 Jan 2018 16:13:34 +0100 Subject: pkg/compiler: support type templates Netlink descriptions contain tons of code duplication, and need much more for proper descriptions. Introduce type templates to simplify writing such descriptions and remove code duplication. Note: type templates are experimental, have poor error handling and are subject to change. Type templates can be declared as follows: ``` type buffer[DIR] ptr[DIR, array[int8]] type fileoff[BASE] BASE type nlattr[TYPE, PAYLOAD] { nla_len len[parent, int16] nla_type const[TYPE, int16] payload PAYLOAD } [align_4] ``` and later used as follows: ``` syscall(a buffer[in], b fileoff[int64], c ptr[in, nlattr[FOO, int32]]) ``` --- pkg/ast/ast.go | 10 +- pkg/ast/clone.go | 137 ++++++++------- pkg/ast/format.go | 76 ++++---- pkg/ast/parser.go | 31 +++- pkg/ast/test_util.go | 7 +- pkg/ast/testdata/all.txt | 19 +- pkg/ast/walk.go | 10 +- pkg/compiler/check.go | 354 +++++++++++++++++++++++++++----------- pkg/compiler/compiler.go | 66 ++++++- pkg/compiler/compiler_test.go | 94 +++++++++- pkg/compiler/consts.go | 258 +++++++++++++-------------- pkg/compiler/consts_test.go | 21 ++- pkg/compiler/gen.go | 3 + pkg/compiler/testdata/all.txt | 60 +++++++ pkg/compiler/testdata/consts.txt | 11 +- pkg/compiler/testdata/errors.txt | 65 ++++--- pkg/compiler/testdata/errors2.txt | 26 +++ pkg/compiler/types.go | 51 +++--- 18 files changed, 889 insertions(+), 410 deletions(-) create mode 100644 pkg/compiler/testdata/all.txt (limited to 'pkg') diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go index 454f28b37..08703ba33 100644 --- a/pkg/ast/ast.go +++ b/pkg/ast/ast.go @@ -21,9 +21,7 @@ type Description struct { type Node interface { Info() (pos Pos, typ string, name string) // Clone makes a deep copy of the node. - // If newPos is not zero, sets Pos of all nodes to newPos. - // If newPos is zero, Pos of nodes is left intact. - Clone(newPos Pos) Node + Clone() Node // Walk calls callback cb for all child nodes of this node. // Note: it's not recursive. Use Recursive helper for recursive walk. Walk(cb func(Node)) @@ -140,7 +138,11 @@ func (n *StrFlags) Info() (Pos, string, string) { type TypeDef struct { Pos Pos Name *Ident - Type *Type + // Non-template type aliases have only Type filled. + // Templates have Args and either Type or Struct filled. + Args []*Ident + Type *Type + Struct *Struct } func (n *TypeDef) Info() (Pos, string, string) { diff --git a/pkg/ast/clone.go b/pkg/ast/clone.go index dcd715c0a..b915c1f33 100644 --- a/pkg/ast/clone.go +++ b/pkg/ast/clone.go @@ -6,86 +6,93 @@ package ast func (desc *Description) Clone() *Description { desc1 := &Description{} for _, n := range desc.Nodes { - desc1.Nodes = append(desc1.Nodes, n.Clone(Pos{})) + desc1.Nodes = append(desc1.Nodes, n.Clone()) } return desc1 } -func selectPos(newPos, oldPos Pos) Pos { - if newPos.File != "" || newPos.Off != 0 || newPos.Line != 0 || newPos.Col != 0 { - return newPos - } - return oldPos -} - -func (n *NewLine) Clone(newPos Pos) Node { +func (n *NewLine) Clone() Node { return &NewLine{ - Pos: selectPos(newPos, n.Pos), + Pos: n.Pos, } } -func (n *Comment) Clone(newPos Pos) Node { +func (n *Comment) Clone() Node { return &Comment{ - Pos: selectPos(newPos, n.Pos), + Pos: n.Pos, Text: n.Text, } } -func (n *Include) Clone(newPos Pos) Node { +func (n *Include) Clone() Node { return &Include{ - Pos: selectPos(newPos, n.Pos), - File: n.File.Clone(newPos).(*String), + Pos: n.Pos, + File: n.File.Clone().(*String), } } -func (n *Incdir) Clone(newPos Pos) Node { +func (n *Incdir) Clone() Node { return &Incdir{ - Pos: selectPos(newPos, n.Pos), - Dir: n.Dir.Clone(newPos).(*String), + Pos: n.Pos, + Dir: n.Dir.Clone().(*String), } } -func (n *Define) Clone(newPos Pos) Node { +func (n *Define) Clone() Node { return &Define{ - Pos: selectPos(newPos, n.Pos), - Name: n.Name.Clone(newPos).(*Ident), - Value: n.Value.Clone(newPos).(*Int), + Pos: n.Pos, + Name: n.Name.Clone().(*Ident), + Value: n.Value.Clone().(*Int), } } -func (n *Resource) Clone(newPos Pos) Node { +func (n *Resource) Clone() Node { var values []*Int for _, v := range n.Values { - values = append(values, v.Clone(newPos).(*Int)) + values = append(values, v.Clone().(*Int)) } return &Resource{ - Pos: selectPos(newPos, n.Pos), - Name: n.Name.Clone(newPos).(*Ident), - Base: n.Base.Clone(newPos).(*Type), + Pos: n.Pos, + Name: n.Name.Clone().(*Ident), + Base: n.Base.Clone().(*Type), Values: values, } } -func (n *TypeDef) Clone(newPos Pos) Node { +func (n *TypeDef) Clone() Node { + var args []*Ident + for _, v := range n.Args { + args = append(args, v.Clone().(*Ident)) + } + var typ *Type + if n.Type != nil { + typ = n.Type.Clone().(*Type) + } + var str *Struct + if n.Struct != nil { + str = n.Struct.Clone().(*Struct) + } return &TypeDef{ - Pos: selectPos(newPos, n.Pos), - Name: n.Name.Clone(newPos).(*Ident), - Type: n.Type.Clone(newPos).(*Type), + Pos: n.Pos, + Name: n.Name.Clone().(*Ident), + Args: args, + Type: typ, + Struct: str, } } -func (n *Call) Clone(newPos Pos) Node { +func (n *Call) Clone() Node { var args []*Field for _, a := range n.Args { - args = append(args, a.Clone(newPos).(*Field)) + args = append(args, a.Clone().(*Field)) } var ret *Type if n.Ret != nil { - ret = n.Ret.Clone(newPos).(*Type) + ret = n.Ret.Clone().(*Type) } return &Call{ - Pos: selectPos(newPos, n.Pos), - Name: n.Name.Clone(newPos).(*Ident), + Pos: n.Pos, + Name: n.Name.Clone().(*Ident), CallName: n.CallName, NR: n.NR, Args: args, @@ -93,22 +100,22 @@ func (n *Call) Clone(newPos Pos) Node { } } -func (n *Struct) Clone(newPos Pos) Node { +func (n *Struct) Clone() Node { var fields []*Field for _, f := range n.Fields { - fields = append(fields, f.Clone(newPos).(*Field)) + fields = append(fields, f.Clone().(*Field)) } var attrs []*Ident for _, a := range n.Attrs { - attrs = append(attrs, a.Clone(newPos).(*Ident)) + attrs = append(attrs, a.Clone().(*Ident)) } var comments []*Comment for _, c := range n.Comments { - comments = append(comments, c.Clone(newPos).(*Comment)) + comments = append(comments, c.Clone().(*Comment)) } return &Struct{ - Pos: selectPos(newPos, n.Pos), - Name: n.Name.Clone(newPos).(*Ident), + Pos: n.Pos, + Name: n.Name.Clone().(*Ident), Fields: fields, Attrs: attrs, Comments: comments, @@ -116,47 +123,47 @@ func (n *Struct) Clone(newPos Pos) Node { } } -func (n *IntFlags) Clone(newPos Pos) Node { +func (n *IntFlags) Clone() Node { var values []*Int for _, v := range n.Values { - values = append(values, v.Clone(newPos).(*Int)) + values = append(values, v.Clone().(*Int)) } return &IntFlags{ - Pos: selectPos(newPos, n.Pos), - Name: n.Name.Clone(newPos).(*Ident), + Pos: n.Pos, + Name: n.Name.Clone().(*Ident), Values: values, } } -func (n *StrFlags) Clone(newPos Pos) Node { +func (n *StrFlags) Clone() Node { var values []*String for _, v := range n.Values { - values = append(values, v.Clone(newPos).(*String)) + values = append(values, v.Clone().(*String)) } return &StrFlags{ - Pos: selectPos(newPos, n.Pos), - Name: n.Name.Clone(newPos).(*Ident), + Pos: n.Pos, + Name: n.Name.Clone().(*Ident), Values: values, } } -func (n *Ident) Clone(newPos Pos) Node { +func (n *Ident) Clone() Node { return &Ident{ - Pos: selectPos(newPos, n.Pos), + Pos: n.Pos, Name: n.Name, } } -func (n *String) Clone(newPos Pos) Node { +func (n *String) Clone() Node { return &String{ - Pos: selectPos(newPos, n.Pos), + Pos: n.Pos, Value: n.Value, } } -func (n *Int) Clone(newPos Pos) Node { +func (n *Int) Clone() Node { return &Int{ - Pos: selectPos(newPos, n.Pos), + Pos: n.Pos, Value: n.Value, ValueHex: n.ValueHex, Ident: n.Ident, @@ -164,19 +171,19 @@ func (n *Int) Clone(newPos Pos) Node { } } -func (n *Type) Clone(newPos Pos) Node { +func (n *Type) Clone() Node { var args []*Type for _, a := range n.Args { - args = append(args, a.Clone(newPos).(*Type)) + args = append(args, a.Clone().(*Type)) } return &Type{ - Pos: selectPos(newPos, n.Pos), + Pos: n.Pos, Value: n.Value, ValueHex: n.ValueHex, Ident: n.Ident, String: n.String, HasColon: n.HasColon, - Pos2: selectPos(newPos, n.Pos2), + Pos2: n.Pos2, Value2: n.Value2, Value2Hex: n.Value2Hex, Ident2: n.Ident2, @@ -184,15 +191,15 @@ func (n *Type) Clone(newPos Pos) Node { } } -func (n *Field) Clone(newPos Pos) Node { +func (n *Field) Clone() Node { var comments []*Comment for _, c := range n.Comments { - comments = append(comments, c.Clone(newPos).(*Comment)) + comments = append(comments, c.Clone().(*Comment)) } return &Field{ - Pos: selectPos(newPos, n.Pos), - Name: n.Name.Clone(newPos).(*Ident), - Type: n.Type.Clone(newPos).(*Type), + Pos: n.Pos, + Name: n.Name.Clone().(*Ident), + Type: n.Type.Clone().(*Type), NewBlock: n.NewBlock, Comments: comments, } diff --git a/pkg/ast/format.go b/pkg/ast/format.go index a77662df8..c1dd3a624 100644 --- a/pkg/ast/format.go +++ b/pkg/ast/format.go @@ -25,6 +25,16 @@ func FormatWriter(w io.Writer, desc *Description) { } } +func SerializeNode(n Node) string { + s, ok := n.(serializer) + if !ok { + panic(fmt.Sprintf("unknown node: %#v", n)) + } + buf := new(bytes.Buffer) + s.serialize(buf) + return buf.String() +} + type serializer interface { serialize(w io.Writer) } @@ -52,27 +62,25 @@ func (def *Define) serialize(w io.Writer) { func (res *Resource) serialize(w io.Writer) { fmt.Fprintf(w, "resource %v[%v]", res.Name.Name, fmtType(res.Base)) for i, v := range res.Values { - if i == 0 { - fmt.Fprintf(w, ": ") - } else { - fmt.Fprintf(w, ", ") - } - fmt.Fprintf(w, "%v", fmtInt(v)) + fmt.Fprintf(w, "%v%v", comma(i, ": "), fmtInt(v)) } fmt.Fprintf(w, "\n") } func (typedef *TypeDef) serialize(w io.Writer) { - fmt.Fprintf(w, "type %v %v\n", typedef.Name.Name, fmtType(typedef.Type)) + fmt.Fprintf(w, "type %v%v", typedef.Name.Name, fmtIdentList(typedef.Args, false)) + if typedef.Type != nil { + fmt.Fprintf(w, " %v\n", fmtType(typedef.Type)) + } + if typedef.Struct != nil { + typedef.Struct.serialize(w) + } } func (c *Call) serialize(w io.Writer) { fmt.Fprintf(w, "%v(", c.Name.Name) for i, a := range c.Args { - if i != 0 { - fmt.Fprintf(w, ", ") - } - fmt.Fprintf(w, "%v", fmtField(a)) + fmt.Fprintf(w, "%v%v", comma(i, ""), fmtField(a)) } fmt.Fprintf(w, ")") if c.Ret != nil { @@ -112,24 +120,13 @@ func (str *Struct) serialize(w io.Writer) { for _, com := range str.Comments { fmt.Fprintf(w, "#%v\n", com.Text) } - fmt.Fprintf(w, "%c", closing) - if len(str.Attrs) != 0 { - fmt.Fprintf(w, " [") - for i, attr := range str.Attrs { - fmt.Fprintf(w, "%v%v", comma(i), attr.Name) - } - fmt.Fprintf(w, "]") - } - fmt.Fprintf(w, "\n") + fmt.Fprintf(w, "%c%v\n", closing, fmtIdentList(str.Attrs, true)) } func (flags *IntFlags) serialize(w io.Writer) { fmt.Fprintf(w, "%v = ", flags.Name.Name) for i, v := range flags.Values { - if i != 0 { - fmt.Fprintf(w, ", ") - } - fmt.Fprintf(w, "%v", fmtInt(v)) + fmt.Fprintf(w, "%v%v", comma(i, ""), fmtInt(v)) } fmt.Fprintf(w, "\n") } @@ -137,10 +134,7 @@ func (flags *IntFlags) serialize(w io.Writer) { func (flags *StrFlags) serialize(w io.Writer) { fmt.Fprintf(w, "%v = ", flags.Name.Name) for i, v := range flags.Values { - if i != 0 { - fmt.Fprintf(w, ", ") - } - fmt.Fprintf(w, "\"%v\"", v.Value) + fmt.Fprintf(w, "%v\"%v\"", comma(i, ""), v.Value) } fmt.Fprintf(w, "\n") } @@ -149,6 +143,10 @@ func fmtField(f *Field) string { return fmt.Sprintf("%v %v", f.Name.Name, fmtType(f.Type)) } +func (n *Type) serialize(w io.Writer) { + w.Write([]byte(fmtType(n))) +} + func fmtType(t *Type) string { v := "" switch { @@ -178,7 +176,23 @@ func fmtTypeList(args []*Type) string { w := new(bytes.Buffer) fmt.Fprintf(w, "[") for i, t := range args { - fmt.Fprintf(w, "%v%v", comma(i), fmtType(t)) + fmt.Fprintf(w, "%v%v", comma(i, ""), fmtType(t)) + } + fmt.Fprintf(w, "]") + return w.String() +} + +func fmtIdentList(args []*Ident, space bool) string { + if len(args) == 0 { + return "" + } + w := new(bytes.Buffer) + if space { + fmt.Fprintf(w, " ") + } + fmt.Fprintf(w, "[") + for i, arg := range args { + fmt.Fprintf(w, "%v%v", comma(i, ""), arg.Name) } fmt.Fprintf(w, "]") return w.String() @@ -202,9 +216,9 @@ func fmtIntValue(v uint64, hex bool) string { return fmt.Sprint(v) } -func comma(i int) string { +func comma(i int, or string) string { if i == 0 { - return "" + return or } return ", " } diff --git a/pkg/ast/parser.go b/pkg/ast/parser.go index db211ab2a..bd5650ad5 100644 --- a/pkg/ast/parser.go +++ b/pkg/ast/parser.go @@ -251,11 +251,34 @@ func (p *parser) parseResource() *Resource { func (p *parser) parseTypeDef() *TypeDef { pos0 := p.pos name := p.parseIdent() - typ := p.parseType() + var typ *Type + var str *Struct + var args []*Ident + p.expect(tokLBrack, tokIdent) + if p.tryConsume(tokLBrack) { + args = append(args, p.parseIdent()) + for p.tryConsume(tokComma) { + args = append(args, p.parseIdent()) + } + p.consume(tokRBrack) + if p.tok == tokLBrace || p.tok == tokLBrack { + name := &Ident{ + Pos: pos0, + Name: "", + } + str = p.parseStruct(name) + } else { + typ = p.parseType() + } + } else { + typ = p.parseType() + } return &TypeDef{ - Pos: pos0, - Name: name, - Type: typ, + Pos: pos0, + Name: name, + Args: args, + Type: typ, + Struct: str, } } diff --git a/pkg/ast/test_util.go b/pkg/ast/test_util.go index 0aed0a2dc..b9fe12152 100644 --- a/pkg/ast/test_util.go +++ b/pkg/ast/test_util.go @@ -7,6 +7,7 @@ import ( "bufio" "bytes" "io/ioutil" + "path/filepath" "strings" "testing" ) @@ -41,7 +42,7 @@ func NewErrorMatcher(t *testing.T, file string) *ErrorMatcher { break } errors = append(errors, &errorDesc{ - file: file, + file: filepath.Base(file), line: i, text: strings.TrimSpace(string(ln[pos+3:])), }) @@ -82,13 +83,13 @@ nextErr: want.matched = true continue nextErr } - t.Errorf("unexpected error: %v:%v:%v: %v", e.file, e.line, e.col, e.text) + t.Errorf("unexpected error:\n%v:%v:%v: %v", e.file, e.line, e.col, e.text) } for _, want := range em.expect { if want.matched { continue } - t.Errorf("unmatched error: %v:%v: %v", want.file, want.line, want.text) + t.Errorf("unmatched error:\n%v:%v: %v", want.file, want.line, want.text) } } diff --git a/pkg/ast/testdata/all.txt b/pkg/ast/testdata/all.txt index 268b49a47..d4452b34f 100644 --- a/pkg/ast/testdata/all.txt +++ b/pkg/ast/testdata/all.txt @@ -48,5 +48,20 @@ s2 { type mybool8 int8 type net_port proc[1, 2, int16be] -type mybool16 ### unexpected '\n', expecting int, identifier, string -type type4:4 int32 ### unexpected ':', expecting int, identifier, string +type mybool16 ### unexpected '\n', expecting '[', identifier +type type4:4 int32 ### unexpected ':', expecting '[', identifier + +type templ0[] int8 ### unexpected ']', expecting identifier +type templ1[A,] int8 ### unexpected ']', expecting identifier +type templ2[,] int8 ### unexpected ',', expecting identifier +type templ3[ ### unexpected '\n', expecting identifier +type templ4[A] ### unexpected '\n', expecting int, identifier, string +type templ5[A] const[A] +type templ6[A, B] const[A, B] +type templ7[0] ptr[in, int8] ### unexpected int, expecting identifier + +type templ_struct0[A, B] { + len len[parent, int16] + typ const[A, int16] + data B +} [align_4] diff --git a/pkg/ast/walk.go b/pkg/ast/walk.go index fd5065013..fe9112578 100644 --- a/pkg/ast/walk.go +++ b/pkg/ast/walk.go @@ -49,7 +49,15 @@ func (n *Resource) Walk(cb func(Node)) { func (n *TypeDef) Walk(cb func(Node)) { cb(n.Name) - cb(n.Type) + for _, a := range n.Args { + cb(a) + } + if n.Type != nil { + cb(n.Type) + } + if n.Struct != nil { + cb(n.Struct) + } } func (n *Call) Walk(cb func(Node)) { diff --git a/pkg/compiler/check.go b/pkg/compiler/check.go index e9ec872f5..923c4fb19 100644 --- a/pkg/compiler/check.go +++ b/pkg/compiler/check.go @@ -12,16 +12,15 @@ import ( "github.com/google/syzkaller/prog" ) -func (comp *compiler) check() { +func (comp *compiler) typecheck() { comp.checkNames() comp.checkFields() + comp.checkTypedefs() comp.checkTypes() - // The subsequent, more complex, checks expect basic validity of the tree, - // in particular corrent number of type arguments. If there were errors, - // don't proceed to avoid out-of-bounds references to type arguments. - if comp.errors != 0 { - return - } +} + +func (comp *compiler) check() { + comp.checkConsts() comp.checkUsed() comp.checkRecursion() comp.checkLenTargets() @@ -31,9 +30,33 @@ func (comp *compiler) check() { } func (comp *compiler) checkNames() { + includes := make(map[string]bool) + incdirs := make(map[string]bool) + defines := make(map[string]bool) calls := make(map[string]*ast.Call) for _, decl := range comp.desc.Nodes { - switch decl.(type) { + switch n := decl.(type) { + case *ast.Include: + name := n.File.Value + path := n.Pos.File + "/" + name + if includes[path] { + comp.error(n.Pos, "duplicate include %q", name) + } + includes[path] = true + case *ast.Incdir: + name := n.Dir.Value + path := n.Pos.File + "/" + name + if incdirs[path] { + comp.error(n.Pos, "duplicate incdir %q", name) + } + incdirs[path] = true + case *ast.Define: + name := n.Name.Name + path := n.Pos.File + "/" + name + if defines[path] { + comp.error(n.Pos, "duplicate define %v", name) + } + defines[path] = true case *ast.Resource, *ast.Struct, *ast.TypeDef: pos, typ, name := decl.Info() if reservedName[name] { @@ -68,7 +91,6 @@ func (comp *compiler) checkNames() { comp.structs[name] = str } case *ast.IntFlags: - n := decl.(*ast.IntFlags) name := n.Name.Name if reservedName[name] { comp.error(n.Pos, "flags uses reserved name %v", name) @@ -81,7 +103,6 @@ func (comp *compiler) checkNames() { } comp.intFlags[name] = n case *ast.StrFlags: - n := decl.(*ast.StrFlags) name := n.Name.Name if reservedName[name] { comp.error(n.Pos, "string flags uses reserved name %v", name) @@ -94,13 +115,12 @@ func (comp *compiler) checkNames() { } comp.strFlags[name] = n case *ast.Call: - c := decl.(*ast.Call) - name := c.Name.Name + name := n.Name.Name if prev := calls[name]; prev != nil { - comp.error(c.Pos, "syscall %v redeclared, previously declared at %v", + comp.error(n.Pos, "syscall %v redeclared, previously declared at %v", name, prev.Pos) } - calls[name] = c + calls[name] = n } } } @@ -111,17 +131,7 @@ func (comp *compiler) checkFields() { switch n := decl.(type) { case *ast.Struct: _, typ, name := n.Info() - fields := make(map[string]bool) - for _, f := range n.Fields { - fn := f.Name.Name - if fn == "parent" { - comp.error(f.Pos, "reserved field name %v in %v %v", fn, typ, name) - } - if fields[fn] { - comp.error(f.Pos, "duplicate field %v in %v %v", fn, typ, name) - } - fields[fn] = true - } + comp.checkFieldGroup(n.Fields, "field", typ+" "+name) if !n.IsUnion && len(n.Fields) < 1 { comp.error(n.Pos, "struct %v has no fields, need at least 1 field", name) } @@ -131,19 +141,7 @@ func (comp *compiler) checkFields() { } case *ast.Call: name := n.Name.Name - args := make(map[string]bool) - for _, a := range n.Args { - an := a.Name.Name - if an == "parent" { - comp.error(a.Pos, "reserved argument name %v in syscall %v", - an, name) - } - if args[an] { - comp.error(a.Pos, "duplicate argument %v in syscall %v", - an, name) - } - args[an] = true - } + comp.checkFieldGroup(n.Args, "argument", "syscall "+name) if len(n.Args) > maxArgs { comp.error(n.Pos, "syscall %v has %v arguments, allowed maximum is %v", name, len(n.Args), maxArgs) @@ -152,40 +150,93 @@ func (comp *compiler) checkFields() { } } -func (comp *compiler) checkTypes() { +func (comp *compiler) checkFieldGroup(fields []*ast.Field, what, ctx string) { + existing := make(map[string]bool) + for _, f := range fields { + fn := f.Name.Name + if fn == "parent" { + comp.error(f.Pos, "reserved %v name %v in %v", what, fn, ctx) + } + if existing[fn] { + comp.error(f.Pos, "duplicate %v %v in %v", what, fn, ctx) + } + existing[fn] = true + } +} + +func (comp *compiler) checkTypedefs() { for _, decl := range comp.desc.Nodes { switch n := decl.(type) { case *ast.TypeDef: - if comp.typedefs[n.Name.Name] == nil { - continue - } - err0 := comp.errors - comp.checkType(n.Type, false, false, false, false, true, true) - if err0 != comp.errors { - delete(comp.typedefs, n.Name.Name) + if len(n.Args) == 0 { + // Non-template types are fully typed, so we check them ahead of time. + err0 := comp.errors + comp.checkType(checkCtx{}, n.Type, checkIsTypedef) + if err0 != comp.errors { + // To not produce confusing errors on broken type usage. + delete(comp.typedefs, n.Name.Name) + } + } else { + // For templates we only do basic checks of arguments. + names := make(map[string]bool) + for _, arg := range n.Args { + if names[arg.Name] { + comp.error(arg.Pos, "duplicate type argument %v", arg.Name) + } + names[arg.Name] = true + for _, c := range arg.Name { + if c >= 'A' && c <= 'Z' || + c >= '0' && c <= '9' || + c == '_' { + continue + } + comp.error(arg.Pos, "type argument %v must be ALL_CAPS", + arg.Name) + break + } + } } } } +} + +func (comp *compiler) checkTypes() { for _, decl := range comp.desc.Nodes { switch n := decl.(type) { case *ast.Resource: - comp.checkType(n.Base, false, false, false, true, false, false) + comp.checkType(checkCtx{}, n.Base, checkIsResourceBase) case *ast.Struct: - for _, f := range n.Fields { - comp.checkType(f.Type, false, false, !n.IsUnion, false, false, false) - } - comp.checkStruct(n) + comp.checkStruct(checkCtx{}, n) case *ast.Call: for _, a := range n.Args { - comp.checkType(a.Type, true, false, false, false, false, false) + comp.checkType(checkCtx{}, a.Type, checkIsArg) } if n.Ret != nil { - comp.checkType(n.Ret, true, true, false, false, false, false) + comp.checkType(checkCtx{}, n.Ret, checkIsArg|checkIsRet|checkIsRetCtx) } } } } +func (comp *compiler) checkConsts() { + for _, decl := range comp.desc.Nodes { + switch decl.(type) { + case *ast.Call, *ast.Struct, *ast.Resource, *ast.TypeDef: + comp.foreachType(decl, func(t *ast.Type, desc *typeDesc, + args []*ast.Type, base prog.IntTypeCommon) { + if desc.CheckConsts != nil { + desc.CheckConsts(comp, t, args, base) + } + for i, arg := range args { + if check := desc.Args[i].Type.CheckConsts; check != nil { + check(comp, arg) + } + } + }) + } + } +} + func (comp *compiler) checkLenTargets() { for _, decl := range comp.desc.Nodes { switch n := decl.(type) { @@ -471,7 +522,14 @@ func (comp *compiler) recurseField(checked map[string]bool, t *ast.Type, path [] } } -func (comp *compiler) checkStruct(n *ast.Struct) { +func (comp *compiler) checkStruct(ctx checkCtx, n *ast.Struct) { + var flags checkFlags + if !n.IsUnion { + flags |= checkIsStruct + } + for _, f := range n.Fields { + comp.checkType(ctx, f.Type, flags) + } if n.IsUnion { comp.parseUnionAttrs(n) } else { @@ -479,7 +537,22 @@ func (comp *compiler) checkStruct(n *ast.Struct) { } } -func (comp *compiler) checkType(t *ast.Type, isArg, isRet, isStruct, isResourceBase, isTypedef, isTypedefCtx bool) { +type checkFlags int + +const ( + checkIsArg checkFlags = 1 << iota // immidiate syscall arg type + checkIsRet // immidiate syscall ret type + checkIsRetCtx // inside of syscall ret type + checkIsStruct // immidiate struct field type + checkIsResourceBase // immidiate resource base type + checkIsTypedef // immidiate type alias/template type +) + +type checkCtx struct { + instantiationStack []string +} + +func (comp *compiler) checkType(ctx checkCtx, t *ast.Type, flags checkFlags) { if unexpected, _, ok := checkTypeKind(t, kindIdent); !ok { comp.error(t.Pos, "unexpected %v, expect type", unexpected) return @@ -490,29 +563,13 @@ func (comp *compiler) checkType(t *ast.Type, isArg, isRet, isStruct, isResourceB return } if desc == typeTypedef { - if isTypedefCtx { - comp.error(t.Pos, "type aliases can't refer to other type aliases") - return - } - if t.HasColon { - comp.error(t.Pos, "type alias %v with ':'", t.Ident) - return - } - if len(t.Args) != 0 { - comp.error(t.Pos, "type alias %v with arguments", t.Ident) - return - } - *t = *comp.typedefs[t.Ident].Type.Clone(t.Pos).(*ast.Type) - desc = comp.getTypeDesc(t) - if isArg && desc.NeedBase { - baseTypePos := len(t.Args) - 1 - if t.Args[baseTypePos].Ident == "opt" { - baseTypePos-- - } - copy(t.Args[baseTypePos:], t.Args[baseTypePos+1:]) - t.Args = t.Args[:len(t.Args)-1] + err0 := comp.errors + // Replace t with type alias/template target type inplace, + // and check the replaced type recursively. + comp.replaceTypedef(&ctx, t, desc, flags) + if err0 == comp.errors { + comp.checkType(ctx, t, flags) } - comp.checkType(t, isArg, isRet, isStruct, isResourceBase, isTypedef, isTypedefCtx) return } if t.HasColon { @@ -520,39 +577,47 @@ func (comp *compiler) checkType(t *ast.Type, isArg, isRet, isStruct, isResourceB comp.error(t.Pos2, "unexpected ':'") return } - if !isStruct { + if flags&checkIsStruct == 0 { comp.error(t.Pos2, "unexpected ':', only struct fields can be bitfields") return } } - if isRet && (!desc.CanBeArg || desc.CantBeRet) { + if flags&checkIsRet != 0 && (!desc.CanBeArg || desc.CantBeRet) { comp.error(t.Pos, "%v can't be syscall return", t.Ident) return } - if isArg && !desc.CanBeArg { + if flags&checkIsRetCtx != 0 && desc.CantBeRet { + comp.error(t.Pos, "%v can't be used in syscall return", t.Ident) + return + } + if flags&checkIsArg != 0 && !desc.CanBeArg { comp.error(t.Pos, "%v can't be syscall argument", t.Ident) return } - if isTypedef && !desc.CanBeTypedef { + if flags&checkIsTypedef != 0 && !desc.CanBeTypedef { comp.error(t.Pos, "%v can't be type alias target", t.Ident) return } - if isResourceBase && !desc.ResourceBase { + if flags&checkIsResourceBase != 0 && !desc.ResourceBase { comp.error(t.Pos, "%v can't be resource base (int types can)", t.Ident) return } args, opt := removeOpt(t) - if opt && (desc.CantBeOpt || isResourceBase) { - what := "resource base" - if desc.CantBeOpt { - what = t.Ident + if opt != nil { + if len(opt.Args) != 0 { + comp.error(opt.Pos, "opt can't have arguments") + } + if flags&checkIsResourceBase != 0 || desc.CantBeOpt { + what := "resource base" + if desc.CantBeOpt { + what = t.Ident + } + comp.error(opt.Pos, "%v can't be marked as opt", what) + return } - pos := t.Args[len(t.Args)-1].Pos - comp.error(pos, "%v can't be marked as opt", what) - return } addArgs := 0 - needBase := !isArg && desc.NeedBase + needBase := flags&checkIsArg == 0 && desc.NeedBase if needBase { addArgs++ // last arg must be base type, e.g. const[0, int32] } @@ -569,18 +634,111 @@ func (comp *compiler) checkType(t *ast.Type, isArg, isRet, isStruct, isResourceB err0 := comp.errors for i, arg := range args { if desc.Args[i].Type == typeArgType { - comp.checkType(arg, false, isRet, false, false, false, isTypedefCtx) + comp.checkType(ctx, arg, flags&checkIsRetCtx) } else { comp.checkTypeArg(t, arg, desc.Args[i]) } } - if err0 != comp.errors { + if desc.Check != nil && err0 == comp.errors { + _, args, base := comp.getArgsBase(t, "", prog.DirIn, flags&checkIsArg != 0) + desc.Check(comp, t, args, base) + } +} + +func (comp *compiler) replaceTypedef(ctx *checkCtx, t *ast.Type, desc *typeDesc, flags checkFlags) { + typedefName := t.Ident + if t.HasColon { + comp.error(t.Pos, "type alias %v with ':'", t.Ident) return } - if desc.Check != nil { - _, args, base := comp.getArgsBase(t, "", prog.DirIn, isArg) - desc.Check(comp, t, args, base) + typedef := comp.typedefs[typedefName] + fullTypeName := ast.SerializeNode(t) + for i, prev := range ctx.instantiationStack { + if prev == fullTypeName { + ctx.instantiationStack = append(ctx.instantiationStack, fullTypeName) + path := "" + for j := i; j < len(ctx.instantiationStack); j++ { + if j != i { + path += " -> " + } + path += ctx.instantiationStack[j] + } + comp.error(t.Pos, "type instantiation loop: %v", path) + return + } } + ctx.instantiationStack = append(ctx.instantiationStack, fullTypeName) + nargs := len(typedef.Args) + args := t.Args + for _, arg := range args { + if arg.String != "" { + comp.error(arg.Pos, "template arguments can't be strings (%q)", arg.String) + return + } + } + if nargs != len(t.Args) { + if nargs == 0 { + comp.error(t.Pos, "type %v is not a template", typedefName) + } else { + comp.error(t.Pos, "template %v needs %v arguments instead of %v", + typedefName, nargs, len(t.Args)) + } + return + } + if typedef.Type != nil { + *t = *typedef.Type.Clone().(*ast.Type) + comp.instantiate(t, typedef.Args, args) + } else { + if comp.structs[fullTypeName] == nil { + inst := typedef.Struct.Clone().(*ast.Struct) + inst.Name.Name = fullTypeName + comp.instantiate(inst, typedef.Args, args) + comp.checkStruct(*ctx, inst) + comp.desc.Nodes = append(comp.desc.Nodes, inst) + comp.structs[fullTypeName] = inst + } + *t = ast.Type{ + Pos: t.Pos, + Ident: fullTypeName, + } + } + + // Remove base type if it's not needed in this context. + desc = comp.getTypeDesc(t) + if flags&checkIsArg != 0 && desc.NeedBase { + baseTypePos := len(t.Args) - 1 + if t.Args[baseTypePos].Ident == "opt" { + baseTypePos-- + } + copy(t.Args[baseTypePos:], t.Args[baseTypePos+1:]) + t.Args = t.Args[:len(t.Args)-1] + } +} + +func (comp *compiler) instantiate(templ ast.Node, params []*ast.Ident, args []*ast.Type) { + if len(params) == 0 { + return + } + argMap := make(map[string]*ast.Type) + for i, param := range params { + argMap[param.Name] = args[i] + } + templ.Walk(ast.Recursive(func(n ast.Node) { + templArg, ok := n.(*ast.Type) + if !ok { + return + } + if concreteArg := argMap[templArg.Ident]; concreteArg != nil { + *templArg = *concreteArg.Clone().(*ast.Type) + } + // TODO(dvyukov): somewhat hacky, but required for int8[0:CONST_ARG] + // Need more checks here. E.g. that CONST_ARG does not have subargs. + // And if CONST_ARG is a value, then use concreteArg.Value. + if concreteArg := argMap[templArg.Ident2]; concreteArg != nil { + templArg.Ident2 = concreteArg.Ident + templArg.Pos2 = concreteArg.Pos + } + })) } func (comp *compiler) checkTypeArg(t, arg *ast.Type, argDesc namedArg) { @@ -655,7 +813,7 @@ func checkTypeKind(t *ast.Type, kind int) (unexpected string, expect string, ok unexpected = fmt.Sprintf("string %q", t.String) } case t.Ident != "": - ok = kind == kindIdent + ok = kind == kindIdent || kind == kindInt if !ok { unexpected = fmt.Sprintf("identifier %v", t.Ident) } @@ -688,8 +846,8 @@ func (comp *compiler) checkVarlens() { } func (comp *compiler) isVarlen(t *ast.Type) bool { - desc, args, base := comp.getArgsBase(t, "", prog.DirIn, false) - return desc.Varlen != nil && desc.Varlen(comp, t, args, base) + desc, args, _ := comp.getArgsBase(t, "", prog.DirIn, false) + return desc.Varlen != nil && desc.Varlen(comp, t, args) } func (comp *compiler) checkVarlen(n *ast.Struct) { diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 6019bda94..98348b0f8 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -39,6 +39,8 @@ type Prog struct { StructDescs []*prog.KeyedStruct // Set of unsupported syscalls/flags. Unsupported map[string]bool + // Returned if consts was nil. + fileConsts map[string]*ConstInfo } // Compile compiles sys description. @@ -65,6 +67,20 @@ func Compile(desc *ast.Description, consts map[string]uint64, target *targets.Ta for name, typedef := range builtinTypedefs { comp.typedefs[name] = typedef } + comp.typecheck() + // The subsequent, more complex, checks expect basic validity of the tree, + // in particular corrent number of type arguments. If there were errors, + // don't proceed to avoid out-of-bounds references to type arguments. + if comp.errors != 0 { + return nil + } + if consts == nil { + fileConsts := comp.extractConsts() + if comp.errors != 0 { + return nil + } + return &Prog{fileConsts: fileConsts} + } comp.assignSyscallNumbers(consts) comp.patchConsts(consts) comp.check() @@ -177,9 +193,11 @@ func (comp *compiler) getTypeDesc(t *ast.Type) *typeDesc { func (comp *compiler) getArgsBase(t *ast.Type, field string, dir prog.Dir, isArg bool) ( *typeDesc, []*ast.Type, prog.IntTypeCommon) { desc := comp.getTypeDesc(t) + if desc == nil { + panic(fmt.Sprintf("no type desc for %#v", *t)) + } args, opt := removeOpt(t) - size := sizeUnassigned - com := genCommon(t.Ident, field, size, dir, opt) + com := genCommon(t.Ident, field, sizeUnassigned, dir, opt != nil) base := genIntCommon(com, 0, false) if desc.NeedBase { base.TypeSize = comp.ptrSize @@ -192,12 +210,48 @@ func (comp *compiler) getArgsBase(t *ast.Type, field string, dir prog.Dir, isArg return desc, args, base } -func removeOpt(t *ast.Type) ([]*ast.Type, bool) { +func (comp *compiler) foreachType(n0 ast.Node, + cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) { + switch n := n0.(type) { + case *ast.Call: + for _, arg := range n.Args { + comp.foreachSubType(arg.Type, true, cb) + } + if n.Ret != nil { + comp.foreachSubType(n.Ret, true, cb) + } + case *ast.Resource: + comp.foreachSubType(n.Base, false, cb) + case *ast.Struct: + for _, f := range n.Fields { + comp.foreachSubType(f.Type, false, cb) + } + case *ast.TypeDef: + if len(n.Args) == 0 { + comp.foreachSubType(n.Type, false, cb) + } + default: + panic(fmt.Sprintf("unexpected node %#v", n0)) + } +} + +func (comp *compiler) foreachSubType(t *ast.Type, isArg bool, + cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) { + desc, args, base := comp.getArgsBase(t, "", prog.DirIn, isArg) + cb(t, desc, args, base) + for i, arg := range args { + if desc.Args[i].Type == typeArgType { + comp.foreachSubType(arg, false, cb) + } + } +} + +func removeOpt(t *ast.Type) ([]*ast.Type, *ast.Type) { args := t.Args - if len(args) != 0 && args[len(args)-1].Ident == "opt" { - return args[:len(args)-1], true + if last := len(args) - 1; last >= 0 && args[last].Ident == "opt" { + return args[:last], args[last] } - return args, false + return args, nil } func (comp *compiler) parseIntType(name string) (size uint64, bigEndian bool) { diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go index f26c272e6..a5cc6a86a 100644 --- a/pkg/compiler/compiler_test.go +++ b/pkg/compiler/compiler_test.go @@ -4,10 +4,14 @@ package compiler import ( + "bytes" + "fmt" + "io/ioutil" "path/filepath" "testing" "github.com/google/syzkaller/pkg/ast" + "github.com/google/syzkaller/pkg/serializer" "github.com/google/syzkaller/sys/targets" ) @@ -42,24 +46,98 @@ func TestCompileAll(t *testing.T) { } } -func TestErrors(t *testing.T) { +func TestNoErrors(t *testing.T) { + t.Parallel() consts := map[string]uint64{ "__NR_foo": 1, "C0": 0, "C1": 1, "C2": 2, } - target := targets.List["test"]["64"] - for _, name := range []string{"errors.txt", "errors2.txt"} { - name := name - t.Run(name, func(t *testing.T) { - em := ast.NewErrorMatcher(t, filepath.Join("testdata", name)) - desc := ast.Parse(em.Data, name, em.ErrorHandler) + for _, name := range []string{"all.txt"} { + for _, arch := range []string{"32", "64"} { + name, arch := name, arch + t.Run(fmt.Sprintf("%v/%v", name, arch), func(t *testing.T) { + t.Parallel() + target := targets.List["test"][arch] + eh := func(pos ast.Pos, msg string) { + t.Logf("%v: %v", pos, msg) + } + data, err := ioutil.ReadFile(filepath.Join("testdata", name)) + if err != nil { + t.Fatal(err) + } + astDesc := ast.Parse(data, name, eh) + if astDesc == nil { + t.Fatalf("parsing failed") + } + constInfo := ExtractConsts(astDesc, target, eh) + if constInfo == nil { + t.Fatalf("const extraction failed") + } + desc := Compile(astDesc, consts, target, eh) + if desc == nil { + t.Fatalf("compilation failed") + } + if len(desc.Unsupported) != 0 { + t.Fatalf("something is unsupported:\n%+v", desc.Unsupported) + } + out := new(bytes.Buffer) + fmt.Fprintf(out, "\n\nRESOURCES:\n") + serializer.Write(out, desc.Resources) + fmt.Fprintf(out, "\n\nSTRUCTS:\n") + serializer.Write(out, desc.StructDescs) + fmt.Fprintf(out, "\n\nSYSCALLS:\n") + serializer.Write(out, desc.Syscalls) + if false { + t.Log(out.String()) // useful for debugging + } + }) + } + } +} + +func TestErrors(t *testing.T) { + t.Parallel() + for _, arch := range []string{"32", "64"} { + target := targets.List["test"][arch] + t.Run(arch, func(t *testing.T) { + t.Parallel() + em := ast.NewErrorMatcher(t, filepath.Join("testdata", "errors.txt")) + desc := ast.Parse(em.Data, "errors.txt", em.ErrorHandler) if desc == nil { em.DumpErrors(t) t.Fatalf("parsing failed") } ExtractConsts(desc, target, em.ErrorHandler) + em.Check(t) + }) + } +} + +func TestErrors2(t *testing.T) { + t.Parallel() + consts := map[string]uint64{ + "__NR_foo": 1, + "C0": 0, + "C1": 1, + "C2": 2, + } + for _, arch := range []string{"32", "64"} { + target := targets.List["test"][arch] + t.Run(arch, func(t *testing.T) { + t.Parallel() + em := ast.NewErrorMatcher(t, filepath.Join("testdata", "errors2.txt")) + desc := ast.Parse(em.Data, "errors2.txt", em.ErrorHandler) + if desc == nil { + em.DumpErrors(t) + t.Fatalf("parsing failed") + } + info := ExtractConsts(desc, target, em.ErrorHandler) + if info == nil { + em.DumpErrors(t) + t.Fatalf("const extraction failed") + } Compile(desc, consts, target, em.ErrorHandler) em.Check(t) }) @@ -67,6 +145,7 @@ func TestErrors(t *testing.T) { } func TestFuzz(t *testing.T) { + t.Parallel() inputs := []string{ "d~^gB̉`i\u007f?\xb0.", "da[", @@ -86,6 +165,7 @@ func TestFuzz(t *testing.T) { } func TestAlign(t *testing.T) { + t.Parallel() const input = ` foo$0(a ptr[in, s0]) s0 { diff --git a/pkg/compiler/consts.go b/pkg/compiler/consts.go index f2e4d4850..7f0cb7e38 100644 --- a/pkg/compiler/consts.go +++ b/pkg/compiler/consts.go @@ -14,6 +14,7 @@ import ( "strings" "github.com/google/syzkaller/pkg/ast" + "github.com/google/syzkaller/prog" "github.com/google/syzkaller/sys/targets" ) @@ -24,41 +25,25 @@ type ConstInfo struct { Defines map[string]string } -// ExtractConsts returns list of literal constants and other info required const value extraction. -func ExtractConsts(desc *ast.Description, target *targets.Target, eh0 ast.ErrorHandler) *ConstInfo { - errors := 0 - eh := func(pos ast.Pos, msg string, args ...interface{}) { - errors++ - msg = fmt.Sprintf(msg, args...) - if eh0 != nil { - eh0(pos, msg) - } else { - ast.LoggingHandler(pos, msg) - } - } - info := &ConstInfo{ - Defines: make(map[string]string), +func ExtractConsts(desc *ast.Description, target *targets.Target, eh ast.ErrorHandler) map[string]*ConstInfo { + res := Compile(desc, nil, target, eh) + if res == nil { + return nil } - includeMap := make(map[string]bool) - incdirMap := make(map[string]bool) - constMap := make(map[string]bool) + return res.fileConsts +} - desc.Walk(ast.Recursive(func(n0 ast.Node) { - switch n := n0.(type) { +// extractConsts returns list of literal constants and other info required for const value extraction. +func (comp *compiler) extractConsts() map[string]*ConstInfo { + infos := make(map[string]*constInfo) + for _, decl := range comp.desc.Nodes { + pos, _, _ := decl.Info() + info := getConstInfo(infos, pos) + switch n := decl.(type) { case *ast.Include: - file := n.File.Value - if includeMap[file] { - eh(n.Pos, "duplicate include %q", file) - } - includeMap[file] = true - info.Includes = append(info.Includes, file) + info.includeArray = append(info.includeArray, n.File.Value) case *ast.Incdir: - dir := n.Dir.Value - if incdirMap[dir] { - eh(n.Pos, "duplicate incdir %q", dir) - } - incdirMap[dir] = true - info.Incdirs = append(info.Incdirs, dir) + info.incdirArray = append(info.incdirArray, n.Dir.Value) case *ast.Define: v := fmt.Sprint(n.Value.Value) switch { @@ -68,34 +53,79 @@ func ExtractConsts(desc *ast.Description, target *targets.Target, eh0 ast.ErrorH v = n.Value.Ident } name := n.Name.Name - if info.Defines[name] != "" { - eh(n.Pos, "duplicate define %v", name) - } - info.Defines[name] = v - constMap[name] = true + info.defines[name] = v + info.consts[name] = true case *ast.Call: - if target.SyscallNumbers && !strings.HasPrefix(n.CallName, "syz_") { - constMap[target.SyscallPrefix+n.CallName] = true + if comp.target.SyscallNumbers && !strings.HasPrefix(n.CallName, "syz_") { + info.consts[comp.target.SyscallPrefix+n.CallName] = true } - case *ast.Type: - if c := typeConstIdentifier(n); c != nil { - constMap[c.Ident] = true - constMap[c.Ident2] = true - } - case *ast.Int: - constMap[n.Ident] = true + } + } + + for _, decl := range comp.desc.Nodes { + switch decl.(type) { + case *ast.Call, *ast.Struct, *ast.Resource, *ast.TypeDef: + comp.foreachType(decl, func(t *ast.Type, desc *typeDesc, + args []*ast.Type, _ prog.IntTypeCommon) { + for i, arg := range args { + if desc.Args[i].Type.Kind == kindInt { + if arg.Ident != "" { + info := getConstInfo(infos, arg.Pos) + info.consts[arg.Ident] = true + } + if arg.Ident2 != "" { + info := getConstInfo(infos, arg.Pos2) + info.consts[arg.Ident2] = true + } + } + } + }) + } + } + + comp.desc.Walk(ast.Recursive(func(n0 ast.Node) { + if n, ok := n0.(*ast.Int); ok { + info := getConstInfo(infos, n.Pos) + info.consts[n.Ident] = true } })) - if errors != 0 { - return nil + return convertConstInfo(infos) +} + +type constInfo struct { + consts map[string]bool + defines map[string]string + includeArray []string + incdirArray []string +} + +func getConstInfo(infos map[string]*constInfo, pos ast.Pos) *constInfo { + info := infos[pos.File] + if info == nil { + info = &constInfo{ + consts: make(map[string]bool), + defines: make(map[string]string), + } + infos[pos.File] = info } - info.Consts = toArray(constMap) return info } -// assignSyscallNumbers assigns syscall numbers, discards unsupported syscalls -// and removes no longer irrelevant nodes from the tree (comments, new lines, etc). +func convertConstInfo(infos map[string]*constInfo) map[string]*ConstInfo { + res := make(map[string]*ConstInfo) + for file, info := range infos { + res[file] = &ConstInfo{ + Consts: toArray(info.consts), + Includes: info.includeArray, + Incdirs: info.incdirArray, + Defines: info.defines, + } + } + return res +} + +// assignSyscallNumbers assigns syscall numbers, discards unsupported syscalls. func (comp *compiler) assignSyscallNumbers(consts map[string]uint64) { // Pseudo syscalls starting from syz_ are assigned numbers starting from syzbase. // Note: the numbers must be stable (not depend on file reading order, etc), @@ -116,51 +146,39 @@ func (comp *compiler) assignSyscallNumbers(consts map[string]uint64) { syznr[name] = syzbase + uint64(i) } - var top []ast.Node for _, decl := range comp.desc.Nodes { - switch decl.(type) { - case *ast.Call: - c := decl.(*ast.Call) - if strings.HasPrefix(c.CallName, "syz_") { - c.NR = syznr[c.CallName] - top = append(top, decl) - continue - } - if !comp.target.SyscallNumbers { - top = append(top, decl) - continue - } - // Lookup in consts. - str := comp.target.SyscallPrefix + c.CallName - nr, ok := consts[str] - top = append(top, decl) - if ok { - c.NR = nr - continue - } - c.NR = ^uint64(0) // mark as unused to not generate it - name := "syscall " + c.CallName - if !comp.unsupported[name] { - comp.unsupported[name] = true - comp.warning(c.Pos, "unsupported syscall: %v due to missing const %v", - c.CallName, str) - } - case *ast.IntFlags, *ast.Resource, *ast.Struct, *ast.StrFlags, *ast.TypeDef: - top = append(top, decl) - case *ast.NewLine, *ast.Comment, *ast.Include, *ast.Incdir, *ast.Define: - // These are not needed anymore. - default: - panic(fmt.Sprintf("unknown node type: %#v", decl)) + c, ok := decl.(*ast.Call) + if !ok { + continue + } + if strings.HasPrefix(c.CallName, "syz_") { + c.NR = syznr[c.CallName] + continue + } + // TODO(dvyukov): we don't need even syz consts in this case. + if !comp.target.SyscallNumbers { + continue + } + // Lookup in consts. + str := comp.target.SyscallPrefix + c.CallName + nr, ok := consts[str] + if ok { + c.NR = nr + continue + } + c.NR = ^uint64(0) // mark as unused to not generate it + name := "syscall " + c.CallName + if !comp.unsupported[name] { + comp.unsupported[name] = true + comp.warning(c.Pos, "unsupported syscall: %v due to missing const %v", + c.CallName, str) } } - comp.desc.Nodes = top } // patchConsts replaces all symbolic consts with their numeric values taken from consts map. // Updates desc and returns set of unsupported syscalls and flags. -// After this pass consts are not needed for compilation. func (comp *compiler) patchConsts(consts map[string]uint64) { - var top []ast.Node for _, decl := range comp.desc.Nodes { switch decl.(type) { case *ast.IntFlags: @@ -173,29 +191,29 @@ func (comp *compiler) patchConsts(consts map[string]uint64) { } } n.Values = values - top = append(top, n) - case *ast.StrFlags: - top = append(top, decl) case *ast.Resource, *ast.Struct, *ast.Call, *ast.TypeDef: - // Walk whole tree and replace consts in Int's and Type's. + // Walk whole tree and replace consts in Type's and Int's. missing := "" - decl.Walk(ast.Recursive(func(n0 ast.Node) { - switch n := n0.(type) { - case *ast.Int: - comp.patchIntConst(n.Pos, &n.Value, &n.Ident, consts, &missing) - case *ast.Type: - if c := typeConstIdentifier(n); c != nil { - comp.patchIntConst(c.Pos, &c.Value, &c.Ident, - consts, &missing) - if c.HasColon { - comp.patchIntConst(c.Pos2, &c.Value2, &c.Ident2, - consts, &missing) + comp.foreachType(decl, func(_ *ast.Type, desc *typeDesc, + args []*ast.Type, _ prog.IntTypeCommon) { + for i, arg := range args { + if desc.Args[i].Type.Kind == kindInt { + comp.patchIntConst(arg.Pos, &arg.Value, + &arg.Ident, consts, &missing) + if arg.HasColon { + comp.patchIntConst(arg.Pos2, &arg.Value2, + &arg.Ident2, consts, &missing) } } } - })) + }) + if n, ok := decl.(*ast.Resource); ok { + for _, v := range n.Values { + comp.patchIntConst(v.Pos, &v.Value, + &v.Ident, consts, &missing) + } + } if missing == "" { - top = append(top, decl) continue } // Produce a warning about unsupported syscall/resource/struct. @@ -209,15 +227,11 @@ func (comp *compiler) patchConsts(consts map[string]uint64) { comp.warning(pos, "unsupported %v: %v due to missing const %v", typ, name, missing) } - // We have to keep partially broken resources and structs, - // because otherwise their usages will error. - top = append(top, decl) if c, ok := decl.(*ast.Call); ok { c.NR = ^uint64(0) // mark as unused to not generate it } } } - comp.desc.Nodes = top } func (comp *compiler) patchIntConst(pos ast.Pos, val *uint64, id *string, @@ -237,37 +251,9 @@ func (comp *compiler) patchIntConst(pos ast.Pos, val *uint64, id *string, } } *val = v - *id = "" return ok } -// typeConstIdentifier returns type arg that is an integer constant (subject for const patching), if any. -func typeConstIdentifier(n *ast.Type) *ast.Type { - // TODO: see if we can extract this info from typeDesc/typeArg. - if n.Ident == "const" && len(n.Args) > 0 { - return n.Args[0] - } - if n.Ident == "array" && len(n.Args) > 1 && n.Args[1].Ident != "opt" { - return n.Args[1] - } - if n.Ident == "vma" && len(n.Args) > 0 && n.Args[0].Ident != "opt" { - return n.Args[0] - } - if n.Ident == "string" && len(n.Args) > 1 && n.Args[1].Ident != "opt" { - return n.Args[1] - } - if n.Ident == "csum" && len(n.Args) > 2 && n.Args[1].Ident == "pseudo" { - return n.Args[2] - } - switch n.Ident { - case "int8", "int16", "int16be", "int32", "int32be", "int64", "int64be", "intptr": - if len(n.Args) > 0 && n.Args[0].Ident != "opt" { - return n.Args[0] - } - } - return nil -} - func SerializeConsts(consts map[string]uint64) []byte { type nameValuePair struct { name string diff --git a/pkg/compiler/consts_test.go b/pkg/compiler/consts_test.go index 918f61d0c..13780a128 100644 --- a/pkg/compiler/consts_test.go +++ b/pkg/compiler/consts_test.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "path/filepath" "reflect" + "sort" "testing" "github.com/google/syzkaller/pkg/ast" @@ -18,18 +19,26 @@ func TestExtractConsts(t *testing.T) { if err != nil { t.Fatalf("failed to read input file: %v", err) } - desc := ast.Parse(data, "test", nil) + desc := ast.Parse(data, "consts.txt", nil) if desc == nil { t.Fatalf("failed to parse input") } target := targets.List["linux"]["amd64"] - info := ExtractConsts(desc, target, func(pos ast.Pos, msg string) { + fileInfo := ExtractConsts(desc, target, func(pos ast.Pos, msg string) { t.Fatalf("%v: %v", pos, msg) }) - wantConsts := []string{"CONST1", "CONST10", "CONST11", "CONST12", "CONST13", - "CONST14", "CONST15", "CONST16", - "CONST2", "CONST3", "CONST4", "CONST5", - "CONST6", "CONST7", "CONST8", "CONST9", "__NR_bar", "__NR_foo"} + info := fileInfo["consts.txt"] + if info == nil || len(fileInfo) != 1 { + t.Fatalf("bad file info returned: %+v", info) + } + wantConsts := []string{ + "__NR_bar", "__NR_foo", + "CONST1", "CONST2", "CONST3", "CONST4", "CONST5", + "CONST6", "CONST7", "CONST8", "CONST9", "CONST10", + "CONST11", "CONST12", "CONST13", "CONST14", "CONST15", + "CONST16", "CONST17", "CONST18", "CONST19", "CONST20", + } + sort.Strings(wantConsts) if !reflect.DeepEqual(info.Consts, wantConsts) { t.Fatalf("got consts:\n%q\nwant:\n%q", info.Consts, wantConsts) } diff --git a/pkg/compiler/gen.go b/pkg/compiler/gen.go index b91080db2..1e0e64307 100644 --- a/pkg/compiler/gen.go +++ b/pkg/compiler/gen.go @@ -376,6 +376,9 @@ func (comp *compiler) genFieldArray(fields []*ast.Field, dir prog.Dir, isArg boo func (comp *compiler) genType(t *ast.Type, field string, dir prog.Dir, isArg bool) prog.Type { desc, args, base := comp.getArgsBase(t, field, dir, isArg) + if desc.Gen == nil { + panic(fmt.Sprintf("no gen for %v %#v", field, t)) + } return desc.Gen(comp, t, args, base) } diff --git a/pkg/compiler/testdata/all.txt b/pkg/compiler/testdata/all.txt new file mode 100644 index 000000000..4aeae45df --- /dev/null +++ b/pkg/compiler/testdata/all.txt @@ -0,0 +1,60 @@ +# Copyright 2018 syzkaller project authors. All rights reserved. +# Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +foo$0(a int8) +foo$1(a int8[C1:C2]) +foo$2() ptr[out, array[int32]] + +# Proc type. + +proc_struct1 { + f1 proc[C0, 8, int8] +} + +# Bitfields. + +bitfield0 { + f1 int8:1 + f2 int8:2 +} + +# Type templates. + +type type0 int8 +type templ0[A, B] const[A, B] + +type templ_struct0[A, B] { + len len[parent, int16] + typ const[A, int16] + data B +} [align_4] + +type templ_struct1[C] { + f1 const[C, int8] + f2 int8[0:C] +} + +union_with_templ_struct [ + f1 templ_struct0[C1, type0] + f2 templ_struct0[C2, struct0] +] [varlen] + +struct0 { + f1 int8 + f2 int16 +} + +type templ_struct2[A] templ_struct0[A, int8] +type templ_struct3 templ_struct2[C1] +type templ_struct4 templ_struct3 +type templ_struct5 templ_struct0[C1, templ_struct0[C2, int8]] +type templ_struct6 templ_struct0[C1, templ_struct2[C2]] +type templ_union union_with_templ_struct + +foo$templ0(a templ0[42, int8]) +foo$templ1(a ptr[in, templ_struct0[C2, int8]]) +foo$templ2(a ptr[in, union_with_templ_struct]) +foo$templ3(a ptr[in, templ_struct1[1]], b ptr[in, templ_struct1[2]]) +foo$templ4(a ptr[in, templ_struct1[3]]) +foo$templ5(a ptr[in, templ_struct1[3]]) +foo$templ6(a ptr[in, templ_struct4]) diff --git a/pkg/compiler/testdata/consts.txt b/pkg/compiler/testdata/consts.txt index 179efe081..468f86681 100644 --- a/pkg/compiler/testdata/consts.txt +++ b/pkg/compiler/testdata/consts.txt @@ -20,5 +20,14 @@ str { } bar$BAZ(x vma[opt], y vma[CONST8], z vma[CONST9:CONST10]) -bar$QUX(s ptr[in, string["foo", CONST11]], x csum[s, pseudo, CONST12]) +bar$QUX(s ptr[in, string["foo", CONST11]], x ptr[in, csum[s, pseudo, CONST12, int16]]) bar$FOO(x int8[8:CONST13], y int16be[CONST14:10], z intptr[CONST15:CONST16]) + +type type0 const[CONST17, int8] +type templ0[C] const[C, int8] +foo$0(a templ0[CONST18]) +type templ1[C] { + f1 const[CONST19, int8] + f2 const[C, int8] +} +foo$1(a ptr[in, templ1[CONST20]]) diff --git a/pkg/compiler/testdata/errors.txt b/pkg/compiler/testdata/errors.txt index 0a9363924..3c67ac66a 100644 --- a/pkg/compiler/testdata/errors.txt +++ b/pkg/compiler/testdata/errors.txt @@ -76,7 +76,6 @@ foo$11(a buffer["in"]) ### unexpected string "in" for direction argument of buf foo$12(a buffer[10]) ### unexpected int 10 for direction argument of buffer type, expect [in out inout] foo$13(a int32[2:3]) foo$14(a int32[2:2]) -foo$15(a int32[3:2]) ### bad int range [3:2] foo$16(a int32[3]) foo$17(a ptr[in, int32]) foo$18(a ptr[in, int32[2:3]]) @@ -85,12 +84,10 @@ foo$20(a ptr) ### wrong number of arguments for type ptr, expect direction, ty foo$21(a ptr["foo"]) ### wrong number of arguments for type ptr, expect direction, type, [opt] foo$22(a ptr[in]) ### wrong number of arguments for type ptr, expect direction, type, [opt] foo$23(a ptr[in, s3[in]]) ### wrong number of arguments for type s3, expect no arguments -foo$24(a ptr[in, int32[3:2]]) ### bad int range [3:2] foo$25(a proc[0, "foo"]) ### unexpected string "foo" for per-proc values argument of proc type, expect int foo$26(a flags[no]) ### unknown flags no foo$27(a flags["foo"]) ### unexpected string "foo" for flags argument of flags type, expect identifier foo$28(a ptr[in, string["foo"]], b ptr[in, string["foo", 4]]) -foo$29(a ptr[in, string["foo", 3]]) ### string value "foo\x00" exceeds buffer length 3 foo$30(a ptr[in, string[no]]) ### unknown string flags no foo$31(a int8, b ptr[in, csum[a, inet]]) ### wrong number of arguments for type csum, expect csum target, kind, [proto], base type foo$32(a int8, b ptr[in, csum[a, inet, 1, int32]]) ### only pseudo csum can have proto @@ -98,12 +95,9 @@ foo$33(a int8, b ptr[in, csum[a, pseudo, 1, int32]]) foo$34(a int32["foo"]) ### unexpected string "foo" for range argument of int32 type, expect int foo$35(a ptr[in, s3[opt]]) ### s3 can't be marked as opt foo$36(a const[1:2]) ### unexpected ':' -foo$37(a ptr[in, proc[1000, 1, int8]]) ### values starting from 1000 overflow base type -foo$38(a ptr[in, proc[20, 10, int8]]) ### values starting from 20 with step 10 overflow base type for 32 procs foo$39(a fileoff:1) ### unexpected ':' foo$40(a len["a"]) ### unexpected string "a" for len target argument of len type, expect identifier foo$41(a vma[C1:C2]) -foo$42(a proc[20, 0]) ### proc per-process values must not be 0 foo$43(a ptr[in, string[1]]) ### unexpected int 1, string arg must be a string literal or string flags foo$44(a int32) len[a] ### len can't be syscall return foo$45(a int32) len[b] ### len can't be syscall return @@ -111,10 +105,11 @@ foo$46(a ptr[in, in]) ### unknown type in foo$47(a int32:2) ### unexpected ':', only struct fields can be bitfields foo$48(a ptr[in, int32:7]) ### unexpected ':', only struct fields can be bitfields foo$49(a ptr[in, array[int32, 0:1]]) -foo$50(a ptr[in, array[int32, 0]]) ### arrays of size 0 are not supported -foo$51(a ptr[in, array[int32, 0:0]]) ### arrays of size 0 are not supported foo$52(a intptr, b bitsize[a]) foo$53(a proc[20, 10, opt]) +# This must not error yet (consts are not patched). +foo$54(a ptr[in, string["foo", C1]]) +foo$55(a int8[opt[int8]]) ### opt can't have arguments opt { ### struct uses reserved name opt f1 int32 @@ -140,6 +135,7 @@ s3 { f5 int8:9 ### bitfield of size 9 is too large for base type of size 8 f6 int32:32 f7 int32:33 ### bitfield of size 33 is too large for base type of size 32 + f8 const[0, int32:C1] ### literal const bitfield sizes are not supported } [packed, align_4] s4 { @@ -189,20 +185,15 @@ typestruct { f1 mybool8 f2 mybool16 } -typeunion [ - f1 mybool8 - f2 mybool16 -] type type0 int8 -type type0 int8 ### type type0 redeclared, previously declared as type alias at errors.txt:197:6 -resource type0[int32] ### type type0 redeclared, previously declared as type alias at errors.txt:197:6 +type type0 int8 ### type type0 redeclared, previously declared as type alias at errors.txt:189:6 +resource type0[int32] ### type type0 redeclared, previously declared as type alias at errors.txt:189:6 type0 = 0, 1 -type type1 type1 ### type aliases can't refer to other type aliases +type type1 type1 ### type instantiation loop: type1 -> type1 type type2 int8:4 ### unexpected ':', only struct fields can be bitfields type type3 type2 ### unknown type type2 type type4 const[0] ### wrong number of arguments for type const, expect value, base type -type type5 typeunion ### typeunion can't be type alias target type type6 len[foo, int32] ### len can't be type alias target type type7 len[foo] ### len can't be type alias target resource typeres1[int32] @@ -210,13 +201,10 @@ type type8 typeres1 ### typeres1 can't be type alias target type int8 int8 ### type name int8 conflicts with builtin type type opt int8 ### type uses reserved name opt type type9 const[0, int8] -type type10 type0 ### type aliases can't refer to other type aliases -type type11 typestruct11 ### typestruct11 can't be type alias target type type12 proc[123, 2, int16, opt] type type13 ptr[in, typestruct13] type type14 flags[type0, int32] type type15 const[0, type0] ### unexpected value type0 for base type argument of const type, expect [int8 int16 int32 int64 int16be int32be int64be intptr] -type type16 ptr[in, type0] ### type aliases can't refer to other type aliases type bool8 int8[0:1] ### type name bool8 conflicts with builtin type typestruct11 { @@ -233,12 +221,47 @@ typestruct13 { } foo$100(a mybool8, b mybool16) -foo$101(a type5) ### unknown type type5 foo$102(a type2) ### unknown type type2 foo$103(a type0:4) ### type alias type0 with ':' -foo$104(a type0[opt]) ### type alias type0 with arguments +foo$104(a type0[opt]) ### type type0 is not a template foo$105() type0 foo$106() type6 ### unknown type type6 foo$107(a type9, b type12) foo$108(a flags[type0]) foo$109(a ptr[in, type0]) + +# Type templates. + +type templ0[A, B] const[A, B] +type templ2[A] A[0] +type templ3[A] ptr[in, A] +type templ4[A, A] ptr[in, A] ### duplicate type argument A +type templ5[abc] ptr[in, abc] ### type argument abc must be ALL_CAPS +type templ6[T] ptr[in, T] +type templ7 templ0[templ6, int8] + +# Note: here 42 is stripped as base type, so const ends up without arguments. +foo$201(a templ1[42]) +type templ1[A] const[A] ### wrong number of arguments for type const, expect value + +type templ_struct0[A, B] { + len len[parent, int16] + typ const[A, int16] + data B +} [align_4] + +type templ_struct1[STR] { + f string[STR, 40] +} + +type templ_struct2[A] { + f B ### unknown type B +} + +foo$200(a templ0[42, int8]) +foo$202(a templ0) ### template templ0 needs 2 arguments instead of 0 +foo$203(a type0[42]) ### type type0 is not a template +foo$204(a ptr[in, templ_struct0[42, int8]]) +foo$205(a ptr[in, templ_struct0[int8, int8]]) +foo$206(a ptr[in, templ_struct1["foo"]]) ### template arguments can't be strings ("foo") +foo$207(a ptr[in, templ_struct2[1]]) diff --git a/pkg/compiler/testdata/errors2.txt b/pkg/compiler/testdata/errors2.txt index 5b418ab55..67127c6ad 100644 --- a/pkg/compiler/testdata/errors2.txt +++ b/pkg/compiler/testdata/errors2.txt @@ -42,6 +42,18 @@ sr7 { f1 ptr[in, sr7, opt] } +type templ_sr[T] { + f T +} + +sr8 { + f templ_sr[sr8] ### recursive declaration: sr8.f -> templ_sr[sr8].f -> sr8 (mark some pointers as opt) +} + +sr9 { + f templ_sr[ptr[in, sr9]] ### recursive declaration: sr9.f -> templ_sr[ptr[in, sr9]].f -> sr9 (mark some pointers as opt) +} + # Len target tests. foo$100(a int32, b len[a]) @@ -134,3 +146,17 @@ s403 { sf400 = "foo", "bar", "baz" sf401 = "a", "b", "cd" + +# Const argument checks. + +foo$500(a int32[3:2]) ### bad int range [3:2] +foo$501(a ptr[in, int32[3:2]]) ### bad int range [3:2] +foo$502(a ptr[in, string["foo", C1]]) ### string value "foo\x00" exceeds buffer length 1 +foo$503(a ptr[in, proc[1000, 1, int8]]) ### values starting from 1000 overflow base type +foo$504(a ptr[in, proc[20, 10, int8]]) ### values starting from 20 with step 10 overflow base type for 32 procs +foo$505(a proc[20, 0]) ### proc per-process values must not be 0 +foo$506(a ptr[in, array[int32, 0]]) ### arrays of size 0 are not supported +foo$507(a ptr[in, array[int32, 0:0]]) ### arrays of size 0 are not supported +foo$508(a ptr[in, string["foo", 3]]) ### string value "foo\x00" exceeds buffer length 3 + +type type500 proc[C1, 8, int8] ### values starting from 1 with step 8 overflow base type for 32 procs diff --git a/pkg/compiler/types.go b/pkg/compiler/types.go index ee6e4a559..b1c9854b2 100644 --- a/pkg/compiler/types.go +++ b/pkg/compiler/types.go @@ -23,10 +23,12 @@ type typeDesc struct { ResourceBase bool // can be resource base type? OptArgs int // number of optional arguments in Args array Args []namedArg // type arguments - // Check does custom verification of the type (optional). + // Check does custom verification of the type (optional, consts are not patched yet). Check func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) + // CheckConsts does custom verification of the type (optional, consts are patched). + CheckConsts func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) // Varlen returns if the type is variable-length (false if not set). - Varlen func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) bool + Varlen func(comp *compiler, t *ast.Type, args []*ast.Type) bool // Gen generates corresponding prog.Type. Gen func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type } @@ -37,7 +39,8 @@ type typeArg struct { Kind int // int/ident/string AllowColon bool // allow colon (2:3)? // Check does custom verification of the arg (optional). - Check func(comp *compiler, t *ast.Type) + Check func(comp *compiler, t *ast.Type) + CheckConsts func(comp *compiler, t *ast.Type) } type namedArg struct { @@ -53,7 +56,7 @@ const ( ) var typeInt = &typeDesc{ - Names: []string{"int8", "int16", "int32", "int64", "int16be", "int32be", "int64be", "intptr"}, + Names: typeArgBase.Type.Names, CanBeArg: true, CanBeTypedef: true, AllowColon: true, @@ -102,13 +105,13 @@ var typeArray = &typeDesc{ CantBeOpt: true, OptArgs: 1, Args: []namedArg{{"type", typeArgType}, {"size", typeArgRange}}, - Check: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) { + CheckConsts: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) { if len(args) > 1 && args[1].Value == 0 && args[1].Value2 == 0 { // This is the only case that can yield 0 static type size. comp.error(args[1].Pos, "arrays of size 0 are not supported") } }, - Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) bool { + Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type) bool { if comp.isVarlen(args[0]) { return true } @@ -234,7 +237,7 @@ var typeArgFlags = &typeArg{ var typeFilename = &typeDesc{ Names: []string{"filename"}, CantBeOpt: true, - Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) bool { + Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type) bool { return true }, Gen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type { @@ -326,7 +329,7 @@ var typeProc = &typeDesc{ CanBeTypedef: true, NeedBase: true, Args: []namedArg{{"range start", typeArgInt}, {"per-proc values", typeArgInt}}, - Check: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) { + CheckConsts: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) { start := args[0].Value perProc := args[1].Value if perProc == 0 { @@ -338,7 +341,7 @@ var typeProc = &typeDesc{ const maxPids = 32 // executor knows about this constant (MAX_PIDS) if start >= 1<= 1< 1< 1 { size := args[1].Value vals := []string{args[0].String} @@ -429,7 +432,7 @@ var typeString = &typeDesc{ } } }, - Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) bool { + Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type) bool { return comp.stringSize(args) == 0 }, Gen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type { @@ -503,11 +506,7 @@ var typeArgStringFlags = &typeArg{ } // typeArgType is used as placeholder for any type (e.g. ptr target type). -var typeArgType = &typeArg{ - Check: func(comp *compiler, t *ast.Type) { - panic("must not be called") - }, -} +var typeArgType = &typeArg{} var typeResource = &typeDesc{ // No Names, but getTypeDesc knows how to match it. @@ -533,12 +532,13 @@ func init() { var typeStruct = &typeDesc{ // No Names, but getTypeDesc knows how to match it. - CantBeOpt: true, + CantBeOpt: true, + CanBeTypedef: true, // Varlen/Gen are assigned below due to initialization cycle. } func init() { - typeStruct.Varlen = func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) bool { + typeStruct.Varlen = func(comp *compiler, t *ast.Type, args []*ast.Type) bool { return comp.isStructVarlen(t.Ident) } typeStruct.Gen = func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type { @@ -572,9 +572,6 @@ var typeTypedef = &typeDesc{ Check: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) { panic("must not be called") }, - Gen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type { - panic("must not be called") - }, } var typeArgDir = &typeArg{ @@ -602,7 +599,7 @@ var typeArgInt = &typeArg{ var typeArgRange = &typeArg{ Kind: kindInt, AllowColon: true, - Check: func(comp *compiler, t *ast.Type) { + CheckConsts: func(comp *compiler, t *ast.Type) { if !t.HasColon { t.Value2 = t.Value } @@ -620,6 +617,10 @@ var typeArgBase = namedArg{ AllowColon: true, Check: func(comp *compiler, t *ast.Type) { if t.HasColon { + if t.Ident2 != "" { + comp.error(t.Pos2, "literal const bitfield sizes are not supported") + return + } if t.Value2 == 0 { // This was not supported historically // and does not work the way C bitfields of size 0 work. @@ -667,12 +668,12 @@ func init() { typeConst, typeFlags, typeFilename, - typeFileoff, // make a type alias + typeFileoff, typeVMA, typeCsum, typeProc, typeText, - typeBuffer, // make a type alias + typeBuffer, typeString, } for _, desc := range builtins { -- cgit mrf-deployment