aboutsummaryrefslogtreecommitdiffstats
path: root/prog
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2018-05-05 11:43:00 +0200
committerDmitry Vyukov <dvyukov@google.com>2018-05-05 11:43:00 +0200
commit6a0382b54364673499ec19d3cdad20534c564bce (patch)
treedbf5bfbbef8b1afa2f02ffee12690e90a9b521c6 /prog
parentafe402d20af0d54d4e0baeb9e70e668e2a26f188 (diff)
prog: rework validation code
The current code is total, unstructured mess. Since we now have 1:1 type -> arg correspondence, rework validation around args. This makes code much cleaner and 30% shorter.
Diffstat (limited to 'prog')
-rw-r--r--prog/prog.go1
-rw-r--r--prog/validation.go410
2 files changed, 170 insertions, 241 deletions
diff --git a/prog/prog.go b/prog/prog.go
index 19f5e319f..3474950b9 100644
--- a/prog/prog.go
+++ b/prog/prog.go
@@ -21,6 +21,7 @@ type Call struct {
type Arg interface {
Type() Type
Size() uint64
+ validate(ctx *validCtx) error
}
type ArgCommon struct {
diff --git a/prog/validation.go b/prog/validation.go
index ac7faeee2..73b18c7e1 100644
--- a/prog/validation.go
+++ b/prog/validation.go
@@ -22,8 +22,11 @@ func (p *Prog) validate() error {
uses: make(map[Arg]Arg),
}
for _, c := range p.Calls {
- if err := p.validateCall(ctx, c); err != nil {
- return err
+ if c.Meta == nil {
+ return fmt.Errorf("call does not have meta information")
+ }
+ if err := ctx.validateCall(c); err != nil {
+ return fmt.Errorf("call %v: %v", c.Meta.Name, err)
}
}
for u, orig := range ctx.uses {
@@ -34,288 +37,213 @@ func (p *Prog) validate() error {
return nil
}
-func (p *Prog) validateCall(ctx *validCtx, c *Call) error {
- if c.Meta == nil {
- return fmt.Errorf("call does not have meta information")
- }
+func (ctx *validCtx) validateCall(c *Call) error {
if len(c.Args) != len(c.Meta.Args) {
- return fmt.Errorf("syscall %v: wrong number of arguments, want %v, got %v",
- c.Meta.Name, len(c.Meta.Args), len(c.Args))
+ return fmt.Errorf("wrong number of arguments, want %v, got %v",
+ len(c.Meta.Args), len(c.Args))
}
for _, arg := range c.Args {
- if err := validateArg(ctx, c, arg); err != nil {
+ if err := ctx.validateArg(arg); err != nil {
return err
}
}
+ return ctx.validateRet(c)
+}
+
+func (ctx *validCtx) validateRet(c *Call) error {
if c.Meta.Ret == nil {
if c.Ret != nil {
- return fmt.Errorf("syscall %v: return value without type", c.Meta.Name)
- }
- } else {
- if c.Ret == nil {
- return fmt.Errorf("syscall %v: return value is absent", c.Meta.Name)
- }
- if c.Ret.Type() != c.Meta.Ret {
- return fmt.Errorf("syscall %v: wrong return type", c.Meta.Name)
- }
- if c.Ret.Type().Dir() != DirOut {
- return fmt.Errorf("syscall %v: return value %v is not output", c.Meta.Name, c.Ret)
- }
- if c.Ret.Res != nil || c.Ret.Val != 0 || c.Ret.OpDiv != 0 || c.Ret.OpAdd != 0 {
- return fmt.Errorf("syscall %v: return value %v is not empty", c.Meta.Name, c.Ret)
- }
- if err := validateArg(ctx, c, c.Ret); err != nil {
- return err
+ return fmt.Errorf("return value without type")
}
+ return nil
}
- return nil
+ if c.Ret == nil {
+ return fmt.Errorf("return value is absent")
+ }
+ if c.Ret.Type() != c.Meta.Ret {
+ return fmt.Errorf("wrong return type: %#v vs %#v", c.Ret.Type(), c.Meta.Ret)
+ }
+ if c.Ret.Type().Dir() != DirOut {
+ return fmt.Errorf("return value %v is not output", c.Ret)
+ }
+ if c.Ret.Res != nil || c.Ret.Val != 0 || c.Ret.OpDiv != 0 || c.Ret.OpAdd != 0 {
+ return fmt.Errorf("return value %v is not empty", c.Ret)
+ }
+ return ctx.validateArg(c.Ret)
}
-// nolint: gocyclo
-func validateArg(ctx *validCtx, c *Call, arg Arg) error {
+func (ctx *validCtx) validateArg(arg Arg) error {
if arg == nil {
- return fmt.Errorf("syscall %v: nil arg", c.Meta.Name)
+ return fmt.Errorf("nil arg")
}
if ctx.args[arg] {
- return fmt.Errorf("syscall %v: arg %#v is referenced several times in the tree",
- c.Meta.Name, arg)
- }
- ctx.args[arg] = true
- // TODO(dvyukov): move this to ResultArg verification.
- if used, ok := arg.(*ResultArg); ok {
- for u := range used.uses {
- if u == nil {
- return fmt.Errorf("syscall %v: nil reference in uses for arg %+v",
- c.Meta.Name, arg)
- }
- ctx.uses[u] = arg
- }
+ return fmt.Errorf("arg %#v is referenced several times in the tree", arg)
}
if arg.Type() == nil {
- return fmt.Errorf("syscall %v: no type", c.Meta.Name)
- }
- if arg.Type().Dir() == DirOut {
- switch a := arg.(type) {
- case *ConstArg:
- // We generate output len arguments, which makes sense since it can be
- // a length of a variable-length array which is not known otherwise.
- if _, ok := a.Type().(*LenType); ok {
- break
- }
- if a.Val != 0 && a.Val != a.Type().Default() {
- return fmt.Errorf("syscall %v: output arg '%v'/'%v' has non default value '%+v'",
- c.Meta.Name, a.Type().FieldName(), a.Type().Name(), a)
- }
- case *DataArg:
- if len(a.data) != 0 {
- return fmt.Errorf("syscall %v: output arg '%v' has data",
- c.Meta.Name, a.Type().Name())
- }
- }
+ return fmt.Errorf("no arg type")
}
+ ctx.args[arg] = true
+ return arg.validate(ctx)
+}
- switch typ1 := arg.Type().(type) {
+func (arg *ConstArg) validate(ctx *validCtx) error {
+ switch typ := arg.Type().(type) {
case *IntType:
- switch a := arg.(type) {
- case *ConstArg:
- if a.Type().Dir() == DirOut && (a.Val != 0 && a.Val != a.Type().Default()) {
- return fmt.Errorf("syscall %v: out int arg '%v' has bad const value %v",
- c.Meta.Name, a.Type().Name(), a.Val)
- }
- default:
- return fmt.Errorf("syscall %v: int arg '%v' has bad kind %v",
- c.Meta.Name, arg.Type().Name(), arg)
- }
- case *ResourceType:
- switch a := arg.(type) {
- case *ResultArg:
- if a.Type().Dir() == DirOut && (a.Val != 0 && a.Val != a.Type().Default()) {
- return fmt.Errorf("syscall %v: out resource arg '%v' has bad const value %v",
- c.Meta.Name, a.Type().Name(), a.Val)
- }
- default:
- return fmt.Errorf("syscall %v: fd arg '%v' has bad kind %v",
- c.Meta.Name, arg.Type().Name(), arg)
- }
- case *StructType, *ArrayType:
- switch arg.(type) {
- case *GroupArg:
- default:
- return fmt.Errorf("syscall %v: struct/array arg '%v' has bad kind %#v",
- c.Meta.Name, arg.Type().Name(), arg)
- }
- case *UnionType:
- switch arg.(type) {
- case *UnionArg:
- default:
- return fmt.Errorf("syscall %v: union arg '%v' has bad kind %v",
- c.Meta.Name, arg.Type().Name(), arg)
+ if typ.Dir() == DirOut && (arg.Val != 0 && arg.Val != typ.Default()) {
+ return fmt.Errorf("out int arg '%v' has bad const value %v", typ.Name(), arg.Val)
}
case *ProcType:
- switch a := arg.(type) {
- case *ConstArg:
- if a.Val >= typ1.ValuesPerProc && a.Val != typ1.Default() {
- return fmt.Errorf("syscall %v: per proc arg '%v' has bad value '%v'",
- c.Meta.Name, a.Type().Name(), a.Val)
- }
- default:
- return fmt.Errorf("syscall %v: proc arg '%v' has bad kind %v",
- c.Meta.Name, arg.Type().Name(), arg)
- }
- case *BufferType:
- switch a := arg.(type) {
- case *DataArg:
- switch typ1.Kind {
- case BufferString:
- if typ1.TypeSize != 0 && a.Size() != typ1.TypeSize {
- return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v",
- c.Meta.Name, a.Type().Name(), a.Size(), typ1.TypeSize)
- }
- }
- default:
- return fmt.Errorf("syscall %v: buffer arg '%v' has bad kind %v",
- c.Meta.Name, arg.Type().Name(), arg)
+ if arg.Val >= typ.ValuesPerProc && arg.Val != typ.Default() {
+ return fmt.Errorf("per proc arg '%v' has bad value %v", typ.Name(), arg.Val)
}
case *CsumType:
- switch a := arg.(type) {
- case *ConstArg:
- if a.Val != 0 {
- return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v",
- c.Meta.Name, a.Type().Name(), a.Val)
- }
- default:
- return fmt.Errorf("syscall %v: csum arg '%v' has bad kind %v",
- c.Meta.Name, arg.Type().Name(), arg)
+ if arg.Val != 0 {
+ return fmt.Errorf("csum arg '%v' has nonzero value %v", typ.Name(), arg.Val)
}
- case *PtrType:
- switch a := arg.(type) {
- case *PointerArg:
- if a.Type().Dir() == DirOut {
- return fmt.Errorf("syscall %v: pointer arg '%v' has output direction",
- c.Meta.Name, a.Type().Name())
- }
- if a.Res == nil && !a.Type().Optional() {
- return fmt.Errorf("syscall %v: non optional pointer arg '%v' is nil",
- c.Meta.Name, a.Type().Name())
+ case *ConstType:
+ case *FlagsType:
+ case *LenType:
+ default:
+ return fmt.Errorf("const arg %v has bad type %v", arg, typ.Name())
+ }
+ if typ := arg.Type(); typ.Dir() == DirOut {
+ // We generate output len arguments, which makes sense since it can be
+ // a length of a variable-length array which is not known otherwise.
+ if _, isLen := typ.(*LenType); !isLen {
+ if arg.Val != 0 && arg.Val != typ.Default() {
+ return fmt.Errorf("output arg '%v'/'%v' has non default value '%+v'",
+ typ.FieldName(), typ.Name(), arg)
}
- default:
- return fmt.Errorf("syscall %v: ptr arg '%v' has bad kind %v",
- c.Meta.Name, arg.Type().Name(), arg)
}
}
+ return nil
+}
- switch a := arg.(type) {
- case *ConstArg:
- case *PointerArg:
- maxMem := ctx.target.NumPages * ctx.target.PageSize
- size := a.VmaSize
- if size == 0 && a.Res != nil {
- size = a.Res.Size()
+func (arg *ResultArg) validate(ctx *validCtx) error {
+ typ, ok := arg.Type().(*ResourceType)
+ if !ok {
+ return fmt.Errorf("result arg %v has bad type %v", arg, arg.Type().Name())
+ }
+ for u := range arg.uses {
+ if u == nil {
+ return fmt.Errorf("nil reference in uses for arg %+v", arg)
}
- if a.Address >= maxMem || a.Address+size > maxMem {
- return fmt.Errorf("syscall %v: ptr %v has bad address %v/%v/%v",
- c.Meta.Name, a.Type().Name(), a.Address, a.VmaSize, size)
+ ctx.uses[u] = arg
+ }
+ if typ.Dir() == DirOut && (arg.Val != 0 && arg.Val != typ.Default()) {
+ return fmt.Errorf("out resource arg '%v' has bad const value %v", typ.Name(), arg.Val)
+ }
+ if arg.Res != nil {
+ if !ctx.args[arg.Res] {
+ return fmt.Errorf("result arg %v references out-of-tree result: %#v -> %#v",
+ typ.Name(), arg, arg.Res)
}
- switch t := a.Type().(type) {
- case *VmaType:
- if a.Res != nil {
- return fmt.Errorf("syscall %v: vma arg '%v' has data",
- c.Meta.Name, a.Type().Name())
- }
- if a.VmaSize == 0 && t.Dir() != DirOut && !t.Optional() {
- return fmt.Errorf("syscall %v: vma arg '%v' has size 0",
- c.Meta.Name, a.Type().Name())
- }
- case *PtrType:
- if a.Res != nil {
- if err := validateArg(ctx, c, a.Res); err != nil {
- return err
- }
- }
- if a.VmaSize != 0 {
- return fmt.Errorf("syscall %v: pointer arg '%v' has nonzero size",
- c.Meta.Name, a.Type().Name())
- }
- default:
- return fmt.Errorf("syscall %v: pointer arg '%v' has bad meta type %+v",
- c.Meta.Name, arg.Type().Name(), arg.Type())
+ if !arg.Res.uses[arg] {
+ return fmt.Errorf("result arg '%v' has broken link (%+v)", typ.Name(), arg.Res.uses)
}
- case *DataArg:
- typ1 := a.Type()
- if !typ1.Varlen() && typ1.Size() != a.Size() {
- return fmt.Errorf("syscall %v: data arg %v has wrong size %v, want %v",
- c.Meta.Name, arg.Type().Name(), a.Size(), typ1.Size())
+ }
+ return nil
+}
+
+func (arg *DataArg) validate(ctx *validCtx) error {
+ typ, ok := arg.Type().(*BufferType)
+ if !ok {
+ return fmt.Errorf("data arg %v has bad type %v", arg, arg.Type().Name())
+ }
+ if typ.Dir() == DirOut && len(arg.data) != 0 {
+ return fmt.Errorf("output arg '%v' has data", typ.Name())
+ }
+ if !typ.Varlen() && typ.Size() != arg.Size() {
+ return fmt.Errorf("data arg %v has wrong size %v, want %v",
+ typ.Name(), arg.Size(), typ.Size())
+ }
+ switch typ.Kind {
+ case BufferString:
+ if typ.TypeSize != 0 && arg.Size() != typ.TypeSize {
+ return fmt.Errorf("string arg '%v' has size %v, which should be %v",
+ typ.Name(), arg.Size(), typ.TypeSize)
}
- case *GroupArg:
- switch typ1 := a.Type().(type) {
- case *StructType:
- if len(a.Inner) != len(typ1.Fields) {
- return fmt.Errorf("syscall %v: struct arg '%v' has wrong number of fields: want %v, got %v",
- c.Meta.Name, a.Type().Name(), len(typ1.Fields), len(a.Inner))
- }
- for _, arg1 := range a.Inner {
- if err := validateArg(ctx, c, arg1); err != nil {
- return err
- }
- }
- case *ArrayType:
- if typ1.Kind == ArrayRangeLen && typ1.RangeBegin == typ1.RangeEnd &&
- uint64(len(a.Inner)) != typ1.RangeBegin {
- return fmt.Errorf("syscall %v: array %v has wrong number"+
- " of elements %v, want %v",
- c.Meta.Name, arg.Type().Name(),
- len(a.Inner), typ1.RangeBegin)
- }
- for _, arg1 := range a.Inner {
- if err := validateArg(ctx, c, arg1); err != nil {
- return err
- }
- }
- default:
- return fmt.Errorf("syscall %v: group arg '%v' has bad underlying type %+v",
- c.Meta.Name, arg.Type().Name(), arg.Type())
+ }
+ return nil
+}
+
+func (arg *GroupArg) validate(ctx *validCtx) error {
+ switch typ := arg.Type().(type) {
+ case *StructType:
+ if len(arg.Inner) != len(typ.Fields) {
+ return fmt.Errorf("struct arg '%v' has wrong number of fields: want %v, got %v",
+ typ.Name(), len(typ.Fields), len(arg.Inner))
+ }
+ case *ArrayType:
+ if typ.Kind == ArrayRangeLen && typ.RangeBegin == typ.RangeEnd &&
+ uint64(len(arg.Inner)) != typ.RangeBegin {
+ return fmt.Errorf("array %v has wrong number of elements %v, want %v",
+ typ.Name(), len(arg.Inner), typ.RangeBegin)
}
- case *UnionArg:
- typ1, ok := a.Type().(*UnionType)
- if !ok {
- return fmt.Errorf("syscall %v: union arg '%v' has bad type",
- c.Meta.Name, a.Type().Name())
+ default:
+ return fmt.Errorf("group arg %v has bad type %v", arg, typ.Name())
+ }
+ for _, arg1 := range arg.Inner {
+ if err := ctx.validateArg(arg1); err != nil {
+ return err
}
- found := false
- for _, typ2 := range typ1.Fields {
- if a.Option.Type().Name() == typ2.Name() {
- found = true
- break
- }
+ }
+ return nil
+}
+
+func (arg *UnionArg) validate(ctx *validCtx) error {
+ typ, ok := arg.Type().(*UnionType)
+ if !ok {
+ return fmt.Errorf("union arg %v has bad type %v", arg, arg.Type().Name())
+ }
+ found := false
+ for _, typ1 := range typ.Fields {
+ if arg.Option.Type().Name() == typ1.Name() {
+ found = true
+ break
}
- if !found {
- return fmt.Errorf("syscall %v: union arg '%v' has bad option",
- c.Meta.Name, a.Type().Name())
+ }
+ if !found {
+ return fmt.Errorf("union arg '%v' has bad option", typ.Name())
+ }
+ return ctx.validateArg(arg.Option)
+}
+
+func (arg *PointerArg) validate(ctx *validCtx) error {
+ maxMem := ctx.target.NumPages * ctx.target.PageSize
+ size := arg.VmaSize
+ if size == 0 && arg.Res != nil {
+ size = arg.Res.Size()
+ }
+ if arg.Address >= maxMem || arg.Address+size > maxMem {
+ return fmt.Errorf("ptr %v has bad address %v/%v/%v",
+ arg.Type().Name(), arg.Address, arg.VmaSize, size)
+ }
+ switch typ := arg.Type().(type) {
+ case *VmaType:
+ if arg.Res != nil {
+ return fmt.Errorf("vma arg '%v' has data", typ.Name())
}
- if err := validateArg(ctx, c, a.Option); err != nil {
- return err
+ if arg.VmaSize == 0 && typ.Dir() != DirOut && !typ.Optional() {
+ return fmt.Errorf("vma arg '%v' has size 0", typ.Name())
}
- case *ResultArg:
- switch a.Type().(type) {
- case *ResourceType:
- default:
- return fmt.Errorf("syscall %v: result arg '%v' has bad meta type %+v",
- c.Meta.Name, arg.Type().Name(), arg.Type())
+ case *PtrType:
+ if arg.Res == nil && !arg.Type().Optional() {
+ return fmt.Errorf("non optional pointer arg '%v' is nil", typ.Name())
}
- if a.Res == nil {
- break
+ if arg.Res != nil {
+ if err := ctx.validateArg(arg.Res); err != nil {
+ return err
+ }
}
- if !ctx.args[a.Res] {
- return fmt.Errorf("syscall %v: result arg %v references out-of-tree result: %#v -> %#v",
- c.Meta.Name, a.Type().Name(), arg, a.Res)
+ if arg.VmaSize != 0 {
+ return fmt.Errorf("pointer arg '%v' has nonzero size", typ.Name())
}
- if !a.Res.uses[a] {
- return fmt.Errorf("syscall %v: result arg '%v' has broken link (%+v)",
- c.Meta.Name, a.Type().Name(), a.Res.uses)
+ if typ.Dir() == DirOut {
+ return fmt.Errorf("pointer arg '%v' has output direction", typ.Name())
}
default:
- return fmt.Errorf("syscall %v: unknown arg '%v' kind",
- c.Meta.Name, arg.Type().Name())
+ return fmt.Errorf("ptr arg %v has bad type %v", arg, typ.Name())
}
return nil
}