aboutsummaryrefslogtreecommitdiffstats
path: root/vm/vmimpl/merger.go
blob: 3691dd0c01a25db6c052b20e94f62e9db028e366 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// Copyright 2016 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

package vmimpl

import (
	"bytes"
	"fmt"
	"io"
	"sync"
)

type OutputMerger struct {
	Output chan []byte
	Err    chan error
	teeMu  sync.Mutex
	tee    io.Writer
	wg     sync.WaitGroup
}

type MergerError struct {
	Name string
	R    io.ReadCloser
	Err  error
}

func (err MergerError) Error() string {
	return fmt.Sprintf("failed to read from %v: %v", err.Name, err.Err)
}

func NewOutputMerger(tee io.Writer) *OutputMerger {
	return &OutputMerger{
		Output: make(chan []byte, 1000),
		Err:    make(chan error, 1),
		tee:    tee,
	}
}

func (merger *OutputMerger) Wait() {
	merger.wg.Wait()
	close(merger.Output)
}

func (merger *OutputMerger) Add(name string, r io.ReadCloser) {
	merger.AddDecoder(name, r, nil)
}

func (merger *OutputMerger) AddDecoder(name string, r io.ReadCloser,
	decoder func(data []byte) (start, size int, decoded []byte)) {
	merger.wg.Add(1)
	go func() {
		var pending []byte
		var proto []byte
		var buf [4 << 10]byte
		for {
			n, err := r.Read(buf[:])
			if n != 0 {
				if decoder != nil {
					proto = append(proto, buf[:n]...)
					start, size, decoded := decoder(proto)
					proto = proto[start+size:]
					if len(decoded) != 0 {
						merger.Output <- decoded // note: this can block
					}
				}
				// Remove all carriage returns.
				buf := buf[:n]
				if bytes.IndexByte(buf, '\r') != -1 {
					buf = bytes.ReplaceAll(buf, []byte("\r"), nil)
				}
				pending = append(pending, buf...)
				if pos := bytes.LastIndexByte(pending, '\n'); pos != -1 {
					out := pending[:pos+1]
					if merger.tee != nil {
						merger.teeMu.Lock()
						merger.tee.Write(out)
						merger.teeMu.Unlock()
					}
					select {
					case merger.Output <- append([]byte{}, out...):
						r := copy(pending, pending[pos+1:])
						pending = pending[:r]
					default:
					}
				}
			}
			if err != nil {
				if len(pending) != 0 {
					pending = append(pending, '\n')
					if merger.tee != nil {
						merger.teeMu.Lock()
						merger.tee.Write(pending)
						merger.teeMu.Unlock()
					}
					select {
					case merger.Output <- pending:
					default:
					}
				}
				r.Close()
				select {
				case merger.Err <- MergerError{name, r, err}:
				default:
				}
				merger.wg.Done()
				return
			}
		}
	}()
}