aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorAleksandr Nogikh <nogikh@google.com>2025-01-30 15:29:05 +0100
committerAleksandr Nogikh <nogikh@google.com>2025-02-03 16:09:45 +0000
commit8f267cefd3660f9d5640ebbbd42e295a61774469 (patch)
tree6f0efc247fe55d27f88296517cb3204a8a910cb7 /pkg
parent8f276ef29583e363bb886170f2f424f2d2a0e244 (diff)
pkg/rpcserver: run machine check from the global context
Running it from the VM context causes its cancellation each time VM crashes or the connection is aborted.
Diffstat (limited to 'pkg')
-rw-r--r--pkg/rpcserver/rpcserver.go81
-rw-r--r--pkg/rpcserver/rpcserver_test.go6
2 files changed, 51 insertions, 36 deletions
diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go
index d0e6a15f1..bb9bbbd07 100644
--- a/pkg/rpcserver/rpcserver.go
+++ b/pkg/rpcserver/rpcserver.go
@@ -93,9 +93,9 @@ type server struct {
checker *vminfo.Checker
infoOnce sync.Once
- checkOnce sync.Once
checkDone atomic.Bool
checkFailures int
+ onHandshake chan *handshakeResult
baseSource *queue.DynamicSourceCtl
setupFeatures flatrpc.Feature
canonicalModules *cover.Canonicalizer
@@ -193,15 +193,16 @@ func newImpl(cfg *Config, mgr Manager) *server {
checker := vminfo.New(&cfg.Config)
baseSource := queue.DynamicSource(checker)
return &server{
- cfg: cfg,
- mgr: mgr,
- target: cfg.Target,
- sysTarget: sysTarget,
- timeouts: sysTarget.Timeouts(cfg.Slowdown),
- runners: make(map[int]*Runner),
- checker: checker,
- baseSource: baseSource,
- execSource: queue.Distribute(queue.Retry(baseSource)),
+ cfg: cfg,
+ mgr: mgr,
+ target: cfg.Target,
+ sysTarget: sysTarget,
+ timeouts: sysTarget.Timeouts(cfg.Slowdown),
+ runners: make(map[int]*Runner),
+ checker: checker,
+ baseSource: baseSource,
+ execSource: queue.Distribute(queue.Retry(baseSource)),
+ onHandshake: make(chan *handshakeResult, 1),
Stats: cfg.Stats,
runnerStats: &runnerStats{
@@ -235,9 +236,24 @@ func (serv *server) Serve(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error {
- return serv.handleConn(ctx, g, conn)
+ err := serv.handleConn(ctx, conn)
+ if err != nil {
+ log.Logf(0, "serv.handleConn returend %v", err)
+ }
+ return err
})
})
+ g.Go(func() error {
+ var info *handshakeResult
+ select {
+ case <-ctx.Done():
+ return nil
+ case info = <-serv.onHandshake:
+ }
+ // We run the machine check specifically from the top level context,
+ // not from the per-connection one.
+ return serv.runCheck(ctx, info)
+ })
return g.Wait()
}
@@ -245,7 +261,7 @@ func (serv *server) Port() int {
return serv.serv.Addr.Port
}
-func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *flatrpc.Conn) error {
+func (serv *server) handleConn(ctx context.Context, conn *flatrpc.Conn) error {
connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn)
if err != nil {
log.Logf(1, "%s", err)
@@ -275,7 +291,7 @@ func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *fl
return nil
}
- err = serv.handleRunnerConn(ctx, eg, runner, conn)
+ err = serv.handleRunnerConn(ctx, runner, conn)
log.Logf(2, "runner %v: %v", id, err)
if err != nil && errors.Is(err, errFatal) {
@@ -287,8 +303,7 @@ func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *fl
return nil
}
-func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group,
- runner *Runner, conn *flatrpc.Conn) error {
+func (serv *server) handleRunnerConn(ctx context.Context, runner *Runner, conn *flatrpc.Conn) error {
opts := &handshakeConfig{
VMLess: serv.cfg.VMLess,
Files: serv.checker.RequiredFiles(),
@@ -309,25 +324,17 @@ func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group,
return err
}
- serv.checkOnce.Do(func() {
- // Run the machine check.
- eg.Go(func() error {
- if err := serv.runCheck(ctx, &info); err != nil {
- return fmt.Errorf("%w: %w", errFatal, err)
- }
- return nil
- })
- })
+ select {
+ case serv.onHandshake <- &info:
+ default:
+ }
if serv.triagedCorpus.Load() {
- eg.Go(runner.SendCorpusTriaged)
+ if err := runner.SendCorpusTriaged(); err != nil {
+ return err
+ }
}
-
- go func() {
- <-ctx.Done()
- runner.Stop()
- }()
- return serv.connectionLoop(runner)
+ return serv.connectionLoop(ctx, runner)
}
// Used for errors incompatible with further RPCServer operation.
@@ -379,7 +386,17 @@ func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha
}, nil
}
-func (serv *server) connectionLoop(runner *Runner) error {
+func (serv *server) connectionLoop(baseCtx context.Context, runner *Runner) error {
+ // To "cancel" the runner's loop we need to call runner.Stop().
+ // At the same time, we don't want to leak the goroutine that monitors it,
+ // so we derive a new context and cancel it on function exit.
+ ctx, cancel := context.WithCancel(baseCtx)
+ defer cancel()
+ go func() {
+ <-ctx.Done()
+ runner.Stop()
+ }()
+
if serv.cfg.Cover {
maxSignal := serv.mgr.MaxSignal().ToRaw()
for len(maxSignal) != 0 {
diff --git a/pkg/rpcserver/rpcserver_test.go b/pkg/rpcserver/rpcserver_test.go
index 69379cc98..2da916286 100644
--- a/pkg/rpcserver/rpcserver_test.go
+++ b/pkg/rpcserver/rpcserver_test.go
@@ -9,7 +9,6 @@ import (
"testing"
"github.com/stretchr/testify/assert"
- "golang.org/x/sync/errgroup"
"github.com/google/syzkaller/pkg/csource"
"github.com/google/syzkaller/pkg/flatrpc"
@@ -217,9 +216,8 @@ func TestHandleConn(t *testing.T) {
serv.CreateInstance(1, injectExec, nil)
go flatrpc.Send(clientConn, tt.req)
- var eg errgroup.Group
- serv.handleConn(context.Background(), &eg, serverConn)
- if err := eg.Wait(); err != nil {
+ err = serv.handleConn(context.Background(), serverConn)
+ if err != nil {
t.Fatal(err)
}
})