diff options
| author | Aleksandr Nogikh <nogikh@google.com> | 2025-04-23 12:32:03 +0200 |
|---|---|---|
| committer | Aleksandr Nogikh <nogikh@google.com> | 2025-04-23 15:16:19 +0000 |
| commit | 73a168d010b3ba0a82f850b9fe73e6907539ff20 (patch) | |
| tree | d2142cc472c769438429d694020ae5dc07284c9a | |
| parent | d971f7e21bf575c68223c77d5bcb784ac4912aa1 (diff) | |
vm/dispatcher: make pool.Run cancellable
Make the pool.Run() function take a context.Context to be able to abort
the callback passed to it or abort its scheduling if it's not yet
running.
Otherwise, if the callback is not yet started and the pool's Loop is
aborted, we risk waiting for pool.Run() forever. It prevents the normal
shutdown of repro.Run() and, consequently, the DiffFuzzer functionality.
| -rw-r--r-- | pkg/manager/diff.go | 6 | ||||
| -rw-r--r-- | pkg/repro/repro.go | 5 | ||||
| -rw-r--r-- | pkg/repro/strace.go | 6 | ||||
| -rw-r--r-- | vm/dispatcher/pool.go | 37 | ||||
| -rw-r--r-- | vm/dispatcher/pool_test.go | 61 |
5 files changed, 98 insertions, 17 deletions
diff --git a/pkg/manager/diff.go b/pkg/manager/diff.go index 379fd246c..d7860cc9a 100644 --- a/pkg/manager/diff.go +++ b/pkg/manager/diff.go @@ -6,6 +6,7 @@ package manager import ( "context" "encoding/json" + "errors" "fmt" "math/rand" "net" @@ -583,7 +584,7 @@ func (rr *reproRunner) Run(ctx context.Context, r *repro.Result) { // The third time we leave it as is in case it was important. opts.Threaded = true } - pool.Run(func(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) { + runErr := pool.Run(ctx, func(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) { var ret *instance.ExecProgInstance ret, err = instance.SetupExecProg(inst, rr.kernel.cfg, rr.kernel.reporter, nil) if err != nil { @@ -595,6 +596,9 @@ func (rr *reproRunner) Run(ctx context.Context, r *repro.Result) { Opts: opts, }) }) + if errors.Is(runErr, context.Canceled) { + break + } crashed := result != nil && result.Report != nil log.Logf(1, "attempt #%d to run %q on base: crashed=%v", i, ret.origReport.Title, crashed) if crashed { diff --git a/pkg/repro/repro.go b/pkg/repro/repro.go index c196f71c7..1b3a70246 100644 --- a/pkg/repro/repro.go +++ b/pkg/repro/repro.go @@ -767,7 +767,7 @@ func (pw *poolWrapper) Run(ctx context.Context, params instance.ExecParams, var result *instance.RunResult var err error - pw.pool.Run(func(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) { + runErr := pw.pool.Run(ctx, func(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) { updInfo(func(info *dispatcher.Info) { typ := "syz" if params.CProg != nil { @@ -787,6 +787,9 @@ func (pw *poolWrapper) Run(ctx context.Context, params instance.ExecParams, result, err = ret.RunSyzProg(params) } }) + if runErr != nil { + return nil, runErr + } return result, err } diff --git a/pkg/repro/strace.go b/pkg/repro/strace.go index e101945cf..ceb31de93 100644 --- a/pkg/repro/strace.go +++ b/pkg/repro/strace.go @@ -31,7 +31,7 @@ func RunStrace(result *Result, cfg *mgrconfig.Config, reporter *report.Reporter, } var runRes *instance.RunResult var err error - pool.Run(func(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) { + runErr := pool.Run(context.Background(), func(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) { updInfo(func(info *dispatcher.Info) { info.Status = "running strace" }) @@ -58,7 +58,9 @@ func RunStrace(result *Result, cfg *mgrconfig.Config, reporter *report.Reporter, runRes, err = ret.RunSyzProg(params) } }) - if err != nil { + if runErr != nil { + return straceFailed(runErr) + } else if err != nil { return straceFailed(err) } return &StraceResult{ diff --git a/vm/dispatcher/pool.go b/vm/dispatcher/pool.go index 331bfd062..20e893499 100644 --- a/vm/dispatcher/pool.go +++ b/vm/dispatcher/pool.go @@ -114,7 +114,7 @@ func (p *Pool[T]) Loop(ctx context.Context) { func (p *Pool[T]) runInstance(ctx context.Context, inst *poolInstance[T]) { p.waitUnpaused() ctx, cancel := context.WithCancel(ctx) - + defer cancel() log.Logf(2, "pool: booting instance %d", inst.idx) inst.reset(cancel) @@ -187,13 +187,24 @@ func (p *Pool[T]) ReserveForRun(count int) { } // Run blocks until it has found an instance to execute job and until job has finished. -func (p *Pool[T]) Run(job Runner[T]) { - done := make(chan struct{}) - p.jobs <- func(ctx context.Context, inst T, upd UpdateInfo) { - job(ctx, inst, upd) - close(done) +// Returns an error if the job was aborted by cancelling the context. +func (p *Pool[T]) Run(ctx context.Context, job Runner[T]) error { + done := make(chan error) + // Submit the job. + select { + case p.jobs <- func(jobCtx context.Context, inst T, upd UpdateInfo) { + mergedCtx, cancel := mergeContextCancel(jobCtx, ctx) + defer cancel() + + job(mergedCtx, inst, upd) + done <- mergedCtx.Err() + }: + case <-ctx.Done(): + // If the loop is aborted, no one is going to pick up the job. + return ctx.Err() } - <-done + // Await the job. + return <-done } func (p *Pool[T]) Total() int { @@ -311,3 +322,15 @@ func (pi *poolInstance[T]) free(job Runner[T]) { default: } } + +func mergeContextCancel(main, monitor context.Context) (context.Context, func()) { + withCancel, cancel := context.WithCancel(main) + go func() { + select { + case <-withCancel.Done(): + case <-monitor.Done(): + } + cancel() + }() + return withCancel, cancel +} diff --git a/vm/dispatcher/pool_test.go b/vm/dispatcher/pool_test.go index 452cd598a..cd85367c2 100644 --- a/vm/dispatcher/pool_test.go +++ b/vm/dispatcher/pool_test.go @@ -6,6 +6,7 @@ package dispatcher import ( "context" "runtime" + "sync" "sync/atomic" "testing" "time" @@ -87,7 +88,7 @@ func TestPoolSplit(t *testing.T) { case <-stopRuns: } } - go mgr.Run(job) + go mgr.Run(ctx, job) // So far, there are no reserved instances. for i := 0; i < count; i++ { @@ -113,7 +114,7 @@ func TestPoolSplit(t *testing.T) { // Now let's create and finish more jobs. for i := 0; i < 10; i++ { - go mgr.Run(job) + go mgr.Run(ctx, job) } mgr.ReserveForRun(2) for i := 0; i < 10; i++ { @@ -150,8 +151,7 @@ func TestPoolStress(t *testing.T) { } }() for i := 0; i < 128; i++ { - go mgr.Run(func(ctx context.Context, _ *nilInstance, _ UpdateInfo) { - }) + go mgr.Run(ctx, func(ctx context.Context, _ *nilInstance, _ UpdateInfo) {}) mgr.ReserveForRun(5 + i%5) } @@ -221,7 +221,7 @@ func TestPoolPause(t *testing.T) { }() run := make(chan bool, 1) - go mgr.Run(func(ctx context.Context, _ *nilInstance, _ UpdateInfo) { + go mgr.Run(ctx, func(ctx context.Context, _ *nilInstance, _ UpdateInfo) { run <- true }) time.Sleep(10 * time.Millisecond) @@ -231,12 +231,61 @@ func TestPoolPause(t *testing.T) { mgr.TogglePause(false) <-run - mgr.Run(func(ctx context.Context, _ *nilInstance, _ UpdateInfo) {}) + mgr.Run(ctx, func(ctx context.Context, _ *nilInstance, _ UpdateInfo) {}) cancel() <-done } +func TestPoolCancelRun(t *testing.T) { + // The test to aid the race detector. + mgr := NewPool[*nilInstance]( + 10, + func(idx int) (*nilInstance, error) { + return &nilInstance{}, nil + }, + func(ctx context.Context, _ *nilInstance, _ UpdateInfo) { + <-ctx.Done() + }, + ) + var wg sync.WaitGroup + wg.Add(1) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + mgr.Loop(ctx) + wg.Done() + }() + + mgr.ReserveForRun(2) + + started := make(chan struct{}) + // Schedule more jobs than could be processed simultaneously. + for i := 0; i < 15; i++ { + wg.Add(1) + go func() { + defer wg.Done() + mgr.Run(ctx, func(ctx context.Context, _ *nilInstance, _ UpdateInfo) { + select { + case <-ctx.Done(): + return + case started <- struct{}{}: + } + <-ctx.Done() + }) + }() + } + + // Two can be started. + <-started + <-started + + // Now stop the loop and the jbos. + cancel() + + // Everything must really stop. + wg.Wait() +} + func makePool(count int) []testInstance { var ret []testInstance for i := 0; i < count; i++ { |
