aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAndrey Konovalov <andreyknvl@google.com>2017-02-09 21:33:14 +0100
committerAndrey Konovalov <andreyknvl@google.com>2017-02-09 21:33:14 +0100
commit9989eadf779cf109bf3c64beb10e54df5f11ae3c (patch)
tree0ea844b00fe5ae0296c2163f3120ba5aa7c02d87
parent592f352d7102c317ce2daa5a8f377501679f4571 (diff)
prog: fix cheking nonoptional nil pointers in validation
Also update validation code to use arg.Type instead of passing typ recusively.
-rw-r--r--prog/validation.go91
1 files changed, 43 insertions, 48 deletions
diff --git a/prog/validation.go b/prog/validation.go
index 28c619802..61a12eaab 100644
--- a/prog/validation.go
+++ b/prog/validation.go
@@ -38,8 +38,8 @@ func (c *Call) validate(ctx *validCtx) error {
if len(c.Args) != len(c.Meta.Args) {
return fmt.Errorf("syscall %v: wrong number of arguments, want %v, got %v", c.Meta.Name, len(c.Meta.Args), len(c.Args))
}
- var checkArg func(arg *Arg, typ sys.Type) error
- checkArg = func(arg *Arg, typ sys.Type) error {
+ var checkArg func(arg *Arg) error
+ checkArg = func(arg *Arg) error {
if arg == nil {
return fmt.Errorf("syscall %v: nil arg", c.Meta.Name)
}
@@ -53,24 +53,18 @@ func (c *Call) validate(ctx *validCtx) error {
if arg.Type == nil {
return fmt.Errorf("syscall %v: no type", c.Meta.Name)
}
- if arg.Type.Name() != typ.Name() {
- return fmt.Errorf("syscall %v: type name mismatch: %v vs %v", c.Meta.Name, arg.Type.Name(), typ.Name())
- }
- if arg.Type.FieldName() != typ.FieldName() {
- return fmt.Errorf("syscall %v: field name mismatch: %v vs %v", c.Meta.Name, arg.Type.FieldName(), typ.FieldName())
- }
if arg.Type.Dir() == sys.DirOut {
if (arg.Val != 0 && arg.Val != arg.Type.Default()) || arg.AddrPage != 0 || arg.AddrOffset != 0 {
// We generate output len arguments, which makes sense
// since it can be a length of a variable-length array
// which is not known otherwise.
if _, ok := arg.Type.(*sys.LenType); !ok {
- return fmt.Errorf("syscall %v: output arg '%v' has non default value '%v'", c.Meta.Name, typ.Name(), arg.Val)
+ return fmt.Errorf("syscall %v: output arg '%v' has non default value '%v'", c.Meta.Name, arg.Type.Name(), arg.Val)
}
}
for _, v := range arg.Data {
if v != 0 {
- return fmt.Errorf("syscall %v: output arg '%v' has data", c.Meta.Name, typ.Name())
+ return fmt.Errorf("syscall %v: output arg '%v' has data", c.Meta.Name, arg.Type.Name())
}
}
}
@@ -81,111 +75,112 @@ func (c *Call) validate(ctx *validCtx) error {
case ArgReturn:
case ArgConst:
if arg.Type.Dir() == sys.DirOut && (arg.Val != 0 && arg.Val != arg.Type.Default()) {
- return fmt.Errorf("syscall %v: out resource arg '%v' has bad const value %v", c.Meta.Name, typ.Name(), arg.Val)
+ return fmt.Errorf("syscall %v: out resource arg '%v' has bad const value %v", c.Meta.Name, arg.Type.Name(), arg.Val)
}
default:
- return fmt.Errorf("syscall %v: fd arg '%v' has bad kind %v", c.Meta.Name, typ.Name(), arg.Kind)
+ return fmt.Errorf("syscall %v: fd arg '%v' has bad kind %v", c.Meta.Name, arg.Type.Name(), arg.Kind)
}
case *sys.StructType, *sys.ArrayType:
switch arg.Kind {
case ArgGroup:
default:
- return fmt.Errorf("syscall %v: struct/array arg '%v' has bad kind %v", c.Meta.Name, typ.Name(), arg.Kind)
+ return fmt.Errorf("syscall %v: struct/array arg '%v' has bad kind %v", c.Meta.Name, arg.Type.Name(), arg.Kind)
}
case *sys.UnionType:
switch arg.Kind {
case ArgUnion:
default:
- return fmt.Errorf("syscall %v: union arg '%v' has bad kind %v", c.Meta.Name, typ.Name(), arg.Kind)
+ return fmt.Errorf("syscall %v: union arg '%v' has bad kind %v", c.Meta.Name, arg.Type.Name(), arg.Kind)
}
case *sys.ProcType:
if arg.Val >= uintptr(typ1.ValuesPerProc) {
- return fmt.Errorf("syscall %v: per proc arg '%v' has bad value '%v'", c.Meta.Name, typ.Name(), arg.Val)
+ return fmt.Errorf("syscall %v: per proc arg '%v' has bad value '%v'", c.Meta.Name, arg.Type.Name(), arg.Val)
}
case *sys.BufferType:
switch typ1.Kind {
case sys.BufferString:
if typ1.Length != 0 && len(arg.Data) != int(typ1.Length) {
- return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, typ.Name(), len(arg.Data), typ1.Length)
+ return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, arg.Type.Name(), len(arg.Data), typ1.Length)
}
}
case *sys.CsumType:
if arg.Val != 0 {
- return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v", c.Meta.Name, typ.Name(), arg.Val)
+ return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v", c.Meta.Name, arg.Type.Name(), arg.Val)
+ }
+ case *sys.PtrType:
+ if arg.Type.Dir() == sys.DirOut {
+ return fmt.Errorf("syscall %v: pointer arg '%v' has output direction", c.Meta.Name, arg.Type.Name())
+ }
+ if arg.Res == nil && !arg.Type.Optional() {
+ return fmt.Errorf("syscall %v: non optional pointer arg '%v' is nil", c.Meta.Name, arg.Type.Name())
}
}
switch arg.Kind {
case ArgConst:
case ArgResult:
if arg.Res == nil {
- return fmt.Errorf("syscall %v: result arg '%v' has no reference", c.Meta.Name, typ.Name())
+ return fmt.Errorf("syscall %v: result arg '%v' has no reference", c.Meta.Name, arg.Type.Name())
}
if !ctx.args[arg.Res] {
return fmt.Errorf("syscall %v: result arg '%v' references out-of-tree result: %p%+v -> %p%+v",
- c.Meta.Name, typ.Name(), arg, arg, arg.Res, arg.Res)
+ c.Meta.Name, arg.Type.Name(), arg, arg, arg.Res, arg.Res)
}
if _, ok := arg.Res.Uses[arg]; !ok {
- return fmt.Errorf("syscall %v: result arg '%v' has broken link (%+v)", c.Meta.Name, typ.Name(), arg.Res.Uses)
+ return fmt.Errorf("syscall %v: result arg '%v' has broken link (%+v)", c.Meta.Name, arg.Type.Name(), arg.Res.Uses)
}
case ArgPointer:
- switch typ1 := typ.(type) {
+ switch arg.Type.(type) {
case *sys.VmaType:
if arg.Res != nil {
- return fmt.Errorf("syscall %v: vma arg '%v' has data", c.Meta.Name, typ.Name())
+ return fmt.Errorf("syscall %v: vma arg '%v' has data", c.Meta.Name, arg.Type.Name())
}
if arg.AddrPagesNum == 0 {
- return fmt.Errorf("syscall %v: vma arg '%v' has size 0", c.Meta.Name, typ.Name())
+ return fmt.Errorf("syscall %v: vma arg '%v' has size 0", c.Meta.Name, arg.Type.Name())
}
case *sys.PtrType:
- if arg.Type.Dir() == sys.DirOut {
- return fmt.Errorf("syscall %v: pointer arg '%v' has output direction", c.Meta.Name, typ.Name())
- }
- if arg.Res == nil && !typ.Optional() {
- return fmt.Errorf("syscall %v: non optional pointer arg '%v' is nil", c.Meta.Name, typ.Name())
- }
if arg.Res != nil {
- if err := checkArg(arg.Res, typ1.Type); err != nil {
+ if err := checkArg(arg.Res); err != nil {
return err
}
}
if arg.AddrPagesNum != 0 {
- return fmt.Errorf("syscall %v: pointer arg '%v' has nonzero size", c.Meta.Name, typ.Name())
+ return fmt.Errorf("syscall %v: pointer arg '%v' has nonzero size", c.Meta.Name, arg.Type.Name())
}
default:
- return fmt.Errorf("syscall %v: pointer arg '%v' has bad meta type %+v", c.Meta.Name, typ.Name(), typ)
+ return fmt.Errorf("syscall %v: pointer arg '%v' has bad meta type %+v", c.Meta.Name, arg.Type.Name(), arg.Type)
}
case ArgPageSize:
case ArgData:
- switch typ1 := typ.(type) {
+ switch typ1 := arg.Type.(type) {
case *sys.ArrayType:
if typ2, ok := typ1.Type.(*sys.IntType); !ok || typ2.Size() != 1 {
- return fmt.Errorf("syscall %v: data arg '%v' should be an array", c.Meta.Name, typ.Name())
+ return fmt.Errorf("syscall %v: data arg '%v' should be an array", c.Meta.Name, arg.Type.Name())
}
}
case ArgGroup:
- switch typ1 := typ.(type) {
+ switch typ1 := arg.Type.(type) {
case *sys.StructType:
if len(arg.Inner) != len(typ1.Fields) {
- return fmt.Errorf("syscall %v: struct arg '%v' has wrong number of fields: want %v, got %v", c.Meta.Name, typ.Name(), len(typ1.Fields), len(arg.Inner))
+ return fmt.Errorf("syscall %v: struct arg '%v' has wrong number of fields: want %v, got %v", c.Meta.Name, arg.Type.Name(), len(typ1.Fields), len(arg.Inner))
}
- for i, arg1 := range arg.Inner {
- if err := checkArg(arg1, typ1.Fields[i]); err != nil {
+ for _, arg1 := range arg.Inner {
+ if err := checkArg(arg1); err != nil {
return err
}
}
case *sys.ArrayType:
for _, arg1 := range arg.Inner {
- if err := checkArg(arg1, typ1.Type); err != nil {
+ if err := checkArg(arg1); err != nil {
return err
}
}
default:
- return fmt.Errorf("syscall %v: group arg '%v' has bad underlying type %+v", c.Meta.Name, typ.Name(), typ)
+ return fmt.Errorf("syscall %v: group arg '%v' has bad underlying type %+v", c.Meta.Name, arg.Type.Name(), arg.Type)
}
case ArgUnion:
- typ1, ok := typ.(*sys.UnionType)
+ typ1, ok := arg.Type.(*sys.UnionType)
if !ok {
- return fmt.Errorf("syscall %v: union arg '%v' has bad type", c.Meta.Name, typ.Name())
+ return fmt.Errorf("syscall %v: union arg '%v' has bad type", c.Meta.Name, arg.Type.Name())
}
found := false
for _, typ2 := range typ1.Options {
@@ -195,22 +190,22 @@ func (c *Call) validate(ctx *validCtx) error {
}
}
if !found {
- return fmt.Errorf("syscall %v: union arg '%v' has bad option", c.Meta.Name, typ.Name())
+ return fmt.Errorf("syscall %v: union arg '%v' has bad option", c.Meta.Name, arg.Type.Name())
}
- if err := checkArg(arg.Option, arg.OptionType); err != nil {
+ if err := checkArg(arg.Option); err != nil {
return err
}
case ArgReturn:
default:
- return fmt.Errorf("syscall %v: unknown arg '%v' kind", c.Meta.Name, typ.Name())
+ return fmt.Errorf("syscall %v: unknown arg '%v' kind", c.Meta.Name, arg.Type.Name())
}
return nil
}
- for i, arg := range c.Args {
+ for _, arg := range c.Args {
if arg.Kind == ArgReturn {
return fmt.Errorf("syscall %v: arg '%v' has wrong return kind", c.Meta.Name, arg.Type.Name())
}
- if err := checkArg(arg, c.Meta.Args[i]); err != nil {
+ if err := checkArg(arg); err != nil {
return err
}
}
@@ -221,7 +216,7 @@ func (c *Call) validate(ctx *validCtx) error {
return fmt.Errorf("syscall %v: return value has wrong kind %v", c.Meta.Name, c.Ret.Kind)
}
if c.Meta.Ret != nil {
- if err := checkArg(c.Ret, c.Meta.Ret); err != nil {
+ if err := checkArg(c.Ret); err != nil {
return err
}
} else if c.Ret.Type != nil {