diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2018-08-01 21:06:38 +0200 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2018-08-02 16:57:31 +0200 |
| commit | 95a080a682778a08d682897f97d0dae5c2201d76 (patch) | |
| tree | e258c4abaad3a5a99e8225a4181238ad7dc01754 | |
| parent | ae2f24aa70517d79d364a7202e070f2de6fd4451 (diff) | |
prog: strenghten type checking during validation
Check that argument types match expected static types.
I.e. detect when, say, syscall argument is a resource,
but actual generated argument is a pointer.
| -rw-r--r-- | prog/types.go | 11 | ||||
| -rw-r--r-- | prog/validation.go | 41 |
2 files changed, 28 insertions, 24 deletions
diff --git a/prog/types.go b/prog/types.go index 004194707..68c80e23c 100644 --- a/prog/types.go +++ b/prog/types.go @@ -432,7 +432,7 @@ func (t *ArrayType) isDefaultArg(arg Arg) bool { return false } for _, elem := range a.Inner { - if !t.Type.isDefaultArg(elem) { + if !isDefault(elem) { return false } } @@ -460,7 +460,7 @@ func (t *PtrType) isDefaultArg(arg Arg) bool { if t.Optional() { return a.IsNull() } - return a.Address == 0 && t.Type.isDefaultArg(a.Res) + return a.Address == 0 && isDefault(a.Res) } type StructType struct { @@ -487,8 +487,8 @@ func (t *StructType) makeDefaultArg() Arg { func (t *StructType) isDefaultArg(arg Arg) bool { a := arg.(*GroupArg) - for i, elem := range a.Inner { - if !t.Fields[i].isDefaultArg(elem) { + for _, elem := range a.Inner { + if !isDefault(elem) { return false } } @@ -515,8 +515,7 @@ func (t *UnionType) makeDefaultArg() Arg { func (t *UnionType) isDefaultArg(arg Arg) bool { a := arg.(*UnionArg) - return a.Option.Type().FieldName() == t.Fields[0].FieldName() && - t.Fields[0].isDefaultArg(a.Option) + return a.Option.Type().FieldName() == t.Fields[0].FieldName() && isDefault(a.Option) } type StructDesc struct { diff --git a/prog/validation.go b/prog/validation.go index 3274fca9e..eaffac9b8 100644 --- a/prog/validation.go +++ b/prog/validation.go @@ -51,8 +51,8 @@ func (ctx *validCtx) validateCall(c *Call) error { return fmt.Errorf("wrong number of arguments, want %v, got %v", len(c.Meta.Args), len(c.Args)) } - for _, arg := range c.Args { - if err := ctx.validateArg(arg); err != nil { + for i, arg := range c.Args { + if err := ctx.validateArg(arg, c.Meta.Args[i]); err != nil { return err } } @@ -69,19 +69,16 @@ func (ctx *validCtx) validateRet(c *Call) error { if c.Ret == nil { return fmt.Errorf("return value is absent") } - if c.Ret.Type() != c.Meta.Ret { - return fmt.Errorf("wrong return type: %#v vs %#v", c.Ret.Type(), c.Meta.Ret) - } if c.Ret.Type().Dir() != DirOut { return fmt.Errorf("return value %v is not output", c.Ret) } if c.Ret.Res != nil || c.Ret.Val != 0 || c.Ret.OpDiv != 0 || c.Ret.OpAdd != 0 { return fmt.Errorf("return value %v is not empty", c.Ret) } - return ctx.validateArg(c.Ret) + return ctx.validateArg(c.Ret, c.Meta.Ret) } -func (ctx *validCtx) validateArg(arg Arg) error { +func (ctx *validCtx) validateArg(arg Arg, typ Type) error { if arg == nil { return fmt.Errorf("nil arg") } @@ -91,6 +88,9 @@ func (ctx *validCtx) validateArg(arg Arg) error { if arg.Type() == nil { return fmt.Errorf("no arg type") } + if !ctx.target.isAnyPtr(arg.Type()) && arg.Type() != typ { + return fmt.Errorf("bad arg type %#v, expect %#v", arg.Type(), typ) + } ctx.args[arg] = true return arg.validate(ctx) } @@ -194,20 +194,25 @@ func (arg *GroupArg) validate(ctx *validCtx) error { return fmt.Errorf("struct arg '%v' has wrong number of fields: want %v, got %v", typ.Name(), len(typ.Fields), len(arg.Inner)) } + for i, field := range arg.Inner { + if err := ctx.validateArg(field, typ.Fields[i]); err != nil { + return err + } + } case *ArrayType: if typ.Kind == ArrayRangeLen && typ.RangeBegin == typ.RangeEnd && uint64(len(arg.Inner)) != typ.RangeBegin { return fmt.Errorf("array %v has wrong number of elements %v, want %v", typ.Name(), len(arg.Inner), typ.RangeBegin) } + for _, elem := range arg.Inner { + if err := ctx.validateArg(elem, typ.Type); err != nil { + return err + } + } default: return fmt.Errorf("group arg %v has bad type %v", arg, typ.Name()) } - for _, arg1 := range arg.Inner { - if err := ctx.validateArg(arg1); err != nil { - return err - } - } return nil } @@ -216,17 +221,17 @@ func (arg *UnionArg) validate(ctx *validCtx) error { if !ok { return fmt.Errorf("union arg %v has bad type %v", arg, arg.Type().Name()) } - found := false + var optType Type for _, typ1 := range typ.Fields { - if arg.Option.Type().Name() == typ1.Name() { - found = true + if arg.Option.Type().FieldName() == typ1.FieldName() { + optType = typ1 break } } - if !found { + if optType == nil { return fmt.Errorf("union arg '%v' has bad option", typ.Name()) } - return ctx.validateArg(arg.Option) + return ctx.validateArg(arg.Option, optType) } func (arg *PointerArg) validate(ctx *validCtx) error { @@ -240,7 +245,7 @@ func (arg *PointerArg) validate(ctx *validCtx) error { return fmt.Errorf("non optional pointer arg '%v' is nil", typ.Name()) } if arg.Res != nil { - if err := ctx.validateArg(arg.Res); err != nil { + if err := ctx.validateArg(arg.Res, typ.Type); err != nil { return err } } |
