diff options
| author | Alexander Potapenko <glider@google.com> | 2025-02-20 12:25:04 +0100 |
|---|---|---|
| committer | Alexander Potapenko <glider@google.com> | 2025-02-20 16:45:37 +0000 |
| commit | 0808a665bc75ab0845906bfeca0d12fb520ae6eb (patch) | |
| tree | 04e77371226d0433dd8a865b01bc1eeedebd3348 /pkg | |
| parent | 506687987fc2f8f40b2918782fc2943285fdc602 (diff) | |
pkg/rpcserver: pkg/flatrpc: executor: add handshake stage 0
As we figured out in #5805, syz-manager treats random incoming RPC
connections as trusted, and will crash if a non-executor client sends
an invalid packet to it.
To address this issue, we introduce another stage of handshake, which
includes a cookie exchange:
- upon connection from an executor, the manager sends a ConnectHello RPC
message to it, which contains a random 64-bit cookie;
- the executor calculates a hash of that cookie and includes it into
its ConnectRequest together with the other information;
- before checking the validity of ConnectRequest, the manager ensures
client sanity (passed ID didn't change, hashed cookie has the expected
value)
We deliberately pick a random cookie instead of a magic number: if the
fuzzer somehow learns to send packets to the manager, we don't want it to
crash multiple managers on the same machine.
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/flatrpc/conn_test.go | 29 | ||||
| -rw-r--r-- | pkg/flatrpc/flatrpc.fbs | 5 | ||||
| -rw-r--r-- | pkg/flatrpc/flatrpc.go | 113 | ||||
| -rw-r--r-- | pkg/flatrpc/flatrpc.h | 106 | ||||
| -rw-r--r-- | pkg/flatrpc/helpers.go | 1 | ||||
| -rw-r--r-- | pkg/rpcserver/rpcserver.go | 64 | ||||
| -rw-r--r-- | pkg/rpcserver/rpcserver_test.go | 32 |
7 files changed, 305 insertions, 45 deletions
diff --git a/pkg/flatrpc/conn_test.go b/pkg/flatrpc/conn_test.go index 4b108a5a4..87d5e9a8a 100644 --- a/pkg/flatrpc/conn_test.go +++ b/pkg/flatrpc/conn_test.go @@ -20,7 +20,11 @@ import ( ) func TestConn(t *testing.T) { + connectHello := &ConnectHello{ + Cookie: 1, + } connectReq := &ConnectRequest{ + Cookie: 73856093, Id: 1, Arch: "arch", GitRevision: "rev1", @@ -52,6 +56,9 @@ func TestConn(t *testing.T) { go func() { done <- serv.Serve(context.Background(), func(_ context.Context, c *Conn) error { + if err := Send(c, connectHello); err != nil { + return err + } connectReqGot, err := Recv[*ConnectRequestRaw](c) if err != nil { return err @@ -79,6 +86,12 @@ func TestConn(t *testing.T) { c := dial(t, serv.Addr.String()) defer c.Close() + connectHelloGot, err := Recv[*ConnectHelloRaw](c) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, connectHello, connectHelloGot) + if err := Send(c, connectReq); err != nil { t.Fatal(err) } @@ -102,7 +115,11 @@ func TestConn(t *testing.T) { } func BenchmarkConn(b *testing.B) { + connectHello := &ConnectHello{ + Cookie: 1, + } connectReq := &ConnectRequest{ + Cookie: 73856093, Id: 1, Arch: "arch", GitRevision: "rev1", @@ -125,7 +142,11 @@ func BenchmarkConn(b *testing.B) { done <- serv.Serve(context.Background(), func(_ context.Context, c *Conn) error { for i := 0; i < b.N; i++ { - _, err := Recv[*ConnectRequestRaw](c) + if err := Send(c, connectHello); err != nil { + return err + } + + _, err = Recv[*ConnectRequestRaw](c) if err != nil { return err } @@ -143,10 +164,14 @@ func BenchmarkConn(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { + _, err := Recv[*ConnectHelloRaw](c) + if err != nil { + b.Fatal(err) + } if err := Send(c, connectReq); err != nil { b.Fatal(err) } - _, err := Recv[*ConnectReplyRaw](c) + _, err = Recv[*ConnectReplyRaw](c) if err != nil { b.Fatal(err) } diff --git a/pkg/flatrpc/flatrpc.fbs b/pkg/flatrpc/flatrpc.fbs index 6d2307d6a..58dc7b292 100644 --- a/pkg/flatrpc/flatrpc.fbs +++ b/pkg/flatrpc/flatrpc.fbs @@ -34,8 +34,13 @@ enum Feature : uint64 (bit_flags) { BinFmtMisc, Swap, } + +table ConnectHelloRaw { + cookie :uint64; +} table ConnectRequestRaw { + cookie :uint64; id :int64; arch :string; git_revision :string; diff --git a/pkg/flatrpc/flatrpc.go b/pkg/flatrpc/flatrpc.go index fd5b4f614..d04d5c531 100644 --- a/pkg/flatrpc/flatrpc.go +++ b/pkg/flatrpc/flatrpc.go @@ -485,7 +485,83 @@ func (v SnapshotState) String() string { return "SnapshotState(" + strconv.FormatInt(int64(v), 10) + ")" } +type ConnectHelloRawT struct { + Cookie uint64 `json:"cookie"` +} + +func (t *ConnectHelloRawT) Pack(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + if t == nil { + return 0 + } + ConnectHelloRawStart(builder) + ConnectHelloRawAddCookie(builder, t.Cookie) + return ConnectHelloRawEnd(builder) +} + +func (rcv *ConnectHelloRaw) UnPackTo(t *ConnectHelloRawT) { + t.Cookie = rcv.Cookie() +} + +func (rcv *ConnectHelloRaw) UnPack() *ConnectHelloRawT { + if rcv == nil { + return nil + } + t := &ConnectHelloRawT{} + rcv.UnPackTo(t) + return t +} + +type ConnectHelloRaw struct { + _tab flatbuffers.Table +} + +func GetRootAsConnectHelloRaw(buf []byte, offset flatbuffers.UOffsetT) *ConnectHelloRaw { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &ConnectHelloRaw{} + x.Init(buf, n+offset) + return x +} + +func GetSizePrefixedRootAsConnectHelloRaw(buf []byte, offset flatbuffers.UOffsetT) *ConnectHelloRaw { + n := flatbuffers.GetUOffsetT(buf[offset+flatbuffers.SizeUint32:]) + x := &ConnectHelloRaw{} + x.Init(buf, n+offset+flatbuffers.SizeUint32) + return x +} + +func (rcv *ConnectHelloRaw) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *ConnectHelloRaw) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *ConnectHelloRaw) Cookie() uint64 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.GetUint64(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *ConnectHelloRaw) MutateCookie(n uint64) bool { + return rcv._tab.MutateUint64Slot(4, n) +} + +func ConnectHelloRawStart(builder *flatbuffers.Builder) { + builder.StartObject(1) +} +func ConnectHelloRawAddCookie(builder *flatbuffers.Builder, cookie uint64) { + builder.PrependUint64Slot(0, cookie, 0) +} +func ConnectHelloRawEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} + type ConnectRequestRawT struct { + Cookie uint64 `json:"cookie"` Id int64 `json:"id"` Arch string `json:"arch"` GitRevision string `json:"git_revision"` @@ -500,6 +576,7 @@ func (t *ConnectRequestRawT) Pack(builder *flatbuffers.Builder) flatbuffers.UOff gitRevisionOffset := builder.CreateString(t.GitRevision) syzRevisionOffset := builder.CreateString(t.SyzRevision) ConnectRequestRawStart(builder) + ConnectRequestRawAddCookie(builder, t.Cookie) ConnectRequestRawAddId(builder, t.Id) ConnectRequestRawAddArch(builder, archOffset) ConnectRequestRawAddGitRevision(builder, gitRevisionOffset) @@ -508,6 +585,7 @@ func (t *ConnectRequestRawT) Pack(builder *flatbuffers.Builder) flatbuffers.UOff } func (rcv *ConnectRequestRaw) UnPackTo(t *ConnectRequestRawT) { + t.Cookie = rcv.Cookie() t.Id = rcv.Id() t.Arch = string(rcv.Arch()) t.GitRevision = string(rcv.GitRevision()) @@ -550,20 +628,32 @@ func (rcv *ConnectRequestRaw) Table() flatbuffers.Table { return rcv._tab } -func (rcv *ConnectRequestRaw) Id() int64 { +func (rcv *ConnectRequestRaw) Cookie() uint64 { o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) if o != 0 { + return rcv._tab.GetUint64(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *ConnectRequestRaw) MutateCookie(n uint64) bool { + return rcv._tab.MutateUint64Slot(4, n) +} + +func (rcv *ConnectRequestRaw) Id() int64 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { return rcv._tab.GetInt64(o + rcv._tab.Pos) } return 0 } func (rcv *ConnectRequestRaw) MutateId(n int64) bool { - return rcv._tab.MutateInt64Slot(4, n) + return rcv._tab.MutateInt64Slot(6, n) } func (rcv *ConnectRequestRaw) Arch() []byte { - o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) if o != 0 { return rcv._tab.ByteVector(o + rcv._tab.Pos) } @@ -571,7 +661,7 @@ func (rcv *ConnectRequestRaw) Arch() []byte { } func (rcv *ConnectRequestRaw) GitRevision() []byte { - o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) if o != 0 { return rcv._tab.ByteVector(o + rcv._tab.Pos) } @@ -579,7 +669,7 @@ func (rcv *ConnectRequestRaw) GitRevision() []byte { } func (rcv *ConnectRequestRaw) SyzRevision() []byte { - o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) if o != 0 { return rcv._tab.ByteVector(o + rcv._tab.Pos) } @@ -587,19 +677,22 @@ func (rcv *ConnectRequestRaw) SyzRevision() []byte { } func ConnectRequestRawStart(builder *flatbuffers.Builder) { - builder.StartObject(4) + builder.StartObject(5) +} +func ConnectRequestRawAddCookie(builder *flatbuffers.Builder, cookie uint64) { + builder.PrependUint64Slot(0, cookie, 0) } func ConnectRequestRawAddId(builder *flatbuffers.Builder, id int64) { - builder.PrependInt64Slot(0, id, 0) + builder.PrependInt64Slot(1, id, 0) } func ConnectRequestRawAddArch(builder *flatbuffers.Builder, arch flatbuffers.UOffsetT) { - builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(arch), 0) + builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(arch), 0) } func ConnectRequestRawAddGitRevision(builder *flatbuffers.Builder, gitRevision flatbuffers.UOffsetT) { - builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(gitRevision), 0) + builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(gitRevision), 0) } func ConnectRequestRawAddSyzRevision(builder *flatbuffers.Builder, syzRevision flatbuffers.UOffsetT) { - builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(syzRevision), 0) + builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(syzRevision), 0) } func ConnectRequestRawEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { return builder.EndObject() diff --git a/pkg/flatrpc/flatrpc.h b/pkg/flatrpc/flatrpc.h index 12d905c12..def2648ad 100644 --- a/pkg/flatrpc/flatrpc.h +++ b/pkg/flatrpc/flatrpc.h @@ -15,6 +15,10 @@ static_assert(FLATBUFFERS_VERSION_MAJOR == 2 && namespace rpc { +struct ConnectHelloRaw; +struct ConnectHelloRawBuilder; +struct ConnectHelloRawT; + struct ConnectRequestRaw; struct ConnectRequestRawBuilder; struct ConnectRequestRawT; @@ -846,8 +850,61 @@ FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) ComparisonRaw FLATBUFFERS_FINAL_CLASS { }; FLATBUFFERS_STRUCT_END(ComparisonRaw, 32); +struct ConnectHelloRawT : public flatbuffers::NativeTable { + typedef ConnectHelloRaw TableType; + uint64_t cookie = 0; +}; + +struct ConnectHelloRaw FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConnectHelloRawT NativeTableType; + typedef ConnectHelloRawBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_COOKIE = 4 + }; + uint64_t cookie() const { + return GetField<uint64_t>(VT_COOKIE, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<uint64_t>(verifier, VT_COOKIE, 8) && + verifier.EndTable(); + } + ConnectHelloRawT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ConnectHelloRawT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<ConnectHelloRaw> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConnectHelloRawT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ConnectHelloRawBuilder { + typedef ConnectHelloRaw Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_cookie(uint64_t cookie) { + fbb_.AddElement<uint64_t>(ConnectHelloRaw::VT_COOKIE, cookie, 0); + } + explicit ConnectHelloRawBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset<ConnectHelloRaw> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<ConnectHelloRaw>(end); + return o; + } +}; + +inline flatbuffers::Offset<ConnectHelloRaw> CreateConnectHelloRaw( + flatbuffers::FlatBufferBuilder &_fbb, + uint64_t cookie = 0) { + ConnectHelloRawBuilder builder_(_fbb); + builder_.add_cookie(cookie); + return builder_.Finish(); +} + +flatbuffers::Offset<ConnectHelloRaw> CreateConnectHelloRaw(flatbuffers::FlatBufferBuilder &_fbb, const ConnectHelloRawT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct ConnectRequestRawT : public flatbuffers::NativeTable { typedef ConnectRequestRaw TableType; + uint64_t cookie = 0; int64_t id = 0; std::string arch{}; std::string git_revision{}; @@ -858,11 +915,15 @@ struct ConnectRequestRaw FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ConnectRequestRawT NativeTableType; typedef ConnectRequestRawBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_ID = 4, - VT_ARCH = 6, - VT_GIT_REVISION = 8, - VT_SYZ_REVISION = 10 + VT_COOKIE = 4, + VT_ID = 6, + VT_ARCH = 8, + VT_GIT_REVISION = 10, + VT_SYZ_REVISION = 12 }; + uint64_t cookie() const { + return GetField<uint64_t>(VT_COOKIE, 0); + } int64_t id() const { return GetField<int64_t>(VT_ID, 0); } @@ -877,6 +938,7 @@ struct ConnectRequestRaw FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && + VerifyField<uint64_t>(verifier, VT_COOKIE, 8) && VerifyField<int64_t>(verifier, VT_ID, 8) && VerifyOffset(verifier, VT_ARCH) && verifier.VerifyString(arch()) && @@ -895,6 +957,9 @@ struct ConnectRequestRawBuilder { typedef ConnectRequestRaw Table; flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; + void add_cookie(uint64_t cookie) { + fbb_.AddElement<uint64_t>(ConnectRequestRaw::VT_COOKIE, cookie, 0); + } void add_id(int64_t id) { fbb_.AddElement<int64_t>(ConnectRequestRaw::VT_ID, id, 0); } @@ -920,12 +985,14 @@ struct ConnectRequestRawBuilder { inline flatbuffers::Offset<ConnectRequestRaw> CreateConnectRequestRaw( flatbuffers::FlatBufferBuilder &_fbb, + uint64_t cookie = 0, int64_t id = 0, flatbuffers::Offset<flatbuffers::String> arch = 0, flatbuffers::Offset<flatbuffers::String> git_revision = 0, flatbuffers::Offset<flatbuffers::String> syz_revision = 0) { ConnectRequestRawBuilder builder_(_fbb); builder_.add_id(id); + builder_.add_cookie(cookie); builder_.add_syz_revision(syz_revision); builder_.add_git_revision(git_revision); builder_.add_arch(arch); @@ -934,6 +1001,7 @@ inline flatbuffers::Offset<ConnectRequestRaw> CreateConnectRequestRaw( inline flatbuffers::Offset<ConnectRequestRaw> CreateConnectRequestRawDirect( flatbuffers::FlatBufferBuilder &_fbb, + uint64_t cookie = 0, int64_t id = 0, const char *arch = nullptr, const char *git_revision = nullptr, @@ -943,6 +1011,7 @@ inline flatbuffers::Offset<ConnectRequestRaw> CreateConnectRequestRawDirect( auto syz_revision__ = syz_revision ? _fbb.CreateString(syz_revision) : 0; return rpc::CreateConnectRequestRaw( _fbb, + cookie, id, arch__, git_revision__, @@ -2896,6 +2965,32 @@ inline flatbuffers::Offset<SnapshotRequest> CreateSnapshotRequestDirect( flatbuffers::Offset<SnapshotRequest> CreateSnapshotRequest(flatbuffers::FlatBufferBuilder &_fbb, const SnapshotRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +inline ConnectHelloRawT *ConnectHelloRaw::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr<ConnectHelloRawT>(new ConnectHelloRawT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ConnectHelloRaw::UnPackTo(ConnectHelloRawT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = cookie(); _o->cookie = _e; } +} + +inline flatbuffers::Offset<ConnectHelloRaw> ConnectHelloRaw::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConnectHelloRawT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateConnectHelloRaw(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<ConnectHelloRaw> CreateConnectHelloRaw(flatbuffers::FlatBufferBuilder &_fbb, const ConnectHelloRawT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ConnectHelloRawT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _cookie = _o->cookie; + return rpc::CreateConnectHelloRaw( + _fbb, + _cookie); +} + inline ConnectRequestRawT *ConnectRequestRaw::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = std::unique_ptr<ConnectRequestRawT>(new ConnectRequestRawT()); UnPackTo(_o.get(), _resolver); @@ -2905,6 +3000,7 @@ inline ConnectRequestRawT *ConnectRequestRaw::UnPack(const flatbuffers::resolver inline void ConnectRequestRaw::UnPackTo(ConnectRequestRawT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; + { auto _e = cookie(); _o->cookie = _e; } { auto _e = id(); _o->id = _e; } { auto _e = arch(); if (_e) _o->arch = _e->str(); } { auto _e = git_revision(); if (_e) _o->git_revision = _e->str(); } @@ -2919,12 +3015,14 @@ inline flatbuffers::Offset<ConnectRequestRaw> CreateConnectRequestRaw(flatbuffer (void)_rehasher; (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ConnectRequestRawT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _cookie = _o->cookie; auto _id = _o->id; auto _arch = _o->arch.empty() ? 0 : _fbb.CreateString(_o->arch); auto _git_revision = _o->git_revision.empty() ? 0 : _fbb.CreateString(_o->git_revision); auto _syz_revision = _o->syz_revision.empty() ? 0 : _fbb.CreateString(_o->syz_revision); return rpc::CreateConnectRequestRaw( _fbb, + _cookie, _id, _arch, _git_revision, diff --git a/pkg/flatrpc/helpers.go b/pkg/flatrpc/helpers.go index 9a5463b24..5aa5cfe74 100644 --- a/pkg/flatrpc/helpers.go +++ b/pkg/flatrpc/helpers.go @@ -18,6 +18,7 @@ const AllFeatures = ^Feature(0) // Flatbuffers compiler adds T suffix to object API types, which are actual structs representing types. // This leads to non-idiomatic Go code, e.g. we would have to use []FileInfoT in Go code. // So we use Raw suffix for all flatbuffers tables and rename object API types here to idiomatic names. +type ConnectHello = ConnectHelloRawT type ConnectRequest = ConnectRequestRawT type ConnectReply = ConnectReplyRawT type InfoRequest = InfoRequestRawT diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go index de664cb0b..43761b651 100644 --- a/pkg/rpcserver/rpcserver.go +++ b/pkg/rpcserver/rpcserver.go @@ -8,6 +8,7 @@ import ( "context" "errors" "fmt" + "math/rand" "net/url" "slices" "sort" @@ -232,13 +233,17 @@ func (serv *server) Listen() error { return nil } +// Used for errors incompatible with further RPCServer operation. +var errFatal = errors.New("aborting RPC server") + func (serv *server) Serve(ctx context.Context) error { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error { err := serv.handleConn(ctx, conn) - if err != nil { - log.Logf(0, "serv.handleConn returend %v", err) + if err != nil && !errors.Is(err, errFatal) { + log.Logf(2, "%v", err) + return nil } return err }) @@ -261,24 +266,49 @@ func (serv *server) Port() int { return serv.serv.Addr.Port } +// Must be simple enough to not require adding dependencies to the executor. +func authHash(value uint64) uint64 { + prime1 := uint64(73856093) + prime2 := uint64(83492791) + hashValue := (value * prime1) ^ prime2 + + return hashValue +} + func (serv *server) handleConn(ctx context.Context, conn *flatrpc.Conn) error { + // Use a random cookie, because we do not want the fuzzer to accidentally guess it and DDoS multiple managers. + helloCookie := rand.Uint64() + expectCookie := authHash(helloCookie) + connectHello := &flatrpc.ConnectHello{ + Cookie: helloCookie, + } + + if err := flatrpc.Send(conn, connectHello); err != nil { + // The other side is not an executor. + return fmt.Errorf("failed to establish connection with a remote runner") + } + connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn) if err != nil { - log.Logf(1, "%s", err) - return nil + return err } id := int(connectReq.Id) + + if connectReq.Cookie != expectCookie { + return fmt.Errorf("client failed to respond with a valid cookie: %v (expected %v)", connectReq.Cookie, expectCookie) + } + + // From now on, assume that the client is well-behaving. log.Logf(1, "runner %v connected", id) if serv.cfg.VMLess { - // There is no VM loop, so minic what it would do. + // There is no VM loop, so mimic what it would do. serv.CreateInstance(id, nil, nil) defer func() { serv.StopFuzzing(id) serv.ShutdownInstance(id, true) }() } else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil { - // This is a fatal error. return err } serv.StatVMRestarts.Add(1) @@ -287,18 +317,12 @@ func (serv *server) handleConn(ctx context.Context, conn *flatrpc.Conn) error { runner := serv.runners[id] serv.mu.Unlock() if runner == nil { - log.Logf(2, "unknown VM %v tries to connect", id) - return nil + return fmt.Errorf("unknown VM %v tries to connect", id) } err = serv.handleRunnerConn(ctx, runner, conn) log.Logf(2, "runner %v: %v", id, err) - if err != nil && errors.Is(err, errFatal) { - log.Logf(0, "%v", err) - return err - } - runner.resultCh <- err return nil } @@ -337,9 +361,6 @@ func (serv *server) handleRunnerConn(ctx context.Context, runner *Runner, conn * return serv.connectionLoop(ctx, runner) } -// Used for errors incompatible with further RPCServer operation. -var errFatal = errors.New("aborting RPC server") - func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) { modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files) if err != nil { @@ -419,15 +440,16 @@ func (serv *server) connectionLoop(baseCtx context.Context, runner *Runner) erro func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error { if target.Arch != a.Arch { - return fmt.Errorf("mismatching manager/executor arches: %v vs %v (full request: `%#v`)", target.Arch, a.Arch, a) + return fmt.Errorf("%w: mismatching manager/executor arches: %v vs %v (full request: `%#v`)", + errFatal, target.Arch, a.Arch, a) } if prog.GitRevision != a.GitRevision { - return fmt.Errorf("mismatching manager/executor git revisions: %v vs %v", - prog.GitRevision, a.GitRevision) + return fmt.Errorf("%w: mismatching manager/executor git revisions: %v vs %v", + errFatal, prog.GitRevision, a.GitRevision) } if target.Revision != a.SyzRevision { - return fmt.Errorf("mismatching manager/executor system call descriptions: %v vs %v", - target.Revision, a.SyzRevision) + return fmt.Errorf("%w: mismatching manager/executor system call descriptions: %v vs %v", + errFatal, target.Revision, a.SyzRevision) } return nil } diff --git a/pkg/rpcserver/rpcserver_test.go b/pkg/rpcserver/rpcserver_test.go index 2da916286..429b275ac 100644 --- a/pkg/rpcserver/rpcserver_test.go +++ b/pkg/rpcserver/rpcserver_test.go @@ -6,6 +6,7 @@ package rpcserver import ( "context" "net" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -18,6 +19,7 @@ import ( "github.com/google/syzkaller/pkg/vminfo" "github.com/google/syzkaller/prog" "github.com/google/syzkaller/sys/targets" + "golang.org/x/sync/errgroup" ) func getTestDefaultCfg() mgrconfig.Config { @@ -176,12 +178,14 @@ func TestHandleConn(t *testing.T) { defaultCfg := getTestDefaultCfg() tests := []struct { - name string - modifyCfg func() *mgrconfig.Config - req *flatrpc.ConnectRequest + name string + wantErrMsg string + modifyCfg func() *mgrconfig.Config + req *flatrpc.ConnectRequest }{ { - name: "error, cfg.VMLess = false - unknown VM tries to connect", + name: "error, cfg.VMLess = false - unknown VM tries to connect", + wantErrMsg: "tries to connect", modifyCfg: func() *mgrconfig.Config { return &defaultCfg }, @@ -214,10 +218,22 @@ func TestHandleConn(t *testing.T) { injectExec := make(chan bool) serv.CreateInstance(1, injectExec, nil) - - go flatrpc.Send(clientConn, tt.req) - err = serv.handleConn(context.Background(), serverConn) - if err != nil { + g := errgroup.Group{} + g.Go(func() error { + hello, err := flatrpc.Recv[*flatrpc.ConnectHelloRaw](clientConn) + if err != nil { + return err + } + tt.req.Cookie = authHash(hello.Cookie) + flatrpc.Send(clientConn, tt.req) + return nil + }) + if err := serv.handleConn(context.Background(), serverConn); err != nil { + if !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Fatal(err) + } + } + if err := g.Wait(); err != nil { t.Fatal(err) } }) |
