aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAlexander Potapenko <glider@google.com>2025-02-20 12:25:04 +0100
committerAlexander Potapenko <glider@google.com>2025-02-20 16:45:37 +0000
commit0808a665bc75ab0845906bfeca0d12fb520ae6eb (patch)
tree04e77371226d0433dd8a865b01bc1eeedebd3348
parent506687987fc2f8f40b2918782fc2943285fdc602 (diff)
pkg/rpcserver: pkg/flatrpc: executor: add handshake stage 0
As we figured out in #5805, syz-manager treats random incoming RPC connections as trusted, and will crash if a non-executor client sends an invalid packet to it. To address this issue, we introduce another stage of handshake, which includes a cookie exchange: - upon connection from an executor, the manager sends a ConnectHello RPC message to it, which contains a random 64-bit cookie; - the executor calculates a hash of that cookie and includes it into its ConnectRequest together with the other information; - before checking the validity of ConnectRequest, the manager ensures client sanity (passed ID didn't change, hashed cookie has the expected value) We deliberately pick a random cookie instead of a magic number: if the fuzzer somehow learns to send packets to the manager, we don't want it to crash multiple managers on the same machine.
-rw-r--r--executor/executor_runner.h16
-rw-r--r--pkg/flatrpc/conn_test.go29
-rw-r--r--pkg/flatrpc/flatrpc.fbs5
-rw-r--r--pkg/flatrpc/flatrpc.go113
-rw-r--r--pkg/flatrpc/flatrpc.h106
-rw-r--r--pkg/flatrpc/helpers.go1
-rw-r--r--pkg/rpcserver/rpcserver.go64
-rw-r--r--pkg/rpcserver/rpcserver_test.go32
8 files changed, 321 insertions, 45 deletions
diff --git a/executor/executor_runner.h b/executor/executor_runner.h
index a3b668893..88fc785db 100644
--- a/executor/executor_runner.h
+++ b/executor/executor_runner.h
@@ -629,9 +629,24 @@ private:
failmsg("bad restarting", "restarting=%d", restarting_);
}
+ // Implementation must match that in pkg/rpcserver/rpcserver.go.
+ uint64 HashAuthCookie(uint64 cookie)
+ {
+ const uint64_t prime1 = 73856093;
+ const uint64_t prime2 = 83492791;
+
+ return (cookie * prime1) ^ prime2;
+ }
+
int Handshake()
{
+ // Handshake stage 0: get a cookie from the manager.
+ rpc::ConnectHelloRawT conn_hello;
+ conn_.Recv(conn_hello);
+
+ // Handshake stage 1: share basic information about the client.
rpc::ConnectRequestRawT conn_req;
+ conn_req.cookie = HashAuthCookie(conn_hello.cookie);
conn_req.id = vm_index_;
conn_req.arch = GOARCH;
conn_req.git_revision = GIT_REVISION;
@@ -656,6 +671,7 @@ private:
if (conn_reply.cover)
max_signal_.emplace();
+ // Handshake stage 2: share information requested by the manager.
rpc::InfoRequestRawT info_req;
info_req.files = ReadFiles(conn_reply.files);
diff --git a/pkg/flatrpc/conn_test.go b/pkg/flatrpc/conn_test.go
index 4b108a5a4..87d5e9a8a 100644
--- a/pkg/flatrpc/conn_test.go
+++ b/pkg/flatrpc/conn_test.go
@@ -20,7 +20,11 @@ import (
)
func TestConn(t *testing.T) {
+ connectHello := &ConnectHello{
+ Cookie: 1,
+ }
connectReq := &ConnectRequest{
+ Cookie: 73856093,
Id: 1,
Arch: "arch",
GitRevision: "rev1",
@@ -52,6 +56,9 @@ func TestConn(t *testing.T) {
go func() {
done <- serv.Serve(context.Background(),
func(_ context.Context, c *Conn) error {
+ if err := Send(c, connectHello); err != nil {
+ return err
+ }
connectReqGot, err := Recv[*ConnectRequestRaw](c)
if err != nil {
return err
@@ -79,6 +86,12 @@ func TestConn(t *testing.T) {
c := dial(t, serv.Addr.String())
defer c.Close()
+ connectHelloGot, err := Recv[*ConnectHelloRaw](c)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, connectHello, connectHelloGot)
+
if err := Send(c, connectReq); err != nil {
t.Fatal(err)
}
@@ -102,7 +115,11 @@ func TestConn(t *testing.T) {
}
func BenchmarkConn(b *testing.B) {
+ connectHello := &ConnectHello{
+ Cookie: 1,
+ }
connectReq := &ConnectRequest{
+ Cookie: 73856093,
Id: 1,
Arch: "arch",
GitRevision: "rev1",
@@ -125,7 +142,11 @@ func BenchmarkConn(b *testing.B) {
done <- serv.Serve(context.Background(),
func(_ context.Context, c *Conn) error {
for i := 0; i < b.N; i++ {
- _, err := Recv[*ConnectRequestRaw](c)
+ if err := Send(c, connectHello); err != nil {
+ return err
+ }
+
+ _, err = Recv[*ConnectRequestRaw](c)
if err != nil {
return err
}
@@ -143,10 +164,14 @@ func BenchmarkConn(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
+ _, err := Recv[*ConnectHelloRaw](c)
+ if err != nil {
+ b.Fatal(err)
+ }
if err := Send(c, connectReq); err != nil {
b.Fatal(err)
}
- _, err := Recv[*ConnectReplyRaw](c)
+ _, err = Recv[*ConnectReplyRaw](c)
if err != nil {
b.Fatal(err)
}
diff --git a/pkg/flatrpc/flatrpc.fbs b/pkg/flatrpc/flatrpc.fbs
index 6d2307d6a..58dc7b292 100644
--- a/pkg/flatrpc/flatrpc.fbs
+++ b/pkg/flatrpc/flatrpc.fbs
@@ -34,8 +34,13 @@ enum Feature : uint64 (bit_flags) {
BinFmtMisc,
Swap,
}
+
+table ConnectHelloRaw {
+ cookie :uint64;
+}
table ConnectRequestRaw {
+ cookie :uint64;
id :int64;
arch :string;
git_revision :string;
diff --git a/pkg/flatrpc/flatrpc.go b/pkg/flatrpc/flatrpc.go
index fd5b4f614..d04d5c531 100644
--- a/pkg/flatrpc/flatrpc.go
+++ b/pkg/flatrpc/flatrpc.go
@@ -485,7 +485,83 @@ func (v SnapshotState) String() string {
return "SnapshotState(" + strconv.FormatInt(int64(v), 10) + ")"
}
+type ConnectHelloRawT struct {
+ Cookie uint64 `json:"cookie"`
+}
+
+func (t *ConnectHelloRawT) Pack(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
+ if t == nil {
+ return 0
+ }
+ ConnectHelloRawStart(builder)
+ ConnectHelloRawAddCookie(builder, t.Cookie)
+ return ConnectHelloRawEnd(builder)
+}
+
+func (rcv *ConnectHelloRaw) UnPackTo(t *ConnectHelloRawT) {
+ t.Cookie = rcv.Cookie()
+}
+
+func (rcv *ConnectHelloRaw) UnPack() *ConnectHelloRawT {
+ if rcv == nil {
+ return nil
+ }
+ t := &ConnectHelloRawT{}
+ rcv.UnPackTo(t)
+ return t
+}
+
+type ConnectHelloRaw struct {
+ _tab flatbuffers.Table
+}
+
+func GetRootAsConnectHelloRaw(buf []byte, offset flatbuffers.UOffsetT) *ConnectHelloRaw {
+ n := flatbuffers.GetUOffsetT(buf[offset:])
+ x := &ConnectHelloRaw{}
+ x.Init(buf, n+offset)
+ return x
+}
+
+func GetSizePrefixedRootAsConnectHelloRaw(buf []byte, offset flatbuffers.UOffsetT) *ConnectHelloRaw {
+ n := flatbuffers.GetUOffsetT(buf[offset+flatbuffers.SizeUint32:])
+ x := &ConnectHelloRaw{}
+ x.Init(buf, n+offset+flatbuffers.SizeUint32)
+ return x
+}
+
+func (rcv *ConnectHelloRaw) Init(buf []byte, i flatbuffers.UOffsetT) {
+ rcv._tab.Bytes = buf
+ rcv._tab.Pos = i
+}
+
+func (rcv *ConnectHelloRaw) Table() flatbuffers.Table {
+ return rcv._tab
+}
+
+func (rcv *ConnectHelloRaw) Cookie() uint64 {
+ o := flatbuffers.UOffsetT(rcv._tab.Offset(4))
+ if o != 0 {
+ return rcv._tab.GetUint64(o + rcv._tab.Pos)
+ }
+ return 0
+}
+
+func (rcv *ConnectHelloRaw) MutateCookie(n uint64) bool {
+ return rcv._tab.MutateUint64Slot(4, n)
+}
+
+func ConnectHelloRawStart(builder *flatbuffers.Builder) {
+ builder.StartObject(1)
+}
+func ConnectHelloRawAddCookie(builder *flatbuffers.Builder, cookie uint64) {
+ builder.PrependUint64Slot(0, cookie, 0)
+}
+func ConnectHelloRawEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
+ return builder.EndObject()
+}
+
type ConnectRequestRawT struct {
+ Cookie uint64 `json:"cookie"`
Id int64 `json:"id"`
Arch string `json:"arch"`
GitRevision string `json:"git_revision"`
@@ -500,6 +576,7 @@ func (t *ConnectRequestRawT) Pack(builder *flatbuffers.Builder) flatbuffers.UOff
gitRevisionOffset := builder.CreateString(t.GitRevision)
syzRevisionOffset := builder.CreateString(t.SyzRevision)
ConnectRequestRawStart(builder)
+ ConnectRequestRawAddCookie(builder, t.Cookie)
ConnectRequestRawAddId(builder, t.Id)
ConnectRequestRawAddArch(builder, archOffset)
ConnectRequestRawAddGitRevision(builder, gitRevisionOffset)
@@ -508,6 +585,7 @@ func (t *ConnectRequestRawT) Pack(builder *flatbuffers.Builder) flatbuffers.UOff
}
func (rcv *ConnectRequestRaw) UnPackTo(t *ConnectRequestRawT) {
+ t.Cookie = rcv.Cookie()
t.Id = rcv.Id()
t.Arch = string(rcv.Arch())
t.GitRevision = string(rcv.GitRevision())
@@ -550,20 +628,32 @@ func (rcv *ConnectRequestRaw) Table() flatbuffers.Table {
return rcv._tab
}
-func (rcv *ConnectRequestRaw) Id() int64 {
+func (rcv *ConnectRequestRaw) Cookie() uint64 {
o := flatbuffers.UOffsetT(rcv._tab.Offset(4))
if o != 0 {
+ return rcv._tab.GetUint64(o + rcv._tab.Pos)
+ }
+ return 0
+}
+
+func (rcv *ConnectRequestRaw) MutateCookie(n uint64) bool {
+ return rcv._tab.MutateUint64Slot(4, n)
+}
+
+func (rcv *ConnectRequestRaw) Id() int64 {
+ o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
+ if o != 0 {
return rcv._tab.GetInt64(o + rcv._tab.Pos)
}
return 0
}
func (rcv *ConnectRequestRaw) MutateId(n int64) bool {
- return rcv._tab.MutateInt64Slot(4, n)
+ return rcv._tab.MutateInt64Slot(6, n)
}
func (rcv *ConnectRequestRaw) Arch() []byte {
- o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
+ o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
}
@@ -571,7 +661,7 @@ func (rcv *ConnectRequestRaw) Arch() []byte {
}
func (rcv *ConnectRequestRaw) GitRevision() []byte {
- o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
+ o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
}
@@ -579,7 +669,7 @@ func (rcv *ConnectRequestRaw) GitRevision() []byte {
}
func (rcv *ConnectRequestRaw) SyzRevision() []byte {
- o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
+ o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
}
@@ -587,19 +677,22 @@ func (rcv *ConnectRequestRaw) SyzRevision() []byte {
}
func ConnectRequestRawStart(builder *flatbuffers.Builder) {
- builder.StartObject(4)
+ builder.StartObject(5)
+}
+func ConnectRequestRawAddCookie(builder *flatbuffers.Builder, cookie uint64) {
+ builder.PrependUint64Slot(0, cookie, 0)
}
func ConnectRequestRawAddId(builder *flatbuffers.Builder, id int64) {
- builder.PrependInt64Slot(0, id, 0)
+ builder.PrependInt64Slot(1, id, 0)
}
func ConnectRequestRawAddArch(builder *flatbuffers.Builder, arch flatbuffers.UOffsetT) {
- builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(arch), 0)
+ builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(arch), 0)
}
func ConnectRequestRawAddGitRevision(builder *flatbuffers.Builder, gitRevision flatbuffers.UOffsetT) {
- builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(gitRevision), 0)
+ builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(gitRevision), 0)
}
func ConnectRequestRawAddSyzRevision(builder *flatbuffers.Builder, syzRevision flatbuffers.UOffsetT) {
- builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(syzRevision), 0)
+ builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(syzRevision), 0)
}
func ConnectRequestRawEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
return builder.EndObject()
diff --git a/pkg/flatrpc/flatrpc.h b/pkg/flatrpc/flatrpc.h
index 12d905c12..def2648ad 100644
--- a/pkg/flatrpc/flatrpc.h
+++ b/pkg/flatrpc/flatrpc.h
@@ -15,6 +15,10 @@ static_assert(FLATBUFFERS_VERSION_MAJOR == 2 &&
namespace rpc {
+struct ConnectHelloRaw;
+struct ConnectHelloRawBuilder;
+struct ConnectHelloRawT;
+
struct ConnectRequestRaw;
struct ConnectRequestRawBuilder;
struct ConnectRequestRawT;
@@ -846,8 +850,61 @@ FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) ComparisonRaw FLATBUFFERS_FINAL_CLASS {
};
FLATBUFFERS_STRUCT_END(ComparisonRaw, 32);
+struct ConnectHelloRawT : public flatbuffers::NativeTable {
+ typedef ConnectHelloRaw TableType;
+ uint64_t cookie = 0;
+};
+
+struct ConnectHelloRaw FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ConnectHelloRawT NativeTableType;
+ typedef ConnectHelloRawBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_COOKIE = 4
+ };
+ uint64_t cookie() const {
+ return GetField<uint64_t>(VT_COOKIE, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint64_t>(verifier, VT_COOKIE, 8) &&
+ verifier.EndTable();
+ }
+ ConnectHelloRawT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ConnectHelloRawT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ConnectHelloRaw> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConnectHelloRawT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ConnectHelloRawBuilder {
+ typedef ConnectHelloRaw Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_cookie(uint64_t cookie) {
+ fbb_.AddElement<uint64_t>(ConnectHelloRaw::VT_COOKIE, cookie, 0);
+ }
+ explicit ConnectHelloRawBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ flatbuffers::Offset<ConnectHelloRaw> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ConnectHelloRaw>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ConnectHelloRaw> CreateConnectHelloRaw(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ uint64_t cookie = 0) {
+ ConnectHelloRawBuilder builder_(_fbb);
+ builder_.add_cookie(cookie);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<ConnectHelloRaw> CreateConnectHelloRaw(flatbuffers::FlatBufferBuilder &_fbb, const ConnectHelloRawT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct ConnectRequestRawT : public flatbuffers::NativeTable {
typedef ConnectRequestRaw TableType;
+ uint64_t cookie = 0;
int64_t id = 0;
std::string arch{};
std::string git_revision{};
@@ -858,11 +915,15 @@ struct ConnectRequestRaw FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef ConnectRequestRawT NativeTableType;
typedef ConnectRequestRawBuilder Builder;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
- VT_ID = 4,
- VT_ARCH = 6,
- VT_GIT_REVISION = 8,
- VT_SYZ_REVISION = 10
+ VT_COOKIE = 4,
+ VT_ID = 6,
+ VT_ARCH = 8,
+ VT_GIT_REVISION = 10,
+ VT_SYZ_REVISION = 12
};
+ uint64_t cookie() const {
+ return GetField<uint64_t>(VT_COOKIE, 0);
+ }
int64_t id() const {
return GetField<int64_t>(VT_ID, 0);
}
@@ -877,6 +938,7 @@ struct ConnectRequestRaw FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
+ VerifyField<uint64_t>(verifier, VT_COOKIE, 8) &&
VerifyField<int64_t>(verifier, VT_ID, 8) &&
VerifyOffset(verifier, VT_ARCH) &&
verifier.VerifyString(arch()) &&
@@ -895,6 +957,9 @@ struct ConnectRequestRawBuilder {
typedef ConnectRequestRaw Table;
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
+ void add_cookie(uint64_t cookie) {
+ fbb_.AddElement<uint64_t>(ConnectRequestRaw::VT_COOKIE, cookie, 0);
+ }
void add_id(int64_t id) {
fbb_.AddElement<int64_t>(ConnectRequestRaw::VT_ID, id, 0);
}
@@ -920,12 +985,14 @@ struct ConnectRequestRawBuilder {
inline flatbuffers::Offset<ConnectRequestRaw> CreateConnectRequestRaw(
flatbuffers::FlatBufferBuilder &_fbb,
+ uint64_t cookie = 0,
int64_t id = 0,
flatbuffers::Offset<flatbuffers::String> arch = 0,
flatbuffers::Offset<flatbuffers::String> git_revision = 0,
flatbuffers::Offset<flatbuffers::String> syz_revision = 0) {
ConnectRequestRawBuilder builder_(_fbb);
builder_.add_id(id);
+ builder_.add_cookie(cookie);
builder_.add_syz_revision(syz_revision);
builder_.add_git_revision(git_revision);
builder_.add_arch(arch);
@@ -934,6 +1001,7 @@ inline flatbuffers::Offset<ConnectRequestRaw> CreateConnectRequestRaw(
inline flatbuffers::Offset<ConnectRequestRaw> CreateConnectRequestRawDirect(
flatbuffers::FlatBufferBuilder &_fbb,
+ uint64_t cookie = 0,
int64_t id = 0,
const char *arch = nullptr,
const char *git_revision = nullptr,
@@ -943,6 +1011,7 @@ inline flatbuffers::Offset<ConnectRequestRaw> CreateConnectRequestRawDirect(
auto syz_revision__ = syz_revision ? _fbb.CreateString(syz_revision) : 0;
return rpc::CreateConnectRequestRaw(
_fbb,
+ cookie,
id,
arch__,
git_revision__,
@@ -2896,6 +2965,32 @@ inline flatbuffers::Offset<SnapshotRequest> CreateSnapshotRequestDirect(
flatbuffers::Offset<SnapshotRequest> CreateSnapshotRequest(flatbuffers::FlatBufferBuilder &_fbb, const SnapshotRequestT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+inline ConnectHelloRawT *ConnectHelloRaw::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = std::unique_ptr<ConnectHelloRawT>(new ConnectHelloRawT());
+ UnPackTo(_o.get(), _resolver);
+ return _o.release();
+}
+
+inline void ConnectHelloRaw::UnPackTo(ConnectHelloRawT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = cookie(); _o->cookie = _e; }
+}
+
+inline flatbuffers::Offset<ConnectHelloRaw> ConnectHelloRaw::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConnectHelloRawT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateConnectHelloRaw(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ConnectHelloRaw> CreateConnectHelloRaw(flatbuffers::FlatBufferBuilder &_fbb, const ConnectHelloRawT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ConnectHelloRawT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _cookie = _o->cookie;
+ return rpc::CreateConnectHelloRaw(
+ _fbb,
+ _cookie);
+}
+
inline ConnectRequestRawT *ConnectRequestRaw::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = std::unique_ptr<ConnectRequestRawT>(new ConnectRequestRawT());
UnPackTo(_o.get(), _resolver);
@@ -2905,6 +3000,7 @@ inline ConnectRequestRawT *ConnectRequestRaw::UnPack(const flatbuffers::resolver
inline void ConnectRequestRaw::UnPackTo(ConnectRequestRawT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
+ { auto _e = cookie(); _o->cookie = _e; }
{ auto _e = id(); _o->id = _e; }
{ auto _e = arch(); if (_e) _o->arch = _e->str(); }
{ auto _e = git_revision(); if (_e) _o->git_revision = _e->str(); }
@@ -2919,12 +3015,14 @@ inline flatbuffers::Offset<ConnectRequestRaw> CreateConnectRequestRaw(flatbuffer
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ConnectRequestRawT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _cookie = _o->cookie;
auto _id = _o->id;
auto _arch = _o->arch.empty() ? 0 : _fbb.CreateString(_o->arch);
auto _git_revision = _o->git_revision.empty() ? 0 : _fbb.CreateString(_o->git_revision);
auto _syz_revision = _o->syz_revision.empty() ? 0 : _fbb.CreateString(_o->syz_revision);
return rpc::CreateConnectRequestRaw(
_fbb,
+ _cookie,
_id,
_arch,
_git_revision,
diff --git a/pkg/flatrpc/helpers.go b/pkg/flatrpc/helpers.go
index 9a5463b24..5aa5cfe74 100644
--- a/pkg/flatrpc/helpers.go
+++ b/pkg/flatrpc/helpers.go
@@ -18,6 +18,7 @@ const AllFeatures = ^Feature(0)
// Flatbuffers compiler adds T suffix to object API types, which are actual structs representing types.
// This leads to non-idiomatic Go code, e.g. we would have to use []FileInfoT in Go code.
// So we use Raw suffix for all flatbuffers tables and rename object API types here to idiomatic names.
+type ConnectHello = ConnectHelloRawT
type ConnectRequest = ConnectRequestRawT
type ConnectReply = ConnectReplyRawT
type InfoRequest = InfoRequestRawT
diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go
index de664cb0b..43761b651 100644
--- a/pkg/rpcserver/rpcserver.go
+++ b/pkg/rpcserver/rpcserver.go
@@ -8,6 +8,7 @@ import (
"context"
"errors"
"fmt"
+ "math/rand"
"net/url"
"slices"
"sort"
@@ -232,13 +233,17 @@ func (serv *server) Listen() error {
return nil
}
+// Used for errors incompatible with further RPCServer operation.
+var errFatal = errors.New("aborting RPC server")
+
func (serv *server) Serve(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error {
err := serv.handleConn(ctx, conn)
- if err != nil {
- log.Logf(0, "serv.handleConn returend %v", err)
+ if err != nil && !errors.Is(err, errFatal) {
+ log.Logf(2, "%v", err)
+ return nil
}
return err
})
@@ -261,24 +266,49 @@ func (serv *server) Port() int {
return serv.serv.Addr.Port
}
+// Must be simple enough to not require adding dependencies to the executor.
+func authHash(value uint64) uint64 {
+ prime1 := uint64(73856093)
+ prime2 := uint64(83492791)
+ hashValue := (value * prime1) ^ prime2
+
+ return hashValue
+}
+
func (serv *server) handleConn(ctx context.Context, conn *flatrpc.Conn) error {
+ // Use a random cookie, because we do not want the fuzzer to accidentally guess it and DDoS multiple managers.
+ helloCookie := rand.Uint64()
+ expectCookie := authHash(helloCookie)
+ connectHello := &flatrpc.ConnectHello{
+ Cookie: helloCookie,
+ }
+
+ if err := flatrpc.Send(conn, connectHello); err != nil {
+ // The other side is not an executor.
+ return fmt.Errorf("failed to establish connection with a remote runner")
+ }
+
connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn)
if err != nil {
- log.Logf(1, "%s", err)
- return nil
+ return err
}
id := int(connectReq.Id)
+
+ if connectReq.Cookie != expectCookie {
+ return fmt.Errorf("client failed to respond with a valid cookie: %v (expected %v)", connectReq.Cookie, expectCookie)
+ }
+
+ // From now on, assume that the client is well-behaving.
log.Logf(1, "runner %v connected", id)
if serv.cfg.VMLess {
- // There is no VM loop, so minic what it would do.
+ // There is no VM loop, so mimic what it would do.
serv.CreateInstance(id, nil, nil)
defer func() {
serv.StopFuzzing(id)
serv.ShutdownInstance(id, true)
}()
} else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil {
- // This is a fatal error.
return err
}
serv.StatVMRestarts.Add(1)
@@ -287,18 +317,12 @@ func (serv *server) handleConn(ctx context.Context, conn *flatrpc.Conn) error {
runner := serv.runners[id]
serv.mu.Unlock()
if runner == nil {
- log.Logf(2, "unknown VM %v tries to connect", id)
- return nil
+ return fmt.Errorf("unknown VM %v tries to connect", id)
}
err = serv.handleRunnerConn(ctx, runner, conn)
log.Logf(2, "runner %v: %v", id, err)
- if err != nil && errors.Is(err, errFatal) {
- log.Logf(0, "%v", err)
- return err
- }
-
runner.resultCh <- err
return nil
}
@@ -337,9 +361,6 @@ func (serv *server) handleRunnerConn(ctx context.Context, runner *Runner, conn *
return serv.connectionLoop(ctx, runner)
}
-// Used for errors incompatible with further RPCServer operation.
-var errFatal = errors.New("aborting RPC server")
-
func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) {
modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files)
if err != nil {
@@ -419,15 +440,16 @@ func (serv *server) connectionLoop(baseCtx context.Context, runner *Runner) erro
func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error {
if target.Arch != a.Arch {
- return fmt.Errorf("mismatching manager/executor arches: %v vs %v (full request: `%#v`)", target.Arch, a.Arch, a)
+ return fmt.Errorf("%w: mismatching manager/executor arches: %v vs %v (full request: `%#v`)",
+ errFatal, target.Arch, a.Arch, a)
}
if prog.GitRevision != a.GitRevision {
- return fmt.Errorf("mismatching manager/executor git revisions: %v vs %v",
- prog.GitRevision, a.GitRevision)
+ return fmt.Errorf("%w: mismatching manager/executor git revisions: %v vs %v",
+ errFatal, prog.GitRevision, a.GitRevision)
}
if target.Revision != a.SyzRevision {
- return fmt.Errorf("mismatching manager/executor system call descriptions: %v vs %v",
- target.Revision, a.SyzRevision)
+ return fmt.Errorf("%w: mismatching manager/executor system call descriptions: %v vs %v",
+ errFatal, target.Revision, a.SyzRevision)
}
return nil
}
diff --git a/pkg/rpcserver/rpcserver_test.go b/pkg/rpcserver/rpcserver_test.go
index 2da916286..429b275ac 100644
--- a/pkg/rpcserver/rpcserver_test.go
+++ b/pkg/rpcserver/rpcserver_test.go
@@ -6,6 +6,7 @@ package rpcserver
import (
"context"
"net"
+ "strings"
"testing"
"github.com/stretchr/testify/assert"
@@ -18,6 +19,7 @@ import (
"github.com/google/syzkaller/pkg/vminfo"
"github.com/google/syzkaller/prog"
"github.com/google/syzkaller/sys/targets"
+ "golang.org/x/sync/errgroup"
)
func getTestDefaultCfg() mgrconfig.Config {
@@ -176,12 +178,14 @@ func TestHandleConn(t *testing.T) {
defaultCfg := getTestDefaultCfg()
tests := []struct {
- name string
- modifyCfg func() *mgrconfig.Config
- req *flatrpc.ConnectRequest
+ name string
+ wantErrMsg string
+ modifyCfg func() *mgrconfig.Config
+ req *flatrpc.ConnectRequest
}{
{
- name: "error, cfg.VMLess = false - unknown VM tries to connect",
+ name: "error, cfg.VMLess = false - unknown VM tries to connect",
+ wantErrMsg: "tries to connect",
modifyCfg: func() *mgrconfig.Config {
return &defaultCfg
},
@@ -214,10 +218,22 @@ func TestHandleConn(t *testing.T) {
injectExec := make(chan bool)
serv.CreateInstance(1, injectExec, nil)
-
- go flatrpc.Send(clientConn, tt.req)
- err = serv.handleConn(context.Background(), serverConn)
- if err != nil {
+ g := errgroup.Group{}
+ g.Go(func() error {
+ hello, err := flatrpc.Recv[*flatrpc.ConnectHelloRaw](clientConn)
+ if err != nil {
+ return err
+ }
+ tt.req.Cookie = authHash(hello.Cookie)
+ flatrpc.Send(clientConn, tt.req)
+ return nil
+ })
+ if err := serv.handleConn(context.Background(), serverConn); err != nil {
+ if !strings.Contains(err.Error(), tt.wantErrMsg) {
+ t.Fatal(err)
+ }
+ }
+ if err := g.Wait(); err != nil {
t.Fatal(err)
}
})