diff options
| author | Dmitry Vyukov <dvyukov@google.com> | 2024-08-12 14:31:04 +0200 |
|---|---|---|
| committer | Dmitry Vyukov <dvyukov@google.com> | 2024-08-13 15:13:04 +0000 |
| commit | 068ad4fdb4cc546e45708cc40115e2baf3c49830 (patch) | |
| tree | ecce6fc92638d165acb718e303244eb8b045bbfb | |
| parent | ebd71c054bab8fad556745259f1d7974ed60095d (diff) | |
syz-manager: make snapshot result parsing robust
Use logic similar to flatrpc to avoid panics during result parsing.
| -rw-r--r-- | pkg/flatrpc/conn.go | 51 | ||||
| -rw-r--r-- | syz-manager/snapshot.go | 70 |
2 files changed, 60 insertions, 61 deletions
diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go index c5e1cb1a4..33eca07e1 100644 --- a/pkg/flatrpc/conn.go +++ b/pkg/flatrpc/conn.go @@ -120,24 +120,6 @@ type RecvType[T any] interface { // Receiving should be done from a single goroutine, the received message is valid // only until the next Recv call (messages share the same underlying receive buffer). func Recv[Raw RecvType[T], T any](c *Conn) (res *T, err0 error) { - defer func() { - if err := recover(); err != nil { - err0 = fmt.Errorf("%v", err) - } - }() - raw, err := RecvRaw[Raw](c) - if err != nil { - return nil, err - } - return raw.UnPack(), nil -} - -func RecvRaw[T flatbuffers.FlatBuffer](c *Conn) (res T, err0 error) { - defer func() { - if err := recover(); err != nil { - err0 = fmt.Errorf("%v", err) - } - }() // First, discard the previous message. // For simplicity we copy any data from the next message to the beginning of the buffer. // Theoretically we could something more efficient, e.g. don't copy if we already @@ -151,27 +133,20 @@ func RecvRaw[T flatbuffers.FlatBuffer](c *Conn) (res T, err0 error) { sizePrefixSize = flatbuffers.SizeUint32 maxMessageSize = 64 << 20 ) - var msg T // Then, receive at least the size prefix (4 bytes). // And then the full message, if we have not got it yet. if err := c.recv(sizePrefixSize); err != nil { - return msg, fmt.Errorf("failed to recv %T: %w", msg, err) + return nil, fmt.Errorf("failed to recv %T: %w", (*T)(nil), err) } size := int(flatbuffers.GetSizePrefix(c.data, 0)) if size > maxMessageSize { - return msg, fmt.Errorf("message %T has too large size %v", msg, size) + return nil, fmt.Errorf("message %T has too large size %v", (*T)(nil), size) } c.lastMsg = sizePrefixSize + size if err := c.recv(c.lastMsg); err != nil { - return msg, fmt.Errorf("failed to recv %T: %w", msg, err) + return nil, fmt.Errorf("failed to recv %T: %w", (*T)(nil), err) } - statRecv.Add(c.lastMsg) - // This probably can't be expressed w/o reflect as "new U" where U is *T, - // but I failed to express that as generic constraints. - msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(T) - data := c.data[sizePrefixSize:c.lastMsg] - msg.Init(data, flatbuffers.GetUOffsetT(data)) - return msg, verify(msg, size) + return Parse[Raw](c.data[sizePrefixSize:c.lastMsg]) } // recv ensures that we have at least 'size' bytes received in c.data. @@ -191,6 +166,24 @@ func (c *Conn) recv(size int) error { return nil } +func Parse[Raw RecvType[T], T any](data []byte) (res *T, err0 error) { + defer func() { + if err := recover(); err != nil { + err0 = fmt.Errorf("%v", err) + } + }() + statRecv.Add(len(data)) + // This probably can be expressed w/o reflect as "new U" where U is *T, + // but I failed to express that as generic constraints. + var msg Raw + msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(Raw) + msg.Init(data, flatbuffers.GetUOffsetT(data)) + if err := verify(msg, len(data)); err != nil { + return nil, err + } + return msg.UnPack(), nil +} + func verify(raw any, rawSize int) error { switch msg := raw.(type) { case *ExecutorMessageRaw: diff --git a/syz-manager/snapshot.go b/syz-manager/snapshot.go index b76c58583..52d3f02e2 100644 --- a/syz-manager/snapshot.go +++ b/syz-manager/snapshot.go @@ -121,53 +121,59 @@ func (mgr *Manager) snapshotRun(inst *vm.Instance, builder *flatbuffers.Builder, builder.Finish(msg.Pack(builder)) start := time.Now() - res, output, err := inst.RunSnapshot(builder.FinishedBytes()) + resData, output, err := inst.RunSnapshot(builder.FinishedBytes()) if err != nil { return nil, nil, err } elapsed := time.Since(start) - execError := "" - var info *flatrpc.ProgInfo - if len(res) > 4 { - res = res[4:] - // TODO: use more robust parsing from pkg/flatrpc/conn.go. - var raw flatrpc.ExecutorMessageRaw - raw.Init(res, flatbuffers.GetUOffsetT(res)) - union := raw.UnPack() - if union.Msg != nil && union.Msg.Value != nil { - msg := union.Msg.Value.(*flatrpc.ExecResult) - if msg.Info != nil { - msg.Info.Elapsed = uint64(elapsed) - for len(msg.Info.Calls) < len(req.Prog.Calls) { - msg.Info.Calls = append(msg.Info.Calls, &flatrpc.CallInfo{ - Error: 999, - }) - } - msg.Info.Calls = msg.Info.Calls[:len(req.Prog.Calls)] - if len(msg.Info.ExtraRaw) != 0 { - msg.Info.Extra = msg.Info.ExtraRaw[0] - for _, info := range msg.Info.ExtraRaw[1:] { - msg.Info.Extra.Cover = append(msg.Info.Extra.Cover, info.Cover...) - msg.Info.Extra.Signal = append(msg.Info.Extra.Signal, info.Signal...) - } - msg.Info.ExtraRaw = nil - } + res := parseExecResult(resData) + if res.Info != nil { + res.Info.Elapsed = uint64(elapsed) + for len(res.Info.Calls) < len(req.Prog.Calls) { + res.Info.Calls = append(res.Info.Calls, &flatrpc.CallInfo{ + Error: 999, + }) + } + res.Info.Calls = res.Info.Calls[:len(req.Prog.Calls)] + if len(res.Info.ExtraRaw) != 0 { + res.Info.Extra = res.Info.ExtraRaw[0] + for _, info := range res.Info.ExtraRaw[1:] { + res.Info.Extra.Cover = append(res.Info.Extra.Cover, info.Cover...) + res.Info.Extra.Signal = append(res.Info.Extra.Signal, info.Signal...) } - info = msg.Info - execError = msg.Error + res.Info.ExtraRaw = nil } } + ret := &queue.Result{ Status: queue.Success, - Info: info, + Info: res.Info, } - if execError != "" { + if res.Error != "" { ret.Status = queue.ExecFailure - ret.Err = errors.New(execError) + ret.Err = errors.New(res.Error) } if req.ReturnOutput { ret.Output = output } return ret, output, nil } + +func parseExecResult(data []byte) *flatrpc.ExecResult { + raw, err := flatrpc.Parse[*flatrpc.ExecutorMessageRaw](data[flatbuffers.SizeUint32:]) + if err != nil { + // Don't consider result parsing error as an infrastructure error, + // it's just the test program corrupted memory. + return &flatrpc.ExecResult{ + Error: err.Error(), + } + } + res, ok := raw.Msg.Value.(*flatrpc.ExecResult) + if !ok { + return &flatrpc.ExecResult{ + Error: "result is not ExecResult", + } + } + return res +} |
