aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2018-08-01 21:06:38 +0200
committerDmitry Vyukov <dvyukov@google.com>2018-08-02 16:57:31 +0200
commit95a080a682778a08d682897f97d0dae5c2201d76 (patch)
treee258c4abaad3a5a99e8225a4181238ad7dc01754
parentae2f24aa70517d79d364a7202e070f2de6fd4451 (diff)
prog: strenghten type checking during validation
Check that argument types match expected static types. I.e. detect when, say, syscall argument is a resource, but actual generated argument is a pointer.
-rw-r--r--prog/types.go11
-rw-r--r--prog/validation.go41
2 files changed, 28 insertions, 24 deletions
diff --git a/prog/types.go b/prog/types.go
index 004194707..68c80e23c 100644
--- a/prog/types.go
+++ b/prog/types.go
@@ -432,7 +432,7 @@ func (t *ArrayType) isDefaultArg(arg Arg) bool {
return false
}
for _, elem := range a.Inner {
- if !t.Type.isDefaultArg(elem) {
+ if !isDefault(elem) {
return false
}
}
@@ -460,7 +460,7 @@ func (t *PtrType) isDefaultArg(arg Arg) bool {
if t.Optional() {
return a.IsNull()
}
- return a.Address == 0 && t.Type.isDefaultArg(a.Res)
+ return a.Address == 0 && isDefault(a.Res)
}
type StructType struct {
@@ -487,8 +487,8 @@ func (t *StructType) makeDefaultArg() Arg {
func (t *StructType) isDefaultArg(arg Arg) bool {
a := arg.(*GroupArg)
- for i, elem := range a.Inner {
- if !t.Fields[i].isDefaultArg(elem) {
+ for _, elem := range a.Inner {
+ if !isDefault(elem) {
return false
}
}
@@ -515,8 +515,7 @@ func (t *UnionType) makeDefaultArg() Arg {
func (t *UnionType) isDefaultArg(arg Arg) bool {
a := arg.(*UnionArg)
- return a.Option.Type().FieldName() == t.Fields[0].FieldName() &&
- t.Fields[0].isDefaultArg(a.Option)
+ return a.Option.Type().FieldName() == t.Fields[0].FieldName() && isDefault(a.Option)
}
type StructDesc struct {
diff --git a/prog/validation.go b/prog/validation.go
index 3274fca9e..eaffac9b8 100644
--- a/prog/validation.go
+++ b/prog/validation.go
@@ -51,8 +51,8 @@ func (ctx *validCtx) validateCall(c *Call) error {
return fmt.Errorf("wrong number of arguments, want %v, got %v",
len(c.Meta.Args), len(c.Args))
}
- for _, arg := range c.Args {
- if err := ctx.validateArg(arg); err != nil {
+ for i, arg := range c.Args {
+ if err := ctx.validateArg(arg, c.Meta.Args[i]); err != nil {
return err
}
}
@@ -69,19 +69,16 @@ func (ctx *validCtx) validateRet(c *Call) error {
if c.Ret == nil {
return fmt.Errorf("return value is absent")
}
- if c.Ret.Type() != c.Meta.Ret {
- return fmt.Errorf("wrong return type: %#v vs %#v", c.Ret.Type(), c.Meta.Ret)
- }
if c.Ret.Type().Dir() != DirOut {
return fmt.Errorf("return value %v is not output", c.Ret)
}
if c.Ret.Res != nil || c.Ret.Val != 0 || c.Ret.OpDiv != 0 || c.Ret.OpAdd != 0 {
return fmt.Errorf("return value %v is not empty", c.Ret)
}
- return ctx.validateArg(c.Ret)
+ return ctx.validateArg(c.Ret, c.Meta.Ret)
}
-func (ctx *validCtx) validateArg(arg Arg) error {
+func (ctx *validCtx) validateArg(arg Arg, typ Type) error {
if arg == nil {
return fmt.Errorf("nil arg")
}
@@ -91,6 +88,9 @@ func (ctx *validCtx) validateArg(arg Arg) error {
if arg.Type() == nil {
return fmt.Errorf("no arg type")
}
+ if !ctx.target.isAnyPtr(arg.Type()) && arg.Type() != typ {
+ return fmt.Errorf("bad arg type %#v, expect %#v", arg.Type(), typ)
+ }
ctx.args[arg] = true
return arg.validate(ctx)
}
@@ -194,20 +194,25 @@ func (arg *GroupArg) validate(ctx *validCtx) error {
return fmt.Errorf("struct arg '%v' has wrong number of fields: want %v, got %v",
typ.Name(), len(typ.Fields), len(arg.Inner))
}
+ for i, field := range arg.Inner {
+ if err := ctx.validateArg(field, typ.Fields[i]); err != nil {
+ return err
+ }
+ }
case *ArrayType:
if typ.Kind == ArrayRangeLen && typ.RangeBegin == typ.RangeEnd &&
uint64(len(arg.Inner)) != typ.RangeBegin {
return fmt.Errorf("array %v has wrong number of elements %v, want %v",
typ.Name(), len(arg.Inner), typ.RangeBegin)
}
+ for _, elem := range arg.Inner {
+ if err := ctx.validateArg(elem, typ.Type); err != nil {
+ return err
+ }
+ }
default:
return fmt.Errorf("group arg %v has bad type %v", arg, typ.Name())
}
- for _, arg1 := range arg.Inner {
- if err := ctx.validateArg(arg1); err != nil {
- return err
- }
- }
return nil
}
@@ -216,17 +221,17 @@ func (arg *UnionArg) validate(ctx *validCtx) error {
if !ok {
return fmt.Errorf("union arg %v has bad type %v", arg, arg.Type().Name())
}
- found := false
+ var optType Type
for _, typ1 := range typ.Fields {
- if arg.Option.Type().Name() == typ1.Name() {
- found = true
+ if arg.Option.Type().FieldName() == typ1.FieldName() {
+ optType = typ1
break
}
}
- if !found {
+ if optType == nil {
return fmt.Errorf("union arg '%v' has bad option", typ.Name())
}
- return ctx.validateArg(arg.Option)
+ return ctx.validateArg(arg.Option, optType)
}
func (arg *PointerArg) validate(ctx *validCtx) error {
@@ -240,7 +245,7 @@ func (arg *PointerArg) validate(ctx *validCtx) error {
return fmt.Errorf("non optional pointer arg '%v' is nil", typ.Name())
}
if arg.Res != nil {
- if err := ctx.validateArg(arg.Res); err != nil {
+ if err := ctx.validateArg(arg.Res, typ.Type); err != nil {
return err
}
}