aboutsummaryrefslogtreecommitdiffstats
path: root/vm/proxyapp
diff options
context:
space:
mode:
authorTaras Madan <tarasmadan@google.com>2025-05-15 15:01:02 +0200
committerTaras Madan <tarasmadan@google.com>2025-05-19 09:39:47 +0000
commit27f689959decd391b047c8034d481267d500549e (patch)
tree79ce6364d592fd6841e25ec64ca645fc3c65cdcf /vm/proxyapp
parent8f9cf946b3733d0b4ad3124bce155a4fc3849c3a (diff)
vm: func Run accepts context
It allows to use context as a single termination signal source.
Diffstat (limited to 'vm/proxyapp')
-rw-r--r--vm/proxyapp/proxyappclient.go12
-rw-r--r--vm/proxyapp/proxyappclient_test.go38
2 files changed, 17 insertions, 33 deletions
diff --git a/vm/proxyapp/proxyappclient.go b/vm/proxyapp/proxyappclient.go
index 72faf5655..32a0b0b96 100644
--- a/vm/proxyapp/proxyappclient.go
+++ b/vm/proxyapp/proxyappclient.go
@@ -477,11 +477,7 @@ func buildMerger(names ...string) (*vmimpl.OutputMerger, []io.Writer) {
return merger, wPipes
}
-func (inst *instance) Run(
- timeout time.Duration,
- stop <-chan bool,
- command string,
-) (<-chan []byte, <-chan error, error) {
+func (inst *instance) Run(ctx context.Context, command string) (<-chan []byte, <-chan error, error) {
merger, wPipes := buildMerger("stdout", "stderr", "console")
receivedStdoutChunks := wPipes[0]
receivedStderrChunks := wPipes[1]
@@ -502,7 +498,6 @@ func (inst *instance) Run(
runID := reply.RunID
terminationError := make(chan error, 1)
- timeoutSignal := time.After(timeout)
signalClientErrorf := clientErrorf(receivedStderrChunks)
go func() {
@@ -531,13 +526,10 @@ func (inst *instance) Run(
} else {
continue
}
- case <-timeoutSignal:
+ case <-ctx.Done():
// It is the happy path.
inst.runStop(runID)
terminationError <- vmimpl.ErrTimeout
- case <-stop:
- inst.runStop(runID)
- terminationError <- vmimpl.ErrTimeout
}
break
}
diff --git a/vm/proxyapp/proxyappclient_test.go b/vm/proxyapp/proxyappclient_test.go
index 0f199c0c8..7053f2411 100644
--- a/vm/proxyapp/proxyappclient_test.go
+++ b/vm/proxyapp/proxyappclient_test.go
@@ -5,6 +5,7 @@ package proxyapp
import (
"bytes"
+ "context"
"fmt"
"io"
"net/rpc"
@@ -401,28 +402,13 @@ func TestInstance_Forward_Failure(t *testing.T) {
assert.Empty(t, remoteAddressToUse)
}
-func TestInstance_Run_SimpleOk(t *testing.T) {
- mockInstance, inst := createInstanceFixture(t)
- mockInstance.
- On("RunStart", mock.Anything, mock.Anything).
- Return(nil).
- On("RunReadProgress", mock.Anything, mock.Anything).
- Return(nil).
- Maybe()
-
- outc, errc, err := inst.Run(10*time.Second, make(chan bool), "command")
- assert.NotNil(t, outc)
- assert.NotNil(t, errc)
- assert.Nil(t, err)
-}
-
func TestInstance_Run_Failure(t *testing.T) {
mockInstance, inst := createInstanceFixture(t)
mockInstance.
On("RunStart", mock.Anything, mock.Anything).
Return(fmt.Errorf("run start error"))
- outc, errc, err := inst.Run(10*time.Second, make(chan bool), "command")
+ outc, errc, err := inst.Run(contextWithTimeout(t, 10*time.Second), "command")
assert.Nil(t, outc)
assert.Nil(t, errc)
assert.NotEmpty(t, err)
@@ -438,7 +424,7 @@ func TestInstance_Run_OnTimeout(t *testing.T) {
On("RunStop", mock.Anything, mock.Anything).
Return(nil)
- _, errc, _ := inst.Run(time.Second, make(chan bool), "command")
+ _, errc, _ := inst.Run(contextWithTimeout(t, time.Second), "command")
err := <-errc
assert.Equal(t, err, vmimpl.ErrTimeout)
@@ -455,9 +441,9 @@ func TestInstance_Run_OnStop(t *testing.T) {
On("RunStop", mock.Anything, mock.Anything).
Return(nil)
- stop := make(chan bool)
- _, errc, _ := inst.Run(10*time.Second, stop, "command")
- stop <- true
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ _, errc, _ := inst.Run(ctx, "command")
+ cancel()
err := <-errc
assert.Equal(t, err, vmimpl.ErrTimeout)
}
@@ -478,7 +464,7 @@ func TestInstance_RunReadProgress_OnErrorReceived(t *testing.T) {
Return(nil).
Once()
- outc, _, _ := inst.Run(10*time.Second, make(chan bool), "command")
+ outc, _, _ := inst.Run(contextWithTimeout(t, 10*time.Second), "command")
output := string(<-outc)
assert.Equal(t, "mock error\nSYZFAIL: proxy app plugin error\n", output)
@@ -500,7 +486,7 @@ func TestInstance_RunReadProgress_OnFinished(t *testing.T) {
Return(nil).
Once()
- _, errc, _ := inst.Run(10*time.Second, make(chan bool), "command")
+ _, errc, _ := inst.Run(contextWithTimeout(t, 10*time.Second), "command")
err := <-errc
assert.Equal(t, err, nil)
@@ -519,7 +505,7 @@ func TestInstance_RunReadProgress_Failed(t *testing.T) {
Return(fmt.Errorf("runreadprogresserror")).
Once()
- outc, _, _ := inst.Run(10*time.Second, make(chan bool), "command")
+ outc, _, _ := inst.Run(contextWithTimeout(t, 10*time.Second), "command")
output := string(<-outc)
assert.Equal(t,
@@ -532,3 +518,9 @@ func TestInstance_RunReadProgress_Failed(t *testing.T) {
// [option] check pool size was changed
// TODO: test pool.Close() calls plugin API and return error.
+
+func contextWithTimeout(t *testing.T, timeout time.Duration) context.Context {
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ t.Cleanup(cancel)
+ return ctx
+}