diff options
Diffstat (limited to 'vm/proxyapp')
| -rw-r--r-- | vm/proxyapp/proxyappclient.go | 12 | ||||
| -rw-r--r-- | vm/proxyapp/proxyappclient_test.go | 38 |
2 files changed, 17 insertions, 33 deletions
diff --git a/vm/proxyapp/proxyappclient.go b/vm/proxyapp/proxyappclient.go index 72faf5655..32a0b0b96 100644 --- a/vm/proxyapp/proxyappclient.go +++ b/vm/proxyapp/proxyappclient.go @@ -477,11 +477,7 @@ func buildMerger(names ...string) (*vmimpl.OutputMerger, []io.Writer) { return merger, wPipes } -func (inst *instance) Run( - timeout time.Duration, - stop <-chan bool, - command string, -) (<-chan []byte, <-chan error, error) { +func (inst *instance) Run(ctx context.Context, command string) (<-chan []byte, <-chan error, error) { merger, wPipes := buildMerger("stdout", "stderr", "console") receivedStdoutChunks := wPipes[0] receivedStderrChunks := wPipes[1] @@ -502,7 +498,6 @@ func (inst *instance) Run( runID := reply.RunID terminationError := make(chan error, 1) - timeoutSignal := time.After(timeout) signalClientErrorf := clientErrorf(receivedStderrChunks) go func() { @@ -531,13 +526,10 @@ func (inst *instance) Run( } else { continue } - case <-timeoutSignal: + case <-ctx.Done(): // It is the happy path. inst.runStop(runID) terminationError <- vmimpl.ErrTimeout - case <-stop: - inst.runStop(runID) - terminationError <- vmimpl.ErrTimeout } break } diff --git a/vm/proxyapp/proxyappclient_test.go b/vm/proxyapp/proxyappclient_test.go index 0f199c0c8..7053f2411 100644 --- a/vm/proxyapp/proxyappclient_test.go +++ b/vm/proxyapp/proxyappclient_test.go @@ -5,6 +5,7 @@ package proxyapp import ( "bytes" + "context" "fmt" "io" "net/rpc" @@ -401,28 +402,13 @@ func TestInstance_Forward_Failure(t *testing.T) { assert.Empty(t, remoteAddressToUse) } -func TestInstance_Run_SimpleOk(t *testing.T) { - mockInstance, inst := createInstanceFixture(t) - mockInstance. - On("RunStart", mock.Anything, mock.Anything). - Return(nil). - On("RunReadProgress", mock.Anything, mock.Anything). - Return(nil). - Maybe() - - outc, errc, err := inst.Run(10*time.Second, make(chan bool), "command") - assert.NotNil(t, outc) - assert.NotNil(t, errc) - assert.Nil(t, err) -} - func TestInstance_Run_Failure(t *testing.T) { mockInstance, inst := createInstanceFixture(t) mockInstance. On("RunStart", mock.Anything, mock.Anything). Return(fmt.Errorf("run start error")) - outc, errc, err := inst.Run(10*time.Second, make(chan bool), "command") + outc, errc, err := inst.Run(contextWithTimeout(t, 10*time.Second), "command") assert.Nil(t, outc) assert.Nil(t, errc) assert.NotEmpty(t, err) @@ -438,7 +424,7 @@ func TestInstance_Run_OnTimeout(t *testing.T) { On("RunStop", mock.Anything, mock.Anything). Return(nil) - _, errc, _ := inst.Run(time.Second, make(chan bool), "command") + _, errc, _ := inst.Run(contextWithTimeout(t, time.Second), "command") err := <-errc assert.Equal(t, err, vmimpl.ErrTimeout) @@ -455,9 +441,9 @@ func TestInstance_Run_OnStop(t *testing.T) { On("RunStop", mock.Anything, mock.Anything). Return(nil) - stop := make(chan bool) - _, errc, _ := inst.Run(10*time.Second, stop, "command") - stop <- true + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + _, errc, _ := inst.Run(ctx, "command") + cancel() err := <-errc assert.Equal(t, err, vmimpl.ErrTimeout) } @@ -478,7 +464,7 @@ func TestInstance_RunReadProgress_OnErrorReceived(t *testing.T) { Return(nil). Once() - outc, _, _ := inst.Run(10*time.Second, make(chan bool), "command") + outc, _, _ := inst.Run(contextWithTimeout(t, 10*time.Second), "command") output := string(<-outc) assert.Equal(t, "mock error\nSYZFAIL: proxy app plugin error\n", output) @@ -500,7 +486,7 @@ func TestInstance_RunReadProgress_OnFinished(t *testing.T) { Return(nil). Once() - _, errc, _ := inst.Run(10*time.Second, make(chan bool), "command") + _, errc, _ := inst.Run(contextWithTimeout(t, 10*time.Second), "command") err := <-errc assert.Equal(t, err, nil) @@ -519,7 +505,7 @@ func TestInstance_RunReadProgress_Failed(t *testing.T) { Return(fmt.Errorf("runreadprogresserror")). Once() - outc, _, _ := inst.Run(10*time.Second, make(chan bool), "command") + outc, _, _ := inst.Run(contextWithTimeout(t, 10*time.Second), "command") output := string(<-outc) assert.Equal(t, @@ -532,3 +518,9 @@ func TestInstance_RunReadProgress_Failed(t *testing.T) { // [option] check pool size was changed // TODO: test pool.Close() calls plugin API and return error. + +func contextWithTimeout(t *testing.T, timeout time.Duration) context.Context { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) + return ctx +} |
