From 9989eadf779cf109bf3c64beb10e54df5f11ae3c Mon Sep 17 00:00:00 2001 From: Andrey Konovalov Date: Thu, 9 Feb 2017 21:33:14 +0100 Subject: prog: fix cheking nonoptional nil pointers in validation Also update validation code to use arg.Type instead of passing typ recusively. --- prog/validation.go | 91 ++++++++++++++++++++++++++---------------------------- 1 file changed, 43 insertions(+), 48 deletions(-) diff --git a/prog/validation.go b/prog/validation.go index 28c619802..61a12eaab 100644 --- a/prog/validation.go +++ b/prog/validation.go @@ -38,8 +38,8 @@ func (c *Call) validate(ctx *validCtx) error { if len(c.Args) != len(c.Meta.Args) { return fmt.Errorf("syscall %v: wrong number of arguments, want %v, got %v", c.Meta.Name, len(c.Meta.Args), len(c.Args)) } - var checkArg func(arg *Arg, typ sys.Type) error - checkArg = func(arg *Arg, typ sys.Type) error { + var checkArg func(arg *Arg) error + checkArg = func(arg *Arg) error { if arg == nil { return fmt.Errorf("syscall %v: nil arg", c.Meta.Name) } @@ -53,24 +53,18 @@ func (c *Call) validate(ctx *validCtx) error { if arg.Type == nil { return fmt.Errorf("syscall %v: no type", c.Meta.Name) } - if arg.Type.Name() != typ.Name() { - return fmt.Errorf("syscall %v: type name mismatch: %v vs %v", c.Meta.Name, arg.Type.Name(), typ.Name()) - } - if arg.Type.FieldName() != typ.FieldName() { - return fmt.Errorf("syscall %v: field name mismatch: %v vs %v", c.Meta.Name, arg.Type.FieldName(), typ.FieldName()) - } if arg.Type.Dir() == sys.DirOut { if (arg.Val != 0 && arg.Val != arg.Type.Default()) || arg.AddrPage != 0 || arg.AddrOffset != 0 { // We generate output len arguments, which makes sense // since it can be a length of a variable-length array // which is not known otherwise. if _, ok := arg.Type.(*sys.LenType); !ok { - return fmt.Errorf("syscall %v: output arg '%v' has non default value '%v'", c.Meta.Name, typ.Name(), arg.Val) + return fmt.Errorf("syscall %v: output arg '%v' has non default value '%v'", c.Meta.Name, arg.Type.Name(), arg.Val) } } for _, v := range arg.Data { if v != 0 { - return fmt.Errorf("syscall %v: output arg '%v' has data", c.Meta.Name, typ.Name()) + return fmt.Errorf("syscall %v: output arg '%v' has data", c.Meta.Name, arg.Type.Name()) } } } @@ -81,111 +75,112 @@ func (c *Call) validate(ctx *validCtx) error { case ArgReturn: case ArgConst: if arg.Type.Dir() == sys.DirOut && (arg.Val != 0 && arg.Val != arg.Type.Default()) { - return fmt.Errorf("syscall %v: out resource arg '%v' has bad const value %v", c.Meta.Name, typ.Name(), arg.Val) + return fmt.Errorf("syscall %v: out resource arg '%v' has bad const value %v", c.Meta.Name, arg.Type.Name(), arg.Val) } default: - return fmt.Errorf("syscall %v: fd arg '%v' has bad kind %v", c.Meta.Name, typ.Name(), arg.Kind) + return fmt.Errorf("syscall %v: fd arg '%v' has bad kind %v", c.Meta.Name, arg.Type.Name(), arg.Kind) } case *sys.StructType, *sys.ArrayType: switch arg.Kind { case ArgGroup: default: - return fmt.Errorf("syscall %v: struct/array arg '%v' has bad kind %v", c.Meta.Name, typ.Name(), arg.Kind) + return fmt.Errorf("syscall %v: struct/array arg '%v' has bad kind %v", c.Meta.Name, arg.Type.Name(), arg.Kind) } case *sys.UnionType: switch arg.Kind { case ArgUnion: default: - return fmt.Errorf("syscall %v: union arg '%v' has bad kind %v", c.Meta.Name, typ.Name(), arg.Kind) + return fmt.Errorf("syscall %v: union arg '%v' has bad kind %v", c.Meta.Name, arg.Type.Name(), arg.Kind) } case *sys.ProcType: if arg.Val >= uintptr(typ1.ValuesPerProc) { - return fmt.Errorf("syscall %v: per proc arg '%v' has bad value '%v'", c.Meta.Name, typ.Name(), arg.Val) + return fmt.Errorf("syscall %v: per proc arg '%v' has bad value '%v'", c.Meta.Name, arg.Type.Name(), arg.Val) } case *sys.BufferType: switch typ1.Kind { case sys.BufferString: if typ1.Length != 0 && len(arg.Data) != int(typ1.Length) { - return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, typ.Name(), len(arg.Data), typ1.Length) + return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", c.Meta.Name, arg.Type.Name(), len(arg.Data), typ1.Length) } } case *sys.CsumType: if arg.Val != 0 { - return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v", c.Meta.Name, typ.Name(), arg.Val) + return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v", c.Meta.Name, arg.Type.Name(), arg.Val) + } + case *sys.PtrType: + if arg.Type.Dir() == sys.DirOut { + return fmt.Errorf("syscall %v: pointer arg '%v' has output direction", c.Meta.Name, arg.Type.Name()) + } + if arg.Res == nil && !arg.Type.Optional() { + return fmt.Errorf("syscall %v: non optional pointer arg '%v' is nil", c.Meta.Name, arg.Type.Name()) } } switch arg.Kind { case ArgConst: case ArgResult: if arg.Res == nil { - return fmt.Errorf("syscall %v: result arg '%v' has no reference", c.Meta.Name, typ.Name()) + return fmt.Errorf("syscall %v: result arg '%v' has no reference", c.Meta.Name, arg.Type.Name()) } if !ctx.args[arg.Res] { return fmt.Errorf("syscall %v: result arg '%v' references out-of-tree result: %p%+v -> %p%+v", - c.Meta.Name, typ.Name(), arg, arg, arg.Res, arg.Res) + c.Meta.Name, arg.Type.Name(), arg, arg, arg.Res, arg.Res) } if _, ok := arg.Res.Uses[arg]; !ok { - return fmt.Errorf("syscall %v: result arg '%v' has broken link (%+v)", c.Meta.Name, typ.Name(), arg.Res.Uses) + return fmt.Errorf("syscall %v: result arg '%v' has broken link (%+v)", c.Meta.Name, arg.Type.Name(), arg.Res.Uses) } case ArgPointer: - switch typ1 := typ.(type) { + switch arg.Type.(type) { case *sys.VmaType: if arg.Res != nil { - return fmt.Errorf("syscall %v: vma arg '%v' has data", c.Meta.Name, typ.Name()) + return fmt.Errorf("syscall %v: vma arg '%v' has data", c.Meta.Name, arg.Type.Name()) } if arg.AddrPagesNum == 0 { - return fmt.Errorf("syscall %v: vma arg '%v' has size 0", c.Meta.Name, typ.Name()) + return fmt.Errorf("syscall %v: vma arg '%v' has size 0", c.Meta.Name, arg.Type.Name()) } case *sys.PtrType: - if arg.Type.Dir() == sys.DirOut { - return fmt.Errorf("syscall %v: pointer arg '%v' has output direction", c.Meta.Name, typ.Name()) - } - if arg.Res == nil && !typ.Optional() { - return fmt.Errorf("syscall %v: non optional pointer arg '%v' is nil", c.Meta.Name, typ.Name()) - } if arg.Res != nil { - if err := checkArg(arg.Res, typ1.Type); err != nil { + if err := checkArg(arg.Res); err != nil { return err } } if arg.AddrPagesNum != 0 { - return fmt.Errorf("syscall %v: pointer arg '%v' has nonzero size", c.Meta.Name, typ.Name()) + return fmt.Errorf("syscall %v: pointer arg '%v' has nonzero size", c.Meta.Name, arg.Type.Name()) } default: - return fmt.Errorf("syscall %v: pointer arg '%v' has bad meta type %+v", c.Meta.Name, typ.Name(), typ) + return fmt.Errorf("syscall %v: pointer arg '%v' has bad meta type %+v", c.Meta.Name, arg.Type.Name(), arg.Type) } case ArgPageSize: case ArgData: - switch typ1 := typ.(type) { + switch typ1 := arg.Type.(type) { case *sys.ArrayType: if typ2, ok := typ1.Type.(*sys.IntType); !ok || typ2.Size() != 1 { - return fmt.Errorf("syscall %v: data arg '%v' should be an array", c.Meta.Name, typ.Name()) + return fmt.Errorf("syscall %v: data arg '%v' should be an array", c.Meta.Name, arg.Type.Name()) } } case ArgGroup: - switch typ1 := typ.(type) { + switch typ1 := arg.Type.(type) { case *sys.StructType: if len(arg.Inner) != len(typ1.Fields) { - return fmt.Errorf("syscall %v: struct arg '%v' has wrong number of fields: want %v, got %v", c.Meta.Name, typ.Name(), len(typ1.Fields), len(arg.Inner)) + return fmt.Errorf("syscall %v: struct arg '%v' has wrong number of fields: want %v, got %v", c.Meta.Name, arg.Type.Name(), len(typ1.Fields), len(arg.Inner)) } - for i, arg1 := range arg.Inner { - if err := checkArg(arg1, typ1.Fields[i]); err != nil { + for _, arg1 := range arg.Inner { + if err := checkArg(arg1); err != nil { return err } } case *sys.ArrayType: for _, arg1 := range arg.Inner { - if err := checkArg(arg1, typ1.Type); err != nil { + if err := checkArg(arg1); err != nil { return err } } default: - return fmt.Errorf("syscall %v: group arg '%v' has bad underlying type %+v", c.Meta.Name, typ.Name(), typ) + return fmt.Errorf("syscall %v: group arg '%v' has bad underlying type %+v", c.Meta.Name, arg.Type.Name(), arg.Type) } case ArgUnion: - typ1, ok := typ.(*sys.UnionType) + typ1, ok := arg.Type.(*sys.UnionType) if !ok { - return fmt.Errorf("syscall %v: union arg '%v' has bad type", c.Meta.Name, typ.Name()) + return fmt.Errorf("syscall %v: union arg '%v' has bad type", c.Meta.Name, arg.Type.Name()) } found := false for _, typ2 := range typ1.Options { @@ -195,22 +190,22 @@ func (c *Call) validate(ctx *validCtx) error { } } if !found { - return fmt.Errorf("syscall %v: union arg '%v' has bad option", c.Meta.Name, typ.Name()) + return fmt.Errorf("syscall %v: union arg '%v' has bad option", c.Meta.Name, arg.Type.Name()) } - if err := checkArg(arg.Option, arg.OptionType); err != nil { + if err := checkArg(arg.Option); err != nil { return err } case ArgReturn: default: - return fmt.Errorf("syscall %v: unknown arg '%v' kind", c.Meta.Name, typ.Name()) + return fmt.Errorf("syscall %v: unknown arg '%v' kind", c.Meta.Name, arg.Type.Name()) } return nil } - for i, arg := range c.Args { + for _, arg := range c.Args { if arg.Kind == ArgReturn { return fmt.Errorf("syscall %v: arg '%v' has wrong return kind", c.Meta.Name, arg.Type.Name()) } - if err := checkArg(arg, c.Meta.Args[i]); err != nil { + if err := checkArg(arg); err != nil { return err } } @@ -221,7 +216,7 @@ func (c *Call) validate(ctx *validCtx) error { return fmt.Errorf("syscall %v: return value has wrong kind %v", c.Meta.Name, c.Ret.Kind) } if c.Meta.Ret != nil { - if err := checkArg(c.Ret, c.Meta.Ret); err != nil { + if err := checkArg(c.Ret); err != nil { return err } } else if c.Ret.Type != nil { -- cgit mrf-deployment