diff options
| author | Aleksandr Nogikh <nogikh@google.com> | 2024-07-01 18:24:06 +0200 |
|---|---|---|
| committer | Aleksandr Nogikh <nogikh@google.com> | 2024-07-04 12:34:14 +0000 |
| commit | 891bf15f3ceb3f69eaf590882ba3f245811c1698 (patch) | |
| tree | 7cee7d7754929871f1337cb509cb1fa76ee477bb /pkg | |
| parent | 092c1914a191f5858db674b4e367c6848500429e (diff) | |
pkg/rpcserver: remove direct accesses to Runner fields
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/rpcserver/rpcserver.go | 31 | ||||
| -rw-r--r-- | pkg/rpcserver/runner.go | 61 |
2 files changed, 62 insertions, 30 deletions
diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go index 35b628715..da4faa19a 100644 --- a/pkg/rpcserver/rpcserver.go +++ b/pkg/rpcserver/rpcserver.go @@ -198,24 +198,18 @@ func (serv *Server) VMState() map[string]VMState { func (serv *Server) MachineInfo(name string) []byte { serv.mu.Lock() runner := serv.runners[name] - if runner != nil && (runner.conn == nil || runner.stopped) { - runner = nil - } serv.mu.Unlock() - if runner == nil { + if runner == nil || !runner.alive() { return []byte("VM is not alive") } - return runner.machineInfo + return runner.getMachineInfo() } func (serv *Server) RunnerStatus(name string) []byte { serv.mu.Lock() runner := serv.runners[name] - if runner != nil && (runner.conn == nil || runner.stopped) { - runner = nil - } serv.mu.Unlock() - if runner == nil { + if runner == nil || !runner.alive() { return []byte("VM is not alive") } return runner.queryStatus() @@ -244,13 +238,11 @@ func (serv *Server) handleConn(conn *flatrpc.Conn) { serv.mu.Lock() runner := serv.runners[name] - if runner == nil || runner.stopped { - serv.mu.Unlock() - log.Logf(2, "VM %v shut down before connect", name) + serv.mu.Unlock() + if runner == nil { + log.Logf(2, "unknown VM %v tries to connect", name) return } - serv.mu.Unlock() - defer close(runner.finished) opts := &handshakeConfig{ VMLess: serv.cfg.VMLess, @@ -447,7 +439,6 @@ func (serv *Server) CreateInstance(name string, injectExec chan<- bool) { sysTarget: serv.sysTarget, injectExec: injectExec, infoc: make(chan chan []byte), - finished: make(chan bool), requests: make(map[int64]*queue.Request), executing: make(map[int64]bool), lastExec: MakeLastExecuting(serv.cfg.Procs, 6), @@ -469,13 +460,9 @@ func (serv *Server) CreateInstance(name string, injectExec chan<- bool) { func (serv *Server) StopFuzzing(name string) { serv.mu.Lock() runner := serv.runners[name] - runner.stopped = true - conn := runner.conn serv.info[name] = VMState{StateStopping, time.Now()} serv.mu.Unlock() - if conn != nil { - conn.Close() - } + runner.stop() } func (serv *Server) ShutdownInstance(name string, crashed bool) ([]ExecRecord, []byte) { @@ -484,7 +471,7 @@ func (serv *Server) ShutdownInstance(name string, crashed bool) ([]ExecRecord, [ delete(serv.runners, name) serv.info[name] = VMState{StateOffline, time.Now()} serv.mu.Unlock() - return runner.shutdown(crashed) + return runner.shutdown(crashed), runner.getMachineInfo() } func (serv *Server) DistributeSignalDelta(plus signal.Signal) { @@ -508,7 +495,7 @@ func (serv *Server) foreachRunnerAsync(fn func(runner *Runner)) { serv.mu.Lock() defer serv.mu.Unlock() for _, runner := range serv.runners { - if runner.conn != nil { + if runner.alive() { go fn(runner) } } diff --git a/pkg/rpcserver/runner.go b/pkg/rpcserver/runner.go index ed45e2e31..3886e5ffc 100644 --- a/pkg/rpcserver/runner.go +++ b/pkg/rpcserver/runner.go @@ -10,6 +10,7 @@ import ( "math/rand" "os" "slices" + "sync" "time" "github.com/google/syzkaller/pkg/cover" @@ -30,18 +31,21 @@ type Runner struct { debug bool sysTarget *targets.Target stats *runnerStats - stopped bool finished chan bool injectExec chan<- bool infoc chan chan []byte - conn *flatrpc.Conn - machineInfo []byte canonicalizer *cover.CanonicalizerInstance nextRequestID int64 requests map[int64]*queue.Request executing map[int64]bool lastExec *LastExecuting rnd *rand.Rand + + // The mutex protects all the fields below. + mu sync.Mutex + conn *flatrpc.Conn + stopped bool + machineInfo []byte } type runnerStats struct { @@ -106,13 +110,28 @@ func (runner *Runner) handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error if err := flatrpc.Send(conn, infoReply); err != nil { return err } + runner.mu.Lock() runner.conn = conn runner.machineInfo = ret.MachineInfo runner.canonicalizer = ret.Canonicalizer + runner.mu.Unlock() return nil } func (runner *Runner) connectionLoop() error { + runner.mu.Lock() + stopped := runner.stopped + if !stopped { + runner.finished = make(chan bool) + } + runner.mu.Unlock() + + if stopped { + // The instance was shut down in between, see the shutdown code. + return nil + } + defer close(runner.finished) + var infoc chan []byte defer func() { if infoc != nil { @@ -393,11 +412,25 @@ func (runner *Runner) sendCorpusTriaged() error { return flatrpc.Send(runner.conn, msg) } -func (runner *Runner) shutdown(crashed bool) ([]ExecRecord, []byte) { - if runner.conn != nil { +func (runner *Runner) stop() { + runner.mu.Lock() + runner.stopped = true + conn := runner.conn + runner.mu.Unlock() + if conn != nil { + conn.Close() + } +} + +func (runner *Runner) shutdown(crashed bool) []ExecRecord { + runner.mu.Lock() + runner.stopped = true + finished := runner.finished + runner.mu.Unlock() + + if finished != nil { // Wait for the connection goroutine to finish and stop touching data. - // If conn is nil before we removed the runner, then it won't touch anything. - <-runner.finished + <-finished } for id, req := range runner.requests { status := queue.Restarted @@ -406,7 +439,13 @@ func (runner *Runner) shutdown(crashed bool) ([]ExecRecord, []byte) { } req.Done(&queue.Result{Status: status}) } - return runner.lastExec.Collect(), runner.machineInfo + return runner.lastExec.Collect() +} + +func (runner *Runner) getMachineInfo() []byte { + runner.mu.Lock() + defer runner.mu.Unlock() + return runner.machineInfo } func (runner *Runner) queryStatus() []byte { @@ -425,6 +464,12 @@ func (runner *Runner) queryStatus() []byte { } } +func (runner *Runner) alive() bool { + runner.mu.Lock() + defer runner.mu.Unlock() + return runner.conn != nil && !runner.stopped +} + // addFallbackSignal computes simple fallback signal in cases we don't have real coverage signal. // We use syscall number or-ed with returned errno value as signal. // At least this gives us all combinations of syscall+errno. |
