aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/fuzzer/queue/distributor.go
blob: e6c22df79f6a87d5c224a6b7a2c93430f6079869 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
// Copyright 2024 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

package queue

import (
	"sync"
	"sync/atomic"

	"github.com/google/syzkaller/pkg/stat"
)

// Distributor distributes requests to different VMs during input triage
// (allows to avoid already used VMs).
type Distributor struct {
	source        Source
	seq           atomic.Uint64
	empty         atomic.Bool
	active        atomic.Pointer[[]atomic.Uint64]
	mu            sync.Mutex
	queue         []*Request
	statDelayed   *stat.Val
	statUndelayed *stat.Val
	statViolated  *stat.Val
}

func Distribute(source Source) *Distributor {
	return &Distributor{
		source: source,
		statDelayed: stat.New("distributor delayed", "Number of test programs delayed due to VM avoidance",
			stat.Graph("distributor")),
		statUndelayed: stat.New("distributor undelayed", "Number of test programs undelayed for VM avoidance",
			stat.Graph("distributor")),
		statViolated: stat.New("distributor violated", "Number of test programs violated VM avoidance",
			stat.Graph("distributor")),
	}
}

// Next returns the next request to execute on the given vm.
func (dist *Distributor) Next(vm int) *Request {
	dist.noteActive(vm)
	if req := dist.delayed(vm); req != nil {
		return req
	}
	for {
		req := dist.source.Next()
		if req == nil || !contains(req.Avoid, vm) || !dist.hasOtherActive(req.Avoid) {
			return req
		}
		dist.delay(req)
	}
}

func (dist *Distributor) delay(req *Request) {
	dist.mu.Lock()
	defer dist.mu.Unlock()
	req.delayedSince = dist.seq.Load()
	dist.queue = append(dist.queue, req)
	dist.statDelayed.Add(1)
	dist.empty.Store(false)
}

func (dist *Distributor) delayed(vm int) *Request {
	if dist.empty.Load() {
		return nil
	}
	dist.mu.Lock()
	defer dist.mu.Unlock()
	seq := dist.seq.Load()
	for i, req := range dist.queue {
		violation := contains(req.Avoid, vm)
		// The delayedSince check protects from a situation when we had another VM available,
		// and delayed a request, but then the VM was taken for reproduction and does not
		// serve requests any more. If we could not dispatch a request in 1000 attempts,
		// we gave up and give it to any VM.
		if violation && req.delayedSince+1000 > seq {
			continue
		}
		dist.statUndelayed.Add(1)
		if violation {
			dist.statViolated.Add(1)
		}
		last := len(dist.queue) - 1
		dist.queue[i] = dist.queue[last]
		dist.queue[last] = nil
		dist.queue = dist.queue[:last]
		dist.empty.Store(len(dist.queue) == 0)
		return req
	}
	return nil
}

func (dist *Distributor) noteActive(vm int) {
	active := dist.active.Load()
	if active == nil || len(*active) <= vm {
		dist.mu.Lock()
		active = dist.active.Load()
		if active == nil || len(*active) <= vm {
			tmp := make([]atomic.Uint64, vm+10)
			active = &tmp
			dist.active.Store(active)
		}
		dist.mu.Unlock()
	}
	(*active)[vm].Store(dist.seq.Add(1))
}

// hasOtherActive says if we recently seen activity from VMs not in the set.
func (dist *Distributor) hasOtherActive(set []ExecutorID) bool {
	seq := dist.seq.Load()
	active := *dist.active.Load()
	for vm := range active {
		if contains(set, vm) {
			continue
		}
		// 1000 is semi-random notion of recency.
		if active[vm].Load()+1000 < seq {
			continue
		}
		return true
	}
	return false
}

func contains(set []ExecutorID, vm int) bool {
	for _, id := range set {
		if id.VM == vm {
			return true
		}
	}
	return false
}