aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2024-08-12 14:31:04 +0200
committerDmitry Vyukov <dvyukov@google.com>2024-08-13 15:13:04 +0000
commit068ad4fdb4cc546e45708cc40115e2baf3c49830 (patch)
treeecce6fc92638d165acb718e303244eb8b045bbfb
parentebd71c054bab8fad556745259f1d7974ed60095d (diff)
syz-manager: make snapshot result parsing robust
Use logic similar to flatrpc to avoid panics during result parsing.
-rw-r--r--pkg/flatrpc/conn.go51
-rw-r--r--syz-manager/snapshot.go70
2 files changed, 60 insertions, 61 deletions
diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go
index c5e1cb1a4..33eca07e1 100644
--- a/pkg/flatrpc/conn.go
+++ b/pkg/flatrpc/conn.go
@@ -120,24 +120,6 @@ type RecvType[T any] interface {
// Receiving should be done from a single goroutine, the received message is valid
// only until the next Recv call (messages share the same underlying receive buffer).
func Recv[Raw RecvType[T], T any](c *Conn) (res *T, err0 error) {
- defer func() {
- if err := recover(); err != nil {
- err0 = fmt.Errorf("%v", err)
- }
- }()
- raw, err := RecvRaw[Raw](c)
- if err != nil {
- return nil, err
- }
- return raw.UnPack(), nil
-}
-
-func RecvRaw[T flatbuffers.FlatBuffer](c *Conn) (res T, err0 error) {
- defer func() {
- if err := recover(); err != nil {
- err0 = fmt.Errorf("%v", err)
- }
- }()
// First, discard the previous message.
// For simplicity we copy any data from the next message to the beginning of the buffer.
// Theoretically we could something more efficient, e.g. don't copy if we already
@@ -151,27 +133,20 @@ func RecvRaw[T flatbuffers.FlatBuffer](c *Conn) (res T, err0 error) {
sizePrefixSize = flatbuffers.SizeUint32
maxMessageSize = 64 << 20
)
- var msg T
// Then, receive at least the size prefix (4 bytes).
// And then the full message, if we have not got it yet.
if err := c.recv(sizePrefixSize); err != nil {
- return msg, fmt.Errorf("failed to recv %T: %w", msg, err)
+ return nil, fmt.Errorf("failed to recv %T: %w", (*T)(nil), err)
}
size := int(flatbuffers.GetSizePrefix(c.data, 0))
if size > maxMessageSize {
- return msg, fmt.Errorf("message %T has too large size %v", msg, size)
+ return nil, fmt.Errorf("message %T has too large size %v", (*T)(nil), size)
}
c.lastMsg = sizePrefixSize + size
if err := c.recv(c.lastMsg); err != nil {
- return msg, fmt.Errorf("failed to recv %T: %w", msg, err)
+ return nil, fmt.Errorf("failed to recv %T: %w", (*T)(nil), err)
}
- statRecv.Add(c.lastMsg)
- // This probably can't be expressed w/o reflect as "new U" where U is *T,
- // but I failed to express that as generic constraints.
- msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(T)
- data := c.data[sizePrefixSize:c.lastMsg]
- msg.Init(data, flatbuffers.GetUOffsetT(data))
- return msg, verify(msg, size)
+ return Parse[Raw](c.data[sizePrefixSize:c.lastMsg])
}
// recv ensures that we have at least 'size' bytes received in c.data.
@@ -191,6 +166,24 @@ func (c *Conn) recv(size int) error {
return nil
}
+func Parse[Raw RecvType[T], T any](data []byte) (res *T, err0 error) {
+ defer func() {
+ if err := recover(); err != nil {
+ err0 = fmt.Errorf("%v", err)
+ }
+ }()
+ statRecv.Add(len(data))
+ // This probably can be expressed w/o reflect as "new U" where U is *T,
+ // but I failed to express that as generic constraints.
+ var msg Raw
+ msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(Raw)
+ msg.Init(data, flatbuffers.GetUOffsetT(data))
+ if err := verify(msg, len(data)); err != nil {
+ return nil, err
+ }
+ return msg.UnPack(), nil
+}
+
func verify(raw any, rawSize int) error {
switch msg := raw.(type) {
case *ExecutorMessageRaw:
diff --git a/syz-manager/snapshot.go b/syz-manager/snapshot.go
index b76c58583..52d3f02e2 100644
--- a/syz-manager/snapshot.go
+++ b/syz-manager/snapshot.go
@@ -121,53 +121,59 @@ func (mgr *Manager) snapshotRun(inst *vm.Instance, builder *flatbuffers.Builder,
builder.Finish(msg.Pack(builder))
start := time.Now()
- res, output, err := inst.RunSnapshot(builder.FinishedBytes())
+ resData, output, err := inst.RunSnapshot(builder.FinishedBytes())
if err != nil {
return nil, nil, err
}
elapsed := time.Since(start)
- execError := ""
- var info *flatrpc.ProgInfo
- if len(res) > 4 {
- res = res[4:]
- // TODO: use more robust parsing from pkg/flatrpc/conn.go.
- var raw flatrpc.ExecutorMessageRaw
- raw.Init(res, flatbuffers.GetUOffsetT(res))
- union := raw.UnPack()
- if union.Msg != nil && union.Msg.Value != nil {
- msg := union.Msg.Value.(*flatrpc.ExecResult)
- if msg.Info != nil {
- msg.Info.Elapsed = uint64(elapsed)
- for len(msg.Info.Calls) < len(req.Prog.Calls) {
- msg.Info.Calls = append(msg.Info.Calls, &flatrpc.CallInfo{
- Error: 999,
- })
- }
- msg.Info.Calls = msg.Info.Calls[:len(req.Prog.Calls)]
- if len(msg.Info.ExtraRaw) != 0 {
- msg.Info.Extra = msg.Info.ExtraRaw[0]
- for _, info := range msg.Info.ExtraRaw[1:] {
- msg.Info.Extra.Cover = append(msg.Info.Extra.Cover, info.Cover...)
- msg.Info.Extra.Signal = append(msg.Info.Extra.Signal, info.Signal...)
- }
- msg.Info.ExtraRaw = nil
- }
+ res := parseExecResult(resData)
+ if res.Info != nil {
+ res.Info.Elapsed = uint64(elapsed)
+ for len(res.Info.Calls) < len(req.Prog.Calls) {
+ res.Info.Calls = append(res.Info.Calls, &flatrpc.CallInfo{
+ Error: 999,
+ })
+ }
+ res.Info.Calls = res.Info.Calls[:len(req.Prog.Calls)]
+ if len(res.Info.ExtraRaw) != 0 {
+ res.Info.Extra = res.Info.ExtraRaw[0]
+ for _, info := range res.Info.ExtraRaw[1:] {
+ res.Info.Extra.Cover = append(res.Info.Extra.Cover, info.Cover...)
+ res.Info.Extra.Signal = append(res.Info.Extra.Signal, info.Signal...)
}
- info = msg.Info
- execError = msg.Error
+ res.Info.ExtraRaw = nil
}
}
+
ret := &queue.Result{
Status: queue.Success,
- Info: info,
+ Info: res.Info,
}
- if execError != "" {
+ if res.Error != "" {
ret.Status = queue.ExecFailure
- ret.Err = errors.New(execError)
+ ret.Err = errors.New(res.Error)
}
if req.ReturnOutput {
ret.Output = output
}
return ret, output, nil
}
+
+func parseExecResult(data []byte) *flatrpc.ExecResult {
+ raw, err := flatrpc.Parse[*flatrpc.ExecutorMessageRaw](data[flatbuffers.SizeUint32:])
+ if err != nil {
+ // Don't consider result parsing error as an infrastructure error,
+ // it's just the test program corrupted memory.
+ return &flatrpc.ExecResult{
+ Error: err.Error(),
+ }
+ }
+ res, ok := raw.Msg.Value.(*flatrpc.ExecResult)
+ if !ok {
+ return &flatrpc.ExecResult{
+ Error: "result is not ExecResult",
+ }
+ }
+ return res
+}