aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/rpcserver/local.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/rpcserver/local.go')
-rw-r--r--pkg/rpcserver/local.go59
1 files changed, 35 insertions, 24 deletions
diff --git a/pkg/rpcserver/local.go b/pkg/rpcserver/local.go
index 4ab8827ae..e4e128dcf 100644
--- a/pkg/rpcserver/local.go
+++ b/pkg/rpcserver/local.go
@@ -16,6 +16,7 @@ import (
"github.com/google/syzkaller/pkg/signal"
"github.com/google/syzkaller/pkg/vminfo"
"github.com/google/syzkaller/prog"
+ "golang.org/x/sync/errgroup"
)
type LocalConfig struct {
@@ -39,26 +40,32 @@ func RunLocal(cfg *LocalConfig) error {
if cfg.VMArch == "" {
cfg.VMArch = cfg.Target.Arch
}
+ if cfg.Context == nil {
+ cfg.Context = context.Background()
+ }
cfg.UseCoverEdges = true
cfg.FilterSignal = true
cfg.RPC = ":0"
cfg.PrintMachineCheck = log.V(1)
cfg.Stats = NewStats()
- ctx := &local{
+ localCtx := &local{
cfg: cfg,
setupDone: make(chan bool),
}
- serv := newImpl(&cfg.Config, ctx)
+ serv := newImpl(&cfg.Config, localCtx)
if err := serv.Listen(); err != nil {
return err
}
defer serv.Close()
- ctx.serv = serv
+ localCtx.serv = serv
// setupDone synchronizes assignment to ctx.serv and read of ctx.serv in MachineChecked
// for the race detector b/c it does not understand the synchronization via TCP socket connect/accept.
- close(ctx.setupDone)
+ close(localCtx.setupDone)
+
+ cancelCtx, cancel := context.WithCancel(cfg.Context)
+ eg, ctx := errgroup.WithContext(cancelCtx)
- id := 0
+ const id = 0
connErr := serv.CreateInstance(id, nil, nil)
defer serv.ShutdownInstance(id, true)
@@ -73,7 +80,7 @@ func RunLocal(cfg *LocalConfig) error {
cfg.Executor,
}, args...)
}
- cmd := exec.Command(bin, args...)
+ cmd := exec.CommandContext(ctx, bin, args...)
cmd.Dir = cfg.Dir
if cfg.Debug || cfg.GDB {
cmd.Stdout = os.Stdout
@@ -82,28 +89,32 @@ func RunLocal(cfg *LocalConfig) error {
if cfg.GDB {
cmd.Stdin = os.Stdin
}
- if err := cmd.Start(); err != nil {
- return fmt.Errorf("failed to start executor: %w", err)
- }
- res := make(chan error, 1)
- go func() { res <- cmd.Wait() }()
+ eg.Go(func() error {
+ return serv.Serve(ctx)
+ })
+ eg.Go(func() error {
+ if err := cmd.Start(); err != nil {
+ return fmt.Errorf("failed to start executor: %w", err)
+ }
+ err := cmd.Wait()
+ // Note that we ignore the error if we killed the process by closing the context.
+ if err == nil || ctx.Err() != nil {
+ return nil
+ }
+ return fmt.Errorf("executor process exited: %w", err)
+ })
+
shutdown := make(chan struct{})
if cfg.HandleInterrupts {
osutil.HandleInterrupts(shutdown)
}
- var cmdErr error
select {
+ case <-ctx.Done():
case <-shutdown:
- case <-cfg.Context.Done():
case <-connErr:
- case err := <-res:
- cmdErr = fmt.Errorf("executor process exited: %w", err)
- }
- if cmdErr == nil {
- cmd.Process.Kill()
- <-res
}
- return cmdErr
+ cancel()
+ return eg.Wait()
}
type local struct {
@@ -112,10 +123,10 @@ type local struct {
setupDone chan bool
}
-func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
+func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) {
<-ctx.setupDone
ctx.serv.TriagedCorpus()
- return ctx.cfg.MachineChecked(features, syscalls)
+ return ctx.cfg.MachineChecked(features, syscalls), nil
}
func (ctx *local) BugFrames() ([]string, []string) {
@@ -126,6 +137,6 @@ func (ctx *local) MaxSignal() signal.Signal {
return signal.FromRaw(ctx.cfg.MaxSignal, 0)
}
-func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
- return ctx.cfg.CoverFilter
+func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) {
+ return ctx.cfg.CoverFilter, nil
}