From 94e13671726abbcf766f9b4aacd2ee04de59dcbd Mon Sep 17 00:00:00 2001 From: Aleksandr Nogikh Date: Fri, 24 Jan 2025 17:17:53 +0100 Subject: pkg/rpcserver: refactor to remove Fatalf calls Apply necessary changes to pkg/flatrpc and pkg/manager as well. --- pkg/rpcserver/rpcserver.go | 104 +++++++++++++++++++++++++++++++-------------- 1 file changed, 73 insertions(+), 31 deletions(-) (limited to 'pkg/rpcserver/rpcserver.go') diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go index 003c5f4b9..b3b518b04 100644 --- a/pkg/rpcserver/rpcserver.go +++ b/pkg/rpcserver/rpcserver.go @@ -28,6 +28,7 @@ import ( "github.com/google/syzkaller/prog" "github.com/google/syzkaller/sys/targets" "github.com/google/syzkaller/vm/dispatcher" + "golang.org/x/sync/errgroup" ) type Config struct { @@ -63,8 +64,8 @@ type RemoteConfig struct { type Manager interface { MaxSignal() signal.Signal BugFrames() (leaks []string, races []string) - MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source - CoverageFilter(modules []*vminfo.KernelModule) []uint64 + MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) + CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) } type Server interface { @@ -72,6 +73,7 @@ type Server interface { Close() error Port() int TriagedCorpus() + Serve(context.Context) error CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error ShutdownInstance(id int, crashed bool, extraExecs ...report.ExecutorInfo) ([]ExecRecord, []byte) StopFuzzing(id int) @@ -88,6 +90,7 @@ type server struct { checker *vminfo.Checker infoOnce sync.Once + checkOnce sync.Once checkDone atomic.Bool checkFailures int baseSource *queue.DynamicSourceCtl @@ -217,7 +220,7 @@ func (serv *server) Close() error { } func (serv *server) Listen() error { - s, err := flatrpc.ListenAndServe(serv.cfg.RPC, serv.handleConn) + s, err := flatrpc.Listen(serv.cfg.RPC) if err != nil { return err } @@ -225,15 +228,25 @@ func (serv *server) Listen() error { return nil } +func (serv *server) Serve(ctx context.Context) error { + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error { + return serv.handleConn(ctx, g, conn) + }) + }) + return g.Wait() +} + func (serv *server) Port() int { return serv.serv.Addr.Port } -func (serv *server) handleConn(conn *flatrpc.Conn) { +func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *flatrpc.Conn) error { connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn) if err != nil { log.Logf(1, "%s", err) - return + return nil } id := int(connectReq.Id) log.Logf(1, "runner %v connected", id) @@ -246,7 +259,8 @@ func (serv *server) handleConn(conn *flatrpc.Conn) { serv.ShutdownInstance(id, true) }() } else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil { - log.Fatal(err) + // This is a fatal error. + return err } serv.StatVMRestarts.Add(1) @@ -255,15 +269,23 @@ func (serv *server) handleConn(conn *flatrpc.Conn) { serv.mu.Unlock() if runner == nil { log.Logf(2, "unknown VM %v tries to connect", id) - return + return nil } - err = serv.handleRunnerConn(runner, conn) + err = serv.handleRunnerConn(ctx, eg, runner, conn) log.Logf(2, "runner %v: %v", id, err) + + if err != nil && errors.Is(err, errFatal) { + log.Logf(0, "%v", err) + return err + } + runner.resultCh <- err + return nil } -func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { +func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group, + runner *Runner, conn *flatrpc.Conn) error { opts := &handshakeConfig{ VMLess: serv.cfg.VMLess, Files: serv.checker.RequiredFiles(), @@ -278,22 +300,36 @@ func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { opts.Features = serv.cfg.Features } - err := runner.Handshake(conn, opts) + info, err := runner.Handshake(conn, opts) if err != nil { log.Logf(1, "%v", err) return err } + serv.checkOnce.Do(func() { + // Run the machine check. + eg.Go(func() error { + if err := serv.runCheck(ctx, &info); err != nil { + return fmt.Errorf("%w: %w", errFatal, err) + } + return nil + }) + }) + if serv.triagedCorpus.Load() { - if err := runner.SendCorpusTriaged(); err != nil { - log.Logf(2, "%v", err) - return err - } + eg.Go(runner.SendCorpusTriaged) } + go func() { + <-ctx.Done() + runner.Stop() + }() return serv.connectionLoop(runner) } +// Used for errors incompatible with further RPCServer operation. +var errFatal = errors.New("aborting RPC server") + func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) { modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files) if err != nil { @@ -307,31 +343,36 @@ func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha log.Logf(0, "machine check failed: %v", infoReq.Error) serv.checkFailures++ if serv.checkFailures == 10 { - log.Fatalf("machine check failing") + return handshakeResult{}, fmt.Errorf("%w: machine check failed too many times", errFatal) } return handshakeResult{}, errors.New("machine check failed") } + var retErr error serv.infoOnce.Do(func() { serv.StatModules.Add(len(modules)) serv.canonicalModules = cover.NewCanonicalizer(modules, serv.cfg.Cover) - serv.coverFilter = serv.mgr.CoverageFilter(modules) - // Flatbuffers don't do deep copy of byte slices, - // so clone manually since we pass it a goroutine. - for _, file := range infoReq.Files { - file.Data = slices.Clone(file.Data) + var err error + serv.coverFilter, err = serv.mgr.CoverageFilter(modules) + if err != nil { + retErr = fmt.Errorf("%w: %w", errFatal, err) + return } - // Now execute check programs. - go func() { - if err := serv.runCheck(infoReq); err != nil { - log.Fatalf("check failed: %v", err) - } - }() }) + if retErr != nil { + return handshakeResult{}, retErr + } + // Flatbuffers don't do deep copy of byte slices, + // so clone manually since we may later pass it a goroutine. + for _, file := range infoReq.Files { + file.Data = slices.Clone(file.Data) + } canonicalizer := serv.canonicalModules.NewInstance(modules) return handshakeResult{ CovFilter: canonicalizer.Decanonicalize(serv.coverFilter), MachineInfo: machineInfo, Canonicalizer: canonicalizer, + Files: infoReq.Files, + Features: infoReq.Features, }, nil } @@ -371,10 +412,8 @@ func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error { return nil } -func (serv *server) runCheck(info *flatrpc.InfoRequest) error { - // TODO: take context as a parameter. - enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(context.Background(), - info.Files, info.Features) +func (serv *server) runCheck(ctx context.Context, info *handshakeResult) error { + enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(ctx, info.Files, info.Features) enabledCalls, transitivelyDisabled := serv.target.TransitivelyEnabledCalls(enabledCalls) // Note: need to print disbled syscalls before failing due to an error. // This helps to debug "all system calls are disabled". @@ -386,7 +425,10 @@ func (serv *server) runCheck(info *flatrpc.InfoRequest) error { } enabledFeatures := features.Enabled() serv.setupFeatures = features.NeedSetup() - newSource := serv.mgr.MachineChecked(enabledFeatures, enabledCalls) + newSource, err := serv.mgr.MachineChecked(enabledFeatures, enabledCalls) + if err != nil { + return err + } serv.baseSource.Store(newSource) serv.checkDone.Store(true) return nil -- cgit mrf-deployment