diff options
| author | Sabyrzhan Tasbolatov <snovitoll@gmail.com> | 2024-09-04 21:33:28 +0500 |
|---|---|---|
| committer | Aleksandr Nogikh <nogikh@google.com> | 2024-09-09 16:49:28 +0000 |
| commit | cb2d8b3aef0920cbb5521f948e262598efc3fc1c (patch) | |
| tree | a5dd2e30ebaa63b2835fa6656e8e7078054829a0 /pkg/rpcserver/rpcserver.go | |
| parent | 10df4c09063bf091d9d003880e4d1044b0ec163d (diff) | |
pkg/rpcserver: add unit tests, Manager mocks
Added more test coverage of the package and created an interface of
rpcserver to use it as the dependency (for syz-manager).
Also tried to cover with tests a private method handleConn(),
though it calls handleRunnerConn which has a separate logic in
Handshake(), which within handleConn() unit test we should've mocked.
This will require a refactoring of `runners map[int]*Runner` and
runner.go in general with a separate interface which we can mock as
well.
General idea is to have interfaces of Server (rpc), Runner etc. and mock a
compound logic like Handshake during a separate public (or private if it
has callable, if-else logic) method unit-testing.
Diffstat (limited to 'pkg/rpcserver/rpcserver.go')
| -rw-r--r-- | pkg/rpcserver/rpcserver.go | 81 |
1 files changed, 49 insertions, 32 deletions
diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go index 5a104d81c..367e3b5c6 100644 --- a/pkg/rpcserver/rpcserver.go +++ b/pkg/rpcserver/rpcserver.go @@ -48,6 +48,7 @@ type Config struct { localModules []*vminfo.KernelModule } +//go:generate ../../tools/mockery.sh --name Manager --output ./mocks type Manager interface { MaxSignal() signal.Signal BugFrames() (leaks []string, races []string) @@ -55,9 +56,18 @@ type Manager interface { CoverageFilter(modules []*vminfo.KernelModule) []uint64 } -type Server struct { - Port int +type Server interface { + Listen() error + Close() error + Port() int + TriagedCorpus() + CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error + ShutdownInstance(id int, crashed bool) ([]ExecRecord, []byte) + StopFuzzing(id int) + DistributeSignalDelta(plus signal.Signal) +} +type server struct { cfg *Config mgr Manager serv *flatrpc.Serv @@ -82,7 +92,7 @@ type Server struct { *runnerStats } -func New(cfg *mgrconfig.Config, mgr Manager, debug bool) (*Server, error) { +func New(cfg *mgrconfig.Config, mgr Manager, debug bool) (Server, error) { var pcBase uint64 if cfg.KernelObj != "" { var err error @@ -122,16 +132,16 @@ func New(cfg *mgrconfig.Config, mgr Manager, debug bool) (*Server, error) { Slowdown: cfg.Timeouts.Slowdown, pcBase: pcBase, localModules: cfg.LocalModules, - }, mgr) + }, mgr), nil } -func newImpl(ctx context.Context, cfg *Config, mgr Manager) (*Server, error) { +func newImpl(ctx context.Context, cfg *Config, mgr Manager) *server { + // Note that we use VMArch, rather than Arch. We need the kernel address ranges and bitness. + sysTarget := targets.Get(cfg.Target.OS, cfg.VMArch) cfg.Procs = min(cfg.Procs, prog.MaxPids) checker := vminfo.New(ctx, &cfg.Config) baseSource := queue.DynamicSource(checker) - // Note that we use VMArch, rather than Arch. We need the kernel address ranges and bitness. - sysTarget := targets.Get(cfg.Target.OS, cfg.VMArch) - serv := &Server{ + return &server{ cfg: cfg, mgr: mgr, target: cfg.Target, @@ -156,20 +166,26 @@ func newImpl(ctx context.Context, cfg *Config, mgr Manager) (*Server, error) { statNoExecDuration: queue.StatNoExecDuration, }, } - s, err := flatrpc.ListenAndServe(cfg.RPC, serv.handleConn) +} + +func (serv *server) Close() error { + return serv.serv.Close() +} + +func (serv *server) Listen() error { + s, err := flatrpc.ListenAndServe(serv.cfg.RPC, serv.handleConn) if err != nil { - return nil, err + return err } serv.serv = s - serv.Port = s.Addr.Port - return serv, nil + return nil } -func (serv *Server) Close() error { - return serv.serv.Close() +func (serv *server) Port() int { + return serv.serv.Addr.Port } -func (serv *Server) handleConn(conn *flatrpc.Conn) { +func (serv *server) handleConn(conn *flatrpc.Conn) { connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn) if err != nil { log.Logf(1, "%s", err) @@ -185,8 +201,8 @@ func (serv *Server) handleConn(conn *flatrpc.Conn) { serv.StopFuzzing(id) serv.ShutdownInstance(id, true) }() - } else { - checkRevisions(connectReq, serv.cfg.Target) + } else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil { + log.Fatal(err) } serv.statVMRestarts.Add(1) @@ -203,7 +219,7 @@ func (serv *Server) handleConn(conn *flatrpc.Conn) { runner.resultCh <- err } -func (serv *Server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { +func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { opts := &handshakeConfig{ VMLess: serv.cfg.VMLess, Files: serv.checker.RequiredFiles(), @@ -235,7 +251,7 @@ func (serv *Server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { return serv.connectionLoop(runner) } -func (serv *Server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) { +func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) { modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files) if err != nil { log.Logf(0, "parsing of machine info failed: %v", err) @@ -280,7 +296,7 @@ func (serv *Server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha }, nil } -func (serv *Server) connectionLoop(runner *Runner) error { +func (serv *server) connectionLoop(runner *Runner) error { if serv.cfg.Cover { maxSignal := serv.mgr.MaxSignal().ToRaw() for len(maxSignal) != 0 { @@ -301,21 +317,22 @@ func (serv *Server) connectionLoop(runner *Runner) error { return runner.ConnectionLoop() } -func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) { +func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error { if target.Arch != a.Arch { - log.Fatalf("mismatching manager/executor arches: %v vs %v", target.Arch, a.Arch) + return fmt.Errorf("mismatching manager/executor arches: %v vs %v", target.Arch, a.Arch) } if prog.GitRevision != a.GitRevision { - log.Fatalf("mismatching manager/executor git revisions: %v vs %v", + return fmt.Errorf("mismatching manager/executor git revisions: %v vs %v", prog.GitRevision, a.GitRevision) } if target.Revision != a.SyzRevision { - log.Fatalf("mismatching manager/executor system call descriptions: %v vs %v", + return fmt.Errorf("mismatching manager/executor system call descriptions: %v vs %v", target.Revision, a.SyzRevision) } + return nil } -func (serv *Server) runCheck(checkFilesInfo []*flatrpc.FileInfo, checkFeatureInfo []*flatrpc.FeatureInfo) error { +func (serv *server) runCheck(checkFilesInfo []*flatrpc.FileInfo, checkFeatureInfo []*flatrpc.FeatureInfo) error { enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(checkFilesInfo, checkFeatureInfo) enabledCalls, transitivelyDisabled := serv.target.TransitivelyEnabledCalls(enabledCalls) // Note: need to print disbled syscalls before failing due to an error. @@ -334,7 +351,7 @@ func (serv *Server) runCheck(checkFilesInfo []*flatrpc.FileInfo, checkFeatureInf return nil } -func (serv *Server) printMachineCheck(checkFilesInfo []*flatrpc.FileInfo, enabledCalls map[*prog.Syscall]bool, +func (serv *server) printMachineCheck(checkFilesInfo []*flatrpc.FileInfo, enabledCalls map[*prog.Syscall]bool, disabledCalls, transitivelyDisabled map[*prog.Syscall]string, features vminfo.Features) { buf := new(bytes.Buffer) if len(serv.cfg.Syscalls) != 0 || log.V(1) { @@ -384,7 +401,7 @@ func (serv *Server) printMachineCheck(checkFilesInfo []*flatrpc.FileInfo, enable log.Logf(0, "machine check:\n%s", buf.Bytes()) } -func (serv *Server) CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error { +func (serv *server) CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error { runner := &Runner{ id: id, source: serv.execSource, @@ -415,7 +432,7 @@ func (serv *Server) CreateInstance(id int, injectExec chan<- bool, updInfo dispa // stopInstance prevents further request exchange requests. // To make RPCServer fully forget an instance, shutdownInstance() must be called. -func (serv *Server) StopFuzzing(id int) { +func (serv *server) StopFuzzing(id int) { serv.mu.Lock() runner := serv.runners[id] serv.mu.Unlock() @@ -427,7 +444,7 @@ func (serv *Server) StopFuzzing(id int) { runner.Stop() } -func (serv *Server) ShutdownInstance(id int, crashed bool) ([]ExecRecord, []byte) { +func (serv *server) ShutdownInstance(id int, crashed bool) ([]ExecRecord, []byte) { serv.mu.Lock() runner := serv.runners[id] delete(serv.runners, id) @@ -435,14 +452,14 @@ func (serv *Server) ShutdownInstance(id int, crashed bool) ([]ExecRecord, []byte return runner.Shutdown(crashed), runner.MachineInfo() } -func (serv *Server) DistributeSignalDelta(plus signal.Signal) { +func (serv *server) DistributeSignalDelta(plus signal.Signal) { plusRaw := plus.ToRaw() serv.foreachRunnerAsync(func(runner *Runner) { runner.SendSignalUpdate(plusRaw) }) } -func (serv *Server) TriagedCorpus() { +func (serv *server) TriagedCorpus() { serv.triagedCorpus.Store(true) serv.foreachRunnerAsync(func(runner *Runner) { runner.SendCorpusTriaged() @@ -452,7 +469,7 @@ func (serv *Server) TriagedCorpus() { // foreachRunnerAsync runs callback fn for each connected runner asynchronously. // If a VM has hanged w/o reading out the socket, we want to avoid blocking // important goroutines on the send operations. -func (serv *Server) foreachRunnerAsync(fn func(runner *Runner)) { +func (serv *server) foreachRunnerAsync(fn func(runner *Runner)) { serv.mu.Lock() defer serv.mu.Unlock() for _, runner := range serv.runners { |
