diff options
| author | Alexander Potapenko <glider@google.com> | 2025-02-20 12:25:04 +0100 |
|---|---|---|
| committer | Alexander Potapenko <glider@google.com> | 2025-02-20 16:45:37 +0000 |
| commit | 0808a665bc75ab0845906bfeca0d12fb520ae6eb (patch) | |
| tree | 04e77371226d0433dd8a865b01bc1eeedebd3348 /pkg/rpcserver | |
| parent | 506687987fc2f8f40b2918782fc2943285fdc602 (diff) | |
pkg/rpcserver: pkg/flatrpc: executor: add handshake stage 0
As we figured out in #5805, syz-manager treats random incoming RPC
connections as trusted, and will crash if a non-executor client sends
an invalid packet to it.
To address this issue, we introduce another stage of handshake, which
includes a cookie exchange:
- upon connection from an executor, the manager sends a ConnectHello RPC
message to it, which contains a random 64-bit cookie;
- the executor calculates a hash of that cookie and includes it into
its ConnectRequest together with the other information;
- before checking the validity of ConnectRequest, the manager ensures
client sanity (passed ID didn't change, hashed cookie has the expected
value)
We deliberately pick a random cookie instead of a magic number: if the
fuzzer somehow learns to send packets to the manager, we don't want it to
crash multiple managers on the same machine.
Diffstat (limited to 'pkg/rpcserver')
| -rw-r--r-- | pkg/rpcserver/rpcserver.go | 64 | ||||
| -rw-r--r-- | pkg/rpcserver/rpcserver_test.go | 32 |
2 files changed, 67 insertions, 29 deletions
diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go index de664cb0b..43761b651 100644 --- a/pkg/rpcserver/rpcserver.go +++ b/pkg/rpcserver/rpcserver.go @@ -8,6 +8,7 @@ import ( "context" "errors" "fmt" + "math/rand" "net/url" "slices" "sort" @@ -232,13 +233,17 @@ func (serv *server) Listen() error { return nil } +// Used for errors incompatible with further RPCServer operation. +var errFatal = errors.New("aborting RPC server") + func (serv *server) Serve(ctx context.Context) error { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error { err := serv.handleConn(ctx, conn) - if err != nil { - log.Logf(0, "serv.handleConn returend %v", err) + if err != nil && !errors.Is(err, errFatal) { + log.Logf(2, "%v", err) + return nil } return err }) @@ -261,24 +266,49 @@ func (serv *server) Port() int { return serv.serv.Addr.Port } +// Must be simple enough to not require adding dependencies to the executor. +func authHash(value uint64) uint64 { + prime1 := uint64(73856093) + prime2 := uint64(83492791) + hashValue := (value * prime1) ^ prime2 + + return hashValue +} + func (serv *server) handleConn(ctx context.Context, conn *flatrpc.Conn) error { + // Use a random cookie, because we do not want the fuzzer to accidentally guess it and DDoS multiple managers. + helloCookie := rand.Uint64() + expectCookie := authHash(helloCookie) + connectHello := &flatrpc.ConnectHello{ + Cookie: helloCookie, + } + + if err := flatrpc.Send(conn, connectHello); err != nil { + // The other side is not an executor. + return fmt.Errorf("failed to establish connection with a remote runner") + } + connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn) if err != nil { - log.Logf(1, "%s", err) - return nil + return err } id := int(connectReq.Id) + + if connectReq.Cookie != expectCookie { + return fmt.Errorf("client failed to respond with a valid cookie: %v (expected %v)", connectReq.Cookie, expectCookie) + } + + // From now on, assume that the client is well-behaving. log.Logf(1, "runner %v connected", id) if serv.cfg.VMLess { - // There is no VM loop, so minic what it would do. + // There is no VM loop, so mimic what it would do. serv.CreateInstance(id, nil, nil) defer func() { serv.StopFuzzing(id) serv.ShutdownInstance(id, true) }() } else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil { - // This is a fatal error. return err } serv.StatVMRestarts.Add(1) @@ -287,18 +317,12 @@ func (serv *server) handleConn(ctx context.Context, conn *flatrpc.Conn) error { runner := serv.runners[id] serv.mu.Unlock() if runner == nil { - log.Logf(2, "unknown VM %v tries to connect", id) - return nil + return fmt.Errorf("unknown VM %v tries to connect", id) } err = serv.handleRunnerConn(ctx, runner, conn) log.Logf(2, "runner %v: %v", id, err) - if err != nil && errors.Is(err, errFatal) { - log.Logf(0, "%v", err) - return err - } - runner.resultCh <- err return nil } @@ -337,9 +361,6 @@ func (serv *server) handleRunnerConn(ctx context.Context, runner *Runner, conn * return serv.connectionLoop(ctx, runner) } -// Used for errors incompatible with further RPCServer operation. -var errFatal = errors.New("aborting RPC server") - func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) { modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files) if err != nil { @@ -419,15 +440,16 @@ func (serv *server) connectionLoop(baseCtx context.Context, runner *Runner) erro func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error { if target.Arch != a.Arch { - return fmt.Errorf("mismatching manager/executor arches: %v vs %v (full request: `%#v`)", target.Arch, a.Arch, a) + return fmt.Errorf("%w: mismatching manager/executor arches: %v vs %v (full request: `%#v`)", + errFatal, target.Arch, a.Arch, a) } if prog.GitRevision != a.GitRevision { - return fmt.Errorf("mismatching manager/executor git revisions: %v vs %v", - prog.GitRevision, a.GitRevision) + return fmt.Errorf("%w: mismatching manager/executor git revisions: %v vs %v", + errFatal, prog.GitRevision, a.GitRevision) } if target.Revision != a.SyzRevision { - return fmt.Errorf("mismatching manager/executor system call descriptions: %v vs %v", - target.Revision, a.SyzRevision) + return fmt.Errorf("%w: mismatching manager/executor system call descriptions: %v vs %v", + errFatal, target.Revision, a.SyzRevision) } return nil } diff --git a/pkg/rpcserver/rpcserver_test.go b/pkg/rpcserver/rpcserver_test.go index 2da916286..429b275ac 100644 --- a/pkg/rpcserver/rpcserver_test.go +++ b/pkg/rpcserver/rpcserver_test.go @@ -6,6 +6,7 @@ package rpcserver import ( "context" "net" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -18,6 +19,7 @@ import ( "github.com/google/syzkaller/pkg/vminfo" "github.com/google/syzkaller/prog" "github.com/google/syzkaller/sys/targets" + "golang.org/x/sync/errgroup" ) func getTestDefaultCfg() mgrconfig.Config { @@ -176,12 +178,14 @@ func TestHandleConn(t *testing.T) { defaultCfg := getTestDefaultCfg() tests := []struct { - name string - modifyCfg func() *mgrconfig.Config - req *flatrpc.ConnectRequest + name string + wantErrMsg string + modifyCfg func() *mgrconfig.Config + req *flatrpc.ConnectRequest }{ { - name: "error, cfg.VMLess = false - unknown VM tries to connect", + name: "error, cfg.VMLess = false - unknown VM tries to connect", + wantErrMsg: "tries to connect", modifyCfg: func() *mgrconfig.Config { return &defaultCfg }, @@ -214,10 +218,22 @@ func TestHandleConn(t *testing.T) { injectExec := make(chan bool) serv.CreateInstance(1, injectExec, nil) - - go flatrpc.Send(clientConn, tt.req) - err = serv.handleConn(context.Background(), serverConn) - if err != nil { + g := errgroup.Group{} + g.Go(func() error { + hello, err := flatrpc.Recv[*flatrpc.ConnectHelloRaw](clientConn) + if err != nil { + return err + } + tt.req.Cookie = authHash(hello.Cookie) + flatrpc.Send(clientConn, tt.req) + return nil + }) + if err := serv.handleConn(context.Background(), serverConn); err != nil { + if !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Fatal(err) + } + } + if err := g.Wait(); err != nil { t.Fatal(err) } }) |
