aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/fuzzer/queue/distributor.go
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2024-07-01 14:26:05 +0200
committerDmitry Vyukov <dvyukov@google.com>2024-08-02 13:16:51 +0000
commit66fcb0a84fcd55ad8e1444cdd0bc0ad6592f7329 (patch)
tree998e52d5569938e0251da1eb7c54c3746186b488 /pkg/fuzzer/queue/distributor.go
parent1e9c4cf3ae82ef82220af312606fffe65e124563 (diff)
pkg/fuzzer: try to triage on different VMs
Distribute triage requests to different VMs.
Diffstat (limited to 'pkg/fuzzer/queue/distributor.go')
-rw-r--r--pkg/fuzzer/queue/distributor.go132
1 files changed, 132 insertions, 0 deletions
diff --git a/pkg/fuzzer/queue/distributor.go b/pkg/fuzzer/queue/distributor.go
new file mode 100644
index 000000000..e6c22df79
--- /dev/null
+++ b/pkg/fuzzer/queue/distributor.go
@@ -0,0 +1,132 @@
+// Copyright 2024 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package queue
+
+import (
+ "sync"
+ "sync/atomic"
+
+ "github.com/google/syzkaller/pkg/stat"
+)
+
+// Distributor distributes requests to different VMs during input triage
+// (allows to avoid already used VMs).
+type Distributor struct {
+ source Source
+ seq atomic.Uint64
+ empty atomic.Bool
+ active atomic.Pointer[[]atomic.Uint64]
+ mu sync.Mutex
+ queue []*Request
+ statDelayed *stat.Val
+ statUndelayed *stat.Val
+ statViolated *stat.Val
+}
+
+func Distribute(source Source) *Distributor {
+ return &Distributor{
+ source: source,
+ statDelayed: stat.New("distributor delayed", "Number of test programs delayed due to VM avoidance",
+ stat.Graph("distributor")),
+ statUndelayed: stat.New("distributor undelayed", "Number of test programs undelayed for VM avoidance",
+ stat.Graph("distributor")),
+ statViolated: stat.New("distributor violated", "Number of test programs violated VM avoidance",
+ stat.Graph("distributor")),
+ }
+}
+
+// Next returns the next request to execute on the given vm.
+func (dist *Distributor) Next(vm int) *Request {
+ dist.noteActive(vm)
+ if req := dist.delayed(vm); req != nil {
+ return req
+ }
+ for {
+ req := dist.source.Next()
+ if req == nil || !contains(req.Avoid, vm) || !dist.hasOtherActive(req.Avoid) {
+ return req
+ }
+ dist.delay(req)
+ }
+}
+
+func (dist *Distributor) delay(req *Request) {
+ dist.mu.Lock()
+ defer dist.mu.Unlock()
+ req.delayedSince = dist.seq.Load()
+ dist.queue = append(dist.queue, req)
+ dist.statDelayed.Add(1)
+ dist.empty.Store(false)
+}
+
+func (dist *Distributor) delayed(vm int) *Request {
+ if dist.empty.Load() {
+ return nil
+ }
+ dist.mu.Lock()
+ defer dist.mu.Unlock()
+ seq := dist.seq.Load()
+ for i, req := range dist.queue {
+ violation := contains(req.Avoid, vm)
+ // The delayedSince check protects from a situation when we had another VM available,
+ // and delayed a request, but then the VM was taken for reproduction and does not
+ // serve requests any more. If we could not dispatch a request in 1000 attempts,
+ // we gave up and give it to any VM.
+ if violation && req.delayedSince+1000 > seq {
+ continue
+ }
+ dist.statUndelayed.Add(1)
+ if violation {
+ dist.statViolated.Add(1)
+ }
+ last := len(dist.queue) - 1
+ dist.queue[i] = dist.queue[last]
+ dist.queue[last] = nil
+ dist.queue = dist.queue[:last]
+ dist.empty.Store(len(dist.queue) == 0)
+ return req
+ }
+ return nil
+}
+
+func (dist *Distributor) noteActive(vm int) {
+ active := dist.active.Load()
+ if active == nil || len(*active) <= vm {
+ dist.mu.Lock()
+ active = dist.active.Load()
+ if active == nil || len(*active) <= vm {
+ tmp := make([]atomic.Uint64, vm+10)
+ active = &tmp
+ dist.active.Store(active)
+ }
+ dist.mu.Unlock()
+ }
+ (*active)[vm].Store(dist.seq.Add(1))
+}
+
+// hasOtherActive says if we recently seen activity from VMs not in the set.
+func (dist *Distributor) hasOtherActive(set []ExecutorID) bool {
+ seq := dist.seq.Load()
+ active := *dist.active.Load()
+ for vm := range active {
+ if contains(set, vm) {
+ continue
+ }
+ // 1000 is semi-random notion of recency.
+ if active[vm].Load()+1000 < seq {
+ continue
+ }
+ return true
+ }
+ return false
+}
+
+func contains(set []ExecutorID, vm int) bool {
+ for _, id := range set {
+ if id.VM == vm {
+ return true
+ }
+ }
+ return false
+}