diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2018-05-04 18:03:46 +0200 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2018-05-04 18:03:46 +0200 |
| commit | 2c7e14a847318974490ab59460f0834ea2ee0d24 (patch) | |
| tree | e9ac237ccfeaa3a7d508a07fc37ea7b7d28697df | |
| parent | 08141db61a7a947b701d06aa5c90cd825c55e350 (diff) | |
gometalinter: enable cyclomatic complexity checking
Refactor some functions to be simpler.
Update #538
| -rw-r--r-- | .gometalinter.json | 2 | ||||
| -rw-r--r-- | pkg/ast/scanner.go | 163 | ||||
| -rw-r--r-- | pkg/ifuzz/decode.go | 1 | ||||
| -rw-r--r-- | pkg/ifuzz/encode.go | 1 | ||||
| -rw-r--r-- | pkg/ifuzz/gen/gen.go | 156 | ||||
| -rw-r--r-- | prog/encoding.go | 450 | ||||
| -rw-r--r-- | prog/minimization.go | 281 | ||||
| -rw-r--r-- | prog/mutation.go | 411 | ||||
| -rw-r--r-- | prog/rand.go | 307 | ||||
| -rw-r--r-- | prog/types.go | 2 | ||||
| -rw-r--r-- | prog/validation.go | 501 |
11 files changed, 1184 insertions, 1091 deletions
diff --git a/.gometalinter.json b/.gometalinter.json index c15405271..4185948b1 100644 --- a/.gometalinter.json +++ b/.gometalinter.json @@ -5,6 +5,7 @@ "sort": ["path", "line"], "minconstlength": 7, "linelength": 120, + "cyclo": 50, "skip": [ "pkg/kd" ], @@ -19,6 +20,7 @@ "gosimple", "varcheck", "misspell", + "gocyclo", "vet", "lll" ], diff --git a/pkg/ast/scanner.go b/pkg/ast/scanner.go index 343c3b9ab..67e7f50bd 100644 --- a/pkg/ast/scanner.go +++ b/pkg/ast/scanner.go @@ -134,14 +134,7 @@ func (s *scanner) Scan() (tok token, lit string, pos Pos) { s.next() case s.ch == '`': tok = tokCExpr - for s.next(); s.ch != '`' && s.ch != '\n'; s.next() { - } - if s.ch == '\n' { - s.Error(pos, "C expression is not terminated") - } else { - lit = string(s.data[pos.Off+1 : s.off]) - s.next() - } + lit = s.scanCExpr(pos) case s.prev2 == tokDefine && s.prev1 == tokIdent: // Note: the old form for C expressions, not really lexable. // TODO(dvyukov): get rid of this eventually. @@ -155,74 +148,16 @@ func (s *scanner) Scan() (tok token, lit string, pos Pos) { } lit = string(s.data[pos.Off+1 : s.off]) case s.ch == '"' || s.ch == '<': - // TODO(dvyukov): get rid of <...> strings, that's only includes tok = tokString - closing := byte('"') - if s.ch == '<' { - closing = '>' - } - for s.next(); s.ch != closing; s.next() { - if s.ch == 0 || s.ch == '\n' { - s.Error(pos, "string literal is not terminated") - return - } - } - lit = string(s.data[pos.Off+1 : s.off]) - for i := 0; i < len(lit); i++ { - if lit[i] < 0x20 || lit[i] >= 0x80 { - pos1 := pos - pos1.Col += i + 1 - pos1.Off += i + 1 - s.Error(pos1, "illegal character %#U in string literal", lit[i]) - break - } - } - s.next() + lit = s.scanStr(pos) case s.ch >= '0' && s.ch <= '9': tok = tokInt - for s.ch >= '0' && s.ch <= '9' || - s.ch >= 'a' && s.ch <= 'f' || - s.ch >= 'A' && s.ch <= 'F' || s.ch == 'x' { - s.next() - } - lit = string(s.data[pos.Off:s.off]) - bad := false - if _, err := strconv.ParseUint(lit, 10, 64); err != nil { - if len(lit) > 2 && lit[0] == '0' && lit[1] == 'x' { - if _, err := strconv.ParseUint(lit[2:], 16, 64); err != nil { - bad = true - } - } else { - bad = true - } - } - if bad { - s.Error(pos, fmt.Sprintf("bad integer %q", lit)) - lit = "0" - } + lit = s.scanInt(pos) case s.ch == '\'': tok = tokInt - lit = "0" - s.next() - s.next() - if s.ch != '\'' { - s.Error(pos, "char literal is not terminated") - return - } - s.next() - lit = string(s.data[pos.Off : pos.Off+3]) + lit = s.scanChar(pos) case s.ch == '_' || s.ch >= 'a' && s.ch <= 'z' || s.ch >= 'A' && s.ch <= 'Z': - tok = tokIdent - for s.ch == '_' || s.ch == '$' || - s.ch >= 'a' && s.ch <= 'z' || - s.ch >= 'A' && s.ch <= 'Z' || - s.ch >= '0' && s.ch <= '9' { - s.next() - } - lit = string(s.data[pos.Off:s.off]) - if key, ok := keywords[lit]; ok { - tok = key - } + tok, lit = s.scanIdent(pos) default: tok = punctuation[s.ch] if tok == tokIllegal { @@ -235,6 +170,94 @@ func (s *scanner) Scan() (tok token, lit string, pos Pos) { return } +func (s *scanner) scanCExpr(pos Pos) string { + for s.next(); s.ch != '`' && s.ch != '\n'; s.next() { + } + if s.ch == '\n' { + s.Error(pos, "C expression is not terminated") + return "" + } + lit := string(s.data[pos.Off+1 : s.off]) + s.next() + return lit +} + +func (s *scanner) scanStr(pos Pos) string { + // TODO(dvyukov): get rid of <...> strings, that's only includes + closing := byte('"') + if s.ch == '<' { + closing = '>' + } + for s.next(); s.ch != closing; s.next() { + if s.ch == 0 || s.ch == '\n' { + s.Error(pos, "string literal is not terminated") + return "" + } + } + lit := string(s.data[pos.Off+1 : s.off]) + for i := 0; i < len(lit); i++ { + if lit[i] < 0x20 || lit[i] >= 0x80 { + pos1 := pos + pos1.Col += i + 1 + pos1.Off += i + 1 + s.Error(pos1, "illegal character %#U in string literal", lit[i]) + break + } + } + s.next() + return lit +} + +func (s *scanner) scanInt(pos Pos) string { + for s.ch >= '0' && s.ch <= '9' || + s.ch >= 'a' && s.ch <= 'f' || + s.ch >= 'A' && s.ch <= 'F' || s.ch == 'x' { + s.next() + } + lit := string(s.data[pos.Off:s.off]) + bad := false + if _, err := strconv.ParseUint(lit, 10, 64); err != nil { + if len(lit) > 2 && lit[0] == '0' && lit[1] == 'x' { + if _, err := strconv.ParseUint(lit[2:], 16, 64); err != nil { + bad = true + } + } else { + bad = true + } + } + if bad { + s.Error(pos, fmt.Sprintf("bad integer %q", lit)) + lit = "0" + } + return lit +} + +func (s *scanner) scanChar(pos Pos) string { + s.next() + s.next() + if s.ch != '\'' { + s.Error(pos, "char literal is not terminated") + return "0" + } + s.next() + return string(s.data[pos.Off : pos.Off+3]) +} + +func (s *scanner) scanIdent(pos Pos) (tok token, lit string) { + tok = tokIdent + for s.ch == '_' || s.ch == '$' || + s.ch >= 'a' && s.ch <= 'z' || + s.ch >= 'A' && s.ch <= 'Z' || + s.ch >= '0' && s.ch <= '9' { + s.next() + } + lit = string(s.data[pos.Off:s.off]) + if key, ok := keywords[lit]; ok { + tok = key + } + return +} + func (s *scanner) Error(pos Pos, msg string, args ...interface{}) { s.errors++ s.errorHandler(pos, fmt.Sprintf(msg, args...)) diff --git a/pkg/ifuzz/decode.go b/pkg/ifuzz/decode.go index 1e657ec00..4e4c5873e 100644 --- a/pkg/ifuzz/decode.go +++ b/pkg/ifuzz/decode.go @@ -10,6 +10,7 @@ import ( // Decode decodes instruction length for the given mode. // It can have falsely decode incorrect instructions, // but should not fail to decode correct instructions. +// nolint: gocyclo func Decode(mode int, text []byte) (int, error) { if len(text) == 0 { return 0, fmt.Errorf("zero-length instruction") diff --git a/pkg/ifuzz/encode.go b/pkg/ifuzz/encode.go index 0b3732aa7..4780c4bd8 100644 --- a/pkg/ifuzz/encode.go +++ b/pkg/ifuzz/encode.go @@ -11,6 +11,7 @@ import ( "math/rand" ) +// nolint: gocyclo func (insn *Insn) Encode(cfg *Config, r *rand.Rand) []byte { if !insn.isCompatible(cfg) { panic("instruction is not suitable for this mode") diff --git a/pkg/ifuzz/gen/gen.go b/pkg/ifuzz/gen/gen.go index 499d41633..1368d2089 100644 --- a/pkg/ifuzz/gen/gen.go +++ b/pkg/ifuzz/gen/gen.go @@ -176,12 +176,14 @@ func (err errSkip) Error() string { return string(err) } +// nolint: gocyclo func parsePattern(insn *ifuzz.Insn, vals []string) error { if insn.Opcode != nil { return fmt.Errorf("PATTERN is already parsed for the instruction") } // As spelled these have incorrect format for 16-bit addressing mode and with 67 prefix. - if insn.Name == "NOP5" || insn.Name == "NOP6" || insn.Name == "NOP7" || insn.Name == "NOP8" || insn.Name == "NOP9" { + if insn.Name == "NOP5" || insn.Name == "NOP6" || insn.Name == "NOP7" || + insn.Name == "NOP8" || insn.Name == "NOP9" { return errSkip("") } if insn.Mode == 0 { @@ -270,10 +272,8 @@ func parsePattern(insn *ifuzz.Insn, vals []string) error { insn.Mod = 1 case v == "MOD=2": insn.Mod = 2 - case v == "MODRM()": case v == "lock_prefix": insn.Prefix = append(insn.Prefix, 0xF0) - case v == "nolock_prefix": // Immediates. case v == "UIMM8()", v == "SIMM8()": @@ -331,8 +331,6 @@ func parsePattern(insn *ifuzz.Insn, vals []string) error { insn.VexL = -1 case v == "VL256", v == "VL=1": insn.VexL = 1 - case v == "VL512": - // VL=2 case v == "NOVSR": insn.VexNoR = true case v == "NOEVSR": @@ -340,49 +338,6 @@ func parsePattern(insn *ifuzz.Insn, vals []string) error { // VEXDEST3=0b1 VEXDEST210=0b111 VEXDEST4=0b0 case v == "SE_IMM8()": addImm(insn, 1) - case v == "VMODRM_XMM()": - case v == "VMODRM_YMM()": - case v == "BCRC=0": - case v == "BCRC=1": - case v == "ESIZE_8_BITS()": - case v == "ESIZE_16_BITS()": - case v == "ESIZE_32_BITS()": - case v == "ESIZE_64_BITS()": - case v == "NELEM_GPR_WRITER_STORE()": - case v == "NELEM_GPR_WRITER_STORE_BYTE()": - case v == "NELEM_GPR_WRITER_STORE_WORD()": - case v == "NELEM_GPR_WRITER_LDOP_Q()": - case v == "NELEM_GPR_WRITER_LDOP_D()": - case v == "NELEM_GPR_READER()": - case v == "NELEM_GPR_READER_BYTE()": - case v == "NELEM_GPR_READER_WORD()": - case v == "NELEM_GSCAT()": - case v == "NELEM_HALF()": - case v == "NELEM_FULL()": - case v == "NELEM_FULLMEM()": - case v == "NELEM_QUARTERMEM()": - case v == "NELEM_EIGHTHMEM()": - case v == "NELEM_HALFMEM()": - case v == "NELEM_QUARTERMEM()": - case v == "NELEM_MEM128()": - case v == "NELEM_SCALAR()": - case v == "NELEM_TUPLE1()": - case v == "NELEM_TUPLE2()": - case v == "NELEM_TUPLE4()": - case v == "NELEM_TUPLE8()": - case v == "NELEM_TUPLE1_4X()": - case v == "NELEM_TUPLE1_BYTE()": - case v == "NELEM_TUPLE1_WORD()": - case v == "NELEM_MOVDDUP()": - case v == "UISA_VMODRM_XMM()": - case v == "UISA_VMODRM_YMM()": - case v == "UISA_VMODRM_ZMM()": - case v == "MASK=0": - case v == "FIX_ROUND_LEN128()": - case v == "FIX_ROUND_LEN512()": - case v == "AVX512_ROUND()": - case v == "ZEROING=0": - case v == "SAE()": // Modes case v == "mode64": @@ -393,17 +348,16 @@ func parsePattern(insn *ifuzz.Insn, vals []string) error { insn.Mode &= 1 << ifuzz.ModeProt32 case v == "mode16": insn.Mode &= 1<<ifuzz.ModeProt16 | 1<<ifuzz.ModeReal16 - case v == "eamode64": - case v == "eamode32": - case v == "eamode16": - case v == "eanot16": + case v == "eamode64", + v == "eamode32", + v == "eamode16", + v == "eanot16": case v == "no_refining_prefix": insn.NoRepPrefix = true insn.No66Prefix = true - case v == "no66_prefix": + case v == "no66_prefix", v == "eosz32", v == "eosz64": insn.No66Prefix = true - case v == "not_refining_f3": case v == "f2_refining_prefix", v == "refining_f2", v == "repne", v == "REP=2": insn.Prefix = append(insn.Prefix, 0xF2) insn.NoRepPrefix = true @@ -419,23 +373,83 @@ func parsePattern(insn *ifuzz.Insn, vals []string) error { insn.Rexw = 1 case v == "norexw_prefix", v == "W0": insn.Rexw = -1 - case v == "MPXMODE=1", v == "MPXMODE=0": - case v == "TZCNT=1", v == "TZCNT=0": - case v == "LZCNT=1", v == "LZCNT=0": - case v == "CR_WIDTH()": - case v == "DF64()": - case v == "IMMUNE_REXW()": - case v == "FORCE64()": - case v == "eosz32", v == "eosz64": - insn.No66Prefix = true - case v == "EOSZ=1", v == "EOSZ!=1", v == "EOSZ=2", v == "EOSZ!=2", v == "EOSZ=3", v == "EOSZ!=3": - case v == "BRANCH_HINT()": - case v == "P4=1", v == "P4=0": - case v == "rexb_prefix", v == "norexb_prefix": - case strings.HasPrefix(v, "MODEP5="): - case v == "IMMUNE66()", v == "REFINING66()", v == "IGNORE66()", v == "IMMUNE66_LOOP64()": - case v == "OVERRIDE_SEG0()", v == "OVERRIDE_SEG1()", v == "REMOVE_SEGMENT()": - case v == "ONE()": + case v == "MPXMODE=1", + v == "MPXMODE=0", + v == "TZCNT=1", + v == "TZCNT=0", + v == "LZCNT=1", + v == "LZCNT=0", + v == "CR_WIDTH()", + v == "DF64()", + v == "IMMUNE_REXW()", + v == "FORCE64()", + v == "EOSZ=1", + v == "EOSZ!=1", + v == "EOSZ=2", + v == "EOSZ!=2", + v == "EOSZ=3", + v == "EOSZ!=3", + v == "BRANCH_HINT()", + v == "P4=1", + v == "P4=0", + v == "rexb_prefix", + v == "norexb_prefix", + v == "IMMUNE66()", + v == "REFINING66()", + v == "IGNORE66()", + v == "IMMUNE66_LOOP64()", + v == "OVERRIDE_SEG0()", + v == "OVERRIDE_SEG1()", + v == "REMOVE_SEGMENT()", + v == "ONE()", + v == "nolock_prefix", + v == "MODRM()", + v == "VMODRM_XMM()", + v == "VMODRM_YMM()", + v == "BCRC=0", + v == "BCRC=1", + v == "ESIZE_8_BITS()", + v == "ESIZE_16_BITS()", + v == "ESIZE_32_BITS()", + v == "ESIZE_64_BITS()", + v == "NELEM_GPR_WRITER_STORE()", + v == "NELEM_GPR_WRITER_STORE_BYTE()", + v == "NELEM_GPR_WRITER_STORE_WORD()", + v == "NELEM_GPR_WRITER_LDOP_Q()", + v == "NELEM_GPR_WRITER_LDOP_D()", + v == "NELEM_GPR_READER()", + v == "NELEM_GPR_READER_BYTE()", + v == "NELEM_GPR_READER_WORD()", + v == "NELEM_GSCAT()", + v == "NELEM_HALF()", + v == "NELEM_FULL()", + v == "NELEM_FULLMEM()", + v == "NELEM_QUARTERMEM()", + v == "NELEM_EIGHTHMEM()", + v == "NELEM_HALFMEM()", + v == "NELEM_QUARTERMEM()", + v == "NELEM_MEM128()", + v == "NELEM_SCALAR()", + v == "NELEM_TUPLE1()", + v == "NELEM_TUPLE2()", + v == "NELEM_TUPLE4()", + v == "NELEM_TUPLE8()", + v == "NELEM_TUPLE1_4X()", + v == "NELEM_TUPLE1_BYTE()", + v == "NELEM_TUPLE1_WORD()", + v == "NELEM_MOVDDUP()", + v == "UISA_VMODRM_XMM()", + v == "UISA_VMODRM_YMM()", + v == "UISA_VMODRM_ZMM()", + v == "MASK=0", + v == "FIX_ROUND_LEN128()", + v == "FIX_ROUND_LEN512()", + v == "AVX512_ROUND()", + v == "ZEROING=0", + v == "SAE()", + v == "VL512", // VL=2 + v == "not_refining_f3", + strings.HasPrefix(v, "MODEP5="): default: return errSkip(fmt.Sprintf("unknown pattern %v", v)) } diff --git a/prog/encoding.go b/prog/encoding.go index ad3aad16e..7a38b680f 100644 --- a/prog/encoding.go +++ b/prog/encoding.go @@ -225,7 +225,7 @@ func (target *Target) Deserialize(data []byte) (prog *Prog, err error) { return } -func (target *Target) parseArg(typ Type, p *parser, vars map[string]Arg /*, allowNil bool*/) (Arg, error) { +func (target *Target) parseArg(typ Type, p *parser, vars map[string]Arg) (Arg, error) { r := "" if p.Char() == '<' { p.Parse('<') @@ -233,249 +233,275 @@ func (target *Target) parseArg(typ Type, p *parser, vars map[string]Arg /*, allo p.Parse('=') p.Parse('>') } - var arg Arg -top: - switch p.Char() { - case '0': - val := p.Ident() - v, err := strconv.ParseUint(val, 0, 64) - if err != nil { - return nil, fmt.Errorf("wrong arg value '%v': %v", val, err) - } - switch typ.(type) { - case *ConstType, *IntType, *FlagsType, *ProcType, *LenType, *CsumType: - arg = MakeConstArg(typ, v) - case *ResourceType: - arg = MakeResultArg(typ, nil, v) - case *PtrType, *VmaType: - if typ.Optional() { - arg = MakeNullPointerArg(typ) - } else { - arg = target.defaultArg(typ) - } - default: - eatExcessive(p, true) + arg, err := target.parseArgImpl(typ, p, vars) + if err != nil { + return nil, err + } + if arg == nil { + if typ != nil { arg = target.defaultArg(typ) - break top + } else if r != "" { + return nil, fmt.Errorf("named nil argument") } + } + if r != "" { + vars[r] = arg + } + return arg, nil +} + +func (target *Target) parseArgImpl(typ Type, p *parser, vars map[string]Arg) (Arg, error) { + switch p.Char() { + case '0': + return target.parseArgInt(typ, p) case 'r': - id := p.Ident() - var div, add uint64 - if p.Char() == '/' { - p.Parse('/') - op := p.Ident() - v, err := strconv.ParseUint(op, 0, 64) - if err != nil { - return nil, fmt.Errorf("wrong result div op: '%v'", op) - } - div = v - } - if p.Char() == '+' { - p.Parse('+') - op := p.Ident() - v, err := strconv.ParseUint(op, 0, 64) - if err != nil { - return nil, fmt.Errorf("wrong result add op: '%v'", op) - } - add = v - } - v, ok := vars[id] - if !ok || v == nil { - arg = target.defaultArg(typ) - break - } - if _, ok := v.(ArgUsed); !ok { - arg = target.defaultArg(typ) - break - } - resArg := MakeResultArg(typ, v, 0) - resArg.OpDiv = div - resArg.OpAdd = add - arg = resArg + return target.parseArgRes(typ, p, vars) case '&': - var typ1 Type - switch t1 := typ.(type) { - case *PtrType: - typ1 = t1.Type - case *VmaType: - default: - eatExcessive(p, true) - arg = target.defaultArg(typ) - break top + return target.parseArgAddr(typ, p, vars) + case '"', '\'': + return target.parseArgString(typ, p) + case '{': + return target.parseArgStruct(typ, p, vars) + case '[': + return target.parseArgArray(typ, p, vars) + case '@': + return target.parseArgUnion(typ, p, vars) + case 'n': + p.Parse('n') + p.Parse('i') + p.Parse('l') + return nil, nil + + default: + return nil, fmt.Errorf("failed to parse argument at %v (line #%v/%v: %v)", + int(p.Char()), p.l, p.i, p.s) + } +} + +func (target *Target) parseArgInt(typ Type, p *parser) (Arg, error) { + val := p.Ident() + v, err := strconv.ParseUint(val, 0, 64) + if err != nil { + return nil, fmt.Errorf("wrong arg value '%v': %v", val, err) + } + switch typ.(type) { + case *ConstType, *IntType, *FlagsType, *ProcType, *LenType, *CsumType: + return MakeConstArg(typ, v), nil + case *ResourceType: + return MakeResultArg(typ, nil, v), nil + case *PtrType, *VmaType: + if typ.Optional() { + return MakeNullPointerArg(typ), nil + } + return target.defaultArg(typ), nil + default: + eatExcessive(p, true) + return target.defaultArg(typ), nil + } +} + +func (target *Target) parseArgRes(typ Type, p *parser, vars map[string]Arg) (Arg, error) { + id := p.Ident() + var div, add uint64 + if p.Char() == '/' { + p.Parse('/') + op := p.Ident() + v, err := strconv.ParseUint(op, 0, 64) + if err != nil { + return nil, fmt.Errorf("wrong result div op: '%v'", op) } - p.Parse('&') - addr, vmaSize, err := target.parseAddr(p) + div = v + } + if p.Char() == '+' { + p.Parse('+') + op := p.Ident() + v, err := strconv.ParseUint(op, 0, 64) if err != nil { - return nil, err + return nil, fmt.Errorf("wrong result add op: '%v'", op) } - var inner Arg - if p.Char() == '=' { + add = v + } + v, ok := vars[id] + if !ok || v == nil { + return target.defaultArg(typ), nil + } + if _, ok := v.(ArgUsed); !ok { + return target.defaultArg(typ), nil + } + arg := MakeResultArg(typ, v, 0) + arg.OpDiv = div + arg.OpAdd = add + return arg, nil +} + +func (target *Target) parseArgAddr(typ Type, p *parser, vars map[string]Arg) (Arg, error) { + var typ1 Type + switch t1 := typ.(type) { + case *PtrType: + typ1 = t1.Type + case *VmaType: + default: + eatExcessive(p, true) + return target.defaultArg(typ), nil + } + p.Parse('&') + addr, vmaSize, err := target.parseAddr(p) + if err != nil { + return nil, err + } + var inner Arg + if p.Char() == '=' { + p.Parse('=') + if p.Char() == 'A' { + p.Parse('A') + p.Parse('N') + p.Parse('Y') p.Parse('=') - if p.Char() == 'A' { - p.Parse('A') - p.Parse('N') - p.Parse('Y') - p.Parse('=') - typ = target.makeAnyPtrType(typ.Size(), typ.FieldName()) - typ1 = target.any.array - } - inner, err = target.parseArg(typ1, p, vars) - if err != nil { - return nil, err - } + typ = target.makeAnyPtrType(typ.Size(), typ.FieldName()) + typ1 = target.any.array } - if typ1 != nil { - if inner == nil { - inner = target.defaultArg(typ1) - } - arg = MakePointerArg(typ, addr, inner) - } else { - arg = MakeVmaPointerArg(typ, addr, vmaSize) - } - case '"', '\'': - if _, ok := typ.(*BufferType); !ok { - eatExcessive(p, true) - arg = target.defaultArg(typ) - break - } - data, err := deserializeData(p) + inner, err = target.parseArg(typ1, p, vars) if err != nil { return nil, err } - size := ^uint64(0) - if p.Char() == '/' { - p.Parse('/') - sizeStr := p.Ident() - size, err = strconv.ParseUint(sizeStr, 0, 64) - if err != nil { - return nil, fmt.Errorf("failed to parse buffer size: %q", sizeStr) - } - } - if !typ.Varlen() { - size = typ.Size() - } else if size == ^uint64(0) { - size = uint64(len(data)) - } - if typ.Dir() == DirOut { - arg = MakeOutDataArg(typ, size) - } else { - if diff := int(size) - len(data); diff > 0 { - data = append(data, make([]byte, diff)...) - } - data = data[:size] - arg = MakeDataArg(typ, data) - } - case '{': - p.Parse('{') - t1, ok := typ.(*StructType) - if !ok { - eatExcessive(p, false) - p.Parse('}') - arg = target.defaultArg(typ) - break - } - var inner []Arg - for i := 0; p.Char() != '}'; i++ { - if i >= len(t1.Fields) { - eatExcessive(p, false) - break - } - fld := t1.Fields[i] - if IsPad(fld) { - inner = append(inner, MakeConstArg(fld, 0)) - } else { - arg, err := target.parseArg(fld, p, vars) - if err != nil { - return nil, err - } - inner = append(inner, arg) - if p.Char() != '}' { - p.Parse(',') - } - } + } + if typ1 == nil { + return MakeVmaPointerArg(typ, addr, vmaSize), nil + } + if inner == nil { + inner = target.defaultArg(typ1) + } + return MakePointerArg(typ, addr, inner), nil +} + +func (target *Target) parseArgString(typ Type, p *parser) (Arg, error) { + if _, ok := typ.(*BufferType); !ok { + eatExcessive(p, true) + return target.defaultArg(typ), nil + } + data, err := deserializeData(p) + if err != nil { + return nil, err + } + size := ^uint64(0) + if p.Char() == '/' { + p.Parse('/') + sizeStr := p.Ident() + size, err = strconv.ParseUint(sizeStr, 0, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse buffer size: %q", sizeStr) } + } + if !typ.Varlen() { + size = typ.Size() + } else if size == ^uint64(0) { + size = uint64(len(data)) + } + if typ.Dir() == DirOut { + return MakeOutDataArg(typ, size), nil + } + if diff := int(size) - len(data); diff > 0 { + data = append(data, make([]byte, diff)...) + } + data = data[:size] + return MakeDataArg(typ, data), nil +} + +func (target *Target) parseArgStruct(typ Type, p *parser, vars map[string]Arg) (Arg, error) { + p.Parse('{') + t1, ok := typ.(*StructType) + if !ok { + eatExcessive(p, false) p.Parse('}') - for len(inner) < len(t1.Fields) { - inner = append(inner, target.defaultArg(t1.Fields[len(inner)])) - } - arg = MakeGroupArg(typ, inner) - case '[': - p.Parse('[') - t1, ok := typ.(*ArrayType) - if !ok { + return target.defaultArg(typ), nil + } + var inner []Arg + for i := 0; p.Char() != '}'; i++ { + if i >= len(t1.Fields) { eatExcessive(p, false) - p.Parse(']') - arg = target.defaultArg(typ) break } - var inner []Arg - for i := 0; p.Char() != ']'; i++ { - arg, err := target.parseArg(t1.Type, p, vars) + fld := t1.Fields[i] + if IsPad(fld) { + inner = append(inner, MakeConstArg(fld, 0)) + } else { + arg, err := target.parseArg(fld, p, vars) if err != nil { return nil, err } inner = append(inner, arg) - if p.Char() != ']' { + if p.Char() != '}' { p.Parse(',') } } + } + p.Parse('}') + for len(inner) < len(t1.Fields) { + inner = append(inner, target.defaultArg(t1.Fields[len(inner)])) + } + return MakeGroupArg(typ, inner), nil +} + +func (target *Target) parseArgArray(typ Type, p *parser, vars map[string]Arg) (Arg, error) { + p.Parse('[') + t1, ok := typ.(*ArrayType) + if !ok { + eatExcessive(p, false) p.Parse(']') - if t1.Kind == ArrayRangeLen && t1.RangeBegin == t1.RangeEnd { - for uint64(len(inner)) < t1.RangeBegin { - inner = append(inner, target.defaultArg(t1.Type)) - } - inner = inner[:t1.RangeBegin] + return target.defaultArg(typ), nil + } + var inner []Arg + for i := 0; p.Char() != ']'; i++ { + arg, err := target.parseArg(t1.Type, p, vars) + if err != nil { + return nil, err } - arg = MakeGroupArg(typ, inner) - case '@': - t1, ok := typ.(*UnionType) - if !ok { - eatExcessive(p, true) - arg = target.defaultArg(typ) - break + inner = append(inner, arg) + if p.Char() != ']' { + p.Parse(',') } - p.Parse('@') - name := p.Ident() - var optType Type - for _, t2 := range t1.Fields { - if name == t2.FieldName() { - optType = t2 - break - } + } + p.Parse(']') + if t1.Kind == ArrayRangeLen && t1.RangeBegin == t1.RangeEnd { + for uint64(len(inner)) < t1.RangeBegin { + inner = append(inner, target.defaultArg(t1.Type)) } - if optType == nil { - eatExcessive(p, true) - arg = target.defaultArg(typ) + inner = inner[:t1.RangeBegin] + } + return MakeGroupArg(typ, inner), nil +} + +func (target *Target) parseArgUnion(typ Type, p *parser, vars map[string]Arg) (Arg, error) { + t1, ok := typ.(*UnionType) + if !ok { + eatExcessive(p, true) + return target.defaultArg(typ), nil + } + p.Parse('@') + name := p.Ident() + var optType Type + for _, t2 := range t1.Fields { + if name == t2.FieldName() { + optType = t2 break } - var opt Arg - if p.Char() == '=' { - p.Parse('=') - var err error - opt, err = target.parseArg(optType, p, vars) - if err != nil { - return nil, err - } - } else { - opt = target.defaultArg(optType) - } - arg = MakeUnionArg(typ, opt) - case 'n': - p.Parse('n') - p.Parse('i') - p.Parse('l') - if typ != nil { - arg = target.defaultArg(typ) - } else if r != "" { - return nil, fmt.Errorf("named nil argument") - } - default: - return nil, fmt.Errorf("failed to parse argument at %v (line #%v/%v: %v)", int(p.Char()), p.l, p.i, p.s) } - if r != "" { - vars[r] = arg + if optType == nil { + eatExcessive(p, true) + return target.defaultArg(typ), nil } - return arg, nil + var opt Arg + if p.Char() == '=' { + p.Parse('=') + var err error + opt, err = target.parseArg(optType, p, vars) + if err != nil { + return nil, err + } + } else { + opt = target.defaultArg(optType) + } + return MakeUnionArg(typ, opt), nil } // Eats excessive call arguments and struct fields to recover after description changes. diff --git a/prog/minimization.go b/prog/minimization.go index 8f2cbae61..71e0f5c63 100644 --- a/prog/minimization.go +++ b/prog/minimization.go @@ -30,6 +30,37 @@ func Minimize(p0 *Prog, callIndex0 int, crash bool, pred0 func(*Prog, int) bool) } // Try to remove all calls except the last one one-by-one. + p0, callIndex0 = removeCalls(p0, callIndex0, crash, pred) + + // Try to minimize individual args. + for i := 0; i < len(p0.Calls); i++ { + ctx := &minimizeArgsCtx{ + p0: &p0, + callIndex0: callIndex0, + crash: crash, + pred: pred, + triedPaths: make(map[string]bool), + } + again: + p := p0.Clone() + call := p.Calls[i] + for j, arg := range call.Args { + if ctx.do(p, call, arg, fmt.Sprintf("%v", j)) { + goto again + } + } + } + + if callIndex0 != -1 { + if callIndex0 < 0 || callIndex0 >= len(p0.Calls) || name0 != p0.Calls[callIndex0].Meta.Name { + panic(fmt.Sprintf("bad call index after minimization: ncalls=%v index=%v call=%v/%v", + len(p0.Calls), callIndex0, name0, p0.Calls[callIndex0].Meta.Name)) + } + } + return p0, callIndex0 +} + +func removeCalls(p0 *Prog, callIndex0 int, crash bool, pred func(*Prog, int) bool) (*Prog, int) { for i := len(p0.Calls) - 1; i >= 0; i-- { if i == callIndex0 { continue @@ -46,156 +77,134 @@ func Minimize(p0 *Prog, callIndex0 int, crash bool, pred0 func(*Prog, int) bool) p0 = p callIndex0 = callIndex } + return p0, callIndex0 +} - var triedPaths map[string]bool +type minimizeArgsCtx struct { + p0 **Prog + callIndex0 int + crash bool + pred func(*Prog, int) bool + triedPaths map[string]bool +} - var rec func(p *Prog, call *Call, arg Arg, path string) bool - rec = func(p *Prog, call *Call, arg Arg, path string) bool { - path += fmt.Sprintf("-%v", arg.Type().FieldName()) - switch typ := arg.Type().(type) { - case *StructType: - a := arg.(*GroupArg) - for _, innerArg := range a.Inner { - if rec(p, call, innerArg, path) { - return true - } - } - case *UnionType: - a := arg.(*UnionArg) - if rec(p, call, a.Option, path) { +func (ctx *minimizeArgsCtx) do(p *Prog, call *Call, arg Arg, path string) bool { + path += fmt.Sprintf("-%v", arg.Type().FieldName()) + switch typ := arg.Type().(type) { + case *StructType: + a := arg.(*GroupArg) + for _, innerArg := range a.Inner { + if ctx.do(p, call, innerArg, path) { return true } - case *PtrType: - // TODO: try to remove optional ptrs - a, ok := arg.(*PointerArg) - if !ok { - // Can also be *ConstArg. - return false - } - if a.Res != nil { - return rec(p, call, a.Res, path) - } - case *ArrayType: - a := arg.(*GroupArg) - for i, innerArg := range a.Inner { - innerPath := fmt.Sprintf("%v-%v", path, i) - if !triedPaths[innerPath] && !crash { - if (typ.Kind == ArrayRangeLen && len(a.Inner) > int(typ.RangeBegin)) || - (typ.Kind == ArrayRandLen) { - copy(a.Inner[i:], a.Inner[i+1:]) - a.Inner = a.Inner[:len(a.Inner)-1] - removeArg(innerArg) - p.Target.assignSizesCall(call) - - if pred(p, callIndex0) { - p0 = p - } else { - triedPaths[innerPath] = true - } + } + case *UnionType: + a := arg.(*UnionArg) + if ctx.do(p, call, a.Option, path) { + return true + } + case *PtrType: + // TODO: try to remove optional ptrs + a, ok := arg.(*PointerArg) + if !ok { + // Can also be *ConstArg. + return false + } + if a.Res != nil { + return ctx.do(p, call, a.Res, path) + } + case *ArrayType: + a := arg.(*GroupArg) + for i, innerArg := range a.Inner { + innerPath := fmt.Sprintf("%v-%v", path, i) + if !ctx.triedPaths[innerPath] && !ctx.crash { + if (typ.Kind == ArrayRangeLen && len(a.Inner) > int(typ.RangeBegin)) || + (typ.Kind == ArrayRandLen) { + copy(a.Inner[i:], a.Inner[i+1:]) + a.Inner = a.Inner[:len(a.Inner)-1] + removeArg(innerArg) + p.Target.assignSizesCall(call) - return true + if ctx.pred(p, ctx.callIndex0) { + *ctx.p0 = p + } else { + ctx.triedPaths[innerPath] = true } - } - if rec(p, call, innerArg, innerPath) { return true } } - case *IntType, *FlagsType, *ProcType: - // TODO: try to reset bits in ints - // TODO: try to set separate flags - if crash { - return false - } - if triedPaths[path] { - return false - } - triedPaths[path] = true - a := arg.(*ConstArg) - if a.Val == typ.Default() { - return false - } - v0 := a.Val - a.Val = typ.Default() - if pred(p, callIndex0) { - p0 = p - return true - } - a.Val = v0 - case *ResourceType: - if crash { - return false - } - if triedPaths[path] { - return false - } - triedPaths[path] = true - a := arg.(*ResultArg) - if a.Res == nil { - return false - } - r0 := a.Res - a.Res = nil - a.Val = typ.Default() - if pred(p, callIndex0) { - p0 = p + if ctx.do(p, call, innerArg, innerPath) { return true } - a.Res = r0 - a.Val = 0 - case *BufferType: - // TODO: try to set individual bytes to 0 - if triedPaths[path] { - return false - } - triedPaths[path] = true - if typ.Kind != BufferBlobRand && typ.Kind != BufferBlobRange || - typ.Dir() == DirOut { - return false - } - a := arg.(*DataArg) - minLen := int(typ.RangeBegin) - for step := len(a.Data()) - minLen; len(a.Data()) > minLen && step > 0; { - if len(a.Data())-step >= minLen { - a.data = a.Data()[:len(a.Data())-step] - p.Target.assignSizesCall(call) - if pred(p, callIndex0) { - continue - } - a.data = a.Data()[:len(a.Data())+step] - p.Target.assignSizesCall(call) - } - step /= 2 - if crash { - break - } - } - p0 = p - case *VmaType, *LenType, *CsumType, *ConstType: + } + case *IntType, *FlagsType, *ProcType: + // TODO: try to reset bits in ints + // TODO: try to set separate flags + if ctx.crash || ctx.triedPaths[path] { return false - default: - panic(fmt.Sprintf("unknown arg type '%+v'", typ)) } - return false - } - - // Try to minimize individual args. - for i := 0; i < len(p0.Calls); i++ { - triedPaths = make(map[string]bool) - again: - p := p0.Clone() - call := p.Calls[i] - for j, arg := range call.Args { - if rec(p, call, arg, fmt.Sprintf("%v", j)) { - goto again - } + ctx.triedPaths[path] = true + a := arg.(*ConstArg) + if a.Val == typ.Default() { + return false } - } - - if callIndex0 != -1 { - if callIndex0 < 0 || callIndex0 >= len(p0.Calls) || name0 != p0.Calls[callIndex0].Meta.Name { - panic(fmt.Sprintf("bad call index after minimization: ncalls=%v index=%v call=%v/%v", - len(p0.Calls), callIndex0, name0, p0.Calls[callIndex0].Meta.Name)) + v0 := a.Val + a.Val = typ.Default() + if ctx.pred(p, ctx.callIndex0) { + *ctx.p0 = p + return true + } + a.Val = v0 + case *ResourceType: + if ctx.crash || ctx.triedPaths[path] { + return false + } + ctx.triedPaths[path] = true + a := arg.(*ResultArg) + if a.Res == nil { + return false } + r0 := a.Res + a.Res = nil + a.Val = typ.Default() + if ctx.pred(p, ctx.callIndex0) { + *ctx.p0 = p + return true + } + a.Res = r0 + a.Val = 0 + case *BufferType: + // TODO: try to set individual bytes to 0 + if ctx.triedPaths[path] { + return false + } + ctx.triedPaths[path] = true + if typ.Kind != BufferBlobRand && typ.Kind != BufferBlobRange || + typ.Dir() == DirOut { + return false + } + a := arg.(*DataArg) + minLen := int(typ.RangeBegin) + for step := len(a.Data()) - minLen; len(a.Data()) > minLen && step > 0; { + if len(a.Data())-step >= minLen { + a.data = a.Data()[:len(a.Data())-step] + p.Target.assignSizesCall(call) + if ctx.pred(p, ctx.callIndex0) { + continue + } + a.data = a.Data()[:len(a.Data())+step] + p.Target.assignSizesCall(call) + } + step /= 2 + if ctx.crash { + break + } + } + *ctx.p0 = p + case *VmaType, *LenType, *CsumType, *ConstType: + return false + default: + panic(fmt.Sprintf("unknown arg type '%+v'", typ)) } - return p0, callIndex0 + return false } diff --git a/prog/mutation.go b/prog/mutation.go index 9ab47d8d6..1c4115731 100644 --- a/prog/mutation.go +++ b/prog/mutation.go @@ -343,223 +343,220 @@ func (ma *mutationArgs) collectArg(arg Arg, ctx *ArgCtx) { } func mutateData(r *randGen, data []byte, minLen, maxLen uint64) []byte { - const maxInc = 35 - retry := false -loop: - for stop := false; !stop || retry; stop = r.oneOf(3) { - retry = false - // TODO(dvyukov): duplicate part of data. - switch r.Intn(7) { - case 0: - // Flip bit in byte. - if len(data) == 0 { - retry = true - continue loop - } - byt := r.Intn(len(data)) - bit := r.Intn(8) - data[byt] ^= 1 << uint(bit) - case 1: - // Insert random bytes. - if len(data) == 0 || uint64(len(data)) >= maxLen { - retry = true - continue loop - } - n := r.Intn(16) + 1 - if r := int(maxLen) - len(data); n > r { - n = r - } - pos := r.Intn(len(data)) - for i := 0; i < n; i++ { - data = append(data, 0) - } - copy(data[pos+n:], data[pos:]) + for stop := false; !stop; stop = stop && r.oneOf(3) { + f := mutateDataFuncs[r.Intn(len(mutateDataFuncs))] + data, stop = f(r, data, minLen, maxLen) + } + return data +} + +const maxInc = 35 + +var mutateDataFuncs = [...]func(r *randGen, data []byte, minLen, maxLen uint64) ([]byte, bool){ + // TODO(dvyukov): duplicate part of data. + // Flip bit in byte. + func(r *randGen, data []byte, minLen, maxLen uint64) ([]byte, bool) { + if len(data) == 0 { + return data, false + } + byt := r.Intn(len(data)) + bit := r.Intn(8) + data[byt] ^= 1 << uint(bit) + return data, true + }, + // Insert random bytes. + func(r *randGen, data []byte, minLen, maxLen uint64) ([]byte, bool) { + if len(data) == 0 || uint64(len(data)) >= maxLen { + return data, false + } + n := r.Intn(16) + 1 + if r := int(maxLen) - len(data); n > r { + n = r + } + pos := r.Intn(len(data)) + for i := 0; i < n; i++ { + data = append(data, 0) + } + copy(data[pos+n:], data[pos:]) + for i := 0; i < n; i++ { + data[pos+i] = byte(r.Int31()) + } + if r.bin() { + data = data[:len(data)-n] // preserve original length + } + return data, true + }, + // Remove bytes. + func(r *randGen, data []byte, minLen, maxLen uint64) ([]byte, bool) { + if uint64(len(data)) <= minLen { + return data, false + } + n := r.Intn(16) + 1 + if n > len(data) { + n = len(data) + } + pos := 0 + if n < len(data) { + pos = r.Intn(len(data) - n) + } + copy(data[pos:], data[pos+n:]) + data = data[:len(data)-n] + if r.bin() { for i := 0; i < n; i++ { - data[pos+i] = byte(r.Int31()) + data = append(data, 0) // preserve original length } - if r.bin() { - data = data[:len(data)-n] // preserve original length - } - case 2: - // Remove bytes. - if uint64(len(data)) <= minLen { - retry = true - continue loop - } - n := r.Intn(16) + 1 - if n > len(data) { - n = len(data) - } - pos := 0 - if n < len(data) { - pos = r.Intn(len(data) - n) + } + return data, true + }, + // Append a bunch of bytes. + func(r *randGen, data []byte, minLen, maxLen uint64) ([]byte, bool) { + if uint64(len(data)) >= maxLen { + return data, false + } + const max = 256 + n := max - r.biasedRand(max, 10) + if r := int(maxLen) - len(data); n > r { + n = r + } + for i := 0; i < n; i++ { + data = append(data, byte(r.rand(256))) + } + return data, true + }, + // Replace int8/int16/int32/int64 with a random value. + func(r *randGen, data []byte, minLen, maxLen uint64) ([]byte, bool) { + switch r.Intn(4) { + case 0: // int8 + if len(data) == 0 { + return data, false + } + data[r.Intn(len(data))] = byte(r.rand(1 << 8)) + case 1: // int16 + if len(data) < 2 { + return data, false + } + i := r.Intn(len(data) - 1) + p := (*uint16)(unsafe.Pointer(&data[i])) + *p = uint16(r.rand(1 << 16)) + case 2: // int32 + if len(data) < 4 { + return data, false + } + i := r.Intn(len(data) - 3) + p := (*uint32)(unsafe.Pointer(&data[i])) + *p = r.Uint32() + case 3: // int64 + if len(data) < 8 { + return data, false + } + i := r.Intn(len(data) - 7) + p := (*uint64)(unsafe.Pointer(&data[i])) + *p = r.Uint64() + } + return data, true + }, + // Add/subtract from an int8/int16/int32/int64. + func(r *randGen, data []byte, minLen, maxLen uint64) ([]byte, bool) { + switch r.Intn(4) { + case 0: // int8 + if len(data) == 0 { + return data, false + } + i := r.Intn(len(data)) + delta := byte(r.rand(2*maxInc+1) - maxInc) + if delta == 0 { + delta = 1 + } + data[i] += delta + case 1: // int16 + if len(data) < 2 { + return data, false + } + i := r.Intn(len(data) - 1) + p := (*uint16)(unsafe.Pointer(&data[i])) + delta := uint16(r.rand(2*maxInc+1) - maxInc) + if delta == 0 { + delta = 1 + } + if r.oneOf(10) { + *p = swap16(swap16(*p) + delta) + } else { + *p += delta } - copy(data[pos:], data[pos+n:]) - data = data[:len(data)-n] - if r.bin() { - for i := 0; i < n; i++ { - data = append(data, 0) // preserve original length - } + case 2: // int32 + if len(data) < 4 { + return data, false } - case 3: - // Append a bunch of bytes. - if uint64(len(data)) >= maxLen { - retry = true - continue loop + i := r.Intn(len(data) - 3) + p := (*uint32)(unsafe.Pointer(&data[i])) + delta := uint32(r.rand(2*maxInc+1) - maxInc) + if delta == 0 { + delta = 1 } - const max = 256 - n := max - r.biasedRand(max, 10) - if r := int(maxLen) - len(data); n > r { - n = r + if r.oneOf(10) { + *p = swap32(swap32(*p) + delta) + } else { + *p += delta } - for i := 0; i < n; i++ { - data = append(data, byte(r.rand(256))) + case 3: // int64 + if len(data) < 8 { + return data, false } - case 4: - // Replace int8/int16/int32/int64 with a random value. - switch r.Intn(4) { - case 0: // int8 - if len(data) == 0 { - retry = true - continue loop - } - data[r.Intn(len(data))] = byte(r.rand(1 << 8)) - case 1: // int16 - if len(data) < 2 { - retry = true - continue loop - } - i := r.Intn(len(data) - 1) - p := (*uint16)(unsafe.Pointer(&data[i])) - *p = uint16(r.rand(1 << 16)) - case 2: // int32 - if len(data) < 4 { - retry = true - continue loop - } - i := r.Intn(len(data) - 3) - p := (*uint32)(unsafe.Pointer(&data[i])) - *p = r.Uint32() - case 3: // int64 - if len(data) < 8 { - retry = true - continue loop - } - i := r.Intn(len(data) - 7) - p := (*uint64)(unsafe.Pointer(&data[i])) - *p = r.Uint64() - } - case 5: - // Add/subtract from an int8/int16/int32/int64. - switch r.Intn(4) { - case 0: // int8 - if len(data) == 0 { - retry = true - continue loop - } - i := r.Intn(len(data)) - delta := byte(r.rand(2*maxInc+1) - maxInc) - if delta == 0 { - delta = 1 - } - data[i] += delta - case 1: // int16 - if len(data) < 2 { - retry = true - continue loop - } - i := r.Intn(len(data) - 1) - p := (*uint16)(unsafe.Pointer(&data[i])) - delta := uint16(r.rand(2*maxInc+1) - maxInc) - if delta == 0 { - delta = 1 - } - if r.oneOf(10) { - *p = swap16(swap16(*p) + delta) - } else { - *p += delta - } - case 2: // int32 - if len(data) < 4 { - retry = true - continue loop - } - i := r.Intn(len(data) - 3) - p := (*uint32)(unsafe.Pointer(&data[i])) - delta := uint32(r.rand(2*maxInc+1) - maxInc) - if delta == 0 { - delta = 1 - } - if r.oneOf(10) { - *p = swap32(swap32(*p) + delta) - } else { - *p += delta - } - case 3: // int64 - if len(data) < 8 { - retry = true - continue loop - } - i := r.Intn(len(data) - 7) - p := (*uint64)(unsafe.Pointer(&data[i])) - delta := r.rand(2*maxInc+1) - maxInc - if delta == 0 { - delta = 1 - } - if r.oneOf(10) { - *p = swap64(swap64(*p) + delta) - } else { - *p += delta - } + i := r.Intn(len(data) - 7) + p := (*uint64)(unsafe.Pointer(&data[i])) + delta := r.rand(2*maxInc+1) - maxInc + if delta == 0 { + delta = 1 } - case 6: - // Set int8/int16/int32/int64 to an interesting value. - switch r.Intn(4) { - case 0: // int8 - if len(data) == 0 { - retry = true - continue loop - } - data[r.Intn(len(data))] = byte(r.randInt()) - case 1: // int16 - if len(data) < 2 { - retry = true - continue loop - } - i := r.Intn(len(data) - 1) - value := uint16(r.randInt()) - if r.oneOf(10) { - value = swap16(value) - } - *(*uint16)(unsafe.Pointer(&data[i])) = value - case 2: // int32 - if len(data) < 4 { - retry = true - continue loop - } - i := r.Intn(len(data) - 3) - value := uint32(r.randInt()) - if r.oneOf(10) { - value = swap32(value) - } - *(*uint32)(unsafe.Pointer(&data[i])) = value - case 3: // int64 - if len(data) < 8 { - retry = true - continue loop - } - i := r.Intn(len(data) - 7) - value := r.randInt() - if r.oneOf(10) { - value = swap64(value) - } - *(*uint64)(unsafe.Pointer(&data[i])) = value + if r.oneOf(10) { + *p = swap64(swap64(*p) + delta) + } else { + *p += delta } - default: - panic("bad") } - } - return data + return data, true + }, + // Set int8/int16/int32/int64 to an interesting value. + func(r *randGen, data []byte, minLen, maxLen uint64) ([]byte, bool) { + switch r.Intn(4) { + case 0: // int8 + if len(data) == 0 { + return data, false + } + data[r.Intn(len(data))] = byte(r.randInt()) + case 1: // int16 + if len(data) < 2 { + return data, false + } + i := r.Intn(len(data) - 1) + value := uint16(r.randInt()) + if r.oneOf(10) { + value = swap16(value) + } + *(*uint16)(unsafe.Pointer(&data[i])) = value + case 2: // int32 + if len(data) < 4 { + return data, false + } + i := r.Intn(len(data) - 3) + value := uint32(r.randInt()) + if r.oneOf(10) { + value = swap32(value) + } + *(*uint32)(unsafe.Pointer(&data[i])) = value + case 3: // int64 + if len(data) < 8 { + return data, false + } + i := r.Intn(len(data) - 7) + value := r.randInt() + if r.oneOf(10) { + value = swap64(value) + } + *(*uint64)(unsafe.Pointer(&data[i])) = value + } + return data, true + }, } func swap16(v uint16) uint16 { diff --git a/prog/rand.go b/prog/rand.go index 0dbb23adc..956e7c7e1 100644 --- a/prog/rand.go +++ b/prog/rand.go @@ -511,163 +511,182 @@ func (r *randGen) generateArgImpl(s *state, typ Type, ignoreSpecial bool) (arg A } } - switch a := typ.(type) { - case *ResourceType: - switch { - case r.nOutOf(1000, 1011): - // Get an existing resource. - var allres []Arg - for name1, res1 := range s.resources { - if name1 == "iocbptr" { - continue - } - if r.target.isCompatibleResource(a.Desc.Name, name1) || - r.oneOf(20) && r.target.isCompatibleResource(a.Desc.Kind[0], name1) { - allres = append(allres, res1...) - } - } - if len(allres) != 0 { - arg = MakeResultArg(a, allres[r.Intn(len(allres))], 0) - } else { - arg, calls = r.createResource(s, a) + if !ignoreSpecial && typ.Dir() != DirOut { + switch typ.(type) { + case *StructType, *UnionType: + if gen := r.target.SpecialTypes[typ.Name()]; gen != nil { + return gen(&Gen{r, s}, typ, nil) } - case r.nOutOf(10, 11): - // Create a new resource. - arg, calls = r.createResource(s, a) - default: - special := a.SpecialValues() - arg = MakeResultArg(a, nil, special[r.Intn(len(special))]) } - return arg, calls - case *BufferType: - switch a.Kind { - case BufferBlobRand, BufferBlobRange: - sz := r.randBufLen() - if a.Kind == BufferBlobRange { - sz = r.randRange(a.RangeBegin, a.RangeEnd) - } - if a.Dir() == DirOut { - return MakeOutDataArg(a, sz), nil - } - data := make([]byte, sz) - for i := range data { - data[i] = byte(r.Intn(256)) - } - return MakeDataArg(a, data), nil - case BufferString: - data := r.randString(s, a) - if a.Dir() == DirOut { - return MakeOutDataArg(a, uint64(len(data))), nil - } - return MakeDataArg(a, data), nil - case BufferFilename: - if a.Dir() == DirOut { - var sz uint64 - switch { - case !a.Varlen(): - sz = a.Size() - case r.nOutOf(1, 3): - sz = r.rand(100) - case r.nOutOf(1, 2): - sz = 108 // UNIX_PATH_MAX - default: - sz = 4096 // PATH_MAX - } - return MakeOutDataArg(a, sz), nil + } + + return typ.generate(r, s) +} + +func (a *ResourceType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + switch { + case r.nOutOf(1000, 1011): + // Get an existing resource. + var allres []Arg + for name1, res1 := range s.resources { + if name1 == "iocbptr" { + continue } - return MakeDataArg(a, []byte(r.filename(s, a))), nil - case BufferText: - if a.Dir() == DirOut { - return MakeOutDataArg(a, uint64(r.Intn(100))), nil + if r.target.isCompatibleResource(a.Desc.Name, name1) || + r.oneOf(20) && r.target.isCompatibleResource(a.Desc.Kind[0], name1) { + allres = append(allres, res1...) } - return MakeDataArg(a, r.generateText(a.Text)), nil - default: - panic("unknown buffer kind") } - case *VmaType: - npages := r.randPageCount() - if a.RangeBegin != 0 || a.RangeEnd != 0 { - npages = a.RangeBegin + uint64(r.Intn(int(a.RangeEnd-a.RangeBegin+1))) + if len(allres) != 0 { + arg = MakeResultArg(a, allres[r.Intn(len(allres))], 0) + } else { + arg, calls = r.createResource(s, a) } - arg := r.allocVMA(s, a, npages) - return arg, nil - case *FlagsType: - return MakeConstArg(a, r.flags(a.Vals)), nil - case *ConstType: - return MakeConstArg(a, a.Val), nil - case *IntType: - v := r.randInt() - switch a.Kind { - case IntFileoff: - switch { - case r.nOutOf(90, 101): - v = 0 - case r.nOutOf(10, 11): - v = r.rand(100) - default: - v = r.randInt() - } - case IntRange: - v = r.randRangeInt(a.RangeBegin, a.RangeEnd) + case r.nOutOf(10, 11): + // Create a new resource. + arg, calls = r.createResource(s, a) + default: + special := a.SpecialValues() + arg = MakeResultArg(a, nil, special[r.Intn(len(special))]) + } + return arg, calls +} + +func (a *BufferType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + switch a.Kind { + case BufferBlobRand, BufferBlobRange: + sz := r.randBufLen() + if a.Kind == BufferBlobRange { + sz = r.randRange(a.RangeBegin, a.RangeEnd) } - return MakeConstArg(a, v), nil - case *ProcType: - return MakeConstArg(a, r.rand(int(a.ValuesPerProc))), nil - case *ArrayType: - var count uint64 - switch a.Kind { - case ArrayRandLen: - count = r.randArrayLen() - case ArrayRangeLen: - count = r.randRange(a.RangeBegin, a.RangeEnd) + if a.Dir() == DirOut { + return MakeOutDataArg(a, sz), nil } - var inner []Arg - var calls []*Call - for i := uint64(0); i < count; i++ { - arg1, calls1 := r.generateArg(s, a.Type) - inner = append(inner, arg1) - calls = append(calls, calls1...) + data := make([]byte, sz) + for i := range data { + data[i] = byte(r.Intn(256)) } - return MakeGroupArg(a, inner), calls - case *StructType: - if !ignoreSpecial { - if gen := r.target.SpecialTypes[a.Name()]; gen != nil && a.Dir() != DirOut { - arg, calls = gen(&Gen{r, s}, a, nil) - return - } + return MakeDataArg(a, data), nil + case BufferString: + data := r.randString(s, a) + if a.Dir() == DirOut { + return MakeOutDataArg(a, uint64(len(data))), nil } - args, calls := r.generateArgs(s, a.Fields) - group := MakeGroupArg(a, args) - return group, calls - case *UnionType: - if !ignoreSpecial { - if gen := r.target.SpecialTypes[a.Name()]; gen != nil && a.Dir() != DirOut { - arg, calls = gen(&Gen{r, s}, a, nil) - return + return MakeDataArg(a, data), nil + case BufferFilename: + if a.Dir() == DirOut { + var sz uint64 + switch { + case !a.Varlen(): + sz = a.Size() + case r.nOutOf(1, 3): + sz = r.rand(100) + case r.nOutOf(1, 2): + sz = 108 // UNIX_PATH_MAX + default: + sz = 4096 // PATH_MAX } + return MakeOutDataArg(a, sz), nil } - optType := a.Fields[r.Intn(len(a.Fields))] - opt, calls := r.generateArg(s, optType) - return MakeUnionArg(a, opt), calls - case *PtrType: - inner, calls := r.generateArg(s, a.Type) - // TODO(dvyukov): remove knowledge about iocb from prog. - if a.Type.Name() == "iocb" && len(s.resources["iocbptr"]) != 0 { - // It is weird, but these are actually identified by kernel by address. - // So try to reuse a previously used address. - addrs := s.resources["iocbptr"] - addr := addrs[r.Intn(len(addrs))].(*PointerArg) - arg = MakePointerArg(a, addr.Address, inner) - return arg, calls + return MakeDataArg(a, []byte(r.filename(s, a))), nil + case BufferText: + if a.Dir() == DirOut { + return MakeOutDataArg(a, uint64(r.Intn(100))), nil } - arg := r.allocAddr(s, a, inner.Size(), inner) - return arg, calls - case *LenType: - // Return placeholder value of 0 while generating len arg. - return MakeConstArg(a, 0), nil - case *CsumType: - return MakeConstArg(a, 0), nil + return MakeDataArg(a, r.generateText(a.Text)), nil default: - panic("unknown argument type") + panic("unknown buffer kind") + } +} + +func (a *VmaType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + npages := r.randPageCount() + if a.RangeBegin != 0 || a.RangeEnd != 0 { + npages = a.RangeBegin + uint64(r.Intn(int(a.RangeEnd-a.RangeBegin+1))) + } + return r.allocVMA(s, a, npages), nil +} + +func (a *FlagsType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + return MakeConstArg(a, r.flags(a.Vals)), nil +} + +func (a *ConstType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + return MakeConstArg(a, a.Val), nil +} + +func (a *IntType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + v := r.randInt() + switch a.Kind { + case IntFileoff: + switch { + case r.nOutOf(90, 101): + v = 0 + case r.nOutOf(10, 11): + v = r.rand(100) + default: + v = r.randInt() + } + case IntRange: + v = r.randRangeInt(a.RangeBegin, a.RangeEnd) } + return MakeConstArg(a, v), nil +} + +func (a *ProcType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + return MakeConstArg(a, r.rand(int(a.ValuesPerProc))), nil +} + +func (a *ArrayType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + var count uint64 + switch a.Kind { + case ArrayRandLen: + count = r.randArrayLen() + case ArrayRangeLen: + count = r.randRange(a.RangeBegin, a.RangeEnd) + } + var inner []Arg + for i := uint64(0); i < count; i++ { + arg1, calls1 := r.generateArg(s, a.Type) + inner = append(inner, arg1) + calls = append(calls, calls1...) + } + return MakeGroupArg(a, inner), calls +} + +func (a *StructType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + args, calls := r.generateArgs(s, a.Fields) + group := MakeGroupArg(a, args) + return group, calls +} + +func (a *UnionType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + optType := a.Fields[r.Intn(len(a.Fields))] + opt, calls := r.generateArg(s, optType) + return MakeUnionArg(a, opt), calls +} + +func (a *PtrType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + inner, calls := r.generateArg(s, a.Type) + // TODO(dvyukov): remove knowledge about iocb from prog. + if a.Type.Name() == "iocb" && len(s.resources["iocbptr"]) != 0 { + // It is weird, but these are actually identified by kernel by address. + // So try to reuse a previously used address. + addrs := s.resources["iocbptr"] + addr := addrs[r.Intn(len(addrs))].(*PointerArg) + arg = MakePointerArg(a, addr.Address, inner) + return arg, calls + } + arg = r.allocAddr(s, a, inner.Size(), inner) + return arg, calls +} + +func (a *LenType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + // Updated later in assignSizesCall. + return MakeConstArg(a, 0), nil +} + +func (a *CsumType) generate(r *randGen, s *state) (arg Arg, calls []*Call) { + // Updated later in calcChecksumsCall. + return MakeConstArg(a, 0), nil } diff --git a/prog/types.go b/prog/types.go index 37aff4e88..0297b003b 100644 --- a/prog/types.go +++ b/prog/types.go @@ -49,6 +49,8 @@ type Type interface { BitfieldOffset() uint64 BitfieldLength() uint64 BitfieldMiddle() bool // returns true for all but last bitfield in a group + + generate(r *randGen, s *state) (arg Arg, calls []*Call) } func IsPad(t Type) bool { diff --git a/prog/validation.go b/prog/validation.go index 462a9220e..5c9275164 100644 --- a/prog/validation.go +++ b/prog/validation.go @@ -10,12 +10,17 @@ import ( var debug = false // enabled in tests type validCtx struct { - args map[Arg]bool - uses map[Arg]Arg + target *Target + args map[Arg]bool + uses map[Arg]Arg } func (p *Prog) validate() error { - ctx := &validCtx{make(map[Arg]bool), make(map[Arg]Arg)} + ctx := &validCtx{ + target: p.Target, + args: make(map[Arg]bool), + uses: make(map[Arg]Arg), + } for _, c := range p.Calls { if err := p.validateCall(ctx, c); err != nil { return err @@ -37,290 +42,284 @@ func (p *Prog) validateCall(ctx *validCtx, c *Call) error { return fmt.Errorf("syscall %v: wrong number of arguments, want %v, got %v", c.Meta.Name, len(c.Meta.Args), len(c.Args)) } - var checkArg func(arg Arg) error - checkArg = func(arg Arg) error { - if arg == nil { - return fmt.Errorf("syscall %v: nil arg", c.Meta.Name) - } - if ctx.args[arg] { - return fmt.Errorf("syscall %v: arg %#v is referenced several times in the tree", - c.Meta.Name, arg) + for _, arg := range c.Args { + if _, ok := arg.(*ReturnArg); ok { + return fmt.Errorf("syscall %v: arg '%v' has wrong return kind", + c.Meta.Name, arg.Type().Name()) } - ctx.args[arg] = true - if used, ok := arg.(ArgUsed); ok { - for u := range *used.Used() { - if u == nil { - return fmt.Errorf("syscall %v: nil reference in uses for arg %+v", - c.Meta.Name, arg) - } - ctx.uses[u] = arg - } + if err := validateArg(ctx, c, arg); err != nil { + return err } - if arg.Type() == nil { - return fmt.Errorf("syscall %v: no type", c.Meta.Name) + } + if c.Ret == nil { + return fmt.Errorf("syscall %v: return value is absent", c.Meta.Name) + } + if _, ok := c.Ret.(*ReturnArg); !ok { + return fmt.Errorf("syscall %v: return value has wrong kind %v", c.Meta.Name, c.Ret) + } + if c.Meta.Ret != nil { + if err := validateArg(ctx, c, c.Ret); err != nil { + return err } - if arg.Type().Dir() == DirOut { - switch a := arg.(type) { - case *ConstArg: - // We generate output len arguments, which makes sense since it can be - // a length of a variable-length array which is not known otherwise. - if _, ok := a.Type().(*LenType); ok { - break - } - if a.Val != 0 && a.Val != a.Type().Default() { - return fmt.Errorf("syscall %v: output arg '%v'/'%v' has non default value '%+v'", - c.Meta.Name, a.Type().FieldName(), a.Type().Name(), a) - } - case *DataArg: - if len(a.data) != 0 { - return fmt.Errorf("syscall %v: output arg '%v' has data", - c.Meta.Name, a.Type().Name()) - } + } else if c.Ret.Type() != nil { + return fmt.Errorf("syscall %v: return value has spurious type: %+v", + c.Meta.Name, c.Ret.Type()) + } + return nil +} + +// nolint +func validateArg(ctx *validCtx, c *Call, arg Arg) error { + if arg == nil { + return fmt.Errorf("syscall %v: nil arg", c.Meta.Name) + } + if ctx.args[arg] { + return fmt.Errorf("syscall %v: arg %#v is referenced several times in the tree", + c.Meta.Name, arg) + } + ctx.args[arg] = true + if used, ok := arg.(ArgUsed); ok { + for u := range *used.Used() { + if u == nil { + return fmt.Errorf("syscall %v: nil reference in uses for arg %+v", + c.Meta.Name, arg) } + ctx.uses[u] = arg } - switch typ1 := arg.Type().(type) { - case *IntType: - switch a := arg.(type) { - case *ConstArg: - if a.Type().Dir() == DirOut && (a.Val != 0 && a.Val != a.Type().Default()) { - return fmt.Errorf("syscall %v: out int arg '%v' has bad const value %v", - c.Meta.Name, a.Type().Name(), a.Val) - } - case *ReturnArg: - default: - return fmt.Errorf("syscall %v: int arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) - } - case *ResourceType: - switch a := arg.(type) { - case *ResultArg: - if a.Type().Dir() == DirOut && (a.Val != 0 && a.Val != a.Type().Default()) { - return fmt.Errorf("syscall %v: out resource arg '%v' has bad const value %v", - c.Meta.Name, a.Type().Name(), a.Val) - } - case *ReturnArg: - default: - return fmt.Errorf("syscall %v: fd arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) + } + if arg.Type() == nil { + return fmt.Errorf("syscall %v: no type", c.Meta.Name) + } + if arg.Type().Dir() == DirOut { + switch a := arg.(type) { + case *ConstArg: + // We generate output len arguments, which makes sense since it can be + // a length of a variable-length array which is not known otherwise. + if _, ok := a.Type().(*LenType); ok { + break } - case *StructType, *ArrayType: - switch arg.(type) { - case *GroupArg: - default: - return fmt.Errorf("syscall %v: struct/array arg '%v' has bad kind %#v", - c.Meta.Name, arg.Type().Name(), arg) + if a.Val != 0 && a.Val != a.Type().Default() { + return fmt.Errorf("syscall %v: output arg '%v'/'%v' has non default value '%+v'", + c.Meta.Name, a.Type().FieldName(), a.Type().Name(), a) } - case *UnionType: - switch arg.(type) { - case *UnionArg: - default: - return fmt.Errorf("syscall %v: union arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) + case *DataArg: + if len(a.data) != 0 { + return fmt.Errorf("syscall %v: output arg '%v' has data", + c.Meta.Name, a.Type().Name()) } - case *ProcType: - switch a := arg.(type) { - case *ConstArg: - if a.Val >= typ1.ValuesPerProc && a.Val != typ1.Default() { - return fmt.Errorf("syscall %v: per proc arg '%v' has bad value '%v'", - c.Meta.Name, a.Type().Name(), a.Val) - } - default: - return fmt.Errorf("syscall %v: proc arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) + } + } + + switch typ1 := arg.Type().(type) { + case *IntType: + switch a := arg.(type) { + case *ConstArg: + if a.Type().Dir() == DirOut && (a.Val != 0 && a.Val != a.Type().Default()) { + return fmt.Errorf("syscall %v: out int arg '%v' has bad const value %v", + c.Meta.Name, a.Type().Name(), a.Val) } - case *BufferType: - switch a := arg.(type) { - case *DataArg: - switch typ1.Kind { - case BufferString: - if typ1.TypeSize != 0 && a.Size() != typ1.TypeSize { - return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", - c.Meta.Name, a.Type().Name(), a.Size(), typ1.TypeSize) - } - } - default: - return fmt.Errorf("syscall %v: buffer arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) + default: + return fmt.Errorf("syscall %v: int arg '%v' has bad kind %v", + c.Meta.Name, arg.Type().Name(), arg) + } + case *ResourceType: + switch a := arg.(type) { + case *ResultArg: + if a.Type().Dir() == DirOut && (a.Val != 0 && a.Val != a.Type().Default()) { + return fmt.Errorf("syscall %v: out resource arg '%v' has bad const value %v", + c.Meta.Name, a.Type().Name(), a.Val) } - case *CsumType: - switch a := arg.(type) { - case *ConstArg: - if a.Val != 0 { - return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v", - c.Meta.Name, a.Type().Name(), a.Val) - } - default: - return fmt.Errorf("syscall %v: csum arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) + case *ReturnArg: + default: + return fmt.Errorf("syscall %v: fd arg '%v' has bad kind %v", + c.Meta.Name, arg.Type().Name(), arg) + } + case *StructType, *ArrayType: + switch arg.(type) { + case *GroupArg: + default: + return fmt.Errorf("syscall %v: struct/array arg '%v' has bad kind %#v", + c.Meta.Name, arg.Type().Name(), arg) + } + case *UnionType: + switch arg.(type) { + case *UnionArg: + default: + return fmt.Errorf("syscall %v: union arg '%v' has bad kind %v", + c.Meta.Name, arg.Type().Name(), arg) + } + case *ProcType: + switch a := arg.(type) { + case *ConstArg: + if a.Val >= typ1.ValuesPerProc && a.Val != typ1.Default() { + return fmt.Errorf("syscall %v: per proc arg '%v' has bad value '%v'", + c.Meta.Name, a.Type().Name(), a.Val) } - case *PtrType: - switch a := arg.(type) { - case *PointerArg: - if a.Type().Dir() == DirOut { - return fmt.Errorf("syscall %v: pointer arg '%v' has output direction", - c.Meta.Name, a.Type().Name()) - } - if a.Res == nil && !a.Type().Optional() { - return fmt.Errorf("syscall %v: non optional pointer arg '%v' is nil", - c.Meta.Name, a.Type().Name()) + default: + return fmt.Errorf("syscall %v: proc arg '%v' has bad kind %v", + c.Meta.Name, arg.Type().Name(), arg) + } + case *BufferType: + switch a := arg.(type) { + case *DataArg: + switch typ1.Kind { + case BufferString: + if typ1.TypeSize != 0 && a.Size() != typ1.TypeSize { + return fmt.Errorf("syscall %v: string arg '%v' has size %v, which should be %v", + c.Meta.Name, a.Type().Name(), a.Size(), typ1.TypeSize) } - default: - return fmt.Errorf("syscall %v: ptr arg '%v' has bad kind %v", - c.Meta.Name, arg.Type().Name(), arg) } + default: + return fmt.Errorf("syscall %v: buffer arg '%v' has bad kind %v", + c.Meta.Name, arg.Type().Name(), arg) } + case *CsumType: switch a := arg.(type) { case *ConstArg: - case *PointerArg: - maxMem := p.Target.NumPages * p.Target.PageSize - size := a.VmaSize - if size == 0 && a.Res != nil { - size = a.Res.Size() - } - if a.Address >= maxMem || a.Address+size > maxMem { - return fmt.Errorf("syscall %v: ptr %v has bad address %v/%v/%v", - c.Meta.Name, a.Type().Name(), a.Address, a.VmaSize, size) + if a.Val != 0 { + return fmt.Errorf("syscall %v: csum arg '%v' has nonzero value %v", + c.Meta.Name, a.Type().Name(), a.Val) } - switch t := a.Type().(type) { - case *VmaType: - if a.Res != nil { - return fmt.Errorf("syscall %v: vma arg '%v' has data", - c.Meta.Name, a.Type().Name()) - } - if a.VmaSize == 0 && t.Dir() != DirOut && !t.Optional() { - return fmt.Errorf("syscall %v: vma arg '%v' has size 0", - c.Meta.Name, a.Type().Name()) - } - case *PtrType: - if a.Res != nil { - if err := checkArg(a.Res); err != nil { - return err - } - } - if a.VmaSize != 0 { - return fmt.Errorf("syscall %v: pointer arg '%v' has nonzero size", - c.Meta.Name, a.Type().Name()) - } - default: - return fmt.Errorf("syscall %v: pointer arg '%v' has bad meta type %+v", - c.Meta.Name, arg.Type().Name(), arg.Type()) - } - case *DataArg: - typ1 := a.Type() - if !typ1.Varlen() && typ1.Size() != a.Size() { - return fmt.Errorf("syscall %v: data arg %v has wrong size %v, want %v", - c.Meta.Name, arg.Type().Name(), a.Size(), typ1.Size()) + default: + return fmt.Errorf("syscall %v: csum arg '%v' has bad kind %v", + c.Meta.Name, arg.Type().Name(), arg) + } + case *PtrType: + switch a := arg.(type) { + case *PointerArg: + if a.Type().Dir() == DirOut { + return fmt.Errorf("syscall %v: pointer arg '%v' has output direction", + c.Meta.Name, a.Type().Name()) } - switch typ1 := a.Type().(type) { - case *ArrayType: - if typ2, ok := typ1.Type.(*IntType); !ok || typ2.Size() != 1 { - return fmt.Errorf("syscall %v: data arg '%v' should be an array", - c.Meta.Name, a.Type().Name()) - } + if a.Res == nil && !a.Type().Optional() { + return fmt.Errorf("syscall %v: non optional pointer arg '%v' is nil", + c.Meta.Name, a.Type().Name()) } - case *GroupArg: - switch typ1 := a.Type().(type) { - case *StructType: - if len(a.Inner) != len(typ1.Fields) { - return fmt.Errorf("syscall %v: struct arg '%v' has wrong number of fields: want %v, got %v", - c.Meta.Name, a.Type().Name(), len(typ1.Fields), len(a.Inner)) - } - for _, arg1 := range a.Inner { - if err := checkArg(arg1); err != nil { - return err - } - } - case *ArrayType: - if typ1.Kind == ArrayRangeLen && typ1.RangeBegin == typ1.RangeEnd && - uint64(len(a.Inner)) != typ1.RangeBegin { - return fmt.Errorf("syscall %v: array %v has wrong number"+ - " of elements %v, want %v", - c.Meta.Name, arg.Type().Name(), - len(a.Inner), typ1.RangeBegin) - } - for _, arg1 := range a.Inner { - if err := checkArg(arg1); err != nil { - return err - } - } - default: - return fmt.Errorf("syscall %v: group arg '%v' has bad underlying type %+v", - c.Meta.Name, arg.Type().Name(), arg.Type()) + default: + return fmt.Errorf("syscall %v: ptr arg '%v' has bad kind %v", + c.Meta.Name, arg.Type().Name(), arg) + } + } + + switch a := arg.(type) { + case *ConstArg: + case *PointerArg: + maxMem := ctx.target.NumPages * ctx.target.PageSize + size := a.VmaSize + if size == 0 && a.Res != nil { + size = a.Res.Size() + } + if a.Address >= maxMem || a.Address+size > maxMem { + return fmt.Errorf("syscall %v: ptr %v has bad address %v/%v/%v", + c.Meta.Name, a.Type().Name(), a.Address, a.VmaSize, size) + } + switch t := a.Type().(type) { + case *VmaType: + if a.Res != nil { + return fmt.Errorf("syscall %v: vma arg '%v' has data", + c.Meta.Name, a.Type().Name()) } - case *UnionArg: - typ1, ok := a.Type().(*UnionType) - if !ok { - return fmt.Errorf("syscall %v: union arg '%v' has bad type", + if a.VmaSize == 0 && t.Dir() != DirOut && !t.Optional() { + return fmt.Errorf("syscall %v: vma arg '%v' has size 0", c.Meta.Name, a.Type().Name()) } - found := false - for _, typ2 := range typ1.Fields { - if a.Option.Type().Name() == typ2.Name() { - found = true - break + case *PtrType: + if a.Res != nil { + if err := validateArg(ctx, c, a.Res); err != nil { + return err } } - if !found { - return fmt.Errorf("syscall %v: union arg '%v' has bad option", + if a.VmaSize != 0 { + return fmt.Errorf("syscall %v: pointer arg '%v' has nonzero size", c.Meta.Name, a.Type().Name()) } - if err := checkArg(a.Option); err != nil { - return err - } - case *ResultArg: - switch a.Type().(type) { - case *ResourceType: - default: - return fmt.Errorf("syscall %v: result arg '%v' has bad meta type %+v", - c.Meta.Name, arg.Type().Name(), arg.Type()) - } - if a.Res == nil { - break + default: + return fmt.Errorf("syscall %v: pointer arg '%v' has bad meta type %+v", + c.Meta.Name, arg.Type().Name(), arg.Type()) + } + case *DataArg: + typ1 := a.Type() + if !typ1.Varlen() && typ1.Size() != a.Size() { + return fmt.Errorf("syscall %v: data arg %v has wrong size %v, want %v", + c.Meta.Name, arg.Type().Name(), a.Size(), typ1.Size()) + } + case *GroupArg: + switch typ1 := a.Type().(type) { + case *StructType: + if len(a.Inner) != len(typ1.Fields) { + return fmt.Errorf("syscall %v: struct arg '%v' has wrong number of fields: want %v, got %v", + c.Meta.Name, a.Type().Name(), len(typ1.Fields), len(a.Inner)) } - if !ctx.args[a.Res] { - return fmt.Errorf("syscall %v: result arg %v references out-of-tree result: %#v -> %#v", - c.Meta.Name, a.Type().Name(), arg, a.Res) + for _, arg1 := range a.Inner { + if err := validateArg(ctx, c, arg1); err != nil { + return err + } } - if !(*a.Res.(ArgUsed).Used())[arg] { - return fmt.Errorf("syscall %v: result arg '%v' has broken link (%+v)", - c.Meta.Name, a.Type().Name(), *a.Res.(ArgUsed).Used()) + case *ArrayType: + if typ1.Kind == ArrayRangeLen && typ1.RangeBegin == typ1.RangeEnd && + uint64(len(a.Inner)) != typ1.RangeBegin { + return fmt.Errorf("syscall %v: array %v has wrong number"+ + " of elements %v, want %v", + c.Meta.Name, arg.Type().Name(), + len(a.Inner), typ1.RangeBegin) } - case *ReturnArg: - switch a.Type().(type) { - case *ResourceType: - case *VmaType: - default: - return fmt.Errorf("syscall %v: result arg '%v' has bad meta type %+v", - c.Meta.Name, arg.Type().Name(), arg.Type()) + for _, arg1 := range a.Inner { + if err := validateArg(ctx, c, arg1); err != nil { + return err + } } default: - return fmt.Errorf("syscall %v: unknown arg '%v' kind", - c.Meta.Name, arg.Type().Name()) + return fmt.Errorf("syscall %v: group arg '%v' has bad underlying type %+v", + c.Meta.Name, arg.Type().Name(), arg.Type()) } - return nil - } - for _, arg := range c.Args { - if _, ok := arg.(*ReturnArg); ok { - return fmt.Errorf("syscall %v: arg '%v' has wrong return kind", - c.Meta.Name, arg.Type().Name()) + case *UnionArg: + typ1, ok := a.Type().(*UnionType) + if !ok { + return fmt.Errorf("syscall %v: union arg '%v' has bad type", + c.Meta.Name, a.Type().Name()) } - if err := checkArg(arg); err != nil { - return err + found := false + for _, typ2 := range typ1.Fields { + if a.Option.Type().Name() == typ2.Name() { + found = true + break + } } - } - if c.Ret == nil { - return fmt.Errorf("syscall %v: return value is absent", c.Meta.Name) - } - if _, ok := c.Ret.(*ReturnArg); !ok { - return fmt.Errorf("syscall %v: return value has wrong kind %v", c.Meta.Name, c.Ret) - } - if c.Meta.Ret != nil { - if err := checkArg(c.Ret); err != nil { + if !found { + return fmt.Errorf("syscall %v: union arg '%v' has bad option", + c.Meta.Name, a.Type().Name()) + } + if err := validateArg(ctx, c, a.Option); err != nil { return err } - } else if c.Ret.Type() != nil { - return fmt.Errorf("syscall %v: return value has spurious type: %+v", - c.Meta.Name, c.Ret.Type()) + case *ResultArg: + switch a.Type().(type) { + case *ResourceType: + default: + return fmt.Errorf("syscall %v: result arg '%v' has bad meta type %+v", + c.Meta.Name, arg.Type().Name(), arg.Type()) + } + if a.Res == nil { + break + } + if !ctx.args[a.Res] { + return fmt.Errorf("syscall %v: result arg %v references out-of-tree result: %#v -> %#v", + c.Meta.Name, a.Type().Name(), arg, a.Res) + } + if !(*a.Res.(ArgUsed).Used())[arg] { + return fmt.Errorf("syscall %v: result arg '%v' has broken link (%+v)", + c.Meta.Name, a.Type().Name(), *a.Res.(ArgUsed).Used()) + } + case *ReturnArg: + switch a.Type().(type) { + case *ResourceType: + default: + return fmt.Errorf("syscall %v: result arg '%v' has bad meta type %+v", + c.Meta.Name, arg.Type().Name(), arg.Type()) + } + default: + return fmt.Errorf("syscall %v: unknown arg '%v' kind", + c.Meta.Name, arg.Type().Name()) } return nil } |
