diff options
| -rw-r--r-- | Makefile | 2 | ||||
| -rw-r--r-- | pkg/rpcserver/local.go | 8 | ||||
| -rw-r--r-- | pkg/rpcserver/mocks/Manager.go | 127 | ||||
| -rw-r--r-- | pkg/rpcserver/rpcserver.go | 81 | ||||
| -rw-r--r-- | pkg/rpcserver/rpcserver_test.go | 209 | ||||
| -rw-r--r-- | syz-manager/manager.go | 11 |
6 files changed, 397 insertions, 41 deletions
@@ -232,7 +232,7 @@ generate: $(MAKE) format generate_go: format_cpp - $(GO) generate ./executor ./pkg/ifuzz ./pkg/build + $(GO) generate ./executor ./pkg/ifuzz ./pkg/build ./pkg/rpcserver $(GO) generate ./vm/proxyapp generate_rpc: diff --git a/pkg/rpcserver/local.go b/pkg/rpcserver/local.go index e1522aa79..09cd1868d 100644 --- a/pkg/rpcserver/local.go +++ b/pkg/rpcserver/local.go @@ -47,8 +47,8 @@ func RunLocal(cfg *LocalConfig) error { cfg: cfg, setupDone: make(chan bool), } - serv, err := newImpl(cfg.Context, &cfg.Config, ctx) - if err != nil { + serv := newImpl(cfg.Context, &cfg.Config, ctx) + if err := serv.Listen(); err != nil { return err } defer serv.Close() @@ -62,7 +62,7 @@ func RunLocal(cfg *LocalConfig) error { defer serv.ShutdownInstance(id, true) bin := cfg.Executor - args := []string{"runner", fmt.Sprint(id), "localhost", fmt.Sprint(serv.Port)} + args := []string{"runner", fmt.Sprint(id), "localhost", fmt.Sprint(serv.Port())} if cfg.GDB { bin = "gdb" args = append([]string{ @@ -107,7 +107,7 @@ func RunLocal(cfg *LocalConfig) error { type local struct { cfg *LocalConfig - serv *Server + serv Server setupDone chan bool } diff --git a/pkg/rpcserver/mocks/Manager.go b/pkg/rpcserver/mocks/Manager.go new file mode 100644 index 000000000..a5662ad89 --- /dev/null +++ b/pkg/rpcserver/mocks/Manager.go @@ -0,0 +1,127 @@ +// Code generated by mockery v2.40.3. DO NOT EDIT. + +package mocks + +import ( + flatrpc "github.com/google/syzkaller/pkg/flatrpc" + mock "github.com/stretchr/testify/mock" + + prog "github.com/google/syzkaller/prog" + + queue "github.com/google/syzkaller/pkg/fuzzer/queue" + + signal "github.com/google/syzkaller/pkg/signal" + + vminfo "github.com/google/syzkaller/pkg/vminfo" +) + +// Manager is an autogenerated mock type for the Manager type +type Manager struct { + mock.Mock +} + +// BugFrames provides a mock function with given fields: +func (_m *Manager) BugFrames() ([]string, []string) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for BugFrames") + } + + var r0 []string + var r1 []string + if rf, ok := ret.Get(0).(func() ([]string, []string)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func() []string); ok { + r1 = rf() + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]string) + } + } + + return r0, r1 +} + +// CoverageFilter provides a mock function with given fields: modules +func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { + ret := _m.Called(modules) + + if len(ret) == 0 { + panic("no return value specified for CoverageFilter") + } + + var r0 []uint64 + if rf, ok := ret.Get(0).(func([]*vminfo.KernelModule) []uint64); ok { + r0 = rf(modules) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]uint64) + } + } + + return r0 +} + +// MachineChecked provides a mock function with given fields: features, syscalls +func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source { + ret := _m.Called(features, syscalls) + + if len(ret) == 0 { + panic("no return value specified for MachineChecked") + } + + var r0 queue.Source + if rf, ok := ret.Get(0).(func(flatrpc.Feature, map[*prog.Syscall]bool) queue.Source); ok { + r0 = rf(features, syscalls) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(queue.Source) + } + } + + return r0 +} + +// MaxSignal provides a mock function with given fields: +func (_m *Manager) MaxSignal() signal.Signal { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for MaxSignal") + } + + var r0 signal.Signal + if rf, ok := ret.Get(0).(func() signal.Signal); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(signal.Signal) + } + } + + return r0 +} + +// NewManager creates a new instance of Manager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewManager(t interface { + mock.TestingT + Cleanup(func()) +}) *Manager { + mock := &Manager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go index 5a104d81c..367e3b5c6 100644 --- a/pkg/rpcserver/rpcserver.go +++ b/pkg/rpcserver/rpcserver.go @@ -48,6 +48,7 @@ type Config struct { localModules []*vminfo.KernelModule } +//go:generate ../../tools/mockery.sh --name Manager --output ./mocks type Manager interface { MaxSignal() signal.Signal BugFrames() (leaks []string, races []string) @@ -55,9 +56,18 @@ type Manager interface { CoverageFilter(modules []*vminfo.KernelModule) []uint64 } -type Server struct { - Port int +type Server interface { + Listen() error + Close() error + Port() int + TriagedCorpus() + CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error + ShutdownInstance(id int, crashed bool) ([]ExecRecord, []byte) + StopFuzzing(id int) + DistributeSignalDelta(plus signal.Signal) +} +type server struct { cfg *Config mgr Manager serv *flatrpc.Serv @@ -82,7 +92,7 @@ type Server struct { *runnerStats } -func New(cfg *mgrconfig.Config, mgr Manager, debug bool) (*Server, error) { +func New(cfg *mgrconfig.Config, mgr Manager, debug bool) (Server, error) { var pcBase uint64 if cfg.KernelObj != "" { var err error @@ -122,16 +132,16 @@ func New(cfg *mgrconfig.Config, mgr Manager, debug bool) (*Server, error) { Slowdown: cfg.Timeouts.Slowdown, pcBase: pcBase, localModules: cfg.LocalModules, - }, mgr) + }, mgr), nil } -func newImpl(ctx context.Context, cfg *Config, mgr Manager) (*Server, error) { +func newImpl(ctx context.Context, cfg *Config, mgr Manager) *server { + // Note that we use VMArch, rather than Arch. We need the kernel address ranges and bitness. + sysTarget := targets.Get(cfg.Target.OS, cfg.VMArch) cfg.Procs = min(cfg.Procs, prog.MaxPids) checker := vminfo.New(ctx, &cfg.Config) baseSource := queue.DynamicSource(checker) - // Note that we use VMArch, rather than Arch. We need the kernel address ranges and bitness. - sysTarget := targets.Get(cfg.Target.OS, cfg.VMArch) - serv := &Server{ + return &server{ cfg: cfg, mgr: mgr, target: cfg.Target, @@ -156,20 +166,26 @@ func newImpl(ctx context.Context, cfg *Config, mgr Manager) (*Server, error) { statNoExecDuration: queue.StatNoExecDuration, }, } - s, err := flatrpc.ListenAndServe(cfg.RPC, serv.handleConn) +} + +func (serv *server) Close() error { + return serv.serv.Close() +} + +func (serv *server) Listen() error { + s, err := flatrpc.ListenAndServe(serv.cfg.RPC, serv.handleConn) if err != nil { - return nil, err + return err } serv.serv = s - serv.Port = s.Addr.Port - return serv, nil + return nil } -func (serv *Server) Close() error { - return serv.serv.Close() +func (serv *server) Port() int { + return serv.serv.Addr.Port } -func (serv *Server) handleConn(conn *flatrpc.Conn) { +func (serv *server) handleConn(conn *flatrpc.Conn) { connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn) if err != nil { log.Logf(1, "%s", err) @@ -185,8 +201,8 @@ func (serv *Server) handleConn(conn *flatrpc.Conn) { serv.StopFuzzing(id) serv.ShutdownInstance(id, true) }() - } else { - checkRevisions(connectReq, serv.cfg.Target) + } else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil { + log.Fatal(err) } serv.statVMRestarts.Add(1) @@ -203,7 +219,7 @@ func (serv *Server) handleConn(conn *flatrpc.Conn) { runner.resultCh <- err } -func (serv *Server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { +func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { opts := &handshakeConfig{ VMLess: serv.cfg.VMLess, Files: serv.checker.RequiredFiles(), @@ -235,7 +251,7 @@ func (serv *Server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { return serv.connectionLoop(runner) } -func (serv *Server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) { +func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) { modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files) if err != nil { log.Logf(0, "parsing of machine info failed: %v", err) @@ -280,7 +296,7 @@ func (serv *Server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha }, nil } -func (serv *Server) connectionLoop(runner *Runner) error { +func (serv *server) connectionLoop(runner *Runner) error { if serv.cfg.Cover { maxSignal := serv.mgr.MaxSignal().ToRaw() for len(maxSignal) != 0 { @@ -301,21 +317,22 @@ func (serv *Server) connectionLoop(runner *Runner) error { return runner.ConnectionLoop() } -func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) { +func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error { if target.Arch != a.Arch { - log.Fatalf("mismatching manager/executor arches: %v vs %v", target.Arch, a.Arch) + return fmt.Errorf("mismatching manager/executor arches: %v vs %v", target.Arch, a.Arch) } if prog.GitRevision != a.GitRevision { - log.Fatalf("mismatching manager/executor git revisions: %v vs %v", + return fmt.Errorf("mismatching manager/executor git revisions: %v vs %v", prog.GitRevision, a.GitRevision) } if target.Revision != a.SyzRevision { - log.Fatalf("mismatching manager/executor system call descriptions: %v vs %v", + return fmt.Errorf("mismatching manager/executor system call descriptions: %v vs %v", target.Revision, a.SyzRevision) } + return nil } -func (serv *Server) runCheck(checkFilesInfo []*flatrpc.FileInfo, checkFeatureInfo []*flatrpc.FeatureInfo) error { +func (serv *server) runCheck(checkFilesInfo []*flatrpc.FileInfo, checkFeatureInfo []*flatrpc.FeatureInfo) error { enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(checkFilesInfo, checkFeatureInfo) enabledCalls, transitivelyDisabled := serv.target.TransitivelyEnabledCalls(enabledCalls) // Note: need to print disbled syscalls before failing due to an error. @@ -334,7 +351,7 @@ func (serv *Server) runCheck(checkFilesInfo []*flatrpc.FileInfo, checkFeatureInf return nil } -func (serv *Server) printMachineCheck(checkFilesInfo []*flatrpc.FileInfo, enabledCalls map[*prog.Syscall]bool, +func (serv *server) printMachineCheck(checkFilesInfo []*flatrpc.FileInfo, enabledCalls map[*prog.Syscall]bool, disabledCalls, transitivelyDisabled map[*prog.Syscall]string, features vminfo.Features) { buf := new(bytes.Buffer) if len(serv.cfg.Syscalls) != 0 || log.V(1) { @@ -384,7 +401,7 @@ func (serv *Server) printMachineCheck(checkFilesInfo []*flatrpc.FileInfo, enable log.Logf(0, "machine check:\n%s", buf.Bytes()) } -func (serv *Server) CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error { +func (serv *server) CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error { runner := &Runner{ id: id, source: serv.execSource, @@ -415,7 +432,7 @@ func (serv *Server) CreateInstance(id int, injectExec chan<- bool, updInfo dispa // stopInstance prevents further request exchange requests. // To make RPCServer fully forget an instance, shutdownInstance() must be called. -func (serv *Server) StopFuzzing(id int) { +func (serv *server) StopFuzzing(id int) { serv.mu.Lock() runner := serv.runners[id] serv.mu.Unlock() @@ -427,7 +444,7 @@ func (serv *Server) StopFuzzing(id int) { runner.Stop() } -func (serv *Server) ShutdownInstance(id int, crashed bool) ([]ExecRecord, []byte) { +func (serv *server) ShutdownInstance(id int, crashed bool) ([]ExecRecord, []byte) { serv.mu.Lock() runner := serv.runners[id] delete(serv.runners, id) @@ -435,14 +452,14 @@ func (serv *Server) ShutdownInstance(id int, crashed bool) ([]ExecRecord, []byte return runner.Shutdown(crashed), runner.MachineInfo() } -func (serv *Server) DistributeSignalDelta(plus signal.Signal) { +func (serv *server) DistributeSignalDelta(plus signal.Signal) { plusRaw := plus.ToRaw() serv.foreachRunnerAsync(func(runner *Runner) { runner.SendSignalUpdate(plusRaw) }) } -func (serv *Server) TriagedCorpus() { +func (serv *server) TriagedCorpus() { serv.triagedCorpus.Store(true) serv.foreachRunnerAsync(func(runner *Runner) { runner.SendCorpusTriaged() @@ -452,7 +469,7 @@ func (serv *Server) TriagedCorpus() { // foreachRunnerAsync runs callback fn for each connected runner asynchronously. // If a VM has hanged w/o reading out the socket, we want to avoid blocking // important goroutines on the send operations. -func (serv *Server) foreachRunnerAsync(fn func(runner *Runner)) { +func (serv *server) foreachRunnerAsync(fn func(runner *Runner)) { serv.mu.Lock() defer serv.mu.Unlock() for _, runner := range serv.runners { diff --git a/pkg/rpcserver/rpcserver_test.go b/pkg/rpcserver/rpcserver_test.go new file mode 100644 index 000000000..3252ddd4f --- /dev/null +++ b/pkg/rpcserver/rpcserver_test.go @@ -0,0 +1,209 @@ +// Copyright 2024 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package rpcserver + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/google/syzkaller/pkg/flatrpc" + "github.com/google/syzkaller/pkg/mgrconfig" + "github.com/google/syzkaller/pkg/rpcserver/mocks" + "github.com/google/syzkaller/prog" + "github.com/google/syzkaller/sys/targets" +) + +func getTestDefaultCfg() mgrconfig.Config { + return mgrconfig.Config{ + Type: targets.Linux, + Sandbox: "none", + Derived: mgrconfig.Derived{ + TargetOS: targets.TestOS, + TargetArch: targets.TestArch64, + TargetVMArch: targets.TestArch64, + Timeouts: targets.Timeouts{Slowdown: 1}, + }, + } +} + +func TestNew(t *testing.T) { + defaultCfg := getTestDefaultCfg() + + nilServer := func(s Server) { + assert.Nil(t, s) + } + + tests := []struct { + name string + modifyCfg func() *mgrconfig.Config + debug bool + expectedServCheck func(Server) + expectsErr bool + expectedErr error + }{ + { + name: "unknown Sandbox", + modifyCfg: func() *mgrconfig.Config { + cfg := defaultCfg + cfg.Sandbox = "unknown" + return &cfg + }, + expectedServCheck: nilServer, + expectsErr: true, + }, + { + name: "experimental features", + modifyCfg: func() *mgrconfig.Config { + cfg := defaultCfg + cfg.Experimental = mgrconfig.Experimental{ + RemoteCover: false, + CoverEdges: true, + } + return &cfg + }, + expectedServCheck: func(srv Server) { + s := srv.(*server) + assert.Equal(t, s.cfg.Config.Features, flatrpc.AllFeatures&(^flatrpc.FeatureExtraCoverage)) + assert.Nil(t, s.serv) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := tt.modifyCfg() + + var err error + cfg.Target, err = prog.GetTarget(cfg.TargetOS, cfg.TargetArch) + assert.NoError(t, err) + + serv, err := New(cfg, nil, tt.debug) + if tt.expectedErr != nil { + assert.Equal(t, tt.expectedErr, err) + } else if tt.expectsErr { + assert.Error(t, err) + } else { + assert.Nil(t, err) + } + tt.expectedServCheck(serv) + }) + } +} + +func TestCheckRevisions(t *testing.T) { + tests := []struct { + name string + req *flatrpc.ConnectRequest + target *prog.Target + noError bool + }{ + { + name: "error - different Arch", + req: &flatrpc.ConnectRequest{ + Arch: "arch", + }, + target: &prog.Target{ + Arch: "arch2", + }, + }, + { + name: "error - different GitRevision", + req: &flatrpc.ConnectRequest{ + Arch: "arch", + GitRevision: "different", + }, + target: &prog.Target{ + Arch: "arch", + }, + }, + { + name: "error - different SyzRevision", + req: &flatrpc.ConnectRequest{ + Arch: "arch", + GitRevision: prog.GitRevision, + SyzRevision: "1", + }, + target: &prog.Target{ + Arch: "arch", + Revision: "2", + }, + }, + { + name: "ok", + req: &flatrpc.ConnectRequest{ + Arch: "arch", + GitRevision: prog.GitRevision, + SyzRevision: "1", + }, + target: &prog.Target{ + Arch: "arch", + Revision: "1", + }, + noError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := checkRevisions(tt.req, tt.target) + if tt.noError { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + }) + } +} + +func TestHandleConn(t *testing.T) { + inConn, outConn := net.Pipe() + serverConn := flatrpc.NewConn(inConn) + clientConn := flatrpc.NewConn(outConn) + + managerMock := mocks.NewManager(t) + debug := false + defaultCfg := getTestDefaultCfg() + + tests := []struct { + name string + modifyCfg func() *mgrconfig.Config + req *flatrpc.ConnectRequest + }{ + { + name: "error, cfg.VMLess = false - unknown VM tries to connect", + modifyCfg: func() *mgrconfig.Config { + return &defaultCfg + }, + req: &flatrpc.ConnectRequest{ + Id: 2, // Valid Runner id is 1. + Arch: "64", + GitRevision: prog.GitRevision, + SyzRevision: "1", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := tt.modifyCfg() + + var err error + cfg.Target, err = prog.GetTarget(cfg.TargetOS, cfg.TargetArch) + cfg.Target.Revision = tt.req.SyzRevision + assert.NoError(t, err) + + s, err := New(cfg, managerMock, debug) + assert.NoError(t, err) + serv := s.(*server) + + injectExec := make(chan bool) + serv.CreateInstance(1, injectExec, nil) + + go flatrpc.Send(clientConn, tt.req) + serv.handleConn(serverConn) + }) + } +} diff --git a/syz-manager/manager.go b/syz-manager/manager.go index db181d3a3..4b89277d4 100644 --- a/syz-manager/manager.go +++ b/syz-manager/manager.go @@ -80,7 +80,7 @@ type Manager struct { sysTarget *targets.Target reporter *report.Reporter crashdir string - serv *rpcserver.Server + serv rpcserver.Server corpus *corpus.Corpus corpusDB *db.DB corpusDBMu sync.Mutex // for concurrent operations on corpusDB @@ -252,7 +252,10 @@ func RunManager(mode Mode, cfg *mgrconfig.Config) { if err != nil { log.Fatalf("failed to create rpc server: %v", err) } - log.Logf(0, "serving rpc on tcp://%v", mgr.serv.Port) + if err := mgr.serv.Listen(); err != nil { + log.Fatalf("failed to start rpc server: %v", err) + } + log.Logf(0, "serving rpc on tcp://%v", mgr.serv.Port()) if cfg.DashboardAddr != "" { opts := []dashapi.DashboardOpts{} @@ -287,7 +290,7 @@ func RunManager(mode Mode, cfg *mgrconfig.Config) { if mgr.vmPool == nil { log.Logf(0, "no VMs started (type=none)") log.Logf(0, "you are supposed to start syz-executor manually as:") - log.Logf(0, "syz-executor runner local manager.ip %v", mgr.serv.Port) + log.Logf(0, "syz-executor runner local manager.ip %v", mgr.serv.Port()) <-vm.Shutdown return } @@ -536,7 +539,7 @@ func (mgr *Manager) fuzzerInstance(ctx context.Context, inst *vm.Instance, updIn func (mgr *Manager) runInstanceInner(ctx context.Context, inst *vm.Instance, injectExec <-chan bool, finishCb vm.EarlyFinishCb) (*report.Report, []byte, error) { - fwdAddr, err := inst.Forward(mgr.serv.Port) + fwdAddr, err := inst.Forward(mgr.serv.Port()) if err != nil { return nil, nil, fmt.Errorf("failed to setup port forwarding: %w", err) } |
