aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/rpcserver/rpcserver.go
diff options
context:
space:
mode:
authorSabyrzhan Tasbolatov <snovitoll@gmail.com>2024-09-04 21:33:28 +0500
committerAleksandr Nogikh <nogikh@google.com>2024-09-09 16:49:28 +0000
commitcb2d8b3aef0920cbb5521f948e262598efc3fc1c (patch)
treea5dd2e30ebaa63b2835fa6656e8e7078054829a0 /pkg/rpcserver/rpcserver.go
parent10df4c09063bf091d9d003880e4d1044b0ec163d (diff)
pkg/rpcserver: add unit tests, Manager mocks
Added more test coverage of the package and created an interface of rpcserver to use it as the dependency (for syz-manager). Also tried to cover with tests a private method handleConn(), though it calls handleRunnerConn which has a separate logic in Handshake(), which within handleConn() unit test we should've mocked. This will require a refactoring of `runners map[int]*Runner` and runner.go in general with a separate interface which we can mock as well. General idea is to have interfaces of Server (rpc), Runner etc. and mock a compound logic like Handshake during a separate public (or private if it has callable, if-else logic) method unit-testing.
Diffstat (limited to 'pkg/rpcserver/rpcserver.go')
-rw-r--r--pkg/rpcserver/rpcserver.go81
1 files changed, 49 insertions, 32 deletions
diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go
index 5a104d81c..367e3b5c6 100644
--- a/pkg/rpcserver/rpcserver.go
+++ b/pkg/rpcserver/rpcserver.go
@@ -48,6 +48,7 @@ type Config struct {
localModules []*vminfo.KernelModule
}
+//go:generate ../../tools/mockery.sh --name Manager --output ./mocks
type Manager interface {
MaxSignal() signal.Signal
BugFrames() (leaks []string, races []string)
@@ -55,9 +56,18 @@ type Manager interface {
CoverageFilter(modules []*vminfo.KernelModule) []uint64
}
-type Server struct {
- Port int
+type Server interface {
+ Listen() error
+ Close() error
+ Port() int
+ TriagedCorpus()
+ CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error
+ ShutdownInstance(id int, crashed bool) ([]ExecRecord, []byte)
+ StopFuzzing(id int)
+ DistributeSignalDelta(plus signal.Signal)
+}
+type server struct {
cfg *Config
mgr Manager
serv *flatrpc.Serv
@@ -82,7 +92,7 @@ type Server struct {
*runnerStats
}
-func New(cfg *mgrconfig.Config, mgr Manager, debug bool) (*Server, error) {
+func New(cfg *mgrconfig.Config, mgr Manager, debug bool) (Server, error) {
var pcBase uint64
if cfg.KernelObj != "" {
var err error
@@ -122,16 +132,16 @@ func New(cfg *mgrconfig.Config, mgr Manager, debug bool) (*Server, error) {
Slowdown: cfg.Timeouts.Slowdown,
pcBase: pcBase,
localModules: cfg.LocalModules,
- }, mgr)
+ }, mgr), nil
}
-func newImpl(ctx context.Context, cfg *Config, mgr Manager) (*Server, error) {
+func newImpl(ctx context.Context, cfg *Config, mgr Manager) *server {
+ // Note that we use VMArch, rather than Arch. We need the kernel address ranges and bitness.
+ sysTarget := targets.Get(cfg.Target.OS, cfg.VMArch)
cfg.Procs = min(cfg.Procs, prog.MaxPids)
checker := vminfo.New(ctx, &cfg.Config)
baseSource := queue.DynamicSource(checker)
- // Note that we use VMArch, rather than Arch. We need the kernel address ranges and bitness.
- sysTarget := targets.Get(cfg.Target.OS, cfg.VMArch)
- serv := &Server{
+ return &server{
cfg: cfg,
mgr: mgr,
target: cfg.Target,
@@ -156,20 +166,26 @@ func newImpl(ctx context.Context, cfg *Config, mgr Manager) (*Server, error) {
statNoExecDuration: queue.StatNoExecDuration,
},
}
- s, err := flatrpc.ListenAndServe(cfg.RPC, serv.handleConn)
+}
+
+func (serv *server) Close() error {
+ return serv.serv.Close()
+}
+
+func (serv *server) Listen() error {
+ s, err := flatrpc.ListenAndServe(serv.cfg.RPC, serv.handleConn)
if err != nil {
- return nil, err
+ return err
}
serv.serv = s
- serv.Port = s.Addr.Port
- return serv, nil
+ return nil
}
-func (serv *Server) Close() error {
- return serv.serv.Close()
+func (serv *server) Port() int {
+ return serv.serv.Addr.Port
}
-func (serv *Server) handleConn(conn *flatrpc.Conn) {
+func (serv *server) handleConn(conn *flatrpc.Conn) {
connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn)
if err != nil {
log.Logf(1, "%s", err)
@@ -185,8 +201,8 @@ func (serv *Server) handleConn(conn *flatrpc.Conn) {
serv.StopFuzzing(id)
serv.ShutdownInstance(id, true)
}()
- } else {
- checkRevisions(connectReq, serv.cfg.Target)
+ } else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil {
+ log.Fatal(err)
}
serv.statVMRestarts.Add(1)
@@ -203,7 +219,7 @@ func (serv *Server) handleConn(conn *flatrpc.Conn) {
runner.resultCh <- err
}
-func (serv *Server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error {
+func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error {
opts := &handshakeConfig{
VMLess: serv.cfg.VMLess,
Files: serv.checker.RequiredFiles(),
@@ -235,7 +251,7 @@ func (serv *Server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error {
return serv.connectionLoop(runner)
}
-func (serv *Server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) {
+func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) {
modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files)
if err != nil {
log.Logf(0, "parsing of machine info failed: %v", err)
@@ -280,7 +296,7 @@ func (serv *Server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha
}, nil
}
-func (serv *Server) connectionLoop(runner *Runner) error {
+func (serv *server) connectionLoop(runner *Runner) error {
if serv.cfg.Cover {
maxSignal := serv.mgr.MaxSignal().ToRaw()
for len(maxSignal) != 0 {
@@ -301,21 +317,22 @@ func (serv *Server) connectionLoop(runner *Runner) error {
return runner.ConnectionLoop()
}
-func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) {
+func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error {
if target.Arch != a.Arch {
- log.Fatalf("mismatching manager/executor arches: %v vs %v", target.Arch, a.Arch)
+ return fmt.Errorf("mismatching manager/executor arches: %v vs %v", target.Arch, a.Arch)
}
if prog.GitRevision != a.GitRevision {
- log.Fatalf("mismatching manager/executor git revisions: %v vs %v",
+ return fmt.Errorf("mismatching manager/executor git revisions: %v vs %v",
prog.GitRevision, a.GitRevision)
}
if target.Revision != a.SyzRevision {
- log.Fatalf("mismatching manager/executor system call descriptions: %v vs %v",
+ return fmt.Errorf("mismatching manager/executor system call descriptions: %v vs %v",
target.Revision, a.SyzRevision)
}
+ return nil
}
-func (serv *Server) runCheck(checkFilesInfo []*flatrpc.FileInfo, checkFeatureInfo []*flatrpc.FeatureInfo) error {
+func (serv *server) runCheck(checkFilesInfo []*flatrpc.FileInfo, checkFeatureInfo []*flatrpc.FeatureInfo) error {
enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(checkFilesInfo, checkFeatureInfo)
enabledCalls, transitivelyDisabled := serv.target.TransitivelyEnabledCalls(enabledCalls)
// Note: need to print disbled syscalls before failing due to an error.
@@ -334,7 +351,7 @@ func (serv *Server) runCheck(checkFilesInfo []*flatrpc.FileInfo, checkFeatureInf
return nil
}
-func (serv *Server) printMachineCheck(checkFilesInfo []*flatrpc.FileInfo, enabledCalls map[*prog.Syscall]bool,
+func (serv *server) printMachineCheck(checkFilesInfo []*flatrpc.FileInfo, enabledCalls map[*prog.Syscall]bool,
disabledCalls, transitivelyDisabled map[*prog.Syscall]string, features vminfo.Features) {
buf := new(bytes.Buffer)
if len(serv.cfg.Syscalls) != 0 || log.V(1) {
@@ -384,7 +401,7 @@ func (serv *Server) printMachineCheck(checkFilesInfo []*flatrpc.FileInfo, enable
log.Logf(0, "machine check:\n%s", buf.Bytes())
}
-func (serv *Server) CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error {
+func (serv *server) CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error {
runner := &Runner{
id: id,
source: serv.execSource,
@@ -415,7 +432,7 @@ func (serv *Server) CreateInstance(id int, injectExec chan<- bool, updInfo dispa
// stopInstance prevents further request exchange requests.
// To make RPCServer fully forget an instance, shutdownInstance() must be called.
-func (serv *Server) StopFuzzing(id int) {
+func (serv *server) StopFuzzing(id int) {
serv.mu.Lock()
runner := serv.runners[id]
serv.mu.Unlock()
@@ -427,7 +444,7 @@ func (serv *Server) StopFuzzing(id int) {
runner.Stop()
}
-func (serv *Server) ShutdownInstance(id int, crashed bool) ([]ExecRecord, []byte) {
+func (serv *server) ShutdownInstance(id int, crashed bool) ([]ExecRecord, []byte) {
serv.mu.Lock()
runner := serv.runners[id]
delete(serv.runners, id)
@@ -435,14 +452,14 @@ func (serv *Server) ShutdownInstance(id int, crashed bool) ([]ExecRecord, []byte
return runner.Shutdown(crashed), runner.MachineInfo()
}
-func (serv *Server) DistributeSignalDelta(plus signal.Signal) {
+func (serv *server) DistributeSignalDelta(plus signal.Signal) {
plusRaw := plus.ToRaw()
serv.foreachRunnerAsync(func(runner *Runner) {
runner.SendSignalUpdate(plusRaw)
})
}
-func (serv *Server) TriagedCorpus() {
+func (serv *server) TriagedCorpus() {
serv.triagedCorpus.Store(true)
serv.foreachRunnerAsync(func(runner *Runner) {
runner.SendCorpusTriaged()
@@ -452,7 +469,7 @@ func (serv *Server) TriagedCorpus() {
// foreachRunnerAsync runs callback fn for each connected runner asynchronously.
// If a VM has hanged w/o reading out the socket, we want to avoid blocking
// important goroutines on the send operations.
-func (serv *Server) foreachRunnerAsync(fn func(runner *Runner)) {
+func (serv *server) foreachRunnerAsync(fn func(runner *Runner)) {
serv.mu.Lock()
defer serv.mu.Unlock()
for _, runner := range serv.runners {