diff options
Diffstat (limited to 'pkg/fuzzer/queue')
| -rw-r--r-- | pkg/fuzzer/queue/distributor.go | 132 | ||||
| -rw-r--r-- | pkg/fuzzer/queue/distributor_test.go | 48 | ||||
| -rw-r--r-- | pkg/fuzzer/queue/queue.go | 21 |
3 files changed, 196 insertions, 5 deletions
diff --git a/pkg/fuzzer/queue/distributor.go b/pkg/fuzzer/queue/distributor.go new file mode 100644 index 000000000..e6c22df79 --- /dev/null +++ b/pkg/fuzzer/queue/distributor.go @@ -0,0 +1,132 @@ +// Copyright 2024 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package queue + +import ( + "sync" + "sync/atomic" + + "github.com/google/syzkaller/pkg/stat" +) + +// Distributor distributes requests to different VMs during input triage +// (allows to avoid already used VMs). +type Distributor struct { + source Source + seq atomic.Uint64 + empty atomic.Bool + active atomic.Pointer[[]atomic.Uint64] + mu sync.Mutex + queue []*Request + statDelayed *stat.Val + statUndelayed *stat.Val + statViolated *stat.Val +} + +func Distribute(source Source) *Distributor { + return &Distributor{ + source: source, + statDelayed: stat.New("distributor delayed", "Number of test programs delayed due to VM avoidance", + stat.Graph("distributor")), + statUndelayed: stat.New("distributor undelayed", "Number of test programs undelayed for VM avoidance", + stat.Graph("distributor")), + statViolated: stat.New("distributor violated", "Number of test programs violated VM avoidance", + stat.Graph("distributor")), + } +} + +// Next returns the next request to execute on the given vm. +func (dist *Distributor) Next(vm int) *Request { + dist.noteActive(vm) + if req := dist.delayed(vm); req != nil { + return req + } + for { + req := dist.source.Next() + if req == nil || !contains(req.Avoid, vm) || !dist.hasOtherActive(req.Avoid) { + return req + } + dist.delay(req) + } +} + +func (dist *Distributor) delay(req *Request) { + dist.mu.Lock() + defer dist.mu.Unlock() + req.delayedSince = dist.seq.Load() + dist.queue = append(dist.queue, req) + dist.statDelayed.Add(1) + dist.empty.Store(false) +} + +func (dist *Distributor) delayed(vm int) *Request { + if dist.empty.Load() { + return nil + } + dist.mu.Lock() + defer dist.mu.Unlock() + seq := dist.seq.Load() + for i, req := range dist.queue { + violation := contains(req.Avoid, vm) + // The delayedSince check protects from a situation when we had another VM available, + // and delayed a request, but then the VM was taken for reproduction and does not + // serve requests any more. If we could not dispatch a request in 1000 attempts, + // we gave up and give it to any VM. + if violation && req.delayedSince+1000 > seq { + continue + } + dist.statUndelayed.Add(1) + if violation { + dist.statViolated.Add(1) + } + last := len(dist.queue) - 1 + dist.queue[i] = dist.queue[last] + dist.queue[last] = nil + dist.queue = dist.queue[:last] + dist.empty.Store(len(dist.queue) == 0) + return req + } + return nil +} + +func (dist *Distributor) noteActive(vm int) { + active := dist.active.Load() + if active == nil || len(*active) <= vm { + dist.mu.Lock() + active = dist.active.Load() + if active == nil || len(*active) <= vm { + tmp := make([]atomic.Uint64, vm+10) + active = &tmp + dist.active.Store(active) + } + dist.mu.Unlock() + } + (*active)[vm].Store(dist.seq.Add(1)) +} + +// hasOtherActive says if we recently seen activity from VMs not in the set. +func (dist *Distributor) hasOtherActive(set []ExecutorID) bool { + seq := dist.seq.Load() + active := *dist.active.Load() + for vm := range active { + if contains(set, vm) { + continue + } + // 1000 is semi-random notion of recency. + if active[vm].Load()+1000 < seq { + continue + } + return true + } + return false +} + +func contains(set []ExecutorID, vm int) bool { + for _, id := range set { + if id.VM == vm { + return true + } + } + return false +} diff --git a/pkg/fuzzer/queue/distributor_test.go b/pkg/fuzzer/queue/distributor_test.go new file mode 100644 index 000000000..7bbf9c2e7 --- /dev/null +++ b/pkg/fuzzer/queue/distributor_test.go @@ -0,0 +1,48 @@ +// Copyright 2024 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package queue + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDistributor(t *testing.T) { + q := Plain() + dist := Distribute(q) + + req := &Request{} + q.Submit(req) + assert.Equal(t, req, dist.Next(0)) + + q.Submit(req) + assert.Equal(t, req, dist.Next(1)) + + // Avoid VM 0. + req.Avoid = []ExecutorID{{VM: 0}} + q.Submit(req) + var noReq *Request + assert.Equal(t, noReq, dist.Next(0)) + assert.Equal(t, noReq, dist.Next(0)) + assert.Equal(t, req, dist.Next(1)) + + // If only VM 0 queries requests, it should eventually got it. + q.Submit(req) + assert.Equal(t, noReq, dist.Next(0)) + for { + got := dist.Next(0) + if got == req { + break + } + assert.Equal(t, noReq, got) + } + + // If all active VMs are in the avoid set, then they should get + // the request immidiatly. + assert.Equal(t, noReq, dist.Next(1)) + req.Avoid = []ExecutorID{{VM: 0}, {VM: 1}} + q.Submit(req) + assert.Equal(t, req, dist.Next(1)) +} diff --git a/pkg/fuzzer/queue/queue.go b/pkg/fuzzer/queue/queue.go index aadbaade8..cbdb2ba19 100644 --- a/pkg/fuzzer/queue/queue.go +++ b/pkg/fuzzer/queue/queue.go @@ -41,18 +41,28 @@ type Request struct { // Important requests will be retried even from crashed VMs. Important bool + // Avoid specifies set of executors that are preferable to avoid when executing this request. + // The restriction is soft since there can be only one executor at all or available right now. + Avoid []ExecutorID + // The callback will be called on request completion in the LIFO order. // If it returns false, all further processing will be stopped. // It allows wrappers to intercept Done() requests. callback DoneCallback - onceCrashed bool + onceCrashed bool + delayedSince uint64 mu sync.Mutex result *Result done chan struct{} } +type ExecutorID struct { + VM int + Proc int +} + type DoneCallback func(*Request, *Result) bool func (r *Request) OnDone(cb DoneCallback) { @@ -137,10 +147,11 @@ func (r *Request) initChannel() { } type Result struct { - Info *flatrpc.ProgInfo - Output []byte - Status Status - Err error // More details in case of ExecFailure. + Info *flatrpc.ProgInfo + Executor ExecutorID + Output []byte + Status Status + Err error // More details in case of ExecFailure. } func (r *Result) clone() *Result { |
