diff options
Diffstat (limited to 'pkg/rpcserver/local.go')
| -rw-r--r-- | pkg/rpcserver/local.go | 59 |
1 files changed, 35 insertions, 24 deletions
diff --git a/pkg/rpcserver/local.go b/pkg/rpcserver/local.go index 4ab8827ae..e4e128dcf 100644 --- a/pkg/rpcserver/local.go +++ b/pkg/rpcserver/local.go @@ -16,6 +16,7 @@ import ( "github.com/google/syzkaller/pkg/signal" "github.com/google/syzkaller/pkg/vminfo" "github.com/google/syzkaller/prog" + "golang.org/x/sync/errgroup" ) type LocalConfig struct { @@ -39,26 +40,32 @@ func RunLocal(cfg *LocalConfig) error { if cfg.VMArch == "" { cfg.VMArch = cfg.Target.Arch } + if cfg.Context == nil { + cfg.Context = context.Background() + } cfg.UseCoverEdges = true cfg.FilterSignal = true cfg.RPC = ":0" cfg.PrintMachineCheck = log.V(1) cfg.Stats = NewStats() - ctx := &local{ + localCtx := &local{ cfg: cfg, setupDone: make(chan bool), } - serv := newImpl(&cfg.Config, ctx) + serv := newImpl(&cfg.Config, localCtx) if err := serv.Listen(); err != nil { return err } defer serv.Close() - ctx.serv = serv + localCtx.serv = serv // setupDone synchronizes assignment to ctx.serv and read of ctx.serv in MachineChecked // for the race detector b/c it does not understand the synchronization via TCP socket connect/accept. - close(ctx.setupDone) + close(localCtx.setupDone) + + cancelCtx, cancel := context.WithCancel(cfg.Context) + eg, ctx := errgroup.WithContext(cancelCtx) - id := 0 + const id = 0 connErr := serv.CreateInstance(id, nil, nil) defer serv.ShutdownInstance(id, true) @@ -73,7 +80,7 @@ func RunLocal(cfg *LocalConfig) error { cfg.Executor, }, args...) } - cmd := exec.Command(bin, args...) + cmd := exec.CommandContext(ctx, bin, args...) cmd.Dir = cfg.Dir if cfg.Debug || cfg.GDB { cmd.Stdout = os.Stdout @@ -82,28 +89,32 @@ func RunLocal(cfg *LocalConfig) error { if cfg.GDB { cmd.Stdin = os.Stdin } - if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to start executor: %w", err) - } - res := make(chan error, 1) - go func() { res <- cmd.Wait() }() + eg.Go(func() error { + return serv.Serve(ctx) + }) + eg.Go(func() error { + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start executor: %w", err) + } + err := cmd.Wait() + // Note that we ignore the error if we killed the process by closing the context. + if err == nil || ctx.Err() != nil { + return nil + } + return fmt.Errorf("executor process exited: %w", err) + }) + shutdown := make(chan struct{}) if cfg.HandleInterrupts { osutil.HandleInterrupts(shutdown) } - var cmdErr error select { + case <-ctx.Done(): case <-shutdown: - case <-cfg.Context.Done(): case <-connErr: - case err := <-res: - cmdErr = fmt.Errorf("executor process exited: %w", err) - } - if cmdErr == nil { - cmd.Process.Kill() - <-res } - return cmdErr + cancel() + return eg.Wait() } type local struct { @@ -112,10 +123,10 @@ type local struct { setupDone chan bool } -func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source { +func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) { <-ctx.setupDone ctx.serv.TriagedCorpus() - return ctx.cfg.MachineChecked(features, syscalls) + return ctx.cfg.MachineChecked(features, syscalls), nil } func (ctx *local) BugFrames() ([]string, []string) { @@ -126,6 +137,6 @@ func (ctx *local) MaxSignal() signal.Signal { return signal.FromRaw(ctx.cfg.MaxSignal, 0) } -func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { - return ctx.cfg.CoverFilter +func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) { + return ctx.cfg.CoverFilter, nil } |
