diff options
| author | Aleksandr Nogikh <nogikh@google.com> | 2026-02-07 19:07:44 +0100 |
|---|---|---|
| committer | Aleksandr Nogikh <nogikh@google.com> | 2026-02-17 14:55:28 +0000 |
| commit | 9ca304f096ce425e3556bbd05745f896f5e0b268 (patch) | |
| tree | f17ce5341faa2f56bc7a6a1f1b34aad6e7442525 /vm/vmimpl | |
| parent | f4288eed9c51fed44b853c252711c00de8761336 (diff) | |
vm/vmimpl: refactor Merger error processing
Introduce an Error() method to avoid capturing the errors of already
overridden decoders.
Diffstat (limited to 'vm/vmimpl')
| -rw-r--r-- | vm/vmimpl/merger.go | 150 | ||||
| -rw-r--r-- | vm/vmimpl/merger_test.go | 63 | ||||
| -rw-r--r-- | vm/vmimpl/util.go | 2 | ||||
| -rw-r--r-- | vm/vmimpl/vmimpl.go | 4 |
4 files changed, 160 insertions, 59 deletions
diff --git a/vm/vmimpl/merger.go b/vm/vmimpl/merger.go index b7a7910fb..a111da389 100644 --- a/vm/vmimpl/merger.go +++ b/vm/vmimpl/merger.go @@ -5,9 +5,12 @@ package vmimpl import ( "bytes" + "context" "fmt" "io" "sync" + + "golang.org/x/sync/errgroup" ) type OutputType int @@ -23,12 +26,17 @@ type Chunk struct { Type OutputType } +type decoderState struct { + done chan struct{} // Closed when the decoder exits. + err error +} + type OutputMerger struct { - Output chan Chunk - Err chan error - teeMu sync.Mutex - tee io.Writer - wg sync.WaitGroup + Output chan Chunk + decoderErr map[string]*decoderState + teeMu sync.Mutex + tee io.Writer + wg sync.WaitGroup } type MergerError struct { @@ -43,9 +51,9 @@ func (err MergerError) Error() string { func NewOutputMerger(tee io.Writer) *OutputMerger { return &OutputMerger{ - Output: make(chan Chunk, 1000), - Err: make(chan error, 1), - tee: tee, + Output: make(chan Chunk, 1000), + decoderErr: map[string]*decoderState{}, + tee: tee, } } @@ -54,70 +62,100 @@ func (merger *OutputMerger) Wait() { close(merger.Output) } +// Errors returns a channel that will receive errors from the curretly active decoderErr. +func (merger *OutputMerger) Errors(ctx context.Context) <-chan error { + eg, egCtx := errgroup.WithContext(ctx) + for _, decoder := range merger.decoderErr { + eg.Go(func() error { + select { + case <-egCtx.Done(): + return nil + case <-decoder.done: + return decoder.err + } + }) + } + ret := make(chan error, 1) + go func() { + err := eg.Wait() + if err != nil { + ret <- err + } + close(ret) + }() + return ret +} + func (merger *OutputMerger) Add(name string, typ OutputType, r io.ReadCloser) { merger.AddDecoder(name, typ, r, nil) } func (merger *OutputMerger) AddDecoder(name string, typ OutputType, r io.ReadCloser, decoder func(data []byte) (start, size int, decoded []byte)) { + state := &decoderState{ + done: make(chan struct{}), + } + merger.decoderErr[name] = state merger.wg.Add(1) go func() { - var pending []byte - var proto []byte - var buf [4 << 10]byte - for { - n, err := r.Read(buf[:]) - if n != 0 { - if decoder != nil { - proto = append(proto, buf[:n]...) - start, size, decoded := decoder(proto) - proto = proto[start+size:] - if len(decoded) != 0 { - merger.Output <- Chunk{decoded, typ} // note: this can block - } + defer merger.wg.Done() + defer close(state.done) + err := merger.runDecoder(typ, r, decoder) + state.err = MergerError{name, r, err} + }() +} +func (merger *OutputMerger) runDecoder(typ OutputType, r io.ReadCloser, + decoder func(data []byte) (start, size int, decoded []byte)) error { + var pending []byte + var proto []byte + var buf [4 << 10]byte + for { + n, err := r.Read(buf[:]) + if n != 0 { + if decoder != nil { + proto = append(proto, buf[:n]...) + start, size, decoded := decoder(proto) + proto = proto[start+size:] + if len(decoded) != 0 { + merger.Output <- Chunk{decoded, typ} // note: this can block } - // Remove all carriage returns. - buf := buf[:n] - if bytes.IndexByte(buf, '\r') != -1 { - buf = bytes.ReplaceAll(buf, []byte("\r"), nil) + } + // Remove all carriage returns. + buf := buf[:n] + if bytes.IndexByte(buf, '\r') != -1 { + buf = bytes.ReplaceAll(buf, []byte("\r"), nil) + } + pending = append(pending, buf...) + if pos := bytes.LastIndexByte(pending, '\n'); pos != -1 { + out := pending[:pos+1] + if merger.tee != nil { + merger.teeMu.Lock() + merger.tee.Write(out) + merger.teeMu.Unlock() } - pending = append(pending, buf...) - if pos := bytes.LastIndexByte(pending, '\n'); pos != -1 { - out := pending[:pos+1] - if merger.tee != nil { - merger.teeMu.Lock() - merger.tee.Write(out) - merger.teeMu.Unlock() - } - select { - case merger.Output <- Chunk{append([]byte{}, out...), typ}: - r := copy(pending, pending[pos+1:]) - pending = pending[:r] - default: - } + select { + case merger.Output <- Chunk{append([]byte{}, out...), typ}: + r := copy(pending, pending[pos+1:]) + pending = pending[:r] + default: } } - if err != nil { - if len(pending) != 0 { - pending = append(pending, '\n') - if merger.tee != nil { - merger.teeMu.Lock() - merger.tee.Write(pending) - merger.teeMu.Unlock() - } - select { - case merger.Output <- Chunk{pending, typ}: - default: - } + } + if err != nil { + if len(pending) != 0 { + pending = append(pending, '\n') + if merger.tee != nil { + merger.teeMu.Lock() + merger.tee.Write(pending) + merger.teeMu.Unlock() } - r.Close() select { - case merger.Err <- MergerError{name, r, err}: + case merger.Output <- Chunk{pending, typ}: default: } - merger.wg.Done() - return } + r.Close() + return err } - }() + } } diff --git a/vm/vmimpl/merger_test.go b/vm/vmimpl/merger_test.go index 57c89321d..e231e5c98 100644 --- a/vm/vmimpl/merger_test.go +++ b/vm/vmimpl/merger_test.go @@ -5,12 +5,14 @@ package vmimpl import ( "bytes" + "context" "errors" "io" "testing" "time" "github.com/google/syzkaller/pkg/osutil" + "github.com/stretchr/testify/assert" ) func TestMerger(t *testing.T) { @@ -67,7 +69,9 @@ func TestMerger(t *testing.T) { } var merr MergerError - if err := <-merger.Err; err == nil { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := <-merger.Errors(ctx); err == nil { t.Fatalf("merger did not produce an error on pipe close") } else if !errors.As(err, &merr) || merr.Name != "pipe1" || merr.R != rp1 || merr.Err != io.EOF { t.Fatalf("merger produced wrong error: %v", err) @@ -85,3 +89,60 @@ func TestMerger(t *testing.T) { t.Fatalf("bad tee: '%s', want '%s'", got, want) } } + +type brokenReader struct { + err error +} + +func (r *brokenReader) Read(p []byte) (int, error) { + return 0, r.err +} + +func (r *brokenReader) Close() error { return nil } + +func TestMergerErrors(t *testing.T) { + merger := NewOutputMerger(nil) + + r1 := &brokenReader{errors.New("foo")} + merger.Add("foo", OutputConsole, r1) + + ctx := context.Background() + var merr MergerError + + // Add a background reader that will just hang. + rHang, wHang, err := osutil.LongPipe() + if err != nil { + t.Fatal(err) + } + merger.Add("background", OutputConsole, rHang) + + err = <-merger.Errors(ctx) + if assert.Error(t, err) { + assert.True(t, errors.As(err, &merr)) + assert.Equal(t, "foo", merr.Name) + assert.EqualError(t, merr.Err, "foo") + } + + // The error must persist. + err = <-merger.Errors(ctx) + if assert.Error(t, err) { + assert.True(t, errors.As(err, &merr)) + assert.Equal(t, "foo", merr.Name) + assert.EqualError(t, merr.Err, "foo") + } + + // We re-add the decoder as "foo". + // The previous error should be gone. + r2 := &brokenReader{errors.New("bar")} + merger.Add("foo", OutputConsole, r2) + + err = <-merger.Errors(ctx) + if assert.Error(t, err) { + assert.True(t, errors.As(err, &merr)) + assert.Equal(t, "foo", merr.Name) + assert.EqualError(t, merr.Err, "bar") + } + + wHang.Close() + merger.Wait() +} diff --git a/vm/vmimpl/util.go b/vm/vmimpl/util.go index a0e3ce841..0d56a3cdc 100644 --- a/vm/vmimpl/util.go +++ b/vm/vmimpl/util.go @@ -30,7 +30,7 @@ type SSHOptions struct { Key string } -func WaitForSSH(timeout time.Duration, opts SSHOptions, OS string, stop chan error, systemSSHCfg, debug bool) error { +func WaitForSSH(timeout time.Duration, opts SSHOptions, OS string, stop <-chan error, systemSSHCfg, debug bool) error { pwd := "pwd" if OS == targets.Windows { pwd = "dir" diff --git a/vm/vmimpl/vmimpl.go b/vm/vmimpl/vmimpl.go index f54e4e2b5..b1745d192 100644 --- a/vm/vmimpl/vmimpl.go +++ b/vm/vmimpl/vmimpl.go @@ -190,6 +190,8 @@ func Multiplex(ctx context.Context, cmd *exec.Cmd, merger *OutputMerger, config } } go func() { + errCtx, cancel := context.WithCancel(context.Background()) + defer cancel() select { case <-ctx.Done(): signal(ErrTimeout) @@ -198,7 +200,7 @@ func Multiplex(ctx context.Context, cmd *exec.Cmd, merger *OutputMerger, config log.Logf(0, "instance closed") } signal(fmt.Errorf("instance closed")) - case err := <-merger.Err: + case err := <-merger.Errors(errCtx): // EOF is not always in perfect sync with exit, so we should wait a bit. if cmdErr := waitAndKill(ctx, cmd); cmdErr == nil { // If the command exited successfully, we got EOF error from merger. |
