aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/syscall_descriptions_syntax.md22
-rw-r--r--pkg/ast/ast.go10
-rw-r--r--pkg/ast/clone.go137
-rw-r--r--pkg/ast/format.go76
-rw-r--r--pkg/ast/parser.go31
-rw-r--r--pkg/ast/test_util.go7
-rw-r--r--pkg/ast/testdata/all.txt19
-rw-r--r--pkg/ast/walk.go10
-rw-r--r--pkg/compiler/check.go354
-rw-r--r--pkg/compiler/compiler.go66
-rw-r--r--pkg/compiler/compiler_test.go94
-rw-r--r--pkg/compiler/consts.go258
-rw-r--r--pkg/compiler/consts_test.go21
-rw-r--r--pkg/compiler/gen.go3
-rw-r--r--pkg/compiler/testdata/all.txt60
-rw-r--r--pkg/compiler/testdata/consts.txt11
-rw-r--r--pkg/compiler/testdata/errors.txt65
-rw-r--r--pkg/compiler/testdata/errors2.txt26
-rw-r--r--pkg/compiler/types.go51
-rw-r--r--sys/syz-extract/extract.go37
20 files changed, 933 insertions, 425 deletions
diff --git a/docs/syscall_descriptions_syntax.md b/docs/syscall_descriptions_syntax.md
index a0b7b45b5..b8ded85bf 100644
--- a/docs/syscall_descriptions_syntax.md
+++ b/docs/syscall_descriptions_syntax.md
@@ -168,6 +168,28 @@ type bool64 int64[0:1]
type boolptr intptr[0:1]
```
+## Type Templates
+
+**Note: type templates are experimental, can have error handling bugs and are subject to change**
+
+Type templates can be declared as follows:
+
+```
+type buffer[DIR] ptr[DIR, array[int8]]
+type fileoff[BASE] BASE
+type nlattr[TYPE, PAYLOAD] {
+ nla_len len[parent, int16]
+ nla_type const[TYPE, int16]
+ payload PAYLOAD
+} [align_4]
+```
+
+and later used as follows:
+
+```
+syscall(a buffer[in], b fileoff[int64], c ptr[in, nlattr[FOO, int32]])
+```
+
## Length
You can specify length of a particular field in struct or a named argument by using `len`, `bytesize` and `bitsize` types, for example:
diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go
index 454f28b37..08703ba33 100644
--- a/pkg/ast/ast.go
+++ b/pkg/ast/ast.go
@@ -21,9 +21,7 @@ type Description struct {
type Node interface {
Info() (pos Pos, typ string, name string)
// Clone makes a deep copy of the node.
- // If newPos is not zero, sets Pos of all nodes to newPos.
- // If newPos is zero, Pos of nodes is left intact.
- Clone(newPos Pos) Node
+ Clone() Node
// Walk calls callback cb for all child nodes of this node.
// Note: it's not recursive. Use Recursive helper for recursive walk.
Walk(cb func(Node))
@@ -140,7 +138,11 @@ func (n *StrFlags) Info() (Pos, string, string) {
type TypeDef struct {
Pos Pos
Name *Ident
- Type *Type
+ // Non-template type aliases have only Type filled.
+ // Templates have Args and either Type or Struct filled.
+ Args []*Ident
+ Type *Type
+ Struct *Struct
}
func (n *TypeDef) Info() (Pos, string, string) {
diff --git a/pkg/ast/clone.go b/pkg/ast/clone.go
index dcd715c0a..b915c1f33 100644
--- a/pkg/ast/clone.go
+++ b/pkg/ast/clone.go
@@ -6,86 +6,93 @@ package ast
func (desc *Description) Clone() *Description {
desc1 := &Description{}
for _, n := range desc.Nodes {
- desc1.Nodes = append(desc1.Nodes, n.Clone(Pos{}))
+ desc1.Nodes = append(desc1.Nodes, n.Clone())
}
return desc1
}
-func selectPos(newPos, oldPos Pos) Pos {
- if newPos.File != "" || newPos.Off != 0 || newPos.Line != 0 || newPos.Col != 0 {
- return newPos
- }
- return oldPos
-}
-
-func (n *NewLine) Clone(newPos Pos) Node {
+func (n *NewLine) Clone() Node {
return &NewLine{
- Pos: selectPos(newPos, n.Pos),
+ Pos: n.Pos,
}
}
-func (n *Comment) Clone(newPos Pos) Node {
+func (n *Comment) Clone() Node {
return &Comment{
- Pos: selectPos(newPos, n.Pos),
+ Pos: n.Pos,
Text: n.Text,
}
}
-func (n *Include) Clone(newPos Pos) Node {
+func (n *Include) Clone() Node {
return &Include{
- Pos: selectPos(newPos, n.Pos),
- File: n.File.Clone(newPos).(*String),
+ Pos: n.Pos,
+ File: n.File.Clone().(*String),
}
}
-func (n *Incdir) Clone(newPos Pos) Node {
+func (n *Incdir) Clone() Node {
return &Incdir{
- Pos: selectPos(newPos, n.Pos),
- Dir: n.Dir.Clone(newPos).(*String),
+ Pos: n.Pos,
+ Dir: n.Dir.Clone().(*String),
}
}
-func (n *Define) Clone(newPos Pos) Node {
+func (n *Define) Clone() Node {
return &Define{
- Pos: selectPos(newPos, n.Pos),
- Name: n.Name.Clone(newPos).(*Ident),
- Value: n.Value.Clone(newPos).(*Int),
+ Pos: n.Pos,
+ Name: n.Name.Clone().(*Ident),
+ Value: n.Value.Clone().(*Int),
}
}
-func (n *Resource) Clone(newPos Pos) Node {
+func (n *Resource) Clone() Node {
var values []*Int
for _, v := range n.Values {
- values = append(values, v.Clone(newPos).(*Int))
+ values = append(values, v.Clone().(*Int))
}
return &Resource{
- Pos: selectPos(newPos, n.Pos),
- Name: n.Name.Clone(newPos).(*Ident),
- Base: n.Base.Clone(newPos).(*Type),
+ Pos: n.Pos,
+ Name: n.Name.Clone().(*Ident),
+ Base: n.Base.Clone().(*Type),
Values: values,
}
}
-func (n *TypeDef) Clone(newPos Pos) Node {
+func (n *TypeDef) Clone() Node {
+ var args []*Ident
+ for _, v := range n.Args {
+ args = append(args, v.Clone().(*Ident))
+ }
+ var typ *Type
+ if n.Type != nil {
+ typ = n.Type.Clone().(*Type)
+ }
+ var str *Struct
+ if n.Struct != nil {
+ str = n.Struct.Clone().(*Struct)
+ }
return &TypeDef{
- Pos: selectPos(newPos, n.Pos),
- Name: n.Name.Clone(newPos).(*Ident),
- Type: n.Type.Clone(newPos).(*Type),
+ Pos: n.Pos,
+ Name: n.Name.Clone().(*Ident),
+ Args: args,
+ Type: typ,
+ Struct: str,
}
}
-func (n *Call) Clone(newPos Pos) Node {
+func (n *Call) Clone() Node {
var args []*Field
for _, a := range n.Args {
- args = append(args, a.Clone(newPos).(*Field))
+ args = append(args, a.Clone().(*Field))
}
var ret *Type
if n.Ret != nil {
- ret = n.Ret.Clone(newPos).(*Type)
+ ret = n.Ret.Clone().(*Type)
}
return &Call{
- Pos: selectPos(newPos, n.Pos),
- Name: n.Name.Clone(newPos).(*Ident),
+ Pos: n.Pos,
+ Name: n.Name.Clone().(*Ident),
CallName: n.CallName,
NR: n.NR,
Args: args,
@@ -93,22 +100,22 @@ func (n *Call) Clone(newPos Pos) Node {
}
}
-func (n *Struct) Clone(newPos Pos) Node {
+func (n *Struct) Clone() Node {
var fields []*Field
for _, f := range n.Fields {
- fields = append(fields, f.Clone(newPos).(*Field))
+ fields = append(fields, f.Clone().(*Field))
}
var attrs []*Ident
for _, a := range n.Attrs {
- attrs = append(attrs, a.Clone(newPos).(*Ident))
+ attrs = append(attrs, a.Clone().(*Ident))
}
var comments []*Comment
for _, c := range n.Comments {
- comments = append(comments, c.Clone(newPos).(*Comment))
+ comments = append(comments, c.Clone().(*Comment))
}
return &Struct{
- Pos: selectPos(newPos, n.Pos),
- Name: n.Name.Clone(newPos).(*Ident),
+ Pos: n.Pos,
+ Name: n.Name.Clone().(*Ident),
Fields: fields,
Attrs: attrs,
Comments: comments,
@@ -116,47 +123,47 @@ func (n *Struct) Clone(newPos Pos) Node {
}
}
-func (n *IntFlags) Clone(newPos Pos) Node {
+func (n *IntFlags) Clone() Node {
var values []*Int
for _, v := range n.Values {
- values = append(values, v.Clone(newPos).(*Int))
+ values = append(values, v.Clone().(*Int))
}
return &IntFlags{
- Pos: selectPos(newPos, n.Pos),
- Name: n.Name.Clone(newPos).(*Ident),
+ Pos: n.Pos,
+ Name: n.Name.Clone().(*Ident),
Values: values,
}
}
-func (n *StrFlags) Clone(newPos Pos) Node {
+func (n *StrFlags) Clone() Node {
var values []*String
for _, v := range n.Values {
- values = append(values, v.Clone(newPos).(*String))
+ values = append(values, v.Clone().(*String))
}
return &StrFlags{
- Pos: selectPos(newPos, n.Pos),
- Name: n.Name.Clone(newPos).(*Ident),
+ Pos: n.Pos,
+ Name: n.Name.Clone().(*Ident),
Values: values,
}
}
-func (n *Ident) Clone(newPos Pos) Node {
+func (n *Ident) Clone() Node {
return &Ident{
- Pos: selectPos(newPos, n.Pos),
+ Pos: n.Pos,
Name: n.Name,
}
}
-func (n *String) Clone(newPos Pos) Node {
+func (n *String) Clone() Node {
return &String{
- Pos: selectPos(newPos, n.Pos),
+ Pos: n.Pos,
Value: n.Value,
}
}
-func (n *Int) Clone(newPos Pos) Node {
+func (n *Int) Clone() Node {
return &Int{
- Pos: selectPos(newPos, n.Pos),
+ Pos: n.Pos,
Value: n.Value,
ValueHex: n.ValueHex,
Ident: n.Ident,
@@ -164,19 +171,19 @@ func (n *Int) Clone(newPos Pos) Node {
}
}
-func (n *Type) Clone(newPos Pos) Node {
+func (n *Type) Clone() Node {
var args []*Type
for _, a := range n.Args {
- args = append(args, a.Clone(newPos).(*Type))
+ args = append(args, a.Clone().(*Type))
}
return &Type{
- Pos: selectPos(newPos, n.Pos),
+ Pos: n.Pos,
Value: n.Value,
ValueHex: n.ValueHex,
Ident: n.Ident,
String: n.String,
HasColon: n.HasColon,
- Pos2: selectPos(newPos, n.Pos2),
+ Pos2: n.Pos2,
Value2: n.Value2,
Value2Hex: n.Value2Hex,
Ident2: n.Ident2,
@@ -184,15 +191,15 @@ func (n *Type) Clone(newPos Pos) Node {
}
}
-func (n *Field) Clone(newPos Pos) Node {
+func (n *Field) Clone() Node {
var comments []*Comment
for _, c := range n.Comments {
- comments = append(comments, c.Clone(newPos).(*Comment))
+ comments = append(comments, c.Clone().(*Comment))
}
return &Field{
- Pos: selectPos(newPos, n.Pos),
- Name: n.Name.Clone(newPos).(*Ident),
- Type: n.Type.Clone(newPos).(*Type),
+ Pos: n.Pos,
+ Name: n.Name.Clone().(*Ident),
+ Type: n.Type.Clone().(*Type),
NewBlock: n.NewBlock,
Comments: comments,
}
diff --git a/pkg/ast/format.go b/pkg/ast/format.go
index a77662df8..c1dd3a624 100644
--- a/pkg/ast/format.go
+++ b/pkg/ast/format.go
@@ -25,6 +25,16 @@ func FormatWriter(w io.Writer, desc *Description) {
}
}
+func SerializeNode(n Node) string {
+ s, ok := n.(serializer)
+ if !ok {
+ panic(fmt.Sprintf("unknown node: %#v", n))
+ }
+ buf := new(bytes.Buffer)
+ s.serialize(buf)
+ return buf.String()
+}
+
type serializer interface {
serialize(w io.Writer)
}
@@ -52,27 +62,25 @@ func (def *Define) serialize(w io.Writer) {
func (res *Resource) serialize(w io.Writer) {
fmt.Fprintf(w, "resource %v[%v]", res.Name.Name, fmtType(res.Base))
for i, v := range res.Values {
- if i == 0 {
- fmt.Fprintf(w, ": ")
- } else {
- fmt.Fprintf(w, ", ")
- }
- fmt.Fprintf(w, "%v", fmtInt(v))
+ fmt.Fprintf(w, "%v%v", comma(i, ": "), fmtInt(v))
}
fmt.Fprintf(w, "\n")
}
func (typedef *TypeDef) serialize(w io.Writer) {
- fmt.Fprintf(w, "type %v %v\n", typedef.Name.Name, fmtType(typedef.Type))
+ fmt.Fprintf(w, "type %v%v", typedef.Name.Name, fmtIdentList(typedef.Args, false))
+ if typedef.Type != nil {
+ fmt.Fprintf(w, " %v\n", fmtType(typedef.Type))
+ }
+ if typedef.Struct != nil {
+ typedef.Struct.serialize(w)
+ }
}
func (c *Call) serialize(w io.Writer) {
fmt.Fprintf(w, "%v(", c.Name.Name)
for i, a := range c.Args {
- if i != 0 {
- fmt.Fprintf(w, ", ")
- }
- fmt.Fprintf(w, "%v", fmtField(a))
+ fmt.Fprintf(w, "%v%v", comma(i, ""), fmtField(a))
}
fmt.Fprintf(w, ")")
if c.Ret != nil {
@@ -112,24 +120,13 @@ func (str *Struct) serialize(w io.Writer) {
for _, com := range str.Comments {
fmt.Fprintf(w, "#%v\n", com.Text)
}
- fmt.Fprintf(w, "%c", closing)
- if len(str.Attrs) != 0 {
- fmt.Fprintf(w, " [")
- for i, attr := range str.Attrs {
- fmt.Fprintf(w, "%v%v", comma(i), attr.Name)
- }
- fmt.Fprintf(w, "]")
- }
- fmt.Fprintf(w, "\n")
+ fmt.Fprintf(w, "%c%v\n", closing, fmtIdentList(str.Attrs, true))
}
func (flags *IntFlags) serialize(w io.Writer) {
fmt.Fprintf(w, "%v = ", flags.Name.Name)
for i, v := range flags.Values {
- if i != 0 {
- fmt.Fprintf(w, ", ")
- }
- fmt.Fprintf(w, "%v", fmtInt(v))
+ fmt.Fprintf(w, "%v%v", comma(i, ""), fmtInt(v))
}
fmt.Fprintf(w, "\n")
}
@@ -137,10 +134,7 @@ func (flags *IntFlags) serialize(w io.Writer) {
func (flags *StrFlags) serialize(w io.Writer) {
fmt.Fprintf(w, "%v = ", flags.Name.Name)
for i, v := range flags.Values {
- if i != 0 {
- fmt.Fprintf(w, ", ")
- }
- fmt.Fprintf(w, "\"%v\"", v.Value)
+ fmt.Fprintf(w, "%v\"%v\"", comma(i, ""), v.Value)
}
fmt.Fprintf(w, "\n")
}
@@ -149,6 +143,10 @@ func fmtField(f *Field) string {
return fmt.Sprintf("%v %v", f.Name.Name, fmtType(f.Type))
}
+func (n *Type) serialize(w io.Writer) {
+ w.Write([]byte(fmtType(n)))
+}
+
func fmtType(t *Type) string {
v := ""
switch {
@@ -178,7 +176,23 @@ func fmtTypeList(args []*Type) string {
w := new(bytes.Buffer)
fmt.Fprintf(w, "[")
for i, t := range args {
- fmt.Fprintf(w, "%v%v", comma(i), fmtType(t))
+ fmt.Fprintf(w, "%v%v", comma(i, ""), fmtType(t))
+ }
+ fmt.Fprintf(w, "]")
+ return w.String()
+}
+
+func fmtIdentList(args []*Ident, space bool) string {
+ if len(args) == 0 {
+ return ""
+ }
+ w := new(bytes.Buffer)
+ if space {
+ fmt.Fprintf(w, " ")
+ }
+ fmt.Fprintf(w, "[")
+ for i, arg := range args {
+ fmt.Fprintf(w, "%v%v", comma(i, ""), arg.Name)
}
fmt.Fprintf(w, "]")
return w.String()
@@ -202,9 +216,9 @@ func fmtIntValue(v uint64, hex bool) string {
return fmt.Sprint(v)
}
-func comma(i int) string {
+func comma(i int, or string) string {
if i == 0 {
- return ""
+ return or
}
return ", "
}
diff --git a/pkg/ast/parser.go b/pkg/ast/parser.go
index db211ab2a..bd5650ad5 100644
--- a/pkg/ast/parser.go
+++ b/pkg/ast/parser.go
@@ -251,11 +251,34 @@ func (p *parser) parseResource() *Resource {
func (p *parser) parseTypeDef() *TypeDef {
pos0 := p.pos
name := p.parseIdent()
- typ := p.parseType()
+ var typ *Type
+ var str *Struct
+ var args []*Ident
+ p.expect(tokLBrack, tokIdent)
+ if p.tryConsume(tokLBrack) {
+ args = append(args, p.parseIdent())
+ for p.tryConsume(tokComma) {
+ args = append(args, p.parseIdent())
+ }
+ p.consume(tokRBrack)
+ if p.tok == tokLBrace || p.tok == tokLBrack {
+ name := &Ident{
+ Pos: pos0,
+ Name: "",
+ }
+ str = p.parseStruct(name)
+ } else {
+ typ = p.parseType()
+ }
+ } else {
+ typ = p.parseType()
+ }
return &TypeDef{
- Pos: pos0,
- Name: name,
- Type: typ,
+ Pos: pos0,
+ Name: name,
+ Args: args,
+ Type: typ,
+ Struct: str,
}
}
diff --git a/pkg/ast/test_util.go b/pkg/ast/test_util.go
index 0aed0a2dc..b9fe12152 100644
--- a/pkg/ast/test_util.go
+++ b/pkg/ast/test_util.go
@@ -7,6 +7,7 @@ import (
"bufio"
"bytes"
"io/ioutil"
+ "path/filepath"
"strings"
"testing"
)
@@ -41,7 +42,7 @@ func NewErrorMatcher(t *testing.T, file string) *ErrorMatcher {
break
}
errors = append(errors, &errorDesc{
- file: file,
+ file: filepath.Base(file),
line: i,
text: strings.TrimSpace(string(ln[pos+3:])),
})
@@ -82,13 +83,13 @@ nextErr:
want.matched = true
continue nextErr
}
- t.Errorf("unexpected error: %v:%v:%v: %v", e.file, e.line, e.col, e.text)
+ t.Errorf("unexpected error:\n%v:%v:%v: %v", e.file, e.line, e.col, e.text)
}
for _, want := range em.expect {
if want.matched {
continue
}
- t.Errorf("unmatched error: %v:%v: %v", want.file, want.line, want.text)
+ t.Errorf("unmatched error:\n%v:%v: %v", want.file, want.line, want.text)
}
}
diff --git a/pkg/ast/testdata/all.txt b/pkg/ast/testdata/all.txt
index 268b49a47..d4452b34f 100644
--- a/pkg/ast/testdata/all.txt
+++ b/pkg/ast/testdata/all.txt
@@ -48,5 +48,20 @@ s2 {
type mybool8 int8
type net_port proc[1, 2, int16be]
-type mybool16 ### unexpected '\n', expecting int, identifier, string
-type type4:4 int32 ### unexpected ':', expecting int, identifier, string
+type mybool16 ### unexpected '\n', expecting '[', identifier
+type type4:4 int32 ### unexpected ':', expecting '[', identifier
+
+type templ0[] int8 ### unexpected ']', expecting identifier
+type templ1[A,] int8 ### unexpected ']', expecting identifier
+type templ2[,] int8 ### unexpected ',', expecting identifier
+type templ3[ ### unexpected '\n', expecting identifier
+type templ4[A] ### unexpected '\n', expecting int, identifier, string
+type templ5[A] const[A]
+type templ6[A, B] const[A, B]
+type templ7[0] ptr[in, int8] ### unexpected int, expecting identifier
+
+type templ_struct0[A, B] {
+ len len[parent, int16]
+ typ const[A, int16]
+ data B
+} [align_4]
diff --git a/pkg/ast/walk.go b/pkg/ast/walk.go
index fd5065013..fe9112578 100644
--- a/pkg/ast/walk.go
+++ b/pkg/ast/walk.go
@@ -49,7 +49,15 @@ func (n *Resource) Walk(cb func(Node)) {
func (n *TypeDef) Walk(cb func(Node)) {
cb(n.Name)
- cb(n.Type)
+ for _, a := range n.Args {
+ cb(a)
+ }
+ if n.Type != nil {
+ cb(n.Type)
+ }
+ if n.Struct != nil {
+ cb(n.Struct)
+ }
}
func (n *Call) Walk(cb func(Node)) {
diff --git a/pkg/compiler/check.go b/pkg/compiler/check.go
index e9ec872f5..923c4fb19 100644
--- a/pkg/compiler/check.go
+++ b/pkg/compiler/check.go
@@ -12,16 +12,15 @@ import (
"github.com/google/syzkaller/prog"
)
-func (comp *compiler) check() {
+func (comp *compiler) typecheck() {
comp.checkNames()
comp.checkFields()
+ comp.checkTypedefs()
comp.checkTypes()
- // The subsequent, more complex, checks expect basic validity of the tree,
- // in particular corrent number of type arguments. If there were errors,
- // don't proceed to avoid out-of-bounds references to type arguments.
- if comp.errors != 0 {
- return
- }
+}
+
+func (comp *compiler) check() {
+ comp.checkConsts()
comp.checkUsed()
comp.checkRecursion()
comp.checkLenTargets()
@@ -31,9 +30,33 @@ func (comp *compiler) check() {
}
func (comp *compiler) checkNames() {
+ includes := make(map[string]bool)
+ incdirs := make(map[string]bool)
+ defines := make(map[string]bool)
calls := make(map[string]*ast.Call)
for _, decl := range comp.desc.Nodes {
- switch decl.(type) {
+ switch n := decl.(type) {
+ case *ast.Include:
+ name := n.File.Value
+ path := n.Pos.File + "/" + name
+ if includes[path] {
+ comp.error(n.Pos, "duplicate include %q", name)
+ }
+ includes[path] = true
+ case *ast.Incdir:
+ name := n.Dir.Value
+ path := n.Pos.File + "/" + name
+ if incdirs[path] {
+ comp.error(n.Pos, "duplicate incdir %q", name)
+ }
+ incdirs[path] = true
+ case *ast.Define:
+ name := n.Name.Name
+ path := n.Pos.File + "/" + name
+ if defines[path] {
+ comp.error(n.Pos, "duplicate define %v", name)
+ }
+ defines[path] = true
case *ast.Resource, *ast.Struct, *ast.TypeDef:
pos, typ, name := decl.Info()
if reservedName[name] {
@@ -68,7 +91,6 @@ func (comp *compiler) checkNames() {
comp.structs[name] = str
}
case *ast.IntFlags:
- n := decl.(*ast.IntFlags)
name := n.Name.Name
if reservedName[name] {
comp.error(n.Pos, "flags uses reserved name %v", name)
@@ -81,7 +103,6 @@ func (comp *compiler) checkNames() {
}
comp.intFlags[name] = n
case *ast.StrFlags:
- n := decl.(*ast.StrFlags)
name := n.Name.Name
if reservedName[name] {
comp.error(n.Pos, "string flags uses reserved name %v", name)
@@ -94,13 +115,12 @@ func (comp *compiler) checkNames() {
}
comp.strFlags[name] = n
case *ast.Call:
- c := decl.(*ast.Call)
- name := c.Name.Name
+ name := n.Name.Name
if prev := calls[name]; prev != nil {
- comp.error(c.Pos, "syscall %v redeclared, previously declared at %v",
+ comp.error(n.Pos, "syscall %v redeclared, previously declared at %v",
name, prev.Pos)
}
- calls[name] = c
+ calls[name] = n
}
}
}
@@ -111,17 +131,7 @@ func (comp *compiler) checkFields() {
switch n := decl.(type) {
case *ast.Struct:
_, typ, name := n.Info()
- fields := make(map[string]bool)
- for _, f := range n.Fields {
- fn := f.Name.Name
- if fn == "parent" {
- comp.error(f.Pos, "reserved field name %v in %v %v", fn, typ, name)
- }
- if fields[fn] {
- comp.error(f.Pos, "duplicate field %v in %v %v", fn, typ, name)
- }
- fields[fn] = true
- }
+ comp.checkFieldGroup(n.Fields, "field", typ+" "+name)
if !n.IsUnion && len(n.Fields) < 1 {
comp.error(n.Pos, "struct %v has no fields, need at least 1 field", name)
}
@@ -131,19 +141,7 @@ func (comp *compiler) checkFields() {
}
case *ast.Call:
name := n.Name.Name
- args := make(map[string]bool)
- for _, a := range n.Args {
- an := a.Name.Name
- if an == "parent" {
- comp.error(a.Pos, "reserved argument name %v in syscall %v",
- an, name)
- }
- if args[an] {
- comp.error(a.Pos, "duplicate argument %v in syscall %v",
- an, name)
- }
- args[an] = true
- }
+ comp.checkFieldGroup(n.Args, "argument", "syscall "+name)
if len(n.Args) > maxArgs {
comp.error(n.Pos, "syscall %v has %v arguments, allowed maximum is %v",
name, len(n.Args), maxArgs)
@@ -152,40 +150,93 @@ func (comp *compiler) checkFields() {
}
}
-func (comp *compiler) checkTypes() {
+func (comp *compiler) checkFieldGroup(fields []*ast.Field, what, ctx string) {
+ existing := make(map[string]bool)
+ for _, f := range fields {
+ fn := f.Name.Name
+ if fn == "parent" {
+ comp.error(f.Pos, "reserved %v name %v in %v", what, fn, ctx)
+ }
+ if existing[fn] {
+ comp.error(f.Pos, "duplicate %v %v in %v", what, fn, ctx)
+ }
+ existing[fn] = true
+ }
+}
+
+func (comp *compiler) checkTypedefs() {
for _, decl := range comp.desc.Nodes {
switch n := decl.(type) {
case *ast.TypeDef:
- if comp.typedefs[n.Name.Name] == nil {
- continue
- }
- err0 := comp.errors
- comp.checkType(n.Type, false, false, false, false, true, true)
- if err0 != comp.errors {
- delete(comp.typedefs, n.Name.Name)
+ if len(n.Args) == 0 {
+ // Non-template types are fully typed, so we check them ahead of time.
+ err0 := comp.errors
+ comp.checkType(checkCtx{}, n.Type, checkIsTypedef)
+ if err0 != comp.errors {
+ // To not produce confusing errors on broken type usage.
+ delete(comp.typedefs, n.Name.Name)
+ }
+ } else {
+ // For templates we only do basic checks of arguments.
+ names := make(map[string]bool)
+ for _, arg := range n.Args {
+ if names[arg.Name] {
+ comp.error(arg.Pos, "duplicate type argument %v", arg.Name)
+ }
+ names[arg.Name] = true
+ for _, c := range arg.Name {
+ if c >= 'A' && c <= 'Z' ||
+ c >= '0' && c <= '9' ||
+ c == '_' {
+ continue
+ }
+ comp.error(arg.Pos, "type argument %v must be ALL_CAPS",
+ arg.Name)
+ break
+ }
+ }
}
}
}
+}
+
+func (comp *compiler) checkTypes() {
for _, decl := range comp.desc.Nodes {
switch n := decl.(type) {
case *ast.Resource:
- comp.checkType(n.Base, false, false, false, true, false, false)
+ comp.checkType(checkCtx{}, n.Base, checkIsResourceBase)
case *ast.Struct:
- for _, f := range n.Fields {
- comp.checkType(f.Type, false, false, !n.IsUnion, false, false, false)
- }
- comp.checkStruct(n)
+ comp.checkStruct(checkCtx{}, n)
case *ast.Call:
for _, a := range n.Args {
- comp.checkType(a.Type, true, false, false, false, false, false)
+ comp.checkType(checkCtx{}, a.Type, checkIsArg)
}
if n.Ret != nil {
- comp.checkType(n.Ret, true, true, false, false, false, false)
+ comp.checkType(checkCtx{}, n.Ret, checkIsArg|checkIsRet|checkIsRetCtx)
}
}
}
}
+func (comp *compiler) checkConsts() {
+ for _, decl := range comp.desc.Nodes {
+ switch decl.(type) {
+ case *ast.Call, *ast.Struct, *ast.Resource, *ast.TypeDef:
+ comp.foreachType(decl, func(t *ast.Type, desc *typeDesc,
+ args []*ast.Type, base prog.IntTypeCommon) {
+ if desc.CheckConsts != nil {
+ desc.CheckConsts(comp, t, args, base)
+ }
+ for i, arg := range args {
+ if check := desc.Args[i].Type.CheckConsts; check != nil {
+ check(comp, arg)
+ }
+ }
+ })
+ }
+ }
+}
+
func (comp *compiler) checkLenTargets() {
for _, decl := range comp.desc.Nodes {
switch n := decl.(type) {
@@ -471,7 +522,14 @@ func (comp *compiler) recurseField(checked map[string]bool, t *ast.Type, path []
}
}
-func (comp *compiler) checkStruct(n *ast.Struct) {
+func (comp *compiler) checkStruct(ctx checkCtx, n *ast.Struct) {
+ var flags checkFlags
+ if !n.IsUnion {
+ flags |= checkIsStruct
+ }
+ for _, f := range n.Fields {
+ comp.checkType(ctx, f.Type, flags)
+ }
if n.IsUnion {
comp.parseUnionAttrs(n)
} else {
@@ -479,7 +537,22 @@ func (comp *compiler) checkStruct(n *ast.Struct) {
}
}
-func (comp *compiler) checkType(t *ast.Type, isArg, isRet, isStruct, isResourceBase, isTypedef, isTypedefCtx bool) {
+type checkFlags int
+
+const (
+ checkIsArg checkFlags = 1 << iota // immidiate syscall arg type
+ checkIsRet // immidiate syscall ret type
+ checkIsRetCtx // inside of syscall ret type
+ checkIsStruct // immidiate struct field type
+ checkIsResourceBase // immidiate resource base type
+ checkIsTypedef // immidiate type alias/template type
+)
+
+type checkCtx struct {
+ instantiationStack []string
+}
+
+func (comp *compiler) checkType(ctx checkCtx, t *ast.Type, flags checkFlags) {
if unexpected, _, ok := checkTypeKind(t, kindIdent); !ok {
comp.error(t.Pos, "unexpected %v, expect type", unexpected)
return
@@ -490,29 +563,13 @@ func (comp *compiler) checkType(t *ast.Type, isArg, isRet, isStruct, isResourceB
return
}
if desc == typeTypedef {
- if isTypedefCtx {
- comp.error(t.Pos, "type aliases can't refer to other type aliases")
- return
- }
- if t.HasColon {
- comp.error(t.Pos, "type alias %v with ':'", t.Ident)
- return
- }
- if len(t.Args) != 0 {
- comp.error(t.Pos, "type alias %v with arguments", t.Ident)
- return
- }
- *t = *comp.typedefs[t.Ident].Type.Clone(t.Pos).(*ast.Type)
- desc = comp.getTypeDesc(t)
- if isArg && desc.NeedBase {
- baseTypePos := len(t.Args) - 1
- if t.Args[baseTypePos].Ident == "opt" {
- baseTypePos--
- }
- copy(t.Args[baseTypePos:], t.Args[baseTypePos+1:])
- t.Args = t.Args[:len(t.Args)-1]
+ err0 := comp.errors
+ // Replace t with type alias/template target type inplace,
+ // and check the replaced type recursively.
+ comp.replaceTypedef(&ctx, t, desc, flags)
+ if err0 == comp.errors {
+ comp.checkType(ctx, t, flags)
}
- comp.checkType(t, isArg, isRet, isStruct, isResourceBase, isTypedef, isTypedefCtx)
return
}
if t.HasColon {
@@ -520,39 +577,47 @@ func (comp *compiler) checkType(t *ast.Type, isArg, isRet, isStruct, isResourceB
comp.error(t.Pos2, "unexpected ':'")
return
}
- if !isStruct {
+ if flags&checkIsStruct == 0 {
comp.error(t.Pos2, "unexpected ':', only struct fields can be bitfields")
return
}
}
- if isRet && (!desc.CanBeArg || desc.CantBeRet) {
+ if flags&checkIsRet != 0 && (!desc.CanBeArg || desc.CantBeRet) {
comp.error(t.Pos, "%v can't be syscall return", t.Ident)
return
}
- if isArg && !desc.CanBeArg {
+ if flags&checkIsRetCtx != 0 && desc.CantBeRet {
+ comp.error(t.Pos, "%v can't be used in syscall return", t.Ident)
+ return
+ }
+ if flags&checkIsArg != 0 && !desc.CanBeArg {
comp.error(t.Pos, "%v can't be syscall argument", t.Ident)
return
}
- if isTypedef && !desc.CanBeTypedef {
+ if flags&checkIsTypedef != 0 && !desc.CanBeTypedef {
comp.error(t.Pos, "%v can't be type alias target", t.Ident)
return
}
- if isResourceBase && !desc.ResourceBase {
+ if flags&checkIsResourceBase != 0 && !desc.ResourceBase {
comp.error(t.Pos, "%v can't be resource base (int types can)", t.Ident)
return
}
args, opt := removeOpt(t)
- if opt && (desc.CantBeOpt || isResourceBase) {
- what := "resource base"
- if desc.CantBeOpt {
- what = t.Ident
+ if opt != nil {
+ if len(opt.Args) != 0 {
+ comp.error(opt.Pos, "opt can't have arguments")
+ }
+ if flags&checkIsResourceBase != 0 || desc.CantBeOpt {
+ what := "resource base"
+ if desc.CantBeOpt {
+ what = t.Ident
+ }
+ comp.error(opt.Pos, "%v can't be marked as opt", what)
+ return
}
- pos := t.Args[len(t.Args)-1].Pos
- comp.error(pos, "%v can't be marked as opt", what)
- return
}
addArgs := 0
- needBase := !isArg && desc.NeedBase
+ needBase := flags&checkIsArg == 0 && desc.NeedBase
if needBase {
addArgs++ // last arg must be base type, e.g. const[0, int32]
}
@@ -569,18 +634,111 @@ func (comp *compiler) checkType(t *ast.Type, isArg, isRet, isStruct, isResourceB
err0 := comp.errors
for i, arg := range args {
if desc.Args[i].Type == typeArgType {
- comp.checkType(arg, false, isRet, false, false, false, isTypedefCtx)
+ comp.checkType(ctx, arg, flags&checkIsRetCtx)
} else {
comp.checkTypeArg(t, arg, desc.Args[i])
}
}
- if err0 != comp.errors {
+ if desc.Check != nil && err0 == comp.errors {
+ _, args, base := comp.getArgsBase(t, "", prog.DirIn, flags&checkIsArg != 0)
+ desc.Check(comp, t, args, base)
+ }
+}
+
+func (comp *compiler) replaceTypedef(ctx *checkCtx, t *ast.Type, desc *typeDesc, flags checkFlags) {
+ typedefName := t.Ident
+ if t.HasColon {
+ comp.error(t.Pos, "type alias %v with ':'", t.Ident)
return
}
- if desc.Check != nil {
- _, args, base := comp.getArgsBase(t, "", prog.DirIn, isArg)
- desc.Check(comp, t, args, base)
+ typedef := comp.typedefs[typedefName]
+ fullTypeName := ast.SerializeNode(t)
+ for i, prev := range ctx.instantiationStack {
+ if prev == fullTypeName {
+ ctx.instantiationStack = append(ctx.instantiationStack, fullTypeName)
+ path := ""
+ for j := i; j < len(ctx.instantiationStack); j++ {
+ if j != i {
+ path += " -> "
+ }
+ path += ctx.instantiationStack[j]
+ }
+ comp.error(t.Pos, "type instantiation loop: %v", path)
+ return
+ }
}
+ ctx.instantiationStack = append(ctx.instantiationStack, fullTypeName)
+ nargs := len(typedef.Args)
+ args := t.Args
+ for _, arg := range args {
+ if arg.String != "" {
+ comp.error(arg.Pos, "template arguments can't be strings (%q)", arg.String)
+ return
+ }
+ }
+ if nargs != len(t.Args) {
+ if nargs == 0 {
+ comp.error(t.Pos, "type %v is not a template", typedefName)
+ } else {
+ comp.error(t.Pos, "template %v needs %v arguments instead of %v",
+ typedefName, nargs, len(t.Args))
+ }
+ return
+ }
+ if typedef.Type != nil {
+ *t = *typedef.Type.Clone().(*ast.Type)
+ comp.instantiate(t, typedef.Args, args)
+ } else {
+ if comp.structs[fullTypeName] == nil {
+ inst := typedef.Struct.Clone().(*ast.Struct)
+ inst.Name.Name = fullTypeName
+ comp.instantiate(inst, typedef.Args, args)
+ comp.checkStruct(*ctx, inst)
+ comp.desc.Nodes = append(comp.desc.Nodes, inst)
+ comp.structs[fullTypeName] = inst
+ }
+ *t = ast.Type{
+ Pos: t.Pos,
+ Ident: fullTypeName,
+ }
+ }
+
+ // Remove base type if it's not needed in this context.
+ desc = comp.getTypeDesc(t)
+ if flags&checkIsArg != 0 && desc.NeedBase {
+ baseTypePos := len(t.Args) - 1
+ if t.Args[baseTypePos].Ident == "opt" {
+ baseTypePos--
+ }
+ copy(t.Args[baseTypePos:], t.Args[baseTypePos+1:])
+ t.Args = t.Args[:len(t.Args)-1]
+ }
+}
+
+func (comp *compiler) instantiate(templ ast.Node, params []*ast.Ident, args []*ast.Type) {
+ if len(params) == 0 {
+ return
+ }
+ argMap := make(map[string]*ast.Type)
+ for i, param := range params {
+ argMap[param.Name] = args[i]
+ }
+ templ.Walk(ast.Recursive(func(n ast.Node) {
+ templArg, ok := n.(*ast.Type)
+ if !ok {
+ return
+ }
+ if concreteArg := argMap[templArg.Ident]; concreteArg != nil {
+ *templArg = *concreteArg.Clone().(*ast.Type)
+ }
+ // TODO(dvyukov): somewhat hacky, but required for int8[0:CONST_ARG]
+ // Need more checks here. E.g. that CONST_ARG does not have subargs.
+ // And if CONST_ARG is a value, then use concreteArg.Value.
+ if concreteArg := argMap[templArg.Ident2]; concreteArg != nil {
+ templArg.Ident2 = concreteArg.Ident
+ templArg.Pos2 = concreteArg.Pos
+ }
+ }))
}
func (comp *compiler) checkTypeArg(t, arg *ast.Type, argDesc namedArg) {
@@ -655,7 +813,7 @@ func checkTypeKind(t *ast.Type, kind int) (unexpected string, expect string, ok
unexpected = fmt.Sprintf("string %q", t.String)
}
case t.Ident != "":
- ok = kind == kindIdent
+ ok = kind == kindIdent || kind == kindInt
if !ok {
unexpected = fmt.Sprintf("identifier %v", t.Ident)
}
@@ -688,8 +846,8 @@ func (comp *compiler) checkVarlens() {
}
func (comp *compiler) isVarlen(t *ast.Type) bool {
- desc, args, base := comp.getArgsBase(t, "", prog.DirIn, false)
- return desc.Varlen != nil && desc.Varlen(comp, t, args, base)
+ desc, args, _ := comp.getArgsBase(t, "", prog.DirIn, false)
+ return desc.Varlen != nil && desc.Varlen(comp, t, args)
}
func (comp *compiler) checkVarlen(n *ast.Struct) {
diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go
index 6019bda94..98348b0f8 100644
--- a/pkg/compiler/compiler.go
+++ b/pkg/compiler/compiler.go
@@ -39,6 +39,8 @@ type Prog struct {
StructDescs []*prog.KeyedStruct
// Set of unsupported syscalls/flags.
Unsupported map[string]bool
+ // Returned if consts was nil.
+ fileConsts map[string]*ConstInfo
}
// Compile compiles sys description.
@@ -65,6 +67,20 @@ func Compile(desc *ast.Description, consts map[string]uint64, target *targets.Ta
for name, typedef := range builtinTypedefs {
comp.typedefs[name] = typedef
}
+ comp.typecheck()
+ // The subsequent, more complex, checks expect basic validity of the tree,
+ // in particular corrent number of type arguments. If there were errors,
+ // don't proceed to avoid out-of-bounds references to type arguments.
+ if comp.errors != 0 {
+ return nil
+ }
+ if consts == nil {
+ fileConsts := comp.extractConsts()
+ if comp.errors != 0 {
+ return nil
+ }
+ return &Prog{fileConsts: fileConsts}
+ }
comp.assignSyscallNumbers(consts)
comp.patchConsts(consts)
comp.check()
@@ -177,9 +193,11 @@ func (comp *compiler) getTypeDesc(t *ast.Type) *typeDesc {
func (comp *compiler) getArgsBase(t *ast.Type, field string, dir prog.Dir, isArg bool) (
*typeDesc, []*ast.Type, prog.IntTypeCommon) {
desc := comp.getTypeDesc(t)
+ if desc == nil {
+ panic(fmt.Sprintf("no type desc for %#v", *t))
+ }
args, opt := removeOpt(t)
- size := sizeUnassigned
- com := genCommon(t.Ident, field, size, dir, opt)
+ com := genCommon(t.Ident, field, sizeUnassigned, dir, opt != nil)
base := genIntCommon(com, 0, false)
if desc.NeedBase {
base.TypeSize = comp.ptrSize
@@ -192,12 +210,48 @@ func (comp *compiler) getArgsBase(t *ast.Type, field string, dir prog.Dir, isArg
return desc, args, base
}
-func removeOpt(t *ast.Type) ([]*ast.Type, bool) {
+func (comp *compiler) foreachType(n0 ast.Node,
+ cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) {
+ switch n := n0.(type) {
+ case *ast.Call:
+ for _, arg := range n.Args {
+ comp.foreachSubType(arg.Type, true, cb)
+ }
+ if n.Ret != nil {
+ comp.foreachSubType(n.Ret, true, cb)
+ }
+ case *ast.Resource:
+ comp.foreachSubType(n.Base, false, cb)
+ case *ast.Struct:
+ for _, f := range n.Fields {
+ comp.foreachSubType(f.Type, false, cb)
+ }
+ case *ast.TypeDef:
+ if len(n.Args) == 0 {
+ comp.foreachSubType(n.Type, false, cb)
+ }
+ default:
+ panic(fmt.Sprintf("unexpected node %#v", n0))
+ }
+}
+
+func (comp *compiler) foreachSubType(t *ast.Type, isArg bool,
+ cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) {
+ desc, args, base := comp.getArgsBase(t, "", prog.DirIn, isArg)
+ cb(t, desc, args, base)
+ for i, arg := range args {
+ if desc.Args[i].Type == typeArgType {
+ comp.foreachSubType(arg, false, cb)
+ }
+ }
+}
+
+func removeOpt(t *ast.Type) ([]*ast.Type, *ast.Type) {
args := t.Args
- if len(args) != 0 && args[len(args)-1].Ident == "opt" {
- return args[:len(args)-1], true
+ if last := len(args) - 1; last >= 0 && args[last].Ident == "opt" {
+ return args[:last], args[last]
}
- return args, false
+ return args, nil
}
func (comp *compiler) parseIntType(name string) (size uint64, bigEndian bool) {
diff --git a/pkg/compiler/compiler_test.go b/pkg/compiler/compiler_test.go
index f26c272e6..a5cc6a86a 100644
--- a/pkg/compiler/compiler_test.go
+++ b/pkg/compiler/compiler_test.go
@@ -4,10 +4,14 @@
package compiler
import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
"path/filepath"
"testing"
"github.com/google/syzkaller/pkg/ast"
+ "github.com/google/syzkaller/pkg/serializer"
"github.com/google/syzkaller/sys/targets"
)
@@ -42,24 +46,98 @@ func TestCompileAll(t *testing.T) {
}
}
-func TestErrors(t *testing.T) {
+func TestNoErrors(t *testing.T) {
+ t.Parallel()
consts := map[string]uint64{
"__NR_foo": 1,
"C0": 0,
"C1": 1,
"C2": 2,
}
- target := targets.List["test"]["64"]
- for _, name := range []string{"errors.txt", "errors2.txt"} {
- name := name
- t.Run(name, func(t *testing.T) {
- em := ast.NewErrorMatcher(t, filepath.Join("testdata", name))
- desc := ast.Parse(em.Data, name, em.ErrorHandler)
+ for _, name := range []string{"all.txt"} {
+ for _, arch := range []string{"32", "64"} {
+ name, arch := name, arch
+ t.Run(fmt.Sprintf("%v/%v", name, arch), func(t *testing.T) {
+ t.Parallel()
+ target := targets.List["test"][arch]
+ eh := func(pos ast.Pos, msg string) {
+ t.Logf("%v: %v", pos, msg)
+ }
+ data, err := ioutil.ReadFile(filepath.Join("testdata", name))
+ if err != nil {
+ t.Fatal(err)
+ }
+ astDesc := ast.Parse(data, name, eh)
+ if astDesc == nil {
+ t.Fatalf("parsing failed")
+ }
+ constInfo := ExtractConsts(astDesc, target, eh)
+ if constInfo == nil {
+ t.Fatalf("const extraction failed")
+ }
+ desc := Compile(astDesc, consts, target, eh)
+ if desc == nil {
+ t.Fatalf("compilation failed")
+ }
+ if len(desc.Unsupported) != 0 {
+ t.Fatalf("something is unsupported:\n%+v", desc.Unsupported)
+ }
+ out := new(bytes.Buffer)
+ fmt.Fprintf(out, "\n\nRESOURCES:\n")
+ serializer.Write(out, desc.Resources)
+ fmt.Fprintf(out, "\n\nSTRUCTS:\n")
+ serializer.Write(out, desc.StructDescs)
+ fmt.Fprintf(out, "\n\nSYSCALLS:\n")
+ serializer.Write(out, desc.Syscalls)
+ if false {
+ t.Log(out.String()) // useful for debugging
+ }
+ })
+ }
+ }
+}
+
+func TestErrors(t *testing.T) {
+ t.Parallel()
+ for _, arch := range []string{"32", "64"} {
+ target := targets.List["test"][arch]
+ t.Run(arch, func(t *testing.T) {
+ t.Parallel()
+ em := ast.NewErrorMatcher(t, filepath.Join("testdata", "errors.txt"))
+ desc := ast.Parse(em.Data, "errors.txt", em.ErrorHandler)
if desc == nil {
em.DumpErrors(t)
t.Fatalf("parsing failed")
}
ExtractConsts(desc, target, em.ErrorHandler)
+ em.Check(t)
+ })
+ }
+}
+
+func TestErrors2(t *testing.T) {
+ t.Parallel()
+ consts := map[string]uint64{
+ "__NR_foo": 1,
+ "C0": 0,
+ "C1": 1,
+ "C2": 2,
+ }
+ for _, arch := range []string{"32", "64"} {
+ target := targets.List["test"][arch]
+ t.Run(arch, func(t *testing.T) {
+ t.Parallel()
+ em := ast.NewErrorMatcher(t, filepath.Join("testdata", "errors2.txt"))
+ desc := ast.Parse(em.Data, "errors2.txt", em.ErrorHandler)
+ if desc == nil {
+ em.DumpErrors(t)
+ t.Fatalf("parsing failed")
+ }
+ info := ExtractConsts(desc, target, em.ErrorHandler)
+ if info == nil {
+ em.DumpErrors(t)
+ t.Fatalf("const extraction failed")
+ }
Compile(desc, consts, target, em.ErrorHandler)
em.Check(t)
})
@@ -67,6 +145,7 @@ func TestErrors(t *testing.T) {
}
func TestFuzz(t *testing.T) {
+ t.Parallel()
inputs := []string{
"d~^gB̉`i\u007f?\xb0.",
"da[",
@@ -86,6 +165,7 @@ func TestFuzz(t *testing.T) {
}
func TestAlign(t *testing.T) {
+ t.Parallel()
const input = `
foo$0(a ptr[in, s0])
s0 {
diff --git a/pkg/compiler/consts.go b/pkg/compiler/consts.go
index f2e4d4850..7f0cb7e38 100644
--- a/pkg/compiler/consts.go
+++ b/pkg/compiler/consts.go
@@ -14,6 +14,7 @@ import (
"strings"
"github.com/google/syzkaller/pkg/ast"
+ "github.com/google/syzkaller/prog"
"github.com/google/syzkaller/sys/targets"
)
@@ -24,41 +25,25 @@ type ConstInfo struct {
Defines map[string]string
}
-// ExtractConsts returns list of literal constants and other info required const value extraction.
-func ExtractConsts(desc *ast.Description, target *targets.Target, eh0 ast.ErrorHandler) *ConstInfo {
- errors := 0
- eh := func(pos ast.Pos, msg string, args ...interface{}) {
- errors++
- msg = fmt.Sprintf(msg, args...)
- if eh0 != nil {
- eh0(pos, msg)
- } else {
- ast.LoggingHandler(pos, msg)
- }
- }
- info := &ConstInfo{
- Defines: make(map[string]string),
+func ExtractConsts(desc *ast.Description, target *targets.Target, eh ast.ErrorHandler) map[string]*ConstInfo {
+ res := Compile(desc, nil, target, eh)
+ if res == nil {
+ return nil
}
- includeMap := make(map[string]bool)
- incdirMap := make(map[string]bool)
- constMap := make(map[string]bool)
+ return res.fileConsts
+}
- desc.Walk(ast.Recursive(func(n0 ast.Node) {
- switch n := n0.(type) {
+// extractConsts returns list of literal constants and other info required for const value extraction.
+func (comp *compiler) extractConsts() map[string]*ConstInfo {
+ infos := make(map[string]*constInfo)
+ for _, decl := range comp.desc.Nodes {
+ pos, _, _ := decl.Info()
+ info := getConstInfo(infos, pos)
+ switch n := decl.(type) {
case *ast.Include:
- file := n.File.Value
- if includeMap[file] {
- eh(n.Pos, "duplicate include %q", file)
- }
- includeMap[file] = true
- info.Includes = append(info.Includes, file)
+ info.includeArray = append(info.includeArray, n.File.Value)
case *ast.Incdir:
- dir := n.Dir.Value
- if incdirMap[dir] {
- eh(n.Pos, "duplicate incdir %q", dir)
- }
- incdirMap[dir] = true
- info.Incdirs = append(info.Incdirs, dir)
+ info.incdirArray = append(info.incdirArray, n.Dir.Value)
case *ast.Define:
v := fmt.Sprint(n.Value.Value)
switch {
@@ -68,34 +53,79 @@ func ExtractConsts(desc *ast.Description, target *targets.Target, eh0 ast.ErrorH
v = n.Value.Ident
}
name := n.Name.Name
- if info.Defines[name] != "" {
- eh(n.Pos, "duplicate define %v", name)
- }
- info.Defines[name] = v
- constMap[name] = true
+ info.defines[name] = v
+ info.consts[name] = true
case *ast.Call:
- if target.SyscallNumbers && !strings.HasPrefix(n.CallName, "syz_") {
- constMap[target.SyscallPrefix+n.CallName] = true
+ if comp.target.SyscallNumbers && !strings.HasPrefix(n.CallName, "syz_") {
+ info.consts[comp.target.SyscallPrefix+n.CallName] = true
}
- case *ast.Type:
- if c := typeConstIdentifier(n); c != nil {
- constMap[c.Ident] = true
- constMap[c.Ident2] = true
- }
- case *ast.Int:
- constMap[n.Ident] = true
+ }
+ }
+
+ for _, decl := range comp.desc.Nodes {
+ switch decl.(type) {
+ case *ast.Call, *ast.Struct, *ast.Resource, *ast.TypeDef:
+ comp.foreachType(decl, func(t *ast.Type, desc *typeDesc,
+ args []*ast.Type, _ prog.IntTypeCommon) {
+ for i, arg := range args {
+ if desc.Args[i].Type.Kind == kindInt {
+ if arg.Ident != "" {
+ info := getConstInfo(infos, arg.Pos)
+ info.consts[arg.Ident] = true
+ }
+ if arg.Ident2 != "" {
+ info := getConstInfo(infos, arg.Pos2)
+ info.consts[arg.Ident2] = true
+ }
+ }
+ }
+ })
+ }
+ }
+
+ comp.desc.Walk(ast.Recursive(func(n0 ast.Node) {
+ if n, ok := n0.(*ast.Int); ok {
+ info := getConstInfo(infos, n.Pos)
+ info.consts[n.Ident] = true
}
}))
- if errors != 0 {
- return nil
+ return convertConstInfo(infos)
+}
+
+type constInfo struct {
+ consts map[string]bool
+ defines map[string]string
+ includeArray []string
+ incdirArray []string
+}
+
+func getConstInfo(infos map[string]*constInfo, pos ast.Pos) *constInfo {
+ info := infos[pos.File]
+ if info == nil {
+ info = &constInfo{
+ consts: make(map[string]bool),
+ defines: make(map[string]string),
+ }
+ infos[pos.File] = info
}
- info.Consts = toArray(constMap)
return info
}
-// assignSyscallNumbers assigns syscall numbers, discards unsupported syscalls
-// and removes no longer irrelevant nodes from the tree (comments, new lines, etc).
+func convertConstInfo(infos map[string]*constInfo) map[string]*ConstInfo {
+ res := make(map[string]*ConstInfo)
+ for file, info := range infos {
+ res[file] = &ConstInfo{
+ Consts: toArray(info.consts),
+ Includes: info.includeArray,
+ Incdirs: info.incdirArray,
+ Defines: info.defines,
+ }
+ }
+ return res
+}
+
+// assignSyscallNumbers assigns syscall numbers, discards unsupported syscalls.
func (comp *compiler) assignSyscallNumbers(consts map[string]uint64) {
// Pseudo syscalls starting from syz_ are assigned numbers starting from syzbase.
// Note: the numbers must be stable (not depend on file reading order, etc),
@@ -116,51 +146,39 @@ func (comp *compiler) assignSyscallNumbers(consts map[string]uint64) {
syznr[name] = syzbase + uint64(i)
}
- var top []ast.Node
for _, decl := range comp.desc.Nodes {
- switch decl.(type) {
- case *ast.Call:
- c := decl.(*ast.Call)
- if strings.HasPrefix(c.CallName, "syz_") {
- c.NR = syznr[c.CallName]
- top = append(top, decl)
- continue
- }
- if !comp.target.SyscallNumbers {
- top = append(top, decl)
- continue
- }
- // Lookup in consts.
- str := comp.target.SyscallPrefix + c.CallName
- nr, ok := consts[str]
- top = append(top, decl)
- if ok {
- c.NR = nr
- continue
- }
- c.NR = ^uint64(0) // mark as unused to not generate it
- name := "syscall " + c.CallName
- if !comp.unsupported[name] {
- comp.unsupported[name] = true
- comp.warning(c.Pos, "unsupported syscall: %v due to missing const %v",
- c.CallName, str)
- }
- case *ast.IntFlags, *ast.Resource, *ast.Struct, *ast.StrFlags, *ast.TypeDef:
- top = append(top, decl)
- case *ast.NewLine, *ast.Comment, *ast.Include, *ast.Incdir, *ast.Define:
- // These are not needed anymore.
- default:
- panic(fmt.Sprintf("unknown node type: %#v", decl))
+ c, ok := decl.(*ast.Call)
+ if !ok {
+ continue
+ }
+ if strings.HasPrefix(c.CallName, "syz_") {
+ c.NR = syznr[c.CallName]
+ continue
+ }
+ // TODO(dvyukov): we don't need even syz consts in this case.
+ if !comp.target.SyscallNumbers {
+ continue
+ }
+ // Lookup in consts.
+ str := comp.target.SyscallPrefix + c.CallName
+ nr, ok := consts[str]
+ if ok {
+ c.NR = nr
+ continue
+ }
+ c.NR = ^uint64(0) // mark as unused to not generate it
+ name := "syscall " + c.CallName
+ if !comp.unsupported[name] {
+ comp.unsupported[name] = true
+ comp.warning(c.Pos, "unsupported syscall: %v due to missing const %v",
+ c.CallName, str)
}
}
- comp.desc.Nodes = top
}
// patchConsts replaces all symbolic consts with their numeric values taken from consts map.
// Updates desc and returns set of unsupported syscalls and flags.
-// After this pass consts are not needed for compilation.
func (comp *compiler) patchConsts(consts map[string]uint64) {
- var top []ast.Node
for _, decl := range comp.desc.Nodes {
switch decl.(type) {
case *ast.IntFlags:
@@ -173,29 +191,29 @@ func (comp *compiler) patchConsts(consts map[string]uint64) {
}
}
n.Values = values
- top = append(top, n)
- case *ast.StrFlags:
- top = append(top, decl)
case *ast.Resource, *ast.Struct, *ast.Call, *ast.TypeDef:
- // Walk whole tree and replace consts in Int's and Type's.
+ // Walk whole tree and replace consts in Type's and Int's.
missing := ""
- decl.Walk(ast.Recursive(func(n0 ast.Node) {
- switch n := n0.(type) {
- case *ast.Int:
- comp.patchIntConst(n.Pos, &n.Value, &n.Ident, consts, &missing)
- case *ast.Type:
- if c := typeConstIdentifier(n); c != nil {
- comp.patchIntConst(c.Pos, &c.Value, &c.Ident,
- consts, &missing)
- if c.HasColon {
- comp.patchIntConst(c.Pos2, &c.Value2, &c.Ident2,
- consts, &missing)
+ comp.foreachType(decl, func(_ *ast.Type, desc *typeDesc,
+ args []*ast.Type, _ prog.IntTypeCommon) {
+ for i, arg := range args {
+ if desc.Args[i].Type.Kind == kindInt {
+ comp.patchIntConst(arg.Pos, &arg.Value,
+ &arg.Ident, consts, &missing)
+ if arg.HasColon {
+ comp.patchIntConst(arg.Pos2, &arg.Value2,
+ &arg.Ident2, consts, &missing)
}
}
}
- }))
+ })
+ if n, ok := decl.(*ast.Resource); ok {
+ for _, v := range n.Values {
+ comp.patchIntConst(v.Pos, &v.Value,
+ &v.Ident, consts, &missing)
+ }
+ }
if missing == "" {
- top = append(top, decl)
continue
}
// Produce a warning about unsupported syscall/resource/struct.
@@ -209,15 +227,11 @@ func (comp *compiler) patchConsts(consts map[string]uint64) {
comp.warning(pos, "unsupported %v: %v due to missing const %v",
typ, name, missing)
}
- // We have to keep partially broken resources and structs,
- // because otherwise their usages will error.
- top = append(top, decl)
if c, ok := decl.(*ast.Call); ok {
c.NR = ^uint64(0) // mark as unused to not generate it
}
}
}
- comp.desc.Nodes = top
}
func (comp *compiler) patchIntConst(pos ast.Pos, val *uint64, id *string,
@@ -237,37 +251,9 @@ func (comp *compiler) patchIntConst(pos ast.Pos, val *uint64, id *string,
}
}
*val = v
- *id = ""
return ok
}
-// typeConstIdentifier returns type arg that is an integer constant (subject for const patching), if any.
-func typeConstIdentifier(n *ast.Type) *ast.Type {
- // TODO: see if we can extract this info from typeDesc/typeArg.
- if n.Ident == "const" && len(n.Args) > 0 {
- return n.Args[0]
- }
- if n.Ident == "array" && len(n.Args) > 1 && n.Args[1].Ident != "opt" {
- return n.Args[1]
- }
- if n.Ident == "vma" && len(n.Args) > 0 && n.Args[0].Ident != "opt" {
- return n.Args[0]
- }
- if n.Ident == "string" && len(n.Args) > 1 && n.Args[1].Ident != "opt" {
- return n.Args[1]
- }
- if n.Ident == "csum" && len(n.Args) > 2 && n.Args[1].Ident == "pseudo" {
- return n.Args[2]
- }
- switch n.Ident {
- case "int8", "int16", "int16be", "int32", "int32be", "int64", "int64be", "intptr":
- if len(n.Args) > 0 && n.Args[0].Ident != "opt" {
- return n.Args[0]
- }
- }
- return nil
-}
-
func SerializeConsts(consts map[string]uint64) []byte {
type nameValuePair struct {
name string
diff --git a/pkg/compiler/consts_test.go b/pkg/compiler/consts_test.go
index 918f61d0c..13780a128 100644
--- a/pkg/compiler/consts_test.go
+++ b/pkg/compiler/consts_test.go
@@ -7,6 +7,7 @@ import (
"io/ioutil"
"path/filepath"
"reflect"
+ "sort"
"testing"
"github.com/google/syzkaller/pkg/ast"
@@ -18,18 +19,26 @@ func TestExtractConsts(t *testing.T) {
if err != nil {
t.Fatalf("failed to read input file: %v", err)
}
- desc := ast.Parse(data, "test", nil)
+ desc := ast.Parse(data, "consts.txt", nil)
if desc == nil {
t.Fatalf("failed to parse input")
}
target := targets.List["linux"]["amd64"]
- info := ExtractConsts(desc, target, func(pos ast.Pos, msg string) {
+ fileInfo := ExtractConsts(desc, target, func(pos ast.Pos, msg string) {
t.Fatalf("%v: %v", pos, msg)
})
- wantConsts := []string{"CONST1", "CONST10", "CONST11", "CONST12", "CONST13",
- "CONST14", "CONST15", "CONST16",
- "CONST2", "CONST3", "CONST4", "CONST5",
- "CONST6", "CONST7", "CONST8", "CONST9", "__NR_bar", "__NR_foo"}
+ info := fileInfo["consts.txt"]
+ if info == nil || len(fileInfo) != 1 {
+ t.Fatalf("bad file info returned: %+v", info)
+ }
+ wantConsts := []string{
+ "__NR_bar", "__NR_foo",
+ "CONST1", "CONST2", "CONST3", "CONST4", "CONST5",
+ "CONST6", "CONST7", "CONST8", "CONST9", "CONST10",
+ "CONST11", "CONST12", "CONST13", "CONST14", "CONST15",
+ "CONST16", "CONST17", "CONST18", "CONST19", "CONST20",
+ }
+ sort.Strings(wantConsts)
if !reflect.DeepEqual(info.Consts, wantConsts) {
t.Fatalf("got consts:\n%q\nwant:\n%q", info.Consts, wantConsts)
}
diff --git a/pkg/compiler/gen.go b/pkg/compiler/gen.go
index b91080db2..1e0e64307 100644
--- a/pkg/compiler/gen.go
+++ b/pkg/compiler/gen.go
@@ -376,6 +376,9 @@ func (comp *compiler) genFieldArray(fields []*ast.Field, dir prog.Dir, isArg boo
func (comp *compiler) genType(t *ast.Type, field string, dir prog.Dir, isArg bool) prog.Type {
desc, args, base := comp.getArgsBase(t, field, dir, isArg)
+ if desc.Gen == nil {
+ panic(fmt.Sprintf("no gen for %v %#v", field, t))
+ }
return desc.Gen(comp, t, args, base)
}
diff --git a/pkg/compiler/testdata/all.txt b/pkg/compiler/testdata/all.txt
new file mode 100644
index 000000000..4aeae45df
--- /dev/null
+++ b/pkg/compiler/testdata/all.txt
@@ -0,0 +1,60 @@
+# Copyright 2018 syzkaller project authors. All rights reserved.
+# Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+foo$0(a int8)
+foo$1(a int8[C1:C2])
+foo$2() ptr[out, array[int32]]
+
+# Proc type.
+
+proc_struct1 {
+ f1 proc[C0, 8, int8]
+}
+
+# Bitfields.
+
+bitfield0 {
+ f1 int8:1
+ f2 int8:2
+}
+
+# Type templates.
+
+type type0 int8
+type templ0[A, B] const[A, B]
+
+type templ_struct0[A, B] {
+ len len[parent, int16]
+ typ const[A, int16]
+ data B
+} [align_4]
+
+type templ_struct1[C] {
+ f1 const[C, int8]
+ f2 int8[0:C]
+}
+
+union_with_templ_struct [
+ f1 templ_struct0[C1, type0]
+ f2 templ_struct0[C2, struct0]
+] [varlen]
+
+struct0 {
+ f1 int8
+ f2 int16
+}
+
+type templ_struct2[A] templ_struct0[A, int8]
+type templ_struct3 templ_struct2[C1]
+type templ_struct4 templ_struct3
+type templ_struct5 templ_struct0[C1, templ_struct0[C2, int8]]
+type templ_struct6 templ_struct0[C1, templ_struct2[C2]]
+type templ_union union_with_templ_struct
+
+foo$templ0(a templ0[42, int8])
+foo$templ1(a ptr[in, templ_struct0[C2, int8]])
+foo$templ2(a ptr[in, union_with_templ_struct])
+foo$templ3(a ptr[in, templ_struct1[1]], b ptr[in, templ_struct1[2]])
+foo$templ4(a ptr[in, templ_struct1[3]])
+foo$templ5(a ptr[in, templ_struct1[3]])
+foo$templ6(a ptr[in, templ_struct4])
diff --git a/pkg/compiler/testdata/consts.txt b/pkg/compiler/testdata/consts.txt
index 179efe081..468f86681 100644
--- a/pkg/compiler/testdata/consts.txt
+++ b/pkg/compiler/testdata/consts.txt
@@ -20,5 +20,14 @@ str {
}
bar$BAZ(x vma[opt], y vma[CONST8], z vma[CONST9:CONST10])
-bar$QUX(s ptr[in, string["foo", CONST11]], x csum[s, pseudo, CONST12])
+bar$QUX(s ptr[in, string["foo", CONST11]], x ptr[in, csum[s, pseudo, CONST12, int16]])
bar$FOO(x int8[8:CONST13], y int16be[CONST14:10], z intptr[CONST15:CONST16])
+
+type type0 const[CONST17, int8]
+type templ0[C] const[C, int8]
+foo$0(a templ0[CONST18])
+type templ1[C] {
+ f1 const[CONST19, int8]
+ f2 const[C, int8]
+}
+foo$1(a ptr[in, templ1[CONST20]])
diff --git a/pkg/compiler/testdata/errors.txt b/pkg/compiler/testdata/errors.txt
index 0a9363924..3c67ac66a 100644
--- a/pkg/compiler/testdata/errors.txt
+++ b/pkg/compiler/testdata/errors.txt
@@ -76,7 +76,6 @@ foo$11(a buffer["in"]) ### unexpected string "in" for direction argument of buf
foo$12(a buffer[10]) ### unexpected int 10 for direction argument of buffer type, expect [in out inout]
foo$13(a int32[2:3])
foo$14(a int32[2:2])
-foo$15(a int32[3:2]) ### bad int range [3:2]
foo$16(a int32[3])
foo$17(a ptr[in, int32])
foo$18(a ptr[in, int32[2:3]])
@@ -85,12 +84,10 @@ foo$20(a ptr) ### wrong number of arguments for type ptr, expect direction, ty
foo$21(a ptr["foo"]) ### wrong number of arguments for type ptr, expect direction, type, [opt]
foo$22(a ptr[in]) ### wrong number of arguments for type ptr, expect direction, type, [opt]
foo$23(a ptr[in, s3[in]]) ### wrong number of arguments for type s3, expect no arguments
-foo$24(a ptr[in, int32[3:2]]) ### bad int range [3:2]
foo$25(a proc[0, "foo"]) ### unexpected string "foo" for per-proc values argument of proc type, expect int
foo$26(a flags[no]) ### unknown flags no
foo$27(a flags["foo"]) ### unexpected string "foo" for flags argument of flags type, expect identifier
foo$28(a ptr[in, string["foo"]], b ptr[in, string["foo", 4]])
-foo$29(a ptr[in, string["foo", 3]]) ### string value "foo\x00" exceeds buffer length 3
foo$30(a ptr[in, string[no]]) ### unknown string flags no
foo$31(a int8, b ptr[in, csum[a, inet]]) ### wrong number of arguments for type csum, expect csum target, kind, [proto], base type
foo$32(a int8, b ptr[in, csum[a, inet, 1, int32]]) ### only pseudo csum can have proto
@@ -98,12 +95,9 @@ foo$33(a int8, b ptr[in, csum[a, pseudo, 1, int32]])
foo$34(a int32["foo"]) ### unexpected string "foo" for range argument of int32 type, expect int
foo$35(a ptr[in, s3[opt]]) ### s3 can't be marked as opt
foo$36(a const[1:2]) ### unexpected ':'
-foo$37(a ptr[in, proc[1000, 1, int8]]) ### values starting from 1000 overflow base type
-foo$38(a ptr[in, proc[20, 10, int8]]) ### values starting from 20 with step 10 overflow base type for 32 procs
foo$39(a fileoff:1) ### unexpected ':'
foo$40(a len["a"]) ### unexpected string "a" for len target argument of len type, expect identifier
foo$41(a vma[C1:C2])
-foo$42(a proc[20, 0]) ### proc per-process values must not be 0
foo$43(a ptr[in, string[1]]) ### unexpected int 1, string arg must be a string literal or string flags
foo$44(a int32) len[a] ### len can't be syscall return
foo$45(a int32) len[b] ### len can't be syscall return
@@ -111,10 +105,11 @@ foo$46(a ptr[in, in]) ### unknown type in
foo$47(a int32:2) ### unexpected ':', only struct fields can be bitfields
foo$48(a ptr[in, int32:7]) ### unexpected ':', only struct fields can be bitfields
foo$49(a ptr[in, array[int32, 0:1]])
-foo$50(a ptr[in, array[int32, 0]]) ### arrays of size 0 are not supported
-foo$51(a ptr[in, array[int32, 0:0]]) ### arrays of size 0 are not supported
foo$52(a intptr, b bitsize[a])
foo$53(a proc[20, 10, opt])
+# This must not error yet (consts are not patched).
+foo$54(a ptr[in, string["foo", C1]])
+foo$55(a int8[opt[int8]]) ### opt can't have arguments
opt { ### struct uses reserved name opt
f1 int32
@@ -140,6 +135,7 @@ s3 {
f5 int8:9 ### bitfield of size 9 is too large for base type of size 8
f6 int32:32
f7 int32:33 ### bitfield of size 33 is too large for base type of size 32
+ f8 const[0, int32:C1] ### literal const bitfield sizes are not supported
} [packed, align_4]
s4 {
@@ -189,20 +185,15 @@ typestruct {
f1 mybool8
f2 mybool16
}
-typeunion [
- f1 mybool8
- f2 mybool16
-]
type type0 int8
-type type0 int8 ### type type0 redeclared, previously declared as type alias at errors.txt:197:6
-resource type0[int32] ### type type0 redeclared, previously declared as type alias at errors.txt:197:6
+type type0 int8 ### type type0 redeclared, previously declared as type alias at errors.txt:189:6
+resource type0[int32] ### type type0 redeclared, previously declared as type alias at errors.txt:189:6
type0 = 0, 1
-type type1 type1 ### type aliases can't refer to other type aliases
+type type1 type1 ### type instantiation loop: type1 -> type1
type type2 int8:4 ### unexpected ':', only struct fields can be bitfields
type type3 type2 ### unknown type type2
type type4 const[0] ### wrong number of arguments for type const, expect value, base type
-type type5 typeunion ### typeunion can't be type alias target
type type6 len[foo, int32] ### len can't be type alias target
type type7 len[foo] ### len can't be type alias target
resource typeres1[int32]
@@ -210,13 +201,10 @@ type type8 typeres1 ### typeres1 can't be type alias target
type int8 int8 ### type name int8 conflicts with builtin type
type opt int8 ### type uses reserved name opt
type type9 const[0, int8]
-type type10 type0 ### type aliases can't refer to other type aliases
-type type11 typestruct11 ### typestruct11 can't be type alias target
type type12 proc[123, 2, int16, opt]
type type13 ptr[in, typestruct13]
type type14 flags[type0, int32]
type type15 const[0, type0] ### unexpected value type0 for base type argument of const type, expect [int8 int16 int32 int64 int16be int32be int64be intptr]
-type type16 ptr[in, type0] ### type aliases can't refer to other type aliases
type bool8 int8[0:1] ### type name bool8 conflicts with builtin type
typestruct11 {
@@ -233,12 +221,47 @@ typestruct13 {
}
foo$100(a mybool8, b mybool16)
-foo$101(a type5) ### unknown type type5
foo$102(a type2) ### unknown type type2
foo$103(a type0:4) ### type alias type0 with ':'
-foo$104(a type0[opt]) ### type alias type0 with arguments
+foo$104(a type0[opt]) ### type type0 is not a template
foo$105() type0
foo$106() type6 ### unknown type type6
foo$107(a type9, b type12)
foo$108(a flags[type0])
foo$109(a ptr[in, type0])
+
+# Type templates.
+
+type templ0[A, B] const[A, B]
+type templ2[A] A[0]
+type templ3[A] ptr[in, A]
+type templ4[A, A] ptr[in, A] ### duplicate type argument A
+type templ5[abc] ptr[in, abc] ### type argument abc must be ALL_CAPS
+type templ6[T] ptr[in, T]
+type templ7 templ0[templ6, int8]
+
+# Note: here 42 is stripped as base type, so const ends up without arguments.
+foo$201(a templ1[42])
+type templ1[A] const[A] ### wrong number of arguments for type const, expect value
+
+type templ_struct0[A, B] {
+ len len[parent, int16]
+ typ const[A, int16]
+ data B
+} [align_4]
+
+type templ_struct1[STR] {
+ f string[STR, 40]
+}
+
+type templ_struct2[A] {
+ f B ### unknown type B
+}
+
+foo$200(a templ0[42, int8])
+foo$202(a templ0) ### template templ0 needs 2 arguments instead of 0
+foo$203(a type0[42]) ### type type0 is not a template
+foo$204(a ptr[in, templ_struct0[42, int8]])
+foo$205(a ptr[in, templ_struct0[int8, int8]])
+foo$206(a ptr[in, templ_struct1["foo"]]) ### template arguments can't be strings ("foo")
+foo$207(a ptr[in, templ_struct2[1]])
diff --git a/pkg/compiler/testdata/errors2.txt b/pkg/compiler/testdata/errors2.txt
index 5b418ab55..67127c6ad 100644
--- a/pkg/compiler/testdata/errors2.txt
+++ b/pkg/compiler/testdata/errors2.txt
@@ -42,6 +42,18 @@ sr7 {
f1 ptr[in, sr7, opt]
}
+type templ_sr[T] {
+ f T
+}
+
+sr8 {
+ f templ_sr[sr8] ### recursive declaration: sr8.f -> templ_sr[sr8].f -> sr8 (mark some pointers as opt)
+}
+
+sr9 {
+ f templ_sr[ptr[in, sr9]] ### recursive declaration: sr9.f -> templ_sr[ptr[in, sr9]].f -> sr9 (mark some pointers as opt)
+}
+
# Len target tests.
foo$100(a int32, b len[a])
@@ -134,3 +146,17 @@ s403 {
sf400 = "foo", "bar", "baz"
sf401 = "a", "b", "cd"
+
+# Const argument checks.
+
+foo$500(a int32[3:2]) ### bad int range [3:2]
+foo$501(a ptr[in, int32[3:2]]) ### bad int range [3:2]
+foo$502(a ptr[in, string["foo", C1]]) ### string value "foo\x00" exceeds buffer length 1
+foo$503(a ptr[in, proc[1000, 1, int8]]) ### values starting from 1000 overflow base type
+foo$504(a ptr[in, proc[20, 10, int8]]) ### values starting from 20 with step 10 overflow base type for 32 procs
+foo$505(a proc[20, 0]) ### proc per-process values must not be 0
+foo$506(a ptr[in, array[int32, 0]]) ### arrays of size 0 are not supported
+foo$507(a ptr[in, array[int32, 0:0]]) ### arrays of size 0 are not supported
+foo$508(a ptr[in, string["foo", 3]]) ### string value "foo\x00" exceeds buffer length 3
+
+type type500 proc[C1, 8, int8] ### values starting from 1 with step 8 overflow base type for 32 procs
diff --git a/pkg/compiler/types.go b/pkg/compiler/types.go
index ee6e4a559..b1c9854b2 100644
--- a/pkg/compiler/types.go
+++ b/pkg/compiler/types.go
@@ -23,10 +23,12 @@ type typeDesc struct {
ResourceBase bool // can be resource base type?
OptArgs int // number of optional arguments in Args array
Args []namedArg // type arguments
- // Check does custom verification of the type (optional).
+ // Check does custom verification of the type (optional, consts are not patched yet).
Check func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon)
+ // CheckConsts does custom verification of the type (optional, consts are patched).
+ CheckConsts func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon)
// Varlen returns if the type is variable-length (false if not set).
- Varlen func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) bool
+ Varlen func(comp *compiler, t *ast.Type, args []*ast.Type) bool
// Gen generates corresponding prog.Type.
Gen func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type
}
@@ -37,7 +39,8 @@ type typeArg struct {
Kind int // int/ident/string
AllowColon bool // allow colon (2:3)?
// Check does custom verification of the arg (optional).
- Check func(comp *compiler, t *ast.Type)
+ Check func(comp *compiler, t *ast.Type)
+ CheckConsts func(comp *compiler, t *ast.Type)
}
type namedArg struct {
@@ -53,7 +56,7 @@ const (
)
var typeInt = &typeDesc{
- Names: []string{"int8", "int16", "int32", "int64", "int16be", "int32be", "int64be", "intptr"},
+ Names: typeArgBase.Type.Names,
CanBeArg: true,
CanBeTypedef: true,
AllowColon: true,
@@ -102,13 +105,13 @@ var typeArray = &typeDesc{
CantBeOpt: true,
OptArgs: 1,
Args: []namedArg{{"type", typeArgType}, {"size", typeArgRange}},
- Check: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) {
+ CheckConsts: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) {
if len(args) > 1 && args[1].Value == 0 && args[1].Value2 == 0 {
// This is the only case that can yield 0 static type size.
comp.error(args[1].Pos, "arrays of size 0 are not supported")
}
},
- Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) bool {
+ Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type) bool {
if comp.isVarlen(args[0]) {
return true
}
@@ -234,7 +237,7 @@ var typeArgFlags = &typeArg{
var typeFilename = &typeDesc{
Names: []string{"filename"},
CantBeOpt: true,
- Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) bool {
+ Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type) bool {
return true
},
Gen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type {
@@ -326,7 +329,7 @@ var typeProc = &typeDesc{
CanBeTypedef: true,
NeedBase: true,
Args: []namedArg{{"range start", typeArgInt}, {"per-proc values", typeArgInt}},
- Check: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) {
+ CheckConsts: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) {
start := args[0].Value
perProc := args[1].Value
if perProc == 0 {
@@ -338,7 +341,7 @@ var typeProc = &typeDesc{
const maxPids = 32 // executor knows about this constant (MAX_PIDS)
if start >= 1<<size {
comp.error(args[0].Pos, "values starting from %v overflow base type", start)
- } else if start+maxPids*perProc >= 1<<size {
+ } else if start+maxPids*perProc > 1<<size {
comp.error(args[0].Pos, "values starting from %v with step %v overflow base type for %v procs",
start, perProc, maxPids)
}
@@ -357,7 +360,7 @@ var typeText = &typeDesc{
Names: []string{"text"},
CantBeOpt: true,
Args: []namedArg{{"kind", typeArgTextType}},
- Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) bool {
+ Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type) bool {
return true
},
Gen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type {
@@ -413,7 +416,7 @@ var typeString = &typeDesc{
Names: []string{"string"},
OptArgs: 2,
Args: []namedArg{{"literal or flags", typeArgStringFlags}, {"size", typeArgInt}},
- Check: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) {
+ CheckConsts: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) {
if len(args) > 1 {
size := args[1].Value
vals := []string{args[0].String}
@@ -429,7 +432,7 @@ var typeString = &typeDesc{
}
}
},
- Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) bool {
+ Varlen: func(comp *compiler, t *ast.Type, args []*ast.Type) bool {
return comp.stringSize(args) == 0
},
Gen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type {
@@ -503,11 +506,7 @@ var typeArgStringFlags = &typeArg{
}
// typeArgType is used as placeholder for any type (e.g. ptr target type).
-var typeArgType = &typeArg{
- Check: func(comp *compiler, t *ast.Type) {
- panic("must not be called")
- },
-}
+var typeArgType = &typeArg{}
var typeResource = &typeDesc{
// No Names, but getTypeDesc knows how to match it.
@@ -533,12 +532,13 @@ func init() {
var typeStruct = &typeDesc{
// No Names, but getTypeDesc knows how to match it.
- CantBeOpt: true,
+ CantBeOpt: true,
+ CanBeTypedef: true,
// Varlen/Gen are assigned below due to initialization cycle.
}
func init() {
- typeStruct.Varlen = func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) bool {
+ typeStruct.Varlen = func(comp *compiler, t *ast.Type, args []*ast.Type) bool {
return comp.isStructVarlen(t.Ident)
}
typeStruct.Gen = func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type {
@@ -572,9 +572,6 @@ var typeTypedef = &typeDesc{
Check: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) {
panic("must not be called")
},
- Gen: func(comp *compiler, t *ast.Type, args []*ast.Type, base prog.IntTypeCommon) prog.Type {
- panic("must not be called")
- },
}
var typeArgDir = &typeArg{
@@ -602,7 +599,7 @@ var typeArgInt = &typeArg{
var typeArgRange = &typeArg{
Kind: kindInt,
AllowColon: true,
- Check: func(comp *compiler, t *ast.Type) {
+ CheckConsts: func(comp *compiler, t *ast.Type) {
if !t.HasColon {
t.Value2 = t.Value
}
@@ -620,6 +617,10 @@ var typeArgBase = namedArg{
AllowColon: true,
Check: func(comp *compiler, t *ast.Type) {
if t.HasColon {
+ if t.Ident2 != "" {
+ comp.error(t.Pos2, "literal const bitfield sizes are not supported")
+ return
+ }
if t.Value2 == 0 {
// This was not supported historically
// and does not work the way C bitfields of size 0 work.
@@ -667,12 +668,12 @@ func init() {
typeConst,
typeFlags,
typeFilename,
- typeFileoff, // make a type alias
+ typeFileoff,
typeVMA,
typeCsum,
typeProc,
typeText,
- typeBuffer, // make a type alias
+ typeBuffer,
typeString,
}
for _, desc := range builtins {
diff --git a/sys/syz-extract/extract.go b/sys/syz-extract/extract.go
index 81067615a..5e5e95144 100644
--- a/sys/syz-extract/extract.go
+++ b/sys/syz-extract/extract.go
@@ -42,6 +42,7 @@ type File struct {
arch *Arch
name string
undeclared map[string]bool
+ info *compiler.ConstInfo
err error
}
@@ -160,15 +161,17 @@ func main() {
for job := range jobC {
switch j := job.(type) {
case *Arch:
- j.err = OS.prepareArch(j)
+ infos, err := processArch(OS, j)
+ j.err = err
if j.err == nil {
for _, f := range j.files {
+ f.info = infos[f.name]
wg.Add(1)
jobC <- f
}
}
case *File:
- j.undeclared, j.err = processFile(OS, j.arch, j.name)
+ j.undeclared, j.err = processFile(OS, j.arch, j)
}
wg.Done()
}
@@ -205,29 +208,33 @@ func main() {
}
}
-func processFile(OS OS, arch *Arch, inname string) (map[string]bool, error) {
- inname = filepath.Join("sys", arch.target.OS, inname)
- outname := strings.TrimSuffix(inname, ".txt") + "_" + arch.target.Arch + ".const"
- indata, err := ioutil.ReadFile(inname)
- if err != nil {
- return nil, fmt.Errorf("failed to read input file: %v", err)
- }
+func processArch(OS OS, arch *Arch) (map[string]*compiler.ConstInfo, error) {
errBuf := new(bytes.Buffer)
eh := func(pos ast.Pos, msg string) {
fmt.Fprintf(errBuf, "%v: %v\n", pos, msg)
}
- desc := ast.Parse(indata, filepath.Base(inname), eh)
- if desc == nil {
+ top := ast.ParseGlob(filepath.Join("sys", arch.target.OS, "*.txt"), eh)
+ if top == nil {
return nil, fmt.Errorf("%v", errBuf.String())
}
- info := compiler.ExtractConsts(desc, arch.target, eh)
- if info == nil {
+ infos := compiler.ExtractConsts(top, arch.target, eh)
+ if infos == nil {
return nil, fmt.Errorf("%v", errBuf.String())
}
- if len(info.Consts) == 0 {
+ if err := OS.prepareArch(arch); err != nil {
+ return nil, err
+ }
+ return infos, nil
+}
+
+func processFile(OS OS, arch *Arch, file *File) (map[string]bool, error) {
+ inname := filepath.Join("sys", arch.target.OS, file.name)
+ outname := strings.TrimSuffix(inname, ".txt") + "_" + arch.target.Arch + ".const"
+ if len(file.info.Consts) == 0 {
+ os.Remove(outname)
return nil, nil
}
- consts, undeclared, err := OS.processFile(arch, info)
+ consts, undeclared, err := OS.processFile(arch, file.info)
if err != nil {
return nil, err
}