aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2024-07-11 12:49:19 +0200
committerDmitry Vyukov <dvyukov@google.com>2024-07-11 15:09:59 +0000
commit9a444d3613c4d0cdd2d75c68408175488fed0a0b (patch)
tree58cc6c37efade71191c9c8794b5bfeffc6aa304c /pkg
parent3cf1187a067e9bb9d9a3fe079e6942abb526ddb2 (diff)
pkg/flatrpc: verify executor messages better
This should prevent possible OOM kills. See the added comment for details.
Diffstat (limited to 'pkg')
-rw-r--r--pkg/flatrpc/conn.go105
-rw-r--r--pkg/flatrpc/conn_test.go32
2 files changed, 127 insertions, 10 deletions
diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go
index 56afdfca4..f3a997e1c 100644
--- a/pkg/flatrpc/conn.go
+++ b/pkg/flatrpc/conn.go
@@ -11,8 +11,9 @@ import (
"reflect"
"slices"
"sync"
+ "unsafe"
- flatbuffers "github.com/google/flatbuffers/go"
+ "github.com/google/flatbuffers/go"
"github.com/google/syzkaller/pkg/log"
"github.com/google/syzkaller/pkg/stats"
)
@@ -120,12 +121,8 @@ type RecvType[T any] interface {
// only until the next Recv call (messages share the same underlying receive buffer).
func Recv[Raw RecvType[T], T any](c *Conn) (res *T, err0 error) {
defer func() {
- if err1 := recover(); err1 != nil {
- if err2, ok := err1.(error); ok {
- err0 = err2
- } else {
- err0 = fmt.Errorf("%v", err1)
- }
+ if err := recover(); err != nil {
+ err0 = fmt.Errorf("%v", err)
}
}()
raw, err := RecvRaw[Raw](c)
@@ -135,7 +132,12 @@ func Recv[Raw RecvType[T], T any](c *Conn) (res *T, err0 error) {
return raw.UnPack(), nil
}
-func RecvRaw[T flatbuffers.FlatBuffer](c *Conn) (T, error) {
+func RecvRaw[T flatbuffers.FlatBuffer](c *Conn) (res T, err0 error) {
+ defer func() {
+ if err := recover(); err != nil {
+ err0 = fmt.Errorf("%v", err)
+ }
+ }()
// First, discard the previous message.
// For simplicity we copy any data from the next message to the beginning of the buffer.
// Theoretically we could something more efficient, e.g. don't copy if we already
@@ -169,7 +171,7 @@ func RecvRaw[T flatbuffers.FlatBuffer](c *Conn) (T, error) {
msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(T)
data := c.data[sizePrefixSize:c.lastMsg]
msg.Init(data, flatbuffers.GetUOffsetT(data))
- return msg, nil
+ return msg, verify(msg, size)
}
// recv ensures that we have at least 'size' bytes received in c.data.
@@ -188,3 +190,88 @@ func (c *Conn) recv(size int) error {
c.hasData += n
return nil
}
+
+func verify(raw any, rawSize int) error {
+ switch msg := raw.(type) {
+ case *ExecutorMessageRaw:
+ return verifyExecutorMessage(msg, rawSize)
+ }
+ return nil
+}
+
+func verifyExecutorMessage(raw *ExecutorMessageRaw, rawSize int) error {
+ // We receive the message into raw (non object API) type and carefully verify
+ // because the message from the test machine can be corrupted in all possible ways.
+ // Recovering from panics handles most corruptions (since flatbuffers does not use unsafe
+ // and panics on any OOB references). But it's still possible that UnPack may try to allocate
+ // unbounded amount of memory and crash with OOM. To prevent that we check that arrays have
+ // reasonable size. We don't need to check []byte/string b/c for them flatbuffers use
+ // Table.ByteVector which directly references the underlying byte slice and also panics
+ // if size is OOB. But we need to check all other arrays b/c for them flatbuffers will
+ // first do make([]T, size), filling that array later will panic, but it's already too late
+ // since the make will kill the process with OOM.
+ switch typ := raw.MsgType(); typ {
+ case ExecutorMessagesRawExecResult,
+ ExecutorMessagesRawExecuting,
+ ExecutorMessagesRawState:
+ default:
+ return fmt.Errorf("bad executor message type %v", typ)
+ }
+ var tab flatbuffers.Table
+ if !raw.Msg(&tab) {
+ return errors.New("received no message")
+ }
+ // Only ExecResult has arrays.
+ if raw.MsgType() == ExecutorMessagesRawExecResult {
+ var res ExecResultRaw
+ res.Init(tab.Bytes, tab.Pos)
+ return verifyExecResult(&res, rawSize)
+ }
+ return nil
+}
+
+func verifyExecResult(res *ExecResultRaw, rawSize int) error {
+ info := res.Info(nil)
+ if info == nil {
+ return nil
+ }
+ var tmp ComparisonRaw
+ // It's hard to impose good limit on each individual signal/cover/comps array,
+ // so instead we count total memory size for all calls and check that it's not
+ // larger than the total message size.
+ callSize := func(call *CallInfoRaw) int {
+ // Cap array size at 1G to prevent overflows during multiplication by size and addition.
+ const maxSize = 1 << 30
+ size := 0
+ if call.SignalLength() != 0 {
+ size += min(maxSize, call.SignalLength()) * int(unsafe.Sizeof(call.Signal(0)))
+ }
+ if call.CoverLength() != 0 {
+ size += min(maxSize, call.CoverLength()) * int(unsafe.Sizeof(call.Cover(0)))
+ }
+ if call.CompsLength() != 0 {
+ size += min(maxSize, call.CompsLength()) * int(unsafe.Sizeof(call.Comps(&tmp, 0)))
+ }
+ return size
+ }
+ size := 0
+ var call CallInfoRaw
+ for i := 0; i < info.CallsLength(); i++ {
+ if info.Calls(&call, i) {
+ size += callSize(&call)
+ }
+ }
+ for i := 0; i < info.ExtraRawLength(); i++ {
+ if info.ExtraRaw(&call, i) {
+ size += callSize(&call)
+ }
+ }
+ if info.Extra(&call) != nil {
+ size += callSize(&call)
+ }
+ if size > rawSize {
+ return fmt.Errorf("corrupted message: total size %v, size of elements %v",
+ rawSize, size)
+ }
+ return nil
+}
diff --git a/pkg/flatrpc/conn_test.go b/pkg/flatrpc/conn_test.go
index a6f7f23f9..c9448872c 100644
--- a/pkg/flatrpc/conn_test.go
+++ b/pkg/flatrpc/conn_test.go
@@ -6,10 +6,13 @@ package flatrpc
import (
"net"
"os"
+ "runtime/debug"
+ "sync"
"syscall"
"testing"
"time"
+ "github.com/google/flatbuffers/go"
"github.com/stretchr/testify/assert"
)
@@ -147,9 +150,36 @@ func dial(t testing.TB, addr string) *Conn {
return NewConn(conn)
}
+var memoryLimitOnce sync.Once
+
func FuzzRecv(f *testing.F) {
+ msg := &ExecutorMessage{
+ Msg: &ExecutorMessages{
+ Type: ExecutorMessagesRawExecResult,
+ Value: &ExecResult{
+ Id: 1,
+ Output: []byte("aaa"),
+ Error: "bbb",
+ Info: &ProgInfo{
+ ExtraRaw: []*CallInfo{
+ {
+ Signal: []uint64{1, 2},
+ },
+ },
+ },
+ },
+ },
+ }
+ builder := flatbuffers.NewBuilder(0)
+ builder.FinishSizePrefixed(msg.Pack(builder))
+ f.Add(builder.FinishedBytes())
f.Fuzz(func(t *testing.T, data []byte) {
- data = data[:min(len(data), 1<<10)]
+ memoryLimitOnce.Do(func() {
+ debug.SetMemoryLimit(64 << 20)
+ })
+ if len(data) > 1<<10 {
+ t.Skip()
+ }
fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
if err != nil {
t.Fatal(err)