aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorAleksandr Nogikh <nogikh@google.com>2024-07-01 18:24:06 +0200
committerAleksandr Nogikh <nogikh@google.com>2024-07-04 12:34:14 +0000
commit891bf15f3ceb3f69eaf590882ba3f245811c1698 (patch)
tree7cee7d7754929871f1337cb509cb1fa76ee477bb /pkg
parent092c1914a191f5858db674b4e367c6848500429e (diff)
pkg/rpcserver: remove direct accesses to Runner fields
Diffstat (limited to 'pkg')
-rw-r--r--pkg/rpcserver/rpcserver.go31
-rw-r--r--pkg/rpcserver/runner.go61
2 files changed, 62 insertions, 30 deletions
diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go
index 35b628715..da4faa19a 100644
--- a/pkg/rpcserver/rpcserver.go
+++ b/pkg/rpcserver/rpcserver.go
@@ -198,24 +198,18 @@ func (serv *Server) VMState() map[string]VMState {
func (serv *Server) MachineInfo(name string) []byte {
serv.mu.Lock()
runner := serv.runners[name]
- if runner != nil && (runner.conn == nil || runner.stopped) {
- runner = nil
- }
serv.mu.Unlock()
- if runner == nil {
+ if runner == nil || !runner.alive() {
return []byte("VM is not alive")
}
- return runner.machineInfo
+ return runner.getMachineInfo()
}
func (serv *Server) RunnerStatus(name string) []byte {
serv.mu.Lock()
runner := serv.runners[name]
- if runner != nil && (runner.conn == nil || runner.stopped) {
- runner = nil
- }
serv.mu.Unlock()
- if runner == nil {
+ if runner == nil || !runner.alive() {
return []byte("VM is not alive")
}
return runner.queryStatus()
@@ -244,13 +238,11 @@ func (serv *Server) handleConn(conn *flatrpc.Conn) {
serv.mu.Lock()
runner := serv.runners[name]
- if runner == nil || runner.stopped {
- serv.mu.Unlock()
- log.Logf(2, "VM %v shut down before connect", name)
+ serv.mu.Unlock()
+ if runner == nil {
+ log.Logf(2, "unknown VM %v tries to connect", name)
return
}
- serv.mu.Unlock()
- defer close(runner.finished)
opts := &handshakeConfig{
VMLess: serv.cfg.VMLess,
@@ -447,7 +439,6 @@ func (serv *Server) CreateInstance(name string, injectExec chan<- bool) {
sysTarget: serv.sysTarget,
injectExec: injectExec,
infoc: make(chan chan []byte),
- finished: make(chan bool),
requests: make(map[int64]*queue.Request),
executing: make(map[int64]bool),
lastExec: MakeLastExecuting(serv.cfg.Procs, 6),
@@ -469,13 +460,9 @@ func (serv *Server) CreateInstance(name string, injectExec chan<- bool) {
func (serv *Server) StopFuzzing(name string) {
serv.mu.Lock()
runner := serv.runners[name]
- runner.stopped = true
- conn := runner.conn
serv.info[name] = VMState{StateStopping, time.Now()}
serv.mu.Unlock()
- if conn != nil {
- conn.Close()
- }
+ runner.stop()
}
func (serv *Server) ShutdownInstance(name string, crashed bool) ([]ExecRecord, []byte) {
@@ -484,7 +471,7 @@ func (serv *Server) ShutdownInstance(name string, crashed bool) ([]ExecRecord, [
delete(serv.runners, name)
serv.info[name] = VMState{StateOffline, time.Now()}
serv.mu.Unlock()
- return runner.shutdown(crashed)
+ return runner.shutdown(crashed), runner.getMachineInfo()
}
func (serv *Server) DistributeSignalDelta(plus signal.Signal) {
@@ -508,7 +495,7 @@ func (serv *Server) foreachRunnerAsync(fn func(runner *Runner)) {
serv.mu.Lock()
defer serv.mu.Unlock()
for _, runner := range serv.runners {
- if runner.conn != nil {
+ if runner.alive() {
go fn(runner)
}
}
diff --git a/pkg/rpcserver/runner.go b/pkg/rpcserver/runner.go
index ed45e2e31..3886e5ffc 100644
--- a/pkg/rpcserver/runner.go
+++ b/pkg/rpcserver/runner.go
@@ -10,6 +10,7 @@ import (
"math/rand"
"os"
"slices"
+ "sync"
"time"
"github.com/google/syzkaller/pkg/cover"
@@ -30,18 +31,21 @@ type Runner struct {
debug bool
sysTarget *targets.Target
stats *runnerStats
- stopped bool
finished chan bool
injectExec chan<- bool
infoc chan chan []byte
- conn *flatrpc.Conn
- machineInfo []byte
canonicalizer *cover.CanonicalizerInstance
nextRequestID int64
requests map[int64]*queue.Request
executing map[int64]bool
lastExec *LastExecuting
rnd *rand.Rand
+
+ // The mutex protects all the fields below.
+ mu sync.Mutex
+ conn *flatrpc.Conn
+ stopped bool
+ machineInfo []byte
}
type runnerStats struct {
@@ -106,13 +110,28 @@ func (runner *Runner) handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error
if err := flatrpc.Send(conn, infoReply); err != nil {
return err
}
+ runner.mu.Lock()
runner.conn = conn
runner.machineInfo = ret.MachineInfo
runner.canonicalizer = ret.Canonicalizer
+ runner.mu.Unlock()
return nil
}
func (runner *Runner) connectionLoop() error {
+ runner.mu.Lock()
+ stopped := runner.stopped
+ if !stopped {
+ runner.finished = make(chan bool)
+ }
+ runner.mu.Unlock()
+
+ if stopped {
+ // The instance was shut down in between, see the shutdown code.
+ return nil
+ }
+ defer close(runner.finished)
+
var infoc chan []byte
defer func() {
if infoc != nil {
@@ -393,11 +412,25 @@ func (runner *Runner) sendCorpusTriaged() error {
return flatrpc.Send(runner.conn, msg)
}
-func (runner *Runner) shutdown(crashed bool) ([]ExecRecord, []byte) {
- if runner.conn != nil {
+func (runner *Runner) stop() {
+ runner.mu.Lock()
+ runner.stopped = true
+ conn := runner.conn
+ runner.mu.Unlock()
+ if conn != nil {
+ conn.Close()
+ }
+}
+
+func (runner *Runner) shutdown(crashed bool) []ExecRecord {
+ runner.mu.Lock()
+ runner.stopped = true
+ finished := runner.finished
+ runner.mu.Unlock()
+
+ if finished != nil {
// Wait for the connection goroutine to finish and stop touching data.
- // If conn is nil before we removed the runner, then it won't touch anything.
- <-runner.finished
+ <-finished
}
for id, req := range runner.requests {
status := queue.Restarted
@@ -406,7 +439,13 @@ func (runner *Runner) shutdown(crashed bool) ([]ExecRecord, []byte) {
}
req.Done(&queue.Result{Status: status})
}
- return runner.lastExec.Collect(), runner.machineInfo
+ return runner.lastExec.Collect()
+}
+
+func (runner *Runner) getMachineInfo() []byte {
+ runner.mu.Lock()
+ defer runner.mu.Unlock()
+ return runner.machineInfo
}
func (runner *Runner) queryStatus() []byte {
@@ -425,6 +464,12 @@ func (runner *Runner) queryStatus() []byte {
}
}
+func (runner *Runner) alive() bool {
+ runner.mu.Lock()
+ defer runner.mu.Unlock()
+ return runner.conn != nil && !runner.stopped
+}
+
// addFallbackSignal computes simple fallback signal in cases we don't have real coverage signal.
// We use syscall number or-ed with returned errno value as signal.
// At least this gives us all combinations of syscall+errno.