aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/fuzzer/queue
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/fuzzer/queue')
-rw-r--r--pkg/fuzzer/queue/distributor.go132
-rw-r--r--pkg/fuzzer/queue/distributor_test.go48
-rw-r--r--pkg/fuzzer/queue/queue.go21
3 files changed, 196 insertions, 5 deletions
diff --git a/pkg/fuzzer/queue/distributor.go b/pkg/fuzzer/queue/distributor.go
new file mode 100644
index 000000000..e6c22df79
--- /dev/null
+++ b/pkg/fuzzer/queue/distributor.go
@@ -0,0 +1,132 @@
+// Copyright 2024 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package queue
+
+import (
+ "sync"
+ "sync/atomic"
+
+ "github.com/google/syzkaller/pkg/stat"
+)
+
+// Distributor distributes requests to different VMs during input triage
+// (allows to avoid already used VMs).
+type Distributor struct {
+ source Source
+ seq atomic.Uint64
+ empty atomic.Bool
+ active atomic.Pointer[[]atomic.Uint64]
+ mu sync.Mutex
+ queue []*Request
+ statDelayed *stat.Val
+ statUndelayed *stat.Val
+ statViolated *stat.Val
+}
+
+func Distribute(source Source) *Distributor {
+ return &Distributor{
+ source: source,
+ statDelayed: stat.New("distributor delayed", "Number of test programs delayed due to VM avoidance",
+ stat.Graph("distributor")),
+ statUndelayed: stat.New("distributor undelayed", "Number of test programs undelayed for VM avoidance",
+ stat.Graph("distributor")),
+ statViolated: stat.New("distributor violated", "Number of test programs violated VM avoidance",
+ stat.Graph("distributor")),
+ }
+}
+
+// Next returns the next request to execute on the given vm.
+func (dist *Distributor) Next(vm int) *Request {
+ dist.noteActive(vm)
+ if req := dist.delayed(vm); req != nil {
+ return req
+ }
+ for {
+ req := dist.source.Next()
+ if req == nil || !contains(req.Avoid, vm) || !dist.hasOtherActive(req.Avoid) {
+ return req
+ }
+ dist.delay(req)
+ }
+}
+
+func (dist *Distributor) delay(req *Request) {
+ dist.mu.Lock()
+ defer dist.mu.Unlock()
+ req.delayedSince = dist.seq.Load()
+ dist.queue = append(dist.queue, req)
+ dist.statDelayed.Add(1)
+ dist.empty.Store(false)
+}
+
+func (dist *Distributor) delayed(vm int) *Request {
+ if dist.empty.Load() {
+ return nil
+ }
+ dist.mu.Lock()
+ defer dist.mu.Unlock()
+ seq := dist.seq.Load()
+ for i, req := range dist.queue {
+ violation := contains(req.Avoid, vm)
+ // The delayedSince check protects from a situation when we had another VM available,
+ // and delayed a request, but then the VM was taken for reproduction and does not
+ // serve requests any more. If we could not dispatch a request in 1000 attempts,
+ // we gave up and give it to any VM.
+ if violation && req.delayedSince+1000 > seq {
+ continue
+ }
+ dist.statUndelayed.Add(1)
+ if violation {
+ dist.statViolated.Add(1)
+ }
+ last := len(dist.queue) - 1
+ dist.queue[i] = dist.queue[last]
+ dist.queue[last] = nil
+ dist.queue = dist.queue[:last]
+ dist.empty.Store(len(dist.queue) == 0)
+ return req
+ }
+ return nil
+}
+
+func (dist *Distributor) noteActive(vm int) {
+ active := dist.active.Load()
+ if active == nil || len(*active) <= vm {
+ dist.mu.Lock()
+ active = dist.active.Load()
+ if active == nil || len(*active) <= vm {
+ tmp := make([]atomic.Uint64, vm+10)
+ active = &tmp
+ dist.active.Store(active)
+ }
+ dist.mu.Unlock()
+ }
+ (*active)[vm].Store(dist.seq.Add(1))
+}
+
+// hasOtherActive says if we recently seen activity from VMs not in the set.
+func (dist *Distributor) hasOtherActive(set []ExecutorID) bool {
+ seq := dist.seq.Load()
+ active := *dist.active.Load()
+ for vm := range active {
+ if contains(set, vm) {
+ continue
+ }
+ // 1000 is semi-random notion of recency.
+ if active[vm].Load()+1000 < seq {
+ continue
+ }
+ return true
+ }
+ return false
+}
+
+func contains(set []ExecutorID, vm int) bool {
+ for _, id := range set {
+ if id.VM == vm {
+ return true
+ }
+ }
+ return false
+}
diff --git a/pkg/fuzzer/queue/distributor_test.go b/pkg/fuzzer/queue/distributor_test.go
new file mode 100644
index 000000000..7bbf9c2e7
--- /dev/null
+++ b/pkg/fuzzer/queue/distributor_test.go
@@ -0,0 +1,48 @@
+// Copyright 2024 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package queue
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDistributor(t *testing.T) {
+ q := Plain()
+ dist := Distribute(q)
+
+ req := &Request{}
+ q.Submit(req)
+ assert.Equal(t, req, dist.Next(0))
+
+ q.Submit(req)
+ assert.Equal(t, req, dist.Next(1))
+
+ // Avoid VM 0.
+ req.Avoid = []ExecutorID{{VM: 0}}
+ q.Submit(req)
+ var noReq *Request
+ assert.Equal(t, noReq, dist.Next(0))
+ assert.Equal(t, noReq, dist.Next(0))
+ assert.Equal(t, req, dist.Next(1))
+
+ // If only VM 0 queries requests, it should eventually got it.
+ q.Submit(req)
+ assert.Equal(t, noReq, dist.Next(0))
+ for {
+ got := dist.Next(0)
+ if got == req {
+ break
+ }
+ assert.Equal(t, noReq, got)
+ }
+
+ // If all active VMs are in the avoid set, then they should get
+ // the request immidiatly.
+ assert.Equal(t, noReq, dist.Next(1))
+ req.Avoid = []ExecutorID{{VM: 0}, {VM: 1}}
+ q.Submit(req)
+ assert.Equal(t, req, dist.Next(1))
+}
diff --git a/pkg/fuzzer/queue/queue.go b/pkg/fuzzer/queue/queue.go
index aadbaade8..cbdb2ba19 100644
--- a/pkg/fuzzer/queue/queue.go
+++ b/pkg/fuzzer/queue/queue.go
@@ -41,18 +41,28 @@ type Request struct {
// Important requests will be retried even from crashed VMs.
Important bool
+ // Avoid specifies set of executors that are preferable to avoid when executing this request.
+ // The restriction is soft since there can be only one executor at all or available right now.
+ Avoid []ExecutorID
+
// The callback will be called on request completion in the LIFO order.
// If it returns false, all further processing will be stopped.
// It allows wrappers to intercept Done() requests.
callback DoneCallback
- onceCrashed bool
+ onceCrashed bool
+ delayedSince uint64
mu sync.Mutex
result *Result
done chan struct{}
}
+type ExecutorID struct {
+ VM int
+ Proc int
+}
+
type DoneCallback func(*Request, *Result) bool
func (r *Request) OnDone(cb DoneCallback) {
@@ -137,10 +147,11 @@ func (r *Request) initChannel() {
}
type Result struct {
- Info *flatrpc.ProgInfo
- Output []byte
- Status Status
- Err error // More details in case of ExecFailure.
+ Info *flatrpc.ProgInfo
+ Executor ExecutorID
+ Output []byte
+ Status Status
+ Err error // More details in case of ExecFailure.
}
func (r *Result) clone() *Result {