aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/rpcserver
diff options
context:
space:
mode:
authorAlexander Potapenko <glider@google.com>2025-02-20 12:25:04 +0100
committerAlexander Potapenko <glider@google.com>2025-02-20 16:45:37 +0000
commit0808a665bc75ab0845906bfeca0d12fb520ae6eb (patch)
tree04e77371226d0433dd8a865b01bc1eeedebd3348 /pkg/rpcserver
parent506687987fc2f8f40b2918782fc2943285fdc602 (diff)
pkg/rpcserver: pkg/flatrpc: executor: add handshake stage 0
As we figured out in #5805, syz-manager treats random incoming RPC connections as trusted, and will crash if a non-executor client sends an invalid packet to it. To address this issue, we introduce another stage of handshake, which includes a cookie exchange: - upon connection from an executor, the manager sends a ConnectHello RPC message to it, which contains a random 64-bit cookie; - the executor calculates a hash of that cookie and includes it into its ConnectRequest together with the other information; - before checking the validity of ConnectRequest, the manager ensures client sanity (passed ID didn't change, hashed cookie has the expected value) We deliberately pick a random cookie instead of a magic number: if the fuzzer somehow learns to send packets to the manager, we don't want it to crash multiple managers on the same machine.
Diffstat (limited to 'pkg/rpcserver')
-rw-r--r--pkg/rpcserver/rpcserver.go64
-rw-r--r--pkg/rpcserver/rpcserver_test.go32
2 files changed, 67 insertions, 29 deletions
diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go
index de664cb0b..43761b651 100644
--- a/pkg/rpcserver/rpcserver.go
+++ b/pkg/rpcserver/rpcserver.go
@@ -8,6 +8,7 @@ import (
"context"
"errors"
"fmt"
+ "math/rand"
"net/url"
"slices"
"sort"
@@ -232,13 +233,17 @@ func (serv *server) Listen() error {
return nil
}
+// Used for errors incompatible with further RPCServer operation.
+var errFatal = errors.New("aborting RPC server")
+
func (serv *server) Serve(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error {
err := serv.handleConn(ctx, conn)
- if err != nil {
- log.Logf(0, "serv.handleConn returend %v", err)
+ if err != nil && !errors.Is(err, errFatal) {
+ log.Logf(2, "%v", err)
+ return nil
}
return err
})
@@ -261,24 +266,49 @@ func (serv *server) Port() int {
return serv.serv.Addr.Port
}
+// Must be simple enough to not require adding dependencies to the executor.
+func authHash(value uint64) uint64 {
+ prime1 := uint64(73856093)
+ prime2 := uint64(83492791)
+ hashValue := (value * prime1) ^ prime2
+
+ return hashValue
+}
+
func (serv *server) handleConn(ctx context.Context, conn *flatrpc.Conn) error {
+ // Use a random cookie, because we do not want the fuzzer to accidentally guess it and DDoS multiple managers.
+ helloCookie := rand.Uint64()
+ expectCookie := authHash(helloCookie)
+ connectHello := &flatrpc.ConnectHello{
+ Cookie: helloCookie,
+ }
+
+ if err := flatrpc.Send(conn, connectHello); err != nil {
+ // The other side is not an executor.
+ return fmt.Errorf("failed to establish connection with a remote runner")
+ }
+
connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn)
if err != nil {
- log.Logf(1, "%s", err)
- return nil
+ return err
}
id := int(connectReq.Id)
+
+ if connectReq.Cookie != expectCookie {
+ return fmt.Errorf("client failed to respond with a valid cookie: %v (expected %v)", connectReq.Cookie, expectCookie)
+ }
+
+ // From now on, assume that the client is well-behaving.
log.Logf(1, "runner %v connected", id)
if serv.cfg.VMLess {
- // There is no VM loop, so minic what it would do.
+ // There is no VM loop, so mimic what it would do.
serv.CreateInstance(id, nil, nil)
defer func() {
serv.StopFuzzing(id)
serv.ShutdownInstance(id, true)
}()
} else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil {
- // This is a fatal error.
return err
}
serv.StatVMRestarts.Add(1)
@@ -287,18 +317,12 @@ func (serv *server) handleConn(ctx context.Context, conn *flatrpc.Conn) error {
runner := serv.runners[id]
serv.mu.Unlock()
if runner == nil {
- log.Logf(2, "unknown VM %v tries to connect", id)
- return nil
+ return fmt.Errorf("unknown VM %v tries to connect", id)
}
err = serv.handleRunnerConn(ctx, runner, conn)
log.Logf(2, "runner %v: %v", id, err)
- if err != nil && errors.Is(err, errFatal) {
- log.Logf(0, "%v", err)
- return err
- }
-
runner.resultCh <- err
return nil
}
@@ -337,9 +361,6 @@ func (serv *server) handleRunnerConn(ctx context.Context, runner *Runner, conn *
return serv.connectionLoop(ctx, runner)
}
-// Used for errors incompatible with further RPCServer operation.
-var errFatal = errors.New("aborting RPC server")
-
func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) {
modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files)
if err != nil {
@@ -419,15 +440,16 @@ func (serv *server) connectionLoop(baseCtx context.Context, runner *Runner) erro
func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error {
if target.Arch != a.Arch {
- return fmt.Errorf("mismatching manager/executor arches: %v vs %v (full request: `%#v`)", target.Arch, a.Arch, a)
+ return fmt.Errorf("%w: mismatching manager/executor arches: %v vs %v (full request: `%#v`)",
+ errFatal, target.Arch, a.Arch, a)
}
if prog.GitRevision != a.GitRevision {
- return fmt.Errorf("mismatching manager/executor git revisions: %v vs %v",
- prog.GitRevision, a.GitRevision)
+ return fmt.Errorf("%w: mismatching manager/executor git revisions: %v vs %v",
+ errFatal, prog.GitRevision, a.GitRevision)
}
if target.Revision != a.SyzRevision {
- return fmt.Errorf("mismatching manager/executor system call descriptions: %v vs %v",
- target.Revision, a.SyzRevision)
+ return fmt.Errorf("%w: mismatching manager/executor system call descriptions: %v vs %v",
+ errFatal, target.Revision, a.SyzRevision)
}
return nil
}
diff --git a/pkg/rpcserver/rpcserver_test.go b/pkg/rpcserver/rpcserver_test.go
index 2da916286..429b275ac 100644
--- a/pkg/rpcserver/rpcserver_test.go
+++ b/pkg/rpcserver/rpcserver_test.go
@@ -6,6 +6,7 @@ package rpcserver
import (
"context"
"net"
+ "strings"
"testing"
"github.com/stretchr/testify/assert"
@@ -18,6 +19,7 @@ import (
"github.com/google/syzkaller/pkg/vminfo"
"github.com/google/syzkaller/prog"
"github.com/google/syzkaller/sys/targets"
+ "golang.org/x/sync/errgroup"
)
func getTestDefaultCfg() mgrconfig.Config {
@@ -176,12 +178,14 @@ func TestHandleConn(t *testing.T) {
defaultCfg := getTestDefaultCfg()
tests := []struct {
- name string
- modifyCfg func() *mgrconfig.Config
- req *flatrpc.ConnectRequest
+ name string
+ wantErrMsg string
+ modifyCfg func() *mgrconfig.Config
+ req *flatrpc.ConnectRequest
}{
{
- name: "error, cfg.VMLess = false - unknown VM tries to connect",
+ name: "error, cfg.VMLess = false - unknown VM tries to connect",
+ wantErrMsg: "tries to connect",
modifyCfg: func() *mgrconfig.Config {
return &defaultCfg
},
@@ -214,10 +218,22 @@ func TestHandleConn(t *testing.T) {
injectExec := make(chan bool)
serv.CreateInstance(1, injectExec, nil)
-
- go flatrpc.Send(clientConn, tt.req)
- err = serv.handleConn(context.Background(), serverConn)
- if err != nil {
+ g := errgroup.Group{}
+ g.Go(func() error {
+ hello, err := flatrpc.Recv[*flatrpc.ConnectHelloRaw](clientConn)
+ if err != nil {
+ return err
+ }
+ tt.req.Cookie = authHash(hello.Cookie)
+ flatrpc.Send(clientConn, tt.req)
+ return nil
+ })
+ if err := serv.handleConn(context.Background(), serverConn); err != nil {
+ if !strings.Contains(err.Error(), tt.wantErrMsg) {
+ t.Fatal(err)
+ }
+ }
+ if err := g.Wait(); err != nil {
t.Fatal(err)
}
})