diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2018-05-05 11:43:00 +0200 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2018-05-05 11:43:00 +0200 |
| commit | 6a0382b54364673499ec19d3cdad20534c564bce (patch) | |
| tree | dbf5bfbbef8b1afa2f02ffee12690e90a9b521c6 /prog | |
| parent | afe402d20af0d54d4e0baeb9e70e668e2a26f188 (diff) | |
prog: rework validation code
The current code is total, unstructured mess.
Since we now have 1:1 type -> arg correspondence,
rework validation around args. This makes code
much cleaner and 30% shorter.
Diffstat (limited to 'prog')
| -rw-r--r-- | prog/prog.go | 1 | ||||
| -rw-r--r-- | prog/validation.go | 410 |
2 files changed, 170 insertions, 241 deletions
diff --git a/prog/prog.go b/prog/prog.go index 19f5e319f..3474950b9 100644 --- a/prog/prog.go +++ b/prog/prog.go @@ -21,6 +21,7 @@ type Call struct { type Arg interface { Type() Type Size() uint64 + validate(ctx *validCtx) error } type ArgCommon struct { diff --git a/prog/validation.go b/prog/validation.go index ac7faeee2..73b18c7e1 100644 --- a/prog/validation.go +++ b/prog/validation.go @@ -22,8 +22,11 @@ func (p *Prog) validate() error { uses: make(map[Arg]Arg), } for _, c := range p.Calls { - if err := p.validateCall(ctx, c); err != nil { - return err + if c.Meta == nil { + return fmt.Errorf("call does not have meta information") + } + if err := ctx.validateCall(c); err != nil { + return fmt.Errorf("call %v: %v", c.Meta.Name, err) } } for u, orig := range ctx.uses { @@ -34,288 +37,213 @@ func (p *Prog) validate() error { return nil } -func (p *Prog) validateCall(ctx *validCtx, c *Call) error { - if c.Meta == nil { - return fmt.Errorf("call does not have meta information") - } +func (ctx *validCtx) validateCall(c *Call) error { if len(c.Args) != len(c.Meta.Args) { - return fmt.Errorf("syscall %v: wrong number of arguments, want %v, got %v", - c.Meta.Name, len(c.Meta.Args), len(c.Args)) + return fmt.Errorf("wrong number of arguments, want %v, got %v", + len(c.Meta.Args), len(c.Args)) } for _, arg := range c.Args { - if err := validateArg(ctx, c, arg); err != nil { + if err := ctx.validateArg(arg); err != nil { return err } } + return ctx.validateRet(c) +} + +func (ctx *validCtx) validateRet(c *Call) error { if c.Meta.Ret == nil { if c.Ret != nil { - return fmt.Errorf("syscall %v: return value without type", c.Meta.Name) - } - } else { - if c.Ret == nil { - return fmt.Errorf("syscall %v: return value is absent", c.Meta.Name) - } - if c.Ret.Type() != c.Meta.Ret { - return fmt.Errorf("syscall %v: wrong return type", c.Meta.Name) - } - if c.Ret.Type().Dir() != DirOut { - return fmt.Errorf("syscall %v: return value %v is not output", c.Meta.Name, c.Ret) - } - if c.Ret.Res != nil || c.Ret.Val != 0 || c.Ret.OpDiv != 0 || c.Ret.OpAdd != 0 { - return fmt.Errorf("syscall %v: return value %v is not empty", c.Meta.Name, c.Ret) - } - if err := validateArg(ctx, c, c.Ret); err != nil { - return err + return fmt.Errorf("return value without type") } + return nil } - return nil + if c.Ret == nil { + return fmt.Errorf("return value is absent") + } + if c.Ret.Type() != c.Meta.Ret { + return fmt.Errorf("wrong return type: %#v vs %#v", c.Ret.Type(), c.Meta.Ret) + } + if c.Ret.Type().Dir() != DirOut { + return fmt.Errorf("return value %v is not output", c.Ret) + } + if c.Ret.Res != nil || c.Ret.Val != 0 || c.Ret.OpDiv != 0 || c.Ret.OpAdd != 0 { + return fmt.Errorf("return value %v is not empty", c.Ret) + } + return ctx.validateArg(c.Ret) } -// nolint: gocyclo -func validateArg(ctx *validCtx, c *Call, arg Arg) error { +func (ctx *validCtx) validateArg(arg Arg) error { if arg == nil { - return fmt.Errorf("syscall %v: nil arg", c.Meta.Name) + return fmt.Errorf("nil arg") } if ctx.args[arg] { - return fmt.Errorf("syscall %v: arg %#v is referenced several times in the tree", - c.Meta.Name, arg) - } - ctx.args[arg] = true - // TODO(dvyukov): move this to ResultArg verification. - if used, ok := arg.(*ResultArg); ok { - for u := range used.uses { - if u == nil { - return fmt.Errorf("syscall %v: nil reference in uses for arg %+v", - c.Meta.Name, arg) - } - ctx.uses[u] = arg - } + return fmt.Errorf("arg %#v is referenced several times in the tree", arg) } if arg.Type() == nil { - return fmt.Errorf("syscall %v: no type", c.Meta.Name) - } - if arg.Type().Dir() == DirOut { - switch a := arg.(type) { - case *ConstArg: - // We generate output len arguments, which makes sense since it can be - // a length of a variable-length array which is not known otherwise. - if _, ok := a.Type().(*LenType); ok { - break - } - if a.Val != 0 && a.Val != a.Type().Default() { - return fmt.Errorf("syscall %v: output arg '%v'/'%v' has non default value '%+v'", - c.Meta.Name, a.Type().FieldName(), a.Type().Name(), a) - } - case *DataArg: - if len(a.data) != 0 { - return fmt.Errorf("syscall %v: output arg '%v' has data", - c.Meta.Name, a.Type().Name()) - } - } + return fmt.Errorf("no arg type") } + ctx.args[arg] = true + return arg.validate(ctx) +} - switch typ1 := arg.Type().(type) { +func (arg *ConstArg) validate(ctx *validCtx) error { + switch typ := arg.Type().(type) { case *IntType: - switch a := arg.(type) { - case *ConstArg: - if a.Type().Dir() == DirOut && (a.Val != 0 && a.Val != a.Type().Default()) { - return fmt.Errorf("syscall %v: out int arg '%v' has bad const value %v", - c.Meta.Name, a.Type().Name(), a.Val) - } - default: - return fmt.Errorf("syscall %v: int arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) - } - case *ResourceType: - switch a := arg.(type) { - case *ResultArg: - if a.Type().Dir() == DirOut && (a.Val != 0 && a.Val != a.Type().Default()) { - return fmt.Errorf("syscall %v: out resource arg '%v' has bad const value %v", - c.Meta.Name, a.Type().Name(), a.Val) - } - default: - return fmt.Errorf("syscall %v: fd arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) - } - case *StructType, *ArrayType: - switch arg.(type) { - case *GroupArg: - default: - return fmt.Errorf("syscall %v: struct/array arg '%v' has bad kind %#v", - c.Meta.Name, arg.Type().Name(), arg) - } - case *UnionType: - switch arg.(type) { - case *UnionArg: - default: - return fmt.Errorf("syscall %v: union arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) + if typ.Dir() == DirOut && (arg.Val != 0 && arg.Val != typ.Default()) { + return fmt.Errorf("out int arg '%v' has bad const value %v", typ.Name(), arg.Val) } case *ProcType: - switch a := arg.(type) { - case *ConstArg: - if a.Val >= typ1.ValuesPerProc && a.Val != typ1.Default() { - return fmt.Errorf("syscall %v: per proc arg '%v' has bad value '%v'", - c.Meta.Name, a.Type().Name(), a.Val) - } - default: - return fmt.Errorf("syscall %v: proc arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) - } - case *BufferType: - switch a := arg.(type) { - case *DataArg: - switch typ1.Kind { - case BufferString: - if typ1.TypeSize != 0 && a.Size() != typ1.TypeSize { - return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", - c.Meta.Name, a.Type().Name(), a.Size(), typ1.TypeSize) - } - } - default: - return fmt.Errorf("syscall %v: buffer arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) + if arg.Val >= typ.ValuesPerProc && arg.Val != typ.Default() { + return fmt.Errorf("per proc arg '%v' has bad value %v", typ.Name(), arg.Val) } case *CsumType: - switch a := arg.(type) { - case *ConstArg: - if a.Val != 0 { - return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v", - c.Meta.Name, a.Type().Name(), a.Val) - } - default: - return fmt.Errorf("syscall %v: csum arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) + if arg.Val != 0 { + return fmt.Errorf("csum arg '%v' has nonzero value %v", typ.Name(), arg.Val) } - case *PtrType: - switch a := arg.(type) { - case *PointerArg: - if a.Type().Dir() == DirOut { - return fmt.Errorf("syscall %v: pointer arg '%v' has output direction", - c.Meta.Name, a.Type().Name()) - } - if a.Res == nil && !a.Type().Optional() { - return fmt.Errorf("syscall %v: non optional pointer arg '%v' is nil", - c.Meta.Name, a.Type().Name()) + case *ConstType: + case *FlagsType: + case *LenType: + default: + return fmt.Errorf("const arg %v has bad type %v", arg, typ.Name()) + } + if typ := arg.Type(); typ.Dir() == DirOut { + // We generate output len arguments, which makes sense since it can be + // a length of a variable-length array which is not known otherwise. + if _, isLen := typ.(*LenType); !isLen { + if arg.Val != 0 && arg.Val != typ.Default() { + return fmt.Errorf("output arg '%v'/'%v' has non default value '%+v'", + typ.FieldName(), typ.Name(), arg) } - default: - return fmt.Errorf("syscall %v: ptr arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) } } + return nil +} - switch a := arg.(type) { - case *ConstArg: - case *PointerArg: - maxMem := ctx.target.NumPages * ctx.target.PageSize - size := a.VmaSize - if size == 0 && a.Res != nil { - size = a.Res.Size() +func (arg *ResultArg) validate(ctx *validCtx) error { + typ, ok := arg.Type().(*ResourceType) + if !ok { + return fmt.Errorf("result arg %v has bad type %v", arg, arg.Type().Name()) + } + for u := range arg.uses { + if u == nil { + return fmt.Errorf("nil reference in uses for arg %+v", arg) } - if a.Address >= maxMem || a.Address+size > maxMem { - return fmt.Errorf("syscall %v: ptr %v has bad address %v/%v/%v", - c.Meta.Name, a.Type().Name(), a.Address, a.VmaSize, size) + ctx.uses[u] = arg + } + if typ.Dir() == DirOut && (arg.Val != 0 && arg.Val != typ.Default()) { + return fmt.Errorf("out resource arg '%v' has bad const value %v", typ.Name(), arg.Val) + } + if arg.Res != nil { + if !ctx.args[arg.Res] { + return fmt.Errorf("result arg %v references out-of-tree result: %#v -> %#v", + typ.Name(), arg, arg.Res) } - switch t := a.Type().(type) { - case *VmaType: - if a.Res != nil { - return fmt.Errorf("syscall %v: vma arg '%v' has data", - c.Meta.Name, a.Type().Name()) - } - if a.VmaSize == 0 && t.Dir() != DirOut && !t.Optional() { - return fmt.Errorf("syscall %v: vma arg '%v' has size 0", - c.Meta.Name, a.Type().Name()) - } - case *PtrType: - if a.Res != nil { - if err := validateArg(ctx, c, a.Res); err != nil { - return err - } - } - if a.VmaSize != 0 { - return fmt.Errorf("syscall %v: pointer arg '%v' has nonzero size", - c.Meta.Name, a.Type().Name()) - } - default: - return fmt.Errorf("syscall %v: pointer arg '%v' has bad meta type %+v", - c.Meta.Name, arg.Type().Name(), arg.Type()) + if !arg.Res.uses[arg] { + return fmt.Errorf("result arg '%v' has broken link (%+v)", typ.Name(), arg.Res.uses) } - case *DataArg: - typ1 := a.Type() - if !typ1.Varlen() && typ1.Size() != a.Size() { - return fmt.Errorf("syscall %v: data arg %v has wrong size %v, want %v", - c.Meta.Name, arg.Type().Name(), a.Size(), typ1.Size()) + } + return nil +} + +func (arg *DataArg) validate(ctx *validCtx) error { + typ, ok := arg.Type().(*BufferType) + if !ok { + return fmt.Errorf("data arg %v has bad type %v", arg, arg.Type().Name()) + } + if typ.Dir() == DirOut && len(arg.data) != 0 { + return fmt.Errorf("output arg '%v' has data", typ.Name()) + } + if !typ.Varlen() && typ.Size() != arg.Size() { + return fmt.Errorf("data arg %v has wrong size %v, want %v", + typ.Name(), arg.Size(), typ.Size()) + } + switch typ.Kind { + case BufferString: + if typ.TypeSize != 0 && arg.Size() != typ.TypeSize { + return fmt.Errorf("string arg '%v' has size %v, which should be %v", + typ.Name(), arg.Size(), typ.TypeSize) } - case *GroupArg: - switch typ1 := a.Type().(type) { - case *StructType: - if len(a.Inner) != len(typ1.Fields) { - return fmt.Errorf("syscall %v: struct arg '%v' has wrong number of fields: want %v, got %v", - c.Meta.Name, a.Type().Name(), len(typ1.Fields), len(a.Inner)) - } - for _, arg1 := range a.Inner { - if err := validateArg(ctx, c, arg1); err != nil { - return err - } - } - case *ArrayType: - if typ1.Kind == ArrayRangeLen && typ1.RangeBegin == typ1.RangeEnd && - uint64(len(a.Inner)) != typ1.RangeBegin { - return fmt.Errorf("syscall %v: array %v has wrong number"+ - " of elements %v, want %v", - c.Meta.Name, arg.Type().Name(), - len(a.Inner), typ1.RangeBegin) - } - for _, arg1 := range a.Inner { - if err := validateArg(ctx, c, arg1); err != nil { - return err - } - } - default: - return fmt.Errorf("syscall %v: group arg '%v' has bad underlying type %+v", - c.Meta.Name, arg.Type().Name(), arg.Type()) + } + return nil +} + +func (arg *GroupArg) validate(ctx *validCtx) error { + switch typ := arg.Type().(type) { + case *StructType: + if len(arg.Inner) != len(typ.Fields) { + return fmt.Errorf("struct arg '%v' has wrong number of fields: want %v, got %v", + typ.Name(), len(typ.Fields), len(arg.Inner)) + } + case *ArrayType: + if typ.Kind == ArrayRangeLen && typ.RangeBegin == typ.RangeEnd && + uint64(len(arg.Inner)) != typ.RangeBegin { + return fmt.Errorf("array %v has wrong number of elements %v, want %v", + typ.Name(), len(arg.Inner), typ.RangeBegin) } - case *UnionArg: - typ1, ok := a.Type().(*UnionType) - if !ok { - return fmt.Errorf("syscall %v: union arg '%v' has bad type", - c.Meta.Name, a.Type().Name()) + default: + return fmt.Errorf("group arg %v has bad type %v", arg, typ.Name()) + } + for _, arg1 := range arg.Inner { + if err := ctx.validateArg(arg1); err != nil { + return err } - found := false - for _, typ2 := range typ1.Fields { - if a.Option.Type().Name() == typ2.Name() { - found = true - break - } + } + return nil +} + +func (arg *UnionArg) validate(ctx *validCtx) error { + typ, ok := arg.Type().(*UnionType) + if !ok { + return fmt.Errorf("union arg %v has bad type %v", arg, arg.Type().Name()) + } + found := false + for _, typ1 := range typ.Fields { + if arg.Option.Type().Name() == typ1.Name() { + found = true + break } - if !found { - return fmt.Errorf("syscall %v: union arg '%v' has bad option", - c.Meta.Name, a.Type().Name()) + } + if !found { + return fmt.Errorf("union arg '%v' has bad option", typ.Name()) + } + return ctx.validateArg(arg.Option) +} + +func (arg *PointerArg) validate(ctx *validCtx) error { + maxMem := ctx.target.NumPages * ctx.target.PageSize + size := arg.VmaSize + if size == 0 && arg.Res != nil { + size = arg.Res.Size() + } + if arg.Address >= maxMem || arg.Address+size > maxMem { + return fmt.Errorf("ptr %v has bad address %v/%v/%v", + arg.Type().Name(), arg.Address, arg.VmaSize, size) + } + switch typ := arg.Type().(type) { + case *VmaType: + if arg.Res != nil { + return fmt.Errorf("vma arg '%v' has data", typ.Name()) } - if err := validateArg(ctx, c, a.Option); err != nil { - return err + if arg.VmaSize == 0 && typ.Dir() != DirOut && !typ.Optional() { + return fmt.Errorf("vma arg '%v' has size 0", typ.Name()) } - case *ResultArg: - switch a.Type().(type) { - case *ResourceType: - default: - return fmt.Errorf("syscall %v: result arg '%v' has bad meta type %+v", - c.Meta.Name, arg.Type().Name(), arg.Type()) + case *PtrType: + if arg.Res == nil && !arg.Type().Optional() { + return fmt.Errorf("non optional pointer arg '%v' is nil", typ.Name()) } - if a.Res == nil { - break + if arg.Res != nil { + if err := ctx.validateArg(arg.Res); err != nil { + return err + } } - if !ctx.args[a.Res] { - return fmt.Errorf("syscall %v: result arg %v references out-of-tree result: %#v -> %#v", - c.Meta.Name, a.Type().Name(), arg, a.Res) + if arg.VmaSize != 0 { + return fmt.Errorf("pointer arg '%v' has nonzero size", typ.Name()) } - if !a.Res.uses[a] { - return fmt.Errorf("syscall %v: result arg '%v' has broken link (%+v)", - c.Meta.Name, a.Type().Name(), a.Res.uses) + if typ.Dir() == DirOut { + return fmt.Errorf("pointer arg '%v' has output direction", typ.Name()) } default: - return fmt.Errorf("syscall %v: unknown arg '%v' kind", - c.Meta.Name, arg.Type().Name()) + return fmt.Errorf("ptr arg %v has bad type %v", arg, typ.Name()) } return nil } |
