aboutsummaryrefslogtreecommitdiffstats
path: root/vm/vmimpl
diff options
context:
space:
mode:
authorAleksandr Nogikh <nogikh@google.com>2026-02-07 19:07:44 +0100
committerAleksandr Nogikh <nogikh@google.com>2026-02-17 14:55:28 +0000
commit9ca304f096ce425e3556bbd05745f896f5e0b268 (patch)
treef17ce5341faa2f56bc7a6a1f1b34aad6e7442525 /vm/vmimpl
parentf4288eed9c51fed44b853c252711c00de8761336 (diff)
vm/vmimpl: refactor Merger error processing
Introduce an Error() method to avoid capturing the errors of already overridden decoders.
Diffstat (limited to 'vm/vmimpl')
-rw-r--r--vm/vmimpl/merger.go150
-rw-r--r--vm/vmimpl/merger_test.go63
-rw-r--r--vm/vmimpl/util.go2
-rw-r--r--vm/vmimpl/vmimpl.go4
4 files changed, 160 insertions, 59 deletions
diff --git a/vm/vmimpl/merger.go b/vm/vmimpl/merger.go
index b7a7910fb..a111da389 100644
--- a/vm/vmimpl/merger.go
+++ b/vm/vmimpl/merger.go
@@ -5,9 +5,12 @@ package vmimpl
import (
"bytes"
+ "context"
"fmt"
"io"
"sync"
+
+ "golang.org/x/sync/errgroup"
)
type OutputType int
@@ -23,12 +26,17 @@ type Chunk struct {
Type OutputType
}
+type decoderState struct {
+ done chan struct{} // Closed when the decoder exits.
+ err error
+}
+
type OutputMerger struct {
- Output chan Chunk
- Err chan error
- teeMu sync.Mutex
- tee io.Writer
- wg sync.WaitGroup
+ Output chan Chunk
+ decoderErr map[string]*decoderState
+ teeMu sync.Mutex
+ tee io.Writer
+ wg sync.WaitGroup
}
type MergerError struct {
@@ -43,9 +51,9 @@ func (err MergerError) Error() string {
func NewOutputMerger(tee io.Writer) *OutputMerger {
return &OutputMerger{
- Output: make(chan Chunk, 1000),
- Err: make(chan error, 1),
- tee: tee,
+ Output: make(chan Chunk, 1000),
+ decoderErr: map[string]*decoderState{},
+ tee: tee,
}
}
@@ -54,70 +62,100 @@ func (merger *OutputMerger) Wait() {
close(merger.Output)
}
+// Errors returns a channel that will receive errors from the curretly active decoderErr.
+func (merger *OutputMerger) Errors(ctx context.Context) <-chan error {
+ eg, egCtx := errgroup.WithContext(ctx)
+ for _, decoder := range merger.decoderErr {
+ eg.Go(func() error {
+ select {
+ case <-egCtx.Done():
+ return nil
+ case <-decoder.done:
+ return decoder.err
+ }
+ })
+ }
+ ret := make(chan error, 1)
+ go func() {
+ err := eg.Wait()
+ if err != nil {
+ ret <- err
+ }
+ close(ret)
+ }()
+ return ret
+}
+
func (merger *OutputMerger) Add(name string, typ OutputType, r io.ReadCloser) {
merger.AddDecoder(name, typ, r, nil)
}
func (merger *OutputMerger) AddDecoder(name string, typ OutputType, r io.ReadCloser,
decoder func(data []byte) (start, size int, decoded []byte)) {
+ state := &decoderState{
+ done: make(chan struct{}),
+ }
+ merger.decoderErr[name] = state
merger.wg.Add(1)
go func() {
- var pending []byte
- var proto []byte
- var buf [4 << 10]byte
- for {
- n, err := r.Read(buf[:])
- if n != 0 {
- if decoder != nil {
- proto = append(proto, buf[:n]...)
- start, size, decoded := decoder(proto)
- proto = proto[start+size:]
- if len(decoded) != 0 {
- merger.Output <- Chunk{decoded, typ} // note: this can block
- }
+ defer merger.wg.Done()
+ defer close(state.done)
+ err := merger.runDecoder(typ, r, decoder)
+ state.err = MergerError{name, r, err}
+ }()
+}
+func (merger *OutputMerger) runDecoder(typ OutputType, r io.ReadCloser,
+ decoder func(data []byte) (start, size int, decoded []byte)) error {
+ var pending []byte
+ var proto []byte
+ var buf [4 << 10]byte
+ for {
+ n, err := r.Read(buf[:])
+ if n != 0 {
+ if decoder != nil {
+ proto = append(proto, buf[:n]...)
+ start, size, decoded := decoder(proto)
+ proto = proto[start+size:]
+ if len(decoded) != 0 {
+ merger.Output <- Chunk{decoded, typ} // note: this can block
}
- // Remove all carriage returns.
- buf := buf[:n]
- if bytes.IndexByte(buf, '\r') != -1 {
- buf = bytes.ReplaceAll(buf, []byte("\r"), nil)
+ }
+ // Remove all carriage returns.
+ buf := buf[:n]
+ if bytes.IndexByte(buf, '\r') != -1 {
+ buf = bytes.ReplaceAll(buf, []byte("\r"), nil)
+ }
+ pending = append(pending, buf...)
+ if pos := bytes.LastIndexByte(pending, '\n'); pos != -1 {
+ out := pending[:pos+1]
+ if merger.tee != nil {
+ merger.teeMu.Lock()
+ merger.tee.Write(out)
+ merger.teeMu.Unlock()
}
- pending = append(pending, buf...)
- if pos := bytes.LastIndexByte(pending, '\n'); pos != -1 {
- out := pending[:pos+1]
- if merger.tee != nil {
- merger.teeMu.Lock()
- merger.tee.Write(out)
- merger.teeMu.Unlock()
- }
- select {
- case merger.Output <- Chunk{append([]byte{}, out...), typ}:
- r := copy(pending, pending[pos+1:])
- pending = pending[:r]
- default:
- }
+ select {
+ case merger.Output <- Chunk{append([]byte{}, out...), typ}:
+ r := copy(pending, pending[pos+1:])
+ pending = pending[:r]
+ default:
}
}
- if err != nil {
- if len(pending) != 0 {
- pending = append(pending, '\n')
- if merger.tee != nil {
- merger.teeMu.Lock()
- merger.tee.Write(pending)
- merger.teeMu.Unlock()
- }
- select {
- case merger.Output <- Chunk{pending, typ}:
- default:
- }
+ }
+ if err != nil {
+ if len(pending) != 0 {
+ pending = append(pending, '\n')
+ if merger.tee != nil {
+ merger.teeMu.Lock()
+ merger.tee.Write(pending)
+ merger.teeMu.Unlock()
}
- r.Close()
select {
- case merger.Err <- MergerError{name, r, err}:
+ case merger.Output <- Chunk{pending, typ}:
default:
}
- merger.wg.Done()
- return
}
+ r.Close()
+ return err
}
- }()
+ }
}
diff --git a/vm/vmimpl/merger_test.go b/vm/vmimpl/merger_test.go
index 57c89321d..e231e5c98 100644
--- a/vm/vmimpl/merger_test.go
+++ b/vm/vmimpl/merger_test.go
@@ -5,12 +5,14 @@ package vmimpl
import (
"bytes"
+ "context"
"errors"
"io"
"testing"
"time"
"github.com/google/syzkaller/pkg/osutil"
+ "github.com/stretchr/testify/assert"
)
func TestMerger(t *testing.T) {
@@ -67,7 +69,9 @@ func TestMerger(t *testing.T) {
}
var merr MergerError
- if err := <-merger.Err; err == nil {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ if err := <-merger.Errors(ctx); err == nil {
t.Fatalf("merger did not produce an error on pipe close")
} else if !errors.As(err, &merr) || merr.Name != "pipe1" || merr.R != rp1 || merr.Err != io.EOF {
t.Fatalf("merger produced wrong error: %v", err)
@@ -85,3 +89,60 @@ func TestMerger(t *testing.T) {
t.Fatalf("bad tee: '%s', want '%s'", got, want)
}
}
+
+type brokenReader struct {
+ err error
+}
+
+func (r *brokenReader) Read(p []byte) (int, error) {
+ return 0, r.err
+}
+
+func (r *brokenReader) Close() error { return nil }
+
+func TestMergerErrors(t *testing.T) {
+ merger := NewOutputMerger(nil)
+
+ r1 := &brokenReader{errors.New("foo")}
+ merger.Add("foo", OutputConsole, r1)
+
+ ctx := context.Background()
+ var merr MergerError
+
+ // Add a background reader that will just hang.
+ rHang, wHang, err := osutil.LongPipe()
+ if err != nil {
+ t.Fatal(err)
+ }
+ merger.Add("background", OutputConsole, rHang)
+
+ err = <-merger.Errors(ctx)
+ if assert.Error(t, err) {
+ assert.True(t, errors.As(err, &merr))
+ assert.Equal(t, "foo", merr.Name)
+ assert.EqualError(t, merr.Err, "foo")
+ }
+
+ // The error must persist.
+ err = <-merger.Errors(ctx)
+ if assert.Error(t, err) {
+ assert.True(t, errors.As(err, &merr))
+ assert.Equal(t, "foo", merr.Name)
+ assert.EqualError(t, merr.Err, "foo")
+ }
+
+ // We re-add the decoder as "foo".
+ // The previous error should be gone.
+ r2 := &brokenReader{errors.New("bar")}
+ merger.Add("foo", OutputConsole, r2)
+
+ err = <-merger.Errors(ctx)
+ if assert.Error(t, err) {
+ assert.True(t, errors.As(err, &merr))
+ assert.Equal(t, "foo", merr.Name)
+ assert.EqualError(t, merr.Err, "bar")
+ }
+
+ wHang.Close()
+ merger.Wait()
+}
diff --git a/vm/vmimpl/util.go b/vm/vmimpl/util.go
index a0e3ce841..0d56a3cdc 100644
--- a/vm/vmimpl/util.go
+++ b/vm/vmimpl/util.go
@@ -30,7 +30,7 @@ type SSHOptions struct {
Key string
}
-func WaitForSSH(timeout time.Duration, opts SSHOptions, OS string, stop chan error, systemSSHCfg, debug bool) error {
+func WaitForSSH(timeout time.Duration, opts SSHOptions, OS string, stop <-chan error, systemSSHCfg, debug bool) error {
pwd := "pwd"
if OS == targets.Windows {
pwd = "dir"
diff --git a/vm/vmimpl/vmimpl.go b/vm/vmimpl/vmimpl.go
index f54e4e2b5..b1745d192 100644
--- a/vm/vmimpl/vmimpl.go
+++ b/vm/vmimpl/vmimpl.go
@@ -190,6 +190,8 @@ func Multiplex(ctx context.Context, cmd *exec.Cmd, merger *OutputMerger, config
}
}
go func() {
+ errCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
select {
case <-ctx.Done():
signal(ErrTimeout)
@@ -198,7 +200,7 @@ func Multiplex(ctx context.Context, cmd *exec.Cmd, merger *OutputMerger, config
log.Logf(0, "instance closed")
}
signal(fmt.Errorf("instance closed"))
- case err := <-merger.Err:
+ case err := <-merger.Errors(errCtx):
// EOF is not always in perfect sync with exit, so we should wait a bit.
if cmdErr := waitAndKill(ctx, cmd); cmdErr == nil {
// If the command exited successfully, we got EOF error from merger.