From 9a444d3613c4d0cdd2d75c68408175488fed0a0b Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Thu, 11 Jul 2024 12:49:19 +0200 Subject: pkg/flatrpc: verify executor messages better This should prevent possible OOM kills. See the added comment for details. --- pkg/flatrpc/conn.go | 105 +++++++++++++++++++++++++++++++++++++++++++---- pkg/flatrpc/conn_test.go | 32 ++++++++++++++- 2 files changed, 127 insertions(+), 10 deletions(-) (limited to 'pkg') diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go index 56afdfca4..f3a997e1c 100644 --- a/pkg/flatrpc/conn.go +++ b/pkg/flatrpc/conn.go @@ -11,8 +11,9 @@ import ( "reflect" "slices" "sync" + "unsafe" - flatbuffers "github.com/google/flatbuffers/go" + "github.com/google/flatbuffers/go" "github.com/google/syzkaller/pkg/log" "github.com/google/syzkaller/pkg/stats" ) @@ -120,12 +121,8 @@ type RecvType[T any] interface { // only until the next Recv call (messages share the same underlying receive buffer). func Recv[Raw RecvType[T], T any](c *Conn) (res *T, err0 error) { defer func() { - if err1 := recover(); err1 != nil { - if err2, ok := err1.(error); ok { - err0 = err2 - } else { - err0 = fmt.Errorf("%v", err1) - } + if err := recover(); err != nil { + err0 = fmt.Errorf("%v", err) } }() raw, err := RecvRaw[Raw](c) @@ -135,7 +132,12 @@ func Recv[Raw RecvType[T], T any](c *Conn) (res *T, err0 error) { return raw.UnPack(), nil } -func RecvRaw[T flatbuffers.FlatBuffer](c *Conn) (T, error) { +func RecvRaw[T flatbuffers.FlatBuffer](c *Conn) (res T, err0 error) { + defer func() { + if err := recover(); err != nil { + err0 = fmt.Errorf("%v", err) + } + }() // First, discard the previous message. // For simplicity we copy any data from the next message to the beginning of the buffer. // Theoretically we could something more efficient, e.g. don't copy if we already @@ -169,7 +171,7 @@ func RecvRaw[T flatbuffers.FlatBuffer](c *Conn) (T, error) { msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(T) data := c.data[sizePrefixSize:c.lastMsg] msg.Init(data, flatbuffers.GetUOffsetT(data)) - return msg, nil + return msg, verify(msg, size) } // recv ensures that we have at least 'size' bytes received in c.data. @@ -188,3 +190,88 @@ func (c *Conn) recv(size int) error { c.hasData += n return nil } + +func verify(raw any, rawSize int) error { + switch msg := raw.(type) { + case *ExecutorMessageRaw: + return verifyExecutorMessage(msg, rawSize) + } + return nil +} + +func verifyExecutorMessage(raw *ExecutorMessageRaw, rawSize int) error { + // We receive the message into raw (non object API) type and carefully verify + // because the message from the test machine can be corrupted in all possible ways. + // Recovering from panics handles most corruptions (since flatbuffers does not use unsafe + // and panics on any OOB references). But it's still possible that UnPack may try to allocate + // unbounded amount of memory and crash with OOM. To prevent that we check that arrays have + // reasonable size. We don't need to check []byte/string b/c for them flatbuffers use + // Table.ByteVector which directly references the underlying byte slice and also panics + // if size is OOB. But we need to check all other arrays b/c for them flatbuffers will + // first do make([]T, size), filling that array later will panic, but it's already too late + // since the make will kill the process with OOM. + switch typ := raw.MsgType(); typ { + case ExecutorMessagesRawExecResult, + ExecutorMessagesRawExecuting, + ExecutorMessagesRawState: + default: + return fmt.Errorf("bad executor message type %v", typ) + } + var tab flatbuffers.Table + if !raw.Msg(&tab) { + return errors.New("received no message") + } + // Only ExecResult has arrays. + if raw.MsgType() == ExecutorMessagesRawExecResult { + var res ExecResultRaw + res.Init(tab.Bytes, tab.Pos) + return verifyExecResult(&res, rawSize) + } + return nil +} + +func verifyExecResult(res *ExecResultRaw, rawSize int) error { + info := res.Info(nil) + if info == nil { + return nil + } + var tmp ComparisonRaw + // It's hard to impose good limit on each individual signal/cover/comps array, + // so instead we count total memory size for all calls and check that it's not + // larger than the total message size. + callSize := func(call *CallInfoRaw) int { + // Cap array size at 1G to prevent overflows during multiplication by size and addition. + const maxSize = 1 << 30 + size := 0 + if call.SignalLength() != 0 { + size += min(maxSize, call.SignalLength()) * int(unsafe.Sizeof(call.Signal(0))) + } + if call.CoverLength() != 0 { + size += min(maxSize, call.CoverLength()) * int(unsafe.Sizeof(call.Cover(0))) + } + if call.CompsLength() != 0 { + size += min(maxSize, call.CompsLength()) * int(unsafe.Sizeof(call.Comps(&tmp, 0))) + } + return size + } + size := 0 + var call CallInfoRaw + for i := 0; i < info.CallsLength(); i++ { + if info.Calls(&call, i) { + size += callSize(&call) + } + } + for i := 0; i < info.ExtraRawLength(); i++ { + if info.ExtraRaw(&call, i) { + size += callSize(&call) + } + } + if info.Extra(&call) != nil { + size += callSize(&call) + } + if size > rawSize { + return fmt.Errorf("corrupted message: total size %v, size of elements %v", + rawSize, size) + } + return nil +} diff --git a/pkg/flatrpc/conn_test.go b/pkg/flatrpc/conn_test.go index a6f7f23f9..c9448872c 100644 --- a/pkg/flatrpc/conn_test.go +++ b/pkg/flatrpc/conn_test.go @@ -6,10 +6,13 @@ package flatrpc import ( "net" "os" + "runtime/debug" + "sync" "syscall" "testing" "time" + "github.com/google/flatbuffers/go" "github.com/stretchr/testify/assert" ) @@ -147,9 +150,36 @@ func dial(t testing.TB, addr string) *Conn { return NewConn(conn) } +var memoryLimitOnce sync.Once + func FuzzRecv(f *testing.F) { + msg := &ExecutorMessage{ + Msg: &ExecutorMessages{ + Type: ExecutorMessagesRawExecResult, + Value: &ExecResult{ + Id: 1, + Output: []byte("aaa"), + Error: "bbb", + Info: &ProgInfo{ + ExtraRaw: []*CallInfo{ + { + Signal: []uint64{1, 2}, + }, + }, + }, + }, + }, + } + builder := flatbuffers.NewBuilder(0) + builder.FinishSizePrefixed(msg.Pack(builder)) + f.Add(builder.FinishedBytes()) f.Fuzz(func(t *testing.T, data []byte) { - data = data[:min(len(data), 1<<10)] + memoryLimitOnce.Do(func() { + debug.SetMemoryLimit(64 << 20) + }) + if len(data) > 1<<10 { + t.Skip() + } fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0) if err != nil { t.Fatal(err) -- cgit mrf-deployment