From 27f689959decd391b047c8034d481267d500549e Mon Sep 17 00:00:00 2001 From: Taras Madan Date: Thu, 15 May 2025 15:01:02 +0200 Subject: vm: func Run accepts context It allows to use context as a single termination signal source. --- vm/adb/adb.go | 6 +++--- vm/bhyve/bhyve.go | 7 +++---- vm/cuttlefish/cuttlefish.go | 5 +++-- vm/gce/gce.go | 5 ++--- vm/gvisor/gvisor.go | 7 +++---- vm/isolated/isolated.go | 6 +++--- vm/proxyapp/proxyappclient.go | 12 ++---------- vm/proxyapp/proxyappclient_test.go | 38 +++++++++++++++----------------------- vm/qemu/qemu.go | 6 +++--- vm/starnix/starnix.go | 6 +++--- vm/vm.go | 16 +++------------- vm/vm_test.go | 5 +++-- vm/vmimpl/vmimpl.go | 12 +++++------- vm/vmm/vmm.go | 7 +++---- vm/vmware/vmware.go | 6 +++--- 15 files changed, 57 insertions(+), 87 deletions(-) (limited to 'vm') diff --git a/vm/adb/adb.go b/vm/adb/adb.go index 1a7ed7e17..202f352c0 100644 --- a/vm/adb/adb.go +++ b/vm/adb/adb.go @@ -7,6 +7,7 @@ package adb import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -521,7 +522,7 @@ func isRemoteCuttlefish(dev string) (bool, string) { return true, ip } -func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) ( +func (inst *instance) Run(ctx context.Context, command string) ( <-chan []byte, <-chan error, error) { var tty io.ReadCloser var err error @@ -566,9 +567,8 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin merger.Add("console", tty) merger.Add("adb", adbRpipe) - return vmimpl.Multiplex(adb, merger, timeout, vmimpl.MultiplexConfig{ + return vmimpl.Multiplex(ctx, adb, merger, vmimpl.MultiplexConfig{ Console: tty, - Stop: stop, Close: inst.closed, Debug: inst.debug, Scale: inst.timeouts.Scale, diff --git a/vm/bhyve/bhyve.go b/vm/bhyve/bhyve.go index 74025e63d..6cccf7a9f 100644 --- a/vm/bhyve/bhyve.go +++ b/vm/bhyve/bhyve.go @@ -4,6 +4,7 @@ package bhyve import ( + "context" "fmt" "io" "os" @@ -324,7 +325,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) { return vmDst, nil } -func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) ( +func (inst *instance) Run(ctx context.Context, command string) ( <-chan []byte, <-chan error, error) { rpipe, wpipe, err := osutil.LongPipe() if err != nil { @@ -360,9 +361,7 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin go func() { select { - case <-time.After(timeout): - signal(vmimpl.ErrTimeout) - case <-stop: + case <-ctx.Done(): signal(vmimpl.ErrTimeout) case err := <-inst.merger.Err: cmd.Process.Kill() diff --git a/vm/cuttlefish/cuttlefish.go b/vm/cuttlefish/cuttlefish.go index dcc825fbf..55e56d2e5 100644 --- a/vm/cuttlefish/cuttlefish.go +++ b/vm/cuttlefish/cuttlefish.go @@ -11,6 +11,7 @@ package cuttlefish import ( + "context" "fmt" "os/exec" "path/filepath" @@ -167,9 +168,9 @@ func (inst *instance) Close() error { return inst.gceInst.Close() } -func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) ( +func (inst *instance) Run(ctx context.Context, command string) ( <-chan []byte, <-chan error, error) { - return inst.gceInst.Run(timeout, stop, fmt.Sprintf("adb shell 'cd %s; %s'", deviceRoot, command)) + return inst.gceInst.Run(ctx, fmt.Sprintf("adb shell 'cd %s; %s'", deviceRoot, command)) } func (inst *instance) Diagnose(rep *report.Report) ([]byte, bool) { diff --git a/vm/gce/gce.go b/vm/gce/gce.go index 0da3781ba..568eab582 100644 --- a/vm/gce/gce.go +++ b/vm/gce/gce.go @@ -271,7 +271,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) { return vmDst, nil } -func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) ( +func (inst *instance) Run(ctx context.Context, command string) ( <-chan []byte, <-chan error, error) { conRpipe, conWpipe, err := osutil.LongPipe() if err != nil { @@ -340,9 +340,8 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin sshWpipe.Close() merger.Add("ssh", sshRpipe) - return vmimpl.Multiplex(ssh, merger, timeout, vmimpl.MultiplexConfig{ + return vmimpl.Multiplex(ctx, ssh, merger, vmimpl.MultiplexConfig{ Console: vmimpl.CmdCloser{Cmd: con}, - Stop: stop, Close: inst.closed, Debug: inst.debug, Scale: inst.timeouts.Scale, diff --git a/vm/gvisor/gvisor.go b/vm/gvisor/gvisor.go index 8336ee450..b62dcb790 100644 --- a/vm/gvisor/gvisor.go +++ b/vm/gvisor/gvisor.go @@ -7,6 +7,7 @@ package gvisor import ( "bytes" + "context" "fmt" "io" "net" @@ -286,7 +287,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) { return filepath.Join("/", fname), nil } -func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) ( +func (inst *instance) Run(ctx context.Context, command string) ( <-chan []byte, <-chan error, error) { args := []string{"exec", "-user=0:0"} for _, c := range sandboxCaps { @@ -327,9 +328,7 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin go func() { select { - case <-time.After(timeout): - signal(vmimpl.ErrTimeout) - case <-stop: + case <-ctx.Done(): signal(vmimpl.ErrTimeout) case err := <-inst.merger.Err: cmd.Process.Kill() diff --git a/vm/isolated/isolated.go b/vm/isolated/isolated.go index eb70cf369..6e57fef9c 100755 --- a/vm/isolated/isolated.go +++ b/vm/isolated/isolated.go @@ -5,6 +5,7 @@ package isolated import ( "bytes" + "context" "fmt" "io" "os" @@ -311,7 +312,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) { return vmDst, nil } -func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) ( +func (inst *instance) Run(ctx context.Context, command string) ( <-chan []byte, <-chan error, error) { args := append(vmimpl.SSHArgs(inst.debug, inst.Key, inst.Port, inst.cfg.SystemSSHCfg), inst.User+"@"+inst.Addr) @@ -354,9 +355,8 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin merger.Add("dmesg", dmesg) merger.Add("ssh", rpipe) - return vmimpl.Multiplex(cmd, merger, timeout, vmimpl.MultiplexConfig{ + return vmimpl.Multiplex(ctx, cmd, merger, vmimpl.MultiplexConfig{ Console: dmesg, - Stop: stop, Close: inst.closed, Debug: inst.debug, Scale: inst.timeouts.Scale, diff --git a/vm/proxyapp/proxyappclient.go b/vm/proxyapp/proxyappclient.go index 72faf5655..32a0b0b96 100644 --- a/vm/proxyapp/proxyappclient.go +++ b/vm/proxyapp/proxyappclient.go @@ -477,11 +477,7 @@ func buildMerger(names ...string) (*vmimpl.OutputMerger, []io.Writer) { return merger, wPipes } -func (inst *instance) Run( - timeout time.Duration, - stop <-chan bool, - command string, -) (<-chan []byte, <-chan error, error) { +func (inst *instance) Run(ctx context.Context, command string) (<-chan []byte, <-chan error, error) { merger, wPipes := buildMerger("stdout", "stderr", "console") receivedStdoutChunks := wPipes[0] receivedStderrChunks := wPipes[1] @@ -502,7 +498,6 @@ func (inst *instance) Run( runID := reply.RunID terminationError := make(chan error, 1) - timeoutSignal := time.After(timeout) signalClientErrorf := clientErrorf(receivedStderrChunks) go func() { @@ -531,13 +526,10 @@ func (inst *instance) Run( } else { continue } - case <-timeoutSignal: + case <-ctx.Done(): // It is the happy path. inst.runStop(runID) terminationError <- vmimpl.ErrTimeout - case <-stop: - inst.runStop(runID) - terminationError <- vmimpl.ErrTimeout } break } diff --git a/vm/proxyapp/proxyappclient_test.go b/vm/proxyapp/proxyappclient_test.go index 0f199c0c8..7053f2411 100644 --- a/vm/proxyapp/proxyappclient_test.go +++ b/vm/proxyapp/proxyappclient_test.go @@ -5,6 +5,7 @@ package proxyapp import ( "bytes" + "context" "fmt" "io" "net/rpc" @@ -401,28 +402,13 @@ func TestInstance_Forward_Failure(t *testing.T) { assert.Empty(t, remoteAddressToUse) } -func TestInstance_Run_SimpleOk(t *testing.T) { - mockInstance, inst := createInstanceFixture(t) - mockInstance. - On("RunStart", mock.Anything, mock.Anything). - Return(nil). - On("RunReadProgress", mock.Anything, mock.Anything). - Return(nil). - Maybe() - - outc, errc, err := inst.Run(10*time.Second, make(chan bool), "command") - assert.NotNil(t, outc) - assert.NotNil(t, errc) - assert.Nil(t, err) -} - func TestInstance_Run_Failure(t *testing.T) { mockInstance, inst := createInstanceFixture(t) mockInstance. On("RunStart", mock.Anything, mock.Anything). Return(fmt.Errorf("run start error")) - outc, errc, err := inst.Run(10*time.Second, make(chan bool), "command") + outc, errc, err := inst.Run(contextWithTimeout(t, 10*time.Second), "command") assert.Nil(t, outc) assert.Nil(t, errc) assert.NotEmpty(t, err) @@ -438,7 +424,7 @@ func TestInstance_Run_OnTimeout(t *testing.T) { On("RunStop", mock.Anything, mock.Anything). Return(nil) - _, errc, _ := inst.Run(time.Second, make(chan bool), "command") + _, errc, _ := inst.Run(contextWithTimeout(t, time.Second), "command") err := <-errc assert.Equal(t, err, vmimpl.ErrTimeout) @@ -455,9 +441,9 @@ func TestInstance_Run_OnStop(t *testing.T) { On("RunStop", mock.Anything, mock.Anything). Return(nil) - stop := make(chan bool) - _, errc, _ := inst.Run(10*time.Second, stop, "command") - stop <- true + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + _, errc, _ := inst.Run(ctx, "command") + cancel() err := <-errc assert.Equal(t, err, vmimpl.ErrTimeout) } @@ -478,7 +464,7 @@ func TestInstance_RunReadProgress_OnErrorReceived(t *testing.T) { Return(nil). Once() - outc, _, _ := inst.Run(10*time.Second, make(chan bool), "command") + outc, _, _ := inst.Run(contextWithTimeout(t, 10*time.Second), "command") output := string(<-outc) assert.Equal(t, "mock error\nSYZFAIL: proxy app plugin error\n", output) @@ -500,7 +486,7 @@ func TestInstance_RunReadProgress_OnFinished(t *testing.T) { Return(nil). Once() - _, errc, _ := inst.Run(10*time.Second, make(chan bool), "command") + _, errc, _ := inst.Run(contextWithTimeout(t, 10*time.Second), "command") err := <-errc assert.Equal(t, err, nil) @@ -519,7 +505,7 @@ func TestInstance_RunReadProgress_Failed(t *testing.T) { Return(fmt.Errorf("runreadprogresserror")). Once() - outc, _, _ := inst.Run(10*time.Second, make(chan bool), "command") + outc, _, _ := inst.Run(contextWithTimeout(t, 10*time.Second), "command") output := string(<-outc) assert.Equal(t, @@ -532,3 +518,9 @@ func TestInstance_RunReadProgress_Failed(t *testing.T) { // [option] check pool size was changed // TODO: test pool.Close() calls plugin API and return error. + +func contextWithTimeout(t *testing.T, timeout time.Duration) context.Context { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) + return ctx +} diff --git a/vm/qemu/qemu.go b/vm/qemu/qemu.go index 398eb8047..3fb78e9fb 100644 --- a/vm/qemu/qemu.go +++ b/vm/qemu/qemu.go @@ -5,6 +5,7 @@ package qemu import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -667,7 +668,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) { return vmDst, nil } -func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) ( +func (inst *instance) Run(ctx context.Context, command string) ( <-chan []byte, <-chan error, error) { rpipe, wpipe, err := osutil.LongPipe() if err != nil { @@ -707,8 +708,7 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin return nil, nil, err } wpipe.Close() - return vmimpl.Multiplex(cmd, inst.merger, timeout, vmimpl.MultiplexConfig{ - Stop: stop, + return vmimpl.Multiplex(ctx, cmd, inst.merger, vmimpl.MultiplexConfig{ Debug: inst.debug, Scale: inst.timeouts.Scale, }) diff --git a/vm/starnix/starnix.go b/vm/starnix/starnix.go index a90cc9d96..c1921aff5 100644 --- a/vm/starnix/starnix.go +++ b/vm/starnix/starnix.go @@ -5,6 +5,7 @@ package starnix import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -403,7 +404,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) { return vmDst, fmt.Errorf("instance %s: can't push binary %s to instance over scp", inst.name, base) } -func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) ( +func (inst *instance) Run(ctx context.Context, command string) ( <-chan []byte, <-chan error, error) { rpipe, wpipe, err := osutil.LongPipe() if err != nil { @@ -430,8 +431,7 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin return nil, nil, err } wpipe.Close() - return vmimpl.Multiplex(cmd, inst.merger, timeout, vmimpl.MultiplexConfig{ - Stop: stop, + return vmimpl.Multiplex(ctx, cmd, inst.merger, vmimpl.MultiplexConfig{ Debug: inst.debug, Scale: inst.timeouts.Scale, }) diff --git a/vm/vm.go b/vm/vm.go index 522cf5aa8..2fa7d0016 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -253,7 +253,6 @@ const ( ExitError ) -type StopContext context.Context type InjectExecuting <-chan bool type OutputSize int @@ -264,13 +263,11 @@ type EarlyFinishCb func() // and the kernel console output. It detects kernel oopses in output, lost connections, hangs, etc. // Returns command+kernel output and a non-symbolized crash report (nil if no error happens). // Accepted options: -// - StopContext: the context to be used to prematurely stop the command // - ExitCondition: says which exit modes should be considered as errors/OK // - OutputSize: how much output to keep/return -func (inst *Instance) Run(timeout time.Duration, reporter *report.Reporter, command string, opts ...any) ( +func (inst *Instance) Run(ctx context.Context, reporter *report.Reporter, command string, opts ...any) ( []byte, *report.Report, error) { exit := ExitNormal - var stop <-chan bool var injected <-chan bool var finished func() outputSize := beforeContextDefault @@ -278,24 +275,17 @@ func (inst *Instance) Run(timeout time.Duration, reporter *report.Reporter, comm switch opt := o.(type) { case ExitCondition: exit = opt - case StopContext: - stopCh := make(chan bool) - go func() { - <-opt.Done() - close(stopCh) - }() - stop = stopCh case OutputSize: outputSize = int(opt) case InjectExecuting: - injected = (<-chan bool)(opt) + injected = opt case EarlyFinishCb: finished = opt default: panic(fmt.Sprintf("unknown option %#v", opt)) } } - outc, errc, err := inst.impl.Run(timeout, stop, command) + outc, errc, err := inst.impl.Run(ctx, command) if err != nil { return nil, nil, err } diff --git a/vm/vm_test.go b/vm/vm_test.go index 34b656799..f550cc89a 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -5,6 +5,7 @@ package vm import ( "bytes" + "context" "fmt" "testing" "time" @@ -49,7 +50,7 @@ func (inst *testInstance) Forward(port int) (string, error) { return "", nil } -func (inst *testInstance) Run(timeout time.Duration, stop <-chan bool, command string) ( +func (inst *testInstance) Run(ctx context.Context, command string) ( outc <-chan []byte, errc <-chan error, err error) { return inst.outc, inst.errc, nil } @@ -395,7 +396,7 @@ func testMonitorExecution(t *testing.T, test *Test) { test.BodyExecuting(testInst.outc, testInst.errc, inject) done <- true }() - _, rep, err := inst.Run(time.Second, reporter, "", opts...) + _, rep, err := inst.Run(context.Background(), reporter, "", opts...) if err != nil { t.Fatal(err) } diff --git a/vm/vmimpl/vmimpl.go b/vm/vmimpl/vmimpl.go index 6f2416494..68e7030c3 100644 --- a/vm/vmimpl/vmimpl.go +++ b/vm/vmimpl/vmimpl.go @@ -8,6 +8,7 @@ package vmimpl import ( + "context" "crypto/rand" "errors" "fmt" @@ -47,8 +48,8 @@ type Instance interface { // Run runs cmd inside of the VM (think of ssh cmd). // outc receives combined cmd and kernel console output. // errc receives either command Wait return error or vmimpl.ErrTimeout. - // Command is terminated after timeout. Send on the stop chan can be used to terminate it earlier. - Run(timeout time.Duration, stop <-chan bool, command string) (outc <-chan []byte, errc <-chan error, err error) + // Command terminates with context. Use context.WithTimeout to terminate it earlier. + Run(ctx context.Context, command string) (outc <-chan []byte, errc <-chan error, err error) // Diagnose retrieves additional debugging info from the VM // (e.g. by sending some sys-rq's or SIGABORT'ing a Go program). @@ -170,14 +171,13 @@ var WaitForOutputTimeout = 10 * time.Second type MultiplexConfig struct { Console io.Closer - Stop <-chan bool Close <-chan bool Debug bool Scale time.Duration IgnoreError func(err error) bool } -func Multiplex(cmd *exec.Cmd, merger *OutputMerger, timeout time.Duration, config MultiplexConfig) ( +func Multiplex(ctx context.Context, cmd *exec.Cmd, merger *OutputMerger, config MultiplexConfig) ( <-chan []byte, <-chan error, error) { if config.Scale <= 0 { panic("slowdown must be set") @@ -191,9 +191,7 @@ func Multiplex(cmd *exec.Cmd, merger *OutputMerger, timeout time.Duration, confi } go func() { select { - case <-time.After(timeout): - signal(ErrTimeout) - case <-config.Stop: + case <-ctx.Done(): signal(ErrTimeout) case <-config.Close: if config.Debug { diff --git a/vm/vmm/vmm.go b/vm/vmm/vmm.go index 35fdf7650..5190e32b3 100644 --- a/vm/vmm/vmm.go +++ b/vm/vmm/vmm.go @@ -5,6 +5,7 @@ package vmm import ( + "context" "fmt" "io" "os" @@ -250,7 +251,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) { return vmDst, nil } -func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) ( +func (inst *instance) Run(ctx context.Context, command string) ( <-chan []byte, <-chan error, error) { rpipe, wpipe, err := osutil.LongPipe() if err != nil { @@ -281,9 +282,7 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin go func() { select { - case <-time.After(timeout): - signal(vmimpl.ErrTimeout) - case <-stop: + case <-ctx.Done(): signal(vmimpl.ErrTimeout) case err := <-inst.merger.Err: cmd.Process.Kill() diff --git a/vm/vmware/vmware.go b/vm/vmware/vmware.go index 104c3d6a9..e4379717b 100644 --- a/vm/vmware/vmware.go +++ b/vm/vmware/vmware.go @@ -4,6 +4,7 @@ package vmware import ( + "context" "fmt" "io" "net" @@ -173,7 +174,7 @@ func (inst *instance) Copy(hostSrc string) (string, error) { return vmDst, nil } -func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command string) ( +func (inst *instance) Run(ctx context.Context, command string) ( <-chan []byte, <-chan error, error) { vmxDir := filepath.Dir(inst.vmx) serial := filepath.Join(vmxDir, "serial") @@ -217,9 +218,8 @@ func (inst *instance) Run(timeout time.Duration, stop <-chan bool, command strin merger.Add("dmesg", dmesg) merger.Add("ssh", rpipe) - return vmimpl.Multiplex(cmd, merger, timeout, vmimpl.MultiplexConfig{ + return vmimpl.Multiplex(ctx, cmd, merger, vmimpl.MultiplexConfig{ Console: dmesg, - Stop: stop, Close: inst.closed, Debug: inst.debug, Scale: inst.timeouts.Scale, -- cgit mrf-deployment