From a7e4a49fae26bf52b4b8f26aeebc50d947dc1abc Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Fri, 20 Jan 2017 23:55:25 +0100 Subject: all: spot optimizations A bunch of spot optmizations after cpu/memory profiling: 1. Optimize hot-path coverage comparison in fuzzer. 2. Don't allocate and copy serialized program, serialize directly into shmem. 3. Reduce allocations during parsing of output shmem (encoding/binary sucks). 4. Don't allocate and copy coverage arrays, refer directly to the shmem region (we are not going to mutate them). 5. Don't validate programs outside of tests, validation allocates tons of memory. 6. Replace the choose primitive with simpler switches. Choose allocates fullload of memory (for int, func, and everything the func refers). 7. Other minor optimizations. --- prog/clone.go | 6 +- prog/encoding.go | 8 +- prog/encodingexec.go | 61 +++-- prog/encodingexec_test.go | 14 +- prog/mutation.go | 662 +++++++++++++++++++++++----------------------- prog/prog_test.go | 4 + prog/rand.go | 337 +++++++++++------------ prog/validation.go | 5 + 8 files changed, 577 insertions(+), 520 deletions(-) (limited to 'prog') diff --git a/prog/clone.go b/prog/clone.go index 69a54cd4e..6a6837148 100644 --- a/prog/clone.go +++ b/prog/clone.go @@ -44,7 +44,9 @@ func (arg *Arg) clone(c *Call, newargs map[*Arg]*Arg) *Arg { for _, arg2 := range arg.Inner { arg1.Inner = append(arg1.Inner, arg2.clone(c, newargs)) } - arg1.Uses = nil // filled when we clone the referent - newargs[arg] = arg1 + if len(arg1.Uses) != 0 { + arg1.Uses = nil // filled when we clone the referent + newargs[arg] = arg1 + } return arg1 } diff --git a/prog/encoding.go b/prog/encoding.go index b04f667ba..8c7c1674a 100644 --- a/prog/encoding.go +++ b/prog/encoding.go @@ -27,11 +27,9 @@ func (p *Prog) String() string { } func (p *Prog) Serialize() []byte { - /* - if err := p.validate(); err != nil { - panic("serializing invalid program") - } - */ + if err := p.validate(); err != nil { + panic("serializing invalid program") + } buf := new(bytes.Buffer) vars := make(map[*Arg]int) varSeq := 0 diff --git a/prog/encodingexec.go b/prog/encodingexec.go index 517e8eb20..5afd83fd6 100644 --- a/prog/encodingexec.go +++ b/prog/encodingexec.go @@ -25,17 +25,25 @@ const ( ) const ( + ExecBufferSize = 2 << 20 + ptrSize = 8 pageSize = 4 << 10 dataOffset = 512 << 20 ) -func (p *Prog) SerializeForExec(pid int) []byte { +// SerializeForExec serializes program p for execution by process pid into the provided buffer. +// If the provided buffer is too small for the program an error is returned. +func (p *Prog) SerializeForExec(buffer []byte, pid int) error { if err := p.validate(); err != nil { panic(fmt.Errorf("serializing invalid program: %v", err)) } var instrSeq uintptr - w := &execContext{args: make(map[*Arg]*argInfo)} + w := &execContext{ + buf: buffer, + eof: false, + args: make(map[*Arg]argInfo), + } for _, c := range p.Calls { // Calculate arg offsets within structs. // Generate copyin instructions that fill in data into pointer arguments. @@ -43,7 +51,9 @@ func (p *Prog) SerializeForExec(pid int) []byte { if arg.Kind == ArgPointer && arg.Res != nil { var rec func(*Arg, uintptr) uintptr rec = func(arg1 *Arg, offset uintptr) uintptr { - w.args[arg1] = &argInfo{Offset: offset} + if len(arg1.Uses) != 0 { + w.args[arg1] = argInfo{Offset: offset} + } if arg1.Kind == ArgGroup { var totalSize uintptr for _, arg2 := range arg1.Inner { @@ -85,7 +95,9 @@ func (p *Prog) SerializeForExec(pid int) []byte { for _, arg := range c.Args { w.writeArg(arg, pid) } - w.args[c.Ret] = &argInfo{Idx: instrSeq} + if len(c.Ret.Uses) != 0 { + w.args[c.Ret] = argInfo{Idx: instrSeq} + } instrSeq++ // Generate copyout instructions that persist interesting return values. foreachArg(c, func(arg, base *Arg, _ *[]*Arg) { @@ -103,6 +115,7 @@ func (p *Prog) SerializeForExec(pid int) []byte { info := w.args[arg] info.Idx = instrSeq instrSeq++ + w.args[arg] = info w.write(ExecInstrCopyout) w.write(physicalAddr(base) + info.Offset) w.write(arg.Size()) @@ -112,7 +125,10 @@ func (p *Prog) SerializeForExec(pid int) []byte { }) } w.write(ExecInstrEOF) - return w.buf + if w.eof { + return fmt.Errorf("provided buffer is too small") + } + return nil } func physicalAddr(arg *Arg) uintptr { @@ -130,7 +146,8 @@ func physicalAddr(arg *Arg) uintptr { type execContext struct { buf []byte - args map[*Arg]*argInfo + eof bool + args map[*Arg]argInfo } type argInfo struct { @@ -139,7 +156,19 @@ type argInfo struct { } func (w *execContext) write(v uintptr) { - w.buf = append(w.buf, byte(v>>0), byte(v>>8), byte(v>>16), byte(v>>24), byte(v>>32), byte(v>>40), byte(v>>48), byte(v>>56)) + if len(w.buf) < 8 { + w.eof = true + return + } + w.buf[0] = byte(v >> 0) + w.buf[1] = byte(v >> 8) + w.buf[2] = byte(v >> 16) + w.buf[3] = byte(v >> 24) + w.buf[4] = byte(v >> 32) + w.buf[5] = byte(v >> 40) + w.buf[6] = byte(v >> 48) + w.buf[7] = byte(v >> 56) + w.buf = w.buf[8:] } func (w *execContext) writeArg(arg *Arg, pid int) { @@ -171,15 +200,15 @@ func (w *execContext) writeArg(arg *Arg, pid int) { case ArgData: w.write(ExecArgData) w.write(uintptr(len(arg.Data))) - for i := 0; i < len(arg.Data); i += 8 { - var v uintptr - for j := 0; j < 8; j++ { - if i+j >= len(arg.Data) { - break - } - v |= uintptr(arg.Data[i+j]) << uint(j*8) - } - w.write(v) + padded := len(arg.Data) + if pad := 8 - len(arg.Data)%8; pad != 8 { + padded += pad + } + if len(w.buf) < padded { + w.eof = true + } else { + copy(w.buf, arg.Data) + w.buf = w.buf[padded:] } default: panic("unknown arg type") diff --git a/prog/encodingexec_test.go b/prog/encodingexec_test.go index f2476351c..b7a6b9463 100644 --- a/prog/encodingexec_test.go +++ b/prog/encodingexec_test.go @@ -14,9 +14,12 @@ import ( func TestSerializeForExecRandom(t *testing.T) { rs, iters := initTest(t) + buf := make([]byte, ExecBufferSize) for i := 0; i < iters; i++ { p := Generate(rs, 10, nil) - p.SerializeForExec(i % 16) + if err := p.SerializeForExec(buf, i%16); err != nil { + t.Fatalf("failed to serialize: %v", err) + } } } @@ -249,15 +252,22 @@ func TestSerializeForExec(t *testing.T) { }, } + buf := make([]byte, ExecBufferSize) for i, test := range tests { p, err := Deserialize([]byte(test.prog)) if err != nil { t.Fatalf("failed to deserialize prog %v: %v", i, err) } t.Run(fmt.Sprintf("%v:%v", i, p.String()), func(t *testing.T) { - data := p.SerializeForExec(i % 16) + if err := p.SerializeForExec(buf, i%16); err != nil { + t.Fatalf("failed to serialize: %v", err) + } w := new(bytes.Buffer) binary.Write(w, binary.LittleEndian, test.serialized) + data := buf + if len(data) > len(w.Bytes()) { + data = data[:len(w.Bytes())] + } if !bytes.Equal(data, w.Bytes()) { got := make([]uint64, len(data)/8) binary.Read(bytes.NewReader(data), binary.LittleEndian, &got) diff --git a/prog/mutation.go b/prog/mutation.go index 928fe1fdb..2136c86f8 100644 --- a/prog/mutation.go +++ b/prog/mutation.go @@ -26,203 +26,203 @@ func (p *Prog) Mutate(rs rand.Source, ncalls int, ct *ChoiceTable, corpus []*Pro retry := false for stop := false; !stop || retry; stop = r.oneOf(3) { retry = false - r.choose( - 20, func() { - // Insert a new call. - if len(p.Calls) >= ncalls { - retry = true - return - } - idx := r.biasedRand(len(p.Calls)+1, 5) - var c *Call - if idx < len(p.Calls) { - c = p.Calls[idx] - } - s := analyze(ct, p, c) - calls := r.generateCall(s, p) - p.insertBefore(c, calls) - }, - 10, func() { - // Change args of a call. - if len(p.Calls) == 0 { - retry = true - return - } - c := p.Calls[r.Intn(len(p.Calls))] - if len(c.Args) == 0 { + switch { + case r.nOutOf(20, 31): + // Insert a new call. + if len(p.Calls) >= ncalls { + retry = true + continue + } + idx := r.biasedRand(len(p.Calls)+1, 5) + var c *Call + if idx < len(p.Calls) { + c = p.Calls[idx] + } + s := analyze(ct, p, c) + calls := r.generateCall(s, p) + p.insertBefore(c, calls) + case r.nOutOf(10, 11): + // Change args of a call. + if len(p.Calls) == 0 { + retry = true + continue + } + c := p.Calls[r.Intn(len(p.Calls))] + if len(c.Args) == 0 { + retry = true + continue + } + s := analyze(ct, p, c) + for stop := false; !stop; stop = r.oneOf(3) { + args, bases := mutationArgs(c) + if len(args) == 0 { retry = true - return + continue } - s := analyze(ct, p, c) - for stop := false; !stop; stop = r.oneOf(3) { - args, bases := mutationArgs(c) - if len(args) == 0 { - retry = true - return + idx := r.Intn(len(args)) + arg, base := args[idx], bases[idx] + var baseSize uintptr + if base != nil { + if base.Kind != ArgPointer || base.Res == nil { + panic("bad base arg") } - idx := r.Intn(len(args)) - arg, base := args[idx], bases[idx] - var baseSize uintptr - if base != nil { - if base.Kind != ArgPointer || base.Res == nil { - panic("bad base arg") - } - baseSize = base.Res.Size() - } - switch a := arg.Type.(type) { - case *sys.IntType, *sys.FlagsType: - if r.bin() { - arg1, calls1 := r.generateArg(s, arg.Type) - p.replaceArg(c, arg, arg1, calls1) - } else { - r.choose( - 1, func() { arg.Val += uintptr(r.Intn(4)) + 1 }, - 1, func() { arg.Val -= uintptr(r.Intn(4)) + 1 }, - 1, func() { arg.Val ^= 1 << uintptr(r.Intn(64)) }, - ) - } - case *sys.ResourceType, *sys.VmaType, *sys.ProcType: + baseSize = base.Res.Size() + } + switch a := arg.Type.(type) { + case *sys.IntType, *sys.FlagsType: + if r.bin() { arg1, calls1 := r.generateArg(s, arg.Type) p.replaceArg(c, arg, arg1, calls1) - case *sys.BufferType: - switch a.Kind { - case sys.BufferBlobRand, sys.BufferBlobRange: - var data []byte - switch arg.Kind { - case ArgData: - data = append([]byte{}, arg.Data...) - case ArgConst: - // 0 is OK for optional args. - if arg.Val != 0 { - panic(fmt.Sprintf("BufferType has non-zero const value: %v", arg.Val)) - } - default: - panic(fmt.Sprintf("bad arg kind for BufferType: %v", arg.Kind)) - } - minLen := int(0) - maxLen := math.MaxInt32 - if a.Kind == sys.BufferBlobRange { - minLen = int(a.RangeBegin) - maxLen = int(a.RangeEnd) - } - arg.Data = mutateData(r, data, minLen, maxLen) - case sys.BufferString: - if r.bin() { - minLen := int(0) - maxLen := math.MaxInt32 - if a.Length != 0 { - minLen = int(a.Length) - maxLen = int(a.Length) - } - arg.Data = mutateData(r, append([]byte{}, arg.Data...), minLen, maxLen) - } else { - arg.Data = r.randString(s, a.Values, a.Dir()) - } - case sys.BufferFilename: - arg.Data = []byte(r.filename(s)) - case sys.BufferText: - arg.Data = r.mutateText(a.Text, arg.Data) + } else { + switch { + case r.nOutOf(1, 3): + arg.Val += uintptr(r.Intn(4)) + 1 + case r.nOutOf(1, 2): + arg.Val -= uintptr(r.Intn(4)) + 1 default: - panic("unknown buffer kind") + arg.Val ^= 1 << uintptr(r.Intn(64)) } - case *sys.ArrayType: - count := uintptr(0) - switch a.Kind { - case sys.ArrayRandLen: - for count == uintptr(len(arg.Inner)) { - count = r.randArrayLen() - } - case sys.ArrayRangeLen: - if a.RangeBegin == a.RangeEnd { - panic("trying to mutate fixed length array") - } - for count == uintptr(len(arg.Inner)) { - count = r.randRange(int(a.RangeBegin), int(a.RangeEnd)) + } + case *sys.ResourceType, *sys.VmaType, *sys.ProcType: + arg1, calls1 := r.generateArg(s, arg.Type) + p.replaceArg(c, arg, arg1, calls1) + case *sys.BufferType: + switch a.Kind { + case sys.BufferBlobRand, sys.BufferBlobRange: + var data []byte + switch arg.Kind { + case ArgData: + data = append([]byte{}, arg.Data...) + case ArgConst: + // 0 is OK for optional args. + if arg.Val != 0 { + panic(fmt.Sprintf("BufferType has non-zero const value: %v", arg.Val)) } + default: + panic(fmt.Sprintf("bad arg kind for BufferType: %v", arg.Kind)) } - if count > uintptr(len(arg.Inner)) { - var calls []*Call - for count > uintptr(len(arg.Inner)) { - arg1, calls1 := r.generateArg(s, a.Type) - arg.Inner = append(arg.Inner, arg1) - for _, c1 := range calls1 { - calls = append(calls, c1) - s.analyze(c1) - } - } - for _, c1 := range calls { - sanitizeCall(c1) - } - sanitizeCall(c) - p.insertBefore(c, calls) - } else if count < uintptr(len(arg.Inner)) { - for _, arg := range arg.Inner[count:] { - p.removeArg(c, arg) - } - arg.Inner = arg.Inner[:count] + minLen := int(0) + maxLen := math.MaxInt32 + if a.Kind == sys.BufferBlobRange { + minLen = int(a.RangeBegin) + maxLen = int(a.RangeEnd) } - // TODO: swap elements of the array - case *sys.PtrType: - // TODO: we don't know size for out args - size := uintptr(1) - if arg.Res != nil { - size = arg.Res.Size() + arg.Data = mutateData(r, data, minLen, maxLen) + case sys.BufferString: + if r.bin() { + minLen := int(0) + maxLen := math.MaxInt32 + if a.Length != 0 { + minLen = int(a.Length) + maxLen = int(a.Length) + } + arg.Data = mutateData(r, append([]byte{}, arg.Data...), minLen, maxLen) + } else { + arg.Data = r.randString(s, a.Values, a.Dir()) } - arg1, calls1 := r.addr(s, a, size, arg.Res) - p.replaceArg(c, arg, arg1, calls1) - case *sys.StructType: - ctor := isSpecialStruct(a) - if ctor == nil { - panic("bad arg returned by mutationArgs: StructType") + case sys.BufferFilename: + arg.Data = []byte(r.filename(s)) + case sys.BufferText: + arg.Data = r.mutateText(a.Text, arg.Data) + default: + panic("unknown buffer kind") + } + case *sys.ArrayType: + count := uintptr(0) + switch a.Kind { + case sys.ArrayRandLen: + for count == uintptr(len(arg.Inner)) { + count = r.randArrayLen() } - arg1, calls1 := ctor(r, s) - for i, f := range arg1.Inner { - p.replaceArg(c, arg.Inner[i], f, calls1) - calls1 = nil + case sys.ArrayRangeLen: + if a.RangeBegin == a.RangeEnd { + panic("trying to mutate fixed length array") } - case *sys.UnionType: - optType := a.Options[r.Intn(len(a.Options))] - for optType.Name() == arg.OptionType.Name() { - optType = a.Options[r.Intn(len(a.Options))] + for count == uintptr(len(arg.Inner)) { + count = r.randRange(int(a.RangeBegin), int(a.RangeEnd)) } - p.removeArg(c, arg.Option) - opt, calls := r.generateArg(s, optType) - arg1 := unionArg(a, opt, optType) - p.replaceArg(c, arg, arg1, calls) - case *sys.LenType: - panic("bad arg returned by mutationArgs: LenType") - case *sys.ConstType: - panic("bad arg returned by mutationArgs: ConstType") - default: - panic(fmt.Sprintf("bad arg returned by mutationArgs: %#v, type=%#v", *arg, arg.Type)) } - - // Update base pointer if size has increased. - if base != nil && baseSize < base.Res.Size() { - arg1, calls1 := r.addr(s, base.Type, base.Res.Size(), base.Res) - for _, c1 := range calls1 { + if count > uintptr(len(arg.Inner)) { + var calls []*Call + for count > uintptr(len(arg.Inner)) { + arg1, calls1 := r.generateArg(s, a.Type) + arg.Inner = append(arg.Inner, arg1) + for _, c1 := range calls1 { + calls = append(calls, c1) + s.analyze(c1) + } + } + for _, c1 := range calls { sanitizeCall(c1) } - p.insertBefore(c, calls1) - arg.AddrPage = arg1.AddrPage - arg.AddrOffset = arg1.AddrOffset - arg.AddrPagesNum = arg1.AddrPagesNum + sanitizeCall(c) + p.insertBefore(c, calls) + } else if count < uintptr(len(arg.Inner)) { + for _, arg := range arg.Inner[count:] { + p.removeArg(c, arg) + } + arg.Inner = arg.Inner[:count] } - - // Update all len fields. - assignSizesCall(c) + // TODO: swap elements of the array + case *sys.PtrType: + // TODO: we don't know size for out args + size := uintptr(1) + if arg.Res != nil { + size = arg.Res.Size() + } + arg1, calls1 := r.addr(s, a, size, arg.Res) + p.replaceArg(c, arg, arg1, calls1) + case *sys.StructType: + ctor := isSpecialStruct(a) + if ctor == nil { + panic("bad arg returned by mutationArgs: StructType") + } + arg1, calls1 := ctor(r, s) + for i, f := range arg1.Inner { + p.replaceArg(c, arg.Inner[i], f, calls1) + calls1 = nil + } + case *sys.UnionType: + optType := a.Options[r.Intn(len(a.Options))] + for optType.Name() == arg.OptionType.Name() { + optType = a.Options[r.Intn(len(a.Options))] + } + p.removeArg(c, arg.Option) + opt, calls := r.generateArg(s, optType) + arg1 := unionArg(a, opt, optType) + p.replaceArg(c, arg, arg1, calls) + case *sys.LenType: + panic("bad arg returned by mutationArgs: LenType") + case *sys.ConstType: + panic("bad arg returned by mutationArgs: ConstType") + default: + panic(fmt.Sprintf("bad arg returned by mutationArgs: %#v, type=%#v", *arg, arg.Type)) } - }, - 1, func() { - // Remove a random call. - if len(p.Calls) == 0 { - retry = true - return + + // Update base pointer if size has increased. + if base != nil && baseSize < base.Res.Size() { + arg1, calls1 := r.addr(s, base.Type, base.Res.Size(), base.Res) + for _, c1 := range calls1 { + sanitizeCall(c1) + } + p.insertBefore(c, calls1) + arg.AddrPage = arg1.AddrPage + arg.AddrOffset = arg1.AddrOffset + arg.AddrPagesNum = arg1.AddrPagesNum } - idx := r.Intn(len(p.Calls)) - p.removeCall(idx) - }, - ) + + // Update all len fields. + assignSizesCall(c) + } + default: + // Remove a random call. + if len(p.Calls) == 0 { + retry = true + continue + } + idx := r.Intn(len(p.Calls)) + p.removeCall(idx) + } } } @@ -522,159 +522,161 @@ func swap64(v uint64) uint64 { func mutateData(r *randGen, data []byte, minLen, maxLen int) []byte { const maxInc = 35 - for stop := false; !stop; stop = r.bin() { - r.choose( - 100, func() { - // Append byte. - if len(data) >= maxLen { - return - } - data = append(data, byte(r.rand(256))) - }, - 100, func() { - // Remove byte. - if len(data) <= minLen { - return - } - if len(data) == 0 { - return - } - i := r.Intn(len(data)) - copy(data[i:], data[i+1:]) - data = data[:len(data)-1] - }, - 100, func() { - // Replace byte with random value. - if len(data) == 0 { - return - } - data[r.Intn(len(data))] = byte(r.rand(256)) - }, - 100, func() { - // Flip bit in byte. - if len(data) == 0 { - return - } - byt := r.Intn(len(data)) - bit := r.Intn(8) - data[byt] ^= 1 << uint(bit) - }, - 100, func() { - // Swap two bytes. - if len(data) < 2 { - return - } - i1 := r.Intn(len(data)) - i2 := r.Intn(len(data)) - data[i1], data[i2] = data[i2], data[i1] - }, - 100, func() { - // Add / subtract from a byte. - if len(data) == 0 { - return - } - i := r.Intn(len(data)) - delta := byte(r.rand(2*maxInc+1) - maxInc) - if delta == 0 { - delta = 1 - } - data[i] += delta - }, - 100, func() { - // Add / subtract from a uint16. - if len(data) < 2 { - return - } - i := r.Intn(len(data) - 1) - p := (*uint16)(unsafe.Pointer(&data[i])) - delta := uint16(r.rand(2*maxInc+1) - maxInc) - if delta == 0 { - delta = 1 - } - if r.bin() { - *p += delta - } else { - *p = swap16(swap16(*p) + delta) - } - }, - 100, func() { - // Add / subtract from a uint32. - if len(data) < 4 { - return - } - i := r.Intn(len(data) - 3) - p := (*uint32)(unsafe.Pointer(&data[i])) - delta := uint32(r.rand(2*maxInc+1) - maxInc) - if delta == 0 { - delta = 1 - } - if r.bin() { - *p += delta - } else { - *p = swap32(swap32(*p) + delta) - } - }, - 100, func() { - // Add / subtract from a uint64. - if len(data) < 8 { - return - } - i := r.Intn(len(data) - 7) - p := (*uint64)(unsafe.Pointer(&data[i])) - delta := uint64(r.rand(2*maxInc+1) - maxInc) - if delta == 0 { - delta = 1 - } - if r.bin() { - *p += delta - } else { - *p = swap64(swap64(*p) + delta) - } - }, - 100, func() { - // Set byte to an interesting value. - if len(data) == 0 { - return - } - data[r.Intn(len(data))] = byte(r.randInt()) - }, - 100, func() { - // Set uint16 to an interesting value. - if len(data) < 2 { - return - } - i := r.Intn(len(data) - 1) - value := uint16(r.randInt()) - if r.bin() { - value = swap16(value) - } - *(*uint16)(unsafe.Pointer(&data[i])) = value - }, - 100, func() { - // Set uint32 to an interesting value. - if len(data) < 4 { - return - } - i := r.Intn(len(data) - 3) - value := uint32(r.randInt()) - if r.bin() { - value = swap32(value) - } - *(*uint32)(unsafe.Pointer(&data[i])) = value - }, - 100, func() { - // Set uint64 to an interesting value. - if len(data) < 8 { - return - } - i := r.Intn(len(data) - 7) - value := uint64(r.randInt()) - if r.bin() { - value = swap64(value) - } - *(*uint64)(unsafe.Pointer(&data[i])) = value - }, - ) + retry := false +loop: + for stop := false; !stop || retry; stop = r.oneOf(3) { + retry = false + switch r.Intn(13) { + case 0: + // Append byte. + if len(data) >= maxLen { + retry = true + continue loop + } + data = append(data, byte(r.rand(256))) + case 1: + // Remove byte. + if len(data) == 0 || len(data) <= minLen { + retry = true + continue loop + } + i := r.Intn(len(data)) + copy(data[i:], data[i+1:]) + data = data[:len(data)-1] + case 2: + // Replace byte with random value. + if len(data) == 0 { + retry = true + continue loop + } + data[r.Intn(len(data))] = byte(r.rand(256)) + case 3: + // Flip bit in byte. + if len(data) == 0 { + retry = true + continue loop + } + byt := r.Intn(len(data)) + bit := r.Intn(8) + data[byt] ^= 1 << uint(bit) + case 4: + // Swap two bytes. + if len(data) < 2 { + retry = true + continue loop + } + i1 := r.Intn(len(data)) + i2 := r.Intn(len(data)) + data[i1], data[i2] = data[i2], data[i1] + case 5: + // Add / subtract from a byte. + if len(data) == 0 { + retry = true + continue loop + } + i := r.Intn(len(data)) + delta := byte(r.rand(2*maxInc+1) - maxInc) + if delta == 0 { + delta = 1 + } + data[i] += delta + case 6: + // Add / subtract from a uint16. + if len(data) < 2 { + retry = true + continue loop + } + i := r.Intn(len(data) - 1) + p := (*uint16)(unsafe.Pointer(&data[i])) + delta := uint16(r.rand(2*maxInc+1) - maxInc) + if delta == 0 { + delta = 1 + } + if r.bin() { + *p += delta + } else { + *p = swap16(swap16(*p) + delta) + } + case 7: + // Add / subtract from a uint32. + if len(data) < 4 { + retry = true + continue loop + } + i := r.Intn(len(data) - 3) + p := (*uint32)(unsafe.Pointer(&data[i])) + delta := uint32(r.rand(2*maxInc+1) - maxInc) + if delta == 0 { + delta = 1 + } + if r.bin() { + *p += delta + } else { + *p = swap32(swap32(*p) + delta) + } + case 8: + // Add / subtract from a uint64. + if len(data) < 8 { + retry = true + continue loop + } + i := r.Intn(len(data) - 7) + p := (*uint64)(unsafe.Pointer(&data[i])) + delta := uint64(r.rand(2*maxInc+1) - maxInc) + if delta == 0 { + delta = 1 + } + if r.bin() { + *p += delta + } else { + *p = swap64(swap64(*p) + delta) + } + case 9: + // Set byte to an interesting value. + if len(data) == 0 { + retry = true + continue loop + } + data[r.Intn(len(data))] = byte(r.randInt()) + case 10: + // Set uint16 to an interesting value. + if len(data) < 2 { + retry = true + continue loop + } + i := r.Intn(len(data) - 1) + value := uint16(r.randInt()) + if r.bin() { + value = swap16(value) + } + *(*uint16)(unsafe.Pointer(&data[i])) = value + case 11: + // Set uint32 to an interesting value. + if len(data) < 4 { + retry = true + continue loop + } + i := r.Intn(len(data) - 3) + value := uint32(r.randInt()) + if r.bin() { + value = swap32(value) + } + *(*uint32)(unsafe.Pointer(&data[i])) = value + case 12: + // Set uint64 to an interesting value. + if len(data) < 8 { + retry = true + continue loop + } + i := r.Intn(len(data) - 7) + value := uint64(r.randInt()) + if r.bin() { + value = swap64(value) + } + *(*uint64)(unsafe.Pointer(&data[i])) = value + default: + panic("bad") + } } return data } diff --git a/prog/prog_test.go b/prog/prog_test.go index 218a76821..b2655ff25 100644 --- a/prog/prog_test.go +++ b/prog/prog_test.go @@ -12,6 +12,10 @@ import ( "github.com/google/syzkaller/sys" ) +func init() { + debug = true +} + func initTest(t *testing.T) (rand.Source, int) { t.Parallel() iters := 10000 diff --git a/prog/rand.go b/prog/rand.go index 7e981d1f0..4645db811 100644 --- a/prog/rand.go +++ b/prog/rand.go @@ -9,11 +9,14 @@ import ( "math" "math/rand" "strings" + "sync" "github.com/google/syzkaller/ifuzz" "github.com/google/syzkaller/sys" ) +var pageStartPool = sync.Pool{New: func() interface{} { return new([]uintptr) }} + type randGen struct { *rand.Rand inCreateResource bool @@ -60,20 +63,27 @@ var specialInts = []uintptr{ func (r *randGen) randInt() uintptr { v := r.rand64() - r.choose( - 100, func() { v %= 10 }, - 50, func() { v = specialInts[r.Intn(len(specialInts))] }, - 10, func() { v %= 256 }, - 10, func() { v %= 4 << 10 }, - 10, func() { v %= 64 << 10 }, - 1, func() { v %= 1 << 31 }, - 1, func() {}, - ) - r.choose( - 100, func() {}, - 5, func() { v = uintptr(-int(v)) }, - 2, func() { v <<= uint(r.Intn(63)) }, - ) + switch { + case r.nOutOf(100, 182): + v %= 10 + case r.nOutOf(50, 82): + v = specialInts[r.Intn(len(specialInts))] + case r.nOutOf(10, 32): + v %= 256 + case r.nOutOf(10, 22): + v %= 4 << 10 + case r.nOutOf(10, 12): + v %= 64 << 10 + default: + v %= 1 << 31 + } + switch { + case r.nOutOf(100, 107): + case r.nOutOf(5, 7): + v = uintptr(-int(v)) + default: + v <<= uint(r.Intn(63)) + } return v } @@ -101,36 +111,41 @@ func (r *randGen) randArrayLen() uintptr { } func (r *randGen) randBufLen() (n uintptr) { - r.choose( - 1, func() { n = 0 }, - 50, func() { n = r.rand(256) }, - 5, func() { n = 4 << 10 }, - ) + switch { + case r.nOutOf(50, 56): + n = r.rand(256) + case r.nOutOf(5, 6): + n = 4 << 10 + } return } func (r *randGen) randPageCount() (n uintptr) { - r.choose( - 100, func() { n = r.rand(4) + 1 }, - 5, func() { n = r.rand(20) + 1 }, - 1, func() { n = (r.rand(3) + 1) * 1024 }, - ) + switch { + case r.nOutOf(100, 106): + n = r.rand(4) + 1 + case r.nOutOf(5, 6): + n = r.rand(20) + 1 + default: + n = (r.rand(3) + 1) * 1024 + } return } -func (r *randGen) flags(vv []uintptr) uintptr { - var v uintptr - r.choose( - 10, func() { v = 0 }, - 10, func() { v = vv[r.rand(len(vv))] }, - 90, func() { - for stop := false; !stop; stop = r.bin() { - v |= vv[r.rand(len(vv))] - } - }, - 1, func() { v = r.rand64() }, - ) - return v +func (r *randGen) flags(vv []uintptr) (v uintptr) { + switch { + case r.nOutOf(90, 111): + for stop := false; !stop; stop = r.bin() { + v |= vv[r.rand(len(vv))] + } + case r.nOutOf(10, 21): + v = vv[r.rand(len(vv))] + case r.nOutOf(10, 11): + v = 0 + default: + v = r.rand64() + } + return } func (r *randGen) filename(s *state) string { @@ -192,12 +207,15 @@ func (r *randGen) randStringImpl(s *state, vals []string) []byte { punct := []byte{'!', '@', '#', '$', '%', '^', '&', '*', '(', ')', '-', '+', '\\', '/', ':', '.', ',', '-', '\'', '[', ']', '{', '}'} buf := new(bytes.Buffer) - for !r.oneOf(4) { - r.choose( - 10, func() { buf.WriteString(dict[r.Intn(len(dict))]) }, - 10, func() { buf.Write([]byte{punct[r.Intn(len(punct))]}) }, - 1, func() { buf.Write([]byte{byte(r.Intn(256))}) }, - ) + for r.nOutOf(3, 4) { + switch { + case r.nOutOf(10, 21): + buf.WriteString(dict[r.Intn(len(dict))]) + case r.nOutOf(10, 11): + buf.Write([]byte{punct[r.Intn(len(punct))]}) + default: + buf.Write([]byte{byte(r.Intn(256))}) + } } if !r.oneOf(100) { buf.Write([]byte{0}) @@ -227,63 +245,59 @@ func (r *randGen) timespec(s *state, typ *sys.StructType, usec bool) (arg *Arg, // We need to generate timespec/timeval that are either (1) definitely in the past, // or (2) definitely in unreachable fututre, or (3) few ms ahead of now. // Note timespec/timeval can be absolute or relative to now. - r.choose( - 1, func() { - // now for relative, past for absolute - arg = groupArg(typ, []*Arg{ - constArg(typ.Fields[0], 0), - constArg(typ.Fields[1], 0), - }) - }, - 1, func() { - // few ms ahead for relative, past for absolute - nsec := uintptr(10 * 1e6) - if usec { - nsec /= 1e3 - } - arg = groupArg(typ, []*Arg{ - constArg(typ.Fields[0], 0), - constArg(typ.Fields[1], nsec), - }) - }, - 1, func() { - // unreachable fututre for both relative and absolute - arg = groupArg(typ, []*Arg{ - constArg(typ.Fields[0], 2e9), - constArg(typ.Fields[1], 0), - }) - }, - 1, func() { - // few ms ahead for absolute - meta := sys.CallMap["clock_gettime"] - ptrArgType := meta.Args[1].(*sys.PtrType) - argType := ptrArgType.Type.(*sys.StructType) - tp := groupArg(argType, []*Arg{ - constArg(argType.Fields[0], 0), - constArg(argType.Fields[1], 0), - }) - var tpaddr *Arg - tpaddr, calls = r.addr(s, ptrArgType, 2*ptrSize, tp) - gettime := &Call{ - Meta: meta, - Args: []*Arg{ - constArg(meta.Args[0], sys.CLOCK_REALTIME), - tpaddr, - }, - Ret: returnArg(meta.Ret), - } - calls = append(calls, gettime) - sec := resultArg(typ.Fields[0], tp.Inner[0]) - nsec := resultArg(typ.Fields[1], tp.Inner[1]) - if usec { - nsec.OpDiv = 1e3 - nsec.OpAdd = 10 * 1e3 - } else { - nsec.OpAdd = 10 * 1e6 - } - arg = groupArg(typ, []*Arg{sec, nsec}) - }, - ) + switch { + case r.nOutOf(1, 4): + // now for relative, past for absolute + arg = groupArg(typ, []*Arg{ + constArg(typ.Fields[0], 0), + constArg(typ.Fields[1], 0), + }) + case r.nOutOf(1, 3): + // few ms ahead for relative, past for absolute + nsec := uintptr(10 * 1e6) + if usec { + nsec /= 1e3 + } + arg = groupArg(typ, []*Arg{ + constArg(typ.Fields[0], 0), + constArg(typ.Fields[1], nsec), + }) + case r.nOutOf(1, 2): + // unreachable fututre for both relative and absolute + arg = groupArg(typ, []*Arg{ + constArg(typ.Fields[0], 2e9), + constArg(typ.Fields[1], 0), + }) + default: + // few ms ahead for absolute + meta := sys.CallMap["clock_gettime"] + ptrArgType := meta.Args[1].(*sys.PtrType) + argType := ptrArgType.Type.(*sys.StructType) + tp := groupArg(argType, []*Arg{ + constArg(argType.Fields[0], 0), + constArg(argType.Fields[1], 0), + }) + var tpaddr *Arg + tpaddr, calls = r.addr(s, ptrArgType, 2*ptrSize, tp) + gettime := &Call{ + Meta: meta, + Args: []*Arg{ + constArg(meta.Args[0], sys.CLOCK_REALTIME), + tpaddr, + }, + Ret: returnArg(meta.Ret), + } + calls = append(calls, gettime) + sec := resultArg(typ.Fields[0], tp.Inner[0]) + nsec := resultArg(typ.Fields[1], tp.Inner[1]) + if usec { + nsec.OpDiv = 1e3 + nsec.OpAdd = 10 * 1e3 + } else { + nsec.OpAdd = 10 * 1e6 + } + arg = groupArg(typ, []*Arg{sec, nsec}) + } return } @@ -310,7 +324,7 @@ func (r *randGen) addr1(s *state, typ sys.Type, size uintptr, data *Arg) (*Arg, if npages == 0 { npages = 1 } - if r.oneOf(10) { + if r.bin() { return r.randPageAddr(s, typ, npages, data, false), nil } for i := uintptr(0); i < maxPages-npages; i++ { @@ -336,21 +350,23 @@ func (r *randGen) addr(s *state, typ sys.Type, size uintptr, data *Arg) (*Arg, [ panic("bad") } // Patch offset of the address. - r.choose( - 50, func() {}, - 50, func() { arg.AddrOffset = -int(size) }, - 1, func() { - if size > 0 { - arg.AddrOffset = -r.Intn(int(size)) - } - }, - 1, func() { arg.AddrOffset = r.Intn(pageSize) }, - ) + switch { + case r.nOutOf(50, 102): + case r.nOutOf(50, 52): + arg.AddrOffset = -int(size) + case r.nOutOf(1, 2): + arg.AddrOffset = r.Intn(pageSize) + default: + if size > 0 { + arg.AddrOffset = -r.Intn(int(size)) + } + } return arg, calls } func (r *randGen) randPageAddr(s *state, typ sys.Type, npages uintptr, data *Arg, vma bool) *Arg { - var starts []uintptr + poolPtr := pageStartPool.Get().(*[]uintptr) + starts := (*poolPtr)[:0] for i := uintptr(0); i < maxPages-npages; i++ { busy := true for j := uintptr(0); j < npages; j++ { @@ -366,6 +382,8 @@ func (r *randGen) randPageAddr(s *state, typ sys.Type, npages uintptr, data *Arg } starts = append(starts, i) } + *poolPtr = starts + pageStartPool.Put(poolPtr) var page uintptr if len(starts) != 0 { page = starts[r.rand(len(starts))] @@ -508,27 +526,13 @@ func createIfuzzConfig(kind sys.TextKind) *ifuzz.Config { return cfg } -func (r *randGen) choose(args ...interface{}) { - if len(args) == 0 || len(args)%2 != 0 { - panic("bad number of args to choose") - } - n := len(args) / 2 - weights := make([]int, n) - funcs := make([]func(), n) - total := 0 - for i := 0; i < n; i++ { - weights[i] = total + args[i*2].(int) - funcs[i] = args[i*2+1].(func()) - total = weights[i] +// nOutOf returns true n out of outOf times. +func (r *randGen) nOutOf(n, outOf int) bool { + if n <= 0 || n >= outOf { + panic("bad probability") } - x := r.Intn(total) - for i, w := range weights { - if x < w { - funcs[i]() - return - } - } - panic("choose is broken") + v := r.Intn(outOf) + return v < n } func (r *randGen) generateCall(s *state, p *Prog) []*Call { @@ -624,31 +628,28 @@ func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call) switch a := typ.(type) { case *sys.ResourceType: - r.choose( - 1, func() { - special := a.SpecialValues() - arg = constArg(a, special[r.Intn(len(special))]) - }, - 1000, func() { - // Get an existing resource. - var allres []*Arg - for name1, res1 := range s.resources { - if sys.IsCompatibleResource(a.Desc.Name, name1) || - r.oneOf(20) && sys.IsCompatibleResource(a.Desc.Kind[0], name1) { - allres = append(allres, res1...) - } - } - if len(allres) != 0 { - arg = resultArg(a, allres[r.Intn(len(allres))]) - } else { - arg, calls = r.createResource(s, a) + switch { + case r.nOutOf(1000, 1011): + // Get an existing resource. + var allres []*Arg + for name1, res1 := range s.resources { + if sys.IsCompatibleResource(a.Desc.Name, name1) || + r.oneOf(20) && sys.IsCompatibleResource(a.Desc.Kind[0], name1) { + allres = append(allres, res1...) } - }, - 10, func() { - // Create a new resource. + } + if len(allres) != 0 { + arg = resultArg(a, allres[r.Intn(len(allres))]) + } else { arg, calls = r.createResource(s, a) - }, - ) + } + case r.nOutOf(10, 11): + // Create a new resource. + arg, calls = r.createResource(s, a) + default: + special := a.SpecialValues() + arg = constArg(a, special[r.Intn(len(special))]) + } return arg, calls case *sys.BufferType: switch a.Kind { @@ -670,11 +671,14 @@ func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call) case sys.BufferFilename: var data []byte if a.Dir() == sys.DirOut { - r.choose( - 10, func() { data = make([]byte, r.Intn(100)) }, - 10, func() { data = make([]byte, r.Intn(108)) }, // UNIX_PATH_MAX - 10, func() { data = make([]byte, r.Intn(4096)) }, // PATH_MAX - ) + switch { + case r.nOutOf(1, 3): + data = make([]byte, r.Intn(100)) + case r.nOutOf(1, 2): + data = make([]byte, 108) // UNIX_PATH_MAX + default: + data = make([]byte, 4096) // PATH_MAX + } } else { data = []byte(r.filename(s)) } @@ -701,11 +705,14 @@ func (r *randGen) generateArg(s *state, typ sys.Type) (arg *Arg, calls []*Call) case sys.IntSignalno: v %= 130 case sys.IntFileoff: - r.choose( - 90, func() { v = 0 }, - 10, func() { v = r.rand(100) }, - 1, func() { v = r.randInt() }, - ) + switch { + case r.nOutOf(90, 101): + v = 0 + case r.nOutOf(10, 11): + v = r.rand(100) + default: + v = r.randInt() + } case sys.IntRange: v = r.randRangeInt(a.RangeBegin, a.RangeEnd) } diff --git a/prog/validation.go b/prog/validation.go index 75d8e95a4..5eea42e65 100644 --- a/prog/validation.go +++ b/prog/validation.go @@ -9,12 +9,17 @@ import ( "github.com/google/syzkaller/sys" ) +var debug = false // enabled in tests + type validCtx struct { args map[*Arg]bool uses map[*Arg]*Arg } func (p *Prog) validate() error { + if !debug { + return nil + } ctx := &validCtx{make(map[*Arg]bool), make(map[*Arg]*Arg)} for _, c := range p.Calls { if err := c.validate(ctx); err != nil { -- cgit mrf-deployment