aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/rpcserver/rpcserver.go
diff options
context:
space:
mode:
authorAleksandr Nogikh <nogikh@google.com>2025-01-24 17:17:53 +0100
committerAleksandr Nogikh <nogikh@google.com>2025-01-29 10:31:50 +0000
commit94e13671726abbcf766f9b4aacd2ee04de59dcbd (patch)
tree699abaa69f3509857969ca2d7ff3ea001df14c88 /pkg/rpcserver/rpcserver.go
parent6eea27042142c1c5e810b642deb831a8ed55b3da (diff)
pkg/rpcserver: refactor to remove Fatalf calls
Apply necessary changes to pkg/flatrpc and pkg/manager as well.
Diffstat (limited to 'pkg/rpcserver/rpcserver.go')
-rw-r--r--pkg/rpcserver/rpcserver.go104
1 files changed, 73 insertions, 31 deletions
diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go
index 003c5f4b9..b3b518b04 100644
--- a/pkg/rpcserver/rpcserver.go
+++ b/pkg/rpcserver/rpcserver.go
@@ -28,6 +28,7 @@ import (
"github.com/google/syzkaller/prog"
"github.com/google/syzkaller/sys/targets"
"github.com/google/syzkaller/vm/dispatcher"
+ "golang.org/x/sync/errgroup"
)
type Config struct {
@@ -63,8 +64,8 @@ type RemoteConfig struct {
type Manager interface {
MaxSignal() signal.Signal
BugFrames() (leaks []string, races []string)
- MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source
- CoverageFilter(modules []*vminfo.KernelModule) []uint64
+ MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error)
+ CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error)
}
type Server interface {
@@ -72,6 +73,7 @@ type Server interface {
Close() error
Port() int
TriagedCorpus()
+ Serve(context.Context) error
CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error
ShutdownInstance(id int, crashed bool, extraExecs ...report.ExecutorInfo) ([]ExecRecord, []byte)
StopFuzzing(id int)
@@ -88,6 +90,7 @@ type server struct {
checker *vminfo.Checker
infoOnce sync.Once
+ checkOnce sync.Once
checkDone atomic.Bool
checkFailures int
baseSource *queue.DynamicSourceCtl
@@ -217,7 +220,7 @@ func (serv *server) Close() error {
}
func (serv *server) Listen() error {
- s, err := flatrpc.ListenAndServe(serv.cfg.RPC, serv.handleConn)
+ s, err := flatrpc.Listen(serv.cfg.RPC)
if err != nil {
return err
}
@@ -225,15 +228,25 @@ func (serv *server) Listen() error {
return nil
}
+func (serv *server) Serve(ctx context.Context) error {
+ g, ctx := errgroup.WithContext(ctx)
+ g.Go(func() error {
+ return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error {
+ return serv.handleConn(ctx, g, conn)
+ })
+ })
+ return g.Wait()
+}
+
func (serv *server) Port() int {
return serv.serv.Addr.Port
}
-func (serv *server) handleConn(conn *flatrpc.Conn) {
+func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *flatrpc.Conn) error {
connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn)
if err != nil {
log.Logf(1, "%s", err)
- return
+ return nil
}
id := int(connectReq.Id)
log.Logf(1, "runner %v connected", id)
@@ -246,7 +259,8 @@ func (serv *server) handleConn(conn *flatrpc.Conn) {
serv.ShutdownInstance(id, true)
}()
} else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil {
- log.Fatal(err)
+ // This is a fatal error.
+ return err
}
serv.StatVMRestarts.Add(1)
@@ -255,15 +269,23 @@ func (serv *server) handleConn(conn *flatrpc.Conn) {
serv.mu.Unlock()
if runner == nil {
log.Logf(2, "unknown VM %v tries to connect", id)
- return
+ return nil
}
- err = serv.handleRunnerConn(runner, conn)
+ err = serv.handleRunnerConn(ctx, eg, runner, conn)
log.Logf(2, "runner %v: %v", id, err)
+
+ if err != nil && errors.Is(err, errFatal) {
+ log.Logf(0, "%v", err)
+ return err
+ }
+
runner.resultCh <- err
+ return nil
}
-func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error {
+func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group,
+ runner *Runner, conn *flatrpc.Conn) error {
opts := &handshakeConfig{
VMLess: serv.cfg.VMLess,
Files: serv.checker.RequiredFiles(),
@@ -278,22 +300,36 @@ func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error {
opts.Features = serv.cfg.Features
}
- err := runner.Handshake(conn, opts)
+ info, err := runner.Handshake(conn, opts)
if err != nil {
log.Logf(1, "%v", err)
return err
}
+ serv.checkOnce.Do(func() {
+ // Run the machine check.
+ eg.Go(func() error {
+ if err := serv.runCheck(ctx, &info); err != nil {
+ return fmt.Errorf("%w: %w", errFatal, err)
+ }
+ return nil
+ })
+ })
+
if serv.triagedCorpus.Load() {
- if err := runner.SendCorpusTriaged(); err != nil {
- log.Logf(2, "%v", err)
- return err
- }
+ eg.Go(runner.SendCorpusTriaged)
}
+ go func() {
+ <-ctx.Done()
+ runner.Stop()
+ }()
return serv.connectionLoop(runner)
}
+// Used for errors incompatible with further RPCServer operation.
+var errFatal = errors.New("aborting RPC server")
+
func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) {
modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files)
if err != nil {
@@ -307,31 +343,36 @@ func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha
log.Logf(0, "machine check failed: %v", infoReq.Error)
serv.checkFailures++
if serv.checkFailures == 10 {
- log.Fatalf("machine check failing")
+ return handshakeResult{}, fmt.Errorf("%w: machine check failed too many times", errFatal)
}
return handshakeResult{}, errors.New("machine check failed")
}
+ var retErr error
serv.infoOnce.Do(func() {
serv.StatModules.Add(len(modules))
serv.canonicalModules = cover.NewCanonicalizer(modules, serv.cfg.Cover)
- serv.coverFilter = serv.mgr.CoverageFilter(modules)
- // Flatbuffers don't do deep copy of byte slices,
- // so clone manually since we pass it a goroutine.
- for _, file := range infoReq.Files {
- file.Data = slices.Clone(file.Data)
+ var err error
+ serv.coverFilter, err = serv.mgr.CoverageFilter(modules)
+ if err != nil {
+ retErr = fmt.Errorf("%w: %w", errFatal, err)
+ return
}
- // Now execute check programs.
- go func() {
- if err := serv.runCheck(infoReq); err != nil {
- log.Fatalf("check failed: %v", err)
- }
- }()
})
+ if retErr != nil {
+ return handshakeResult{}, retErr
+ }
+ // Flatbuffers don't do deep copy of byte slices,
+ // so clone manually since we may later pass it a goroutine.
+ for _, file := range infoReq.Files {
+ file.Data = slices.Clone(file.Data)
+ }
canonicalizer := serv.canonicalModules.NewInstance(modules)
return handshakeResult{
CovFilter: canonicalizer.Decanonicalize(serv.coverFilter),
MachineInfo: machineInfo,
Canonicalizer: canonicalizer,
+ Files: infoReq.Files,
+ Features: infoReq.Features,
}, nil
}
@@ -371,10 +412,8 @@ func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error {
return nil
}
-func (serv *server) runCheck(info *flatrpc.InfoRequest) error {
- // TODO: take context as a parameter.
- enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(context.Background(),
- info.Files, info.Features)
+func (serv *server) runCheck(ctx context.Context, info *handshakeResult) error {
+ enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(ctx, info.Files, info.Features)
enabledCalls, transitivelyDisabled := serv.target.TransitivelyEnabledCalls(enabledCalls)
// Note: need to print disbled syscalls before failing due to an error.
// This helps to debug "all system calls are disabled".
@@ -386,7 +425,10 @@ func (serv *server) runCheck(info *flatrpc.InfoRequest) error {
}
enabledFeatures := features.Enabled()
serv.setupFeatures = features.NeedSetup()
- newSource := serv.mgr.MachineChecked(enabledFeatures, enabledCalls)
+ newSource, err := serv.mgr.MachineChecked(enabledFeatures, enabledCalls)
+ if err != nil {
+ return err
+ }
serv.baseSource.Store(newSource)
serv.checkDone.Store(true)
return nil