aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAleksandr Nogikh <nogikh@google.com>2025-01-24 17:17:53 +0100
committerAleksandr Nogikh <nogikh@google.com>2025-01-29 10:31:50 +0000
commit94e13671726abbcf766f9b4aacd2ee04de59dcbd (patch)
tree699abaa69f3509857969ca2d7ff3ea001df14c88
parent6eea27042142c1c5e810b642deb831a8ed55b3da (diff)
pkg/rpcserver: refactor to remove Fatalf calls
Apply necessary changes to pkg/flatrpc and pkg/manager as well.
-rw-r--r--pkg/flatrpc/conn.go65
-rw-r--r--pkg/flatrpc/conn_test.go104
-rw-r--r--pkg/manager/diff.go13
-rw-r--r--pkg/rpcserver/local.go59
-rw-r--r--pkg/rpcserver/mocks/Manager.go28
-rw-r--r--pkg/rpcserver/rpcserver.go104
-rw-r--r--pkg/rpcserver/rpcserver_test.go8
-rw-r--r--pkg/rpcserver/runner.go14
-rw-r--r--syz-manager/manager.go29
9 files changed, 275 insertions, 149 deletions
diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go
index 33eca07e1..47b265493 100644
--- a/pkg/flatrpc/conn.go
+++ b/pkg/flatrpc/conn.go
@@ -4,6 +4,7 @@
package flatrpc
import (
+ "context"
"errors"
"fmt"
"io"
@@ -13,9 +14,10 @@ import (
"sync"
"unsafe"
- "github.com/google/flatbuffers/go"
+ flatbuffers "github.com/google/flatbuffers/go"
"github.com/google/syzkaller/pkg/log"
"github.com/google/syzkaller/pkg/stat"
+ "golang.org/x/sync/errgroup"
)
var (
@@ -30,38 +32,55 @@ type Serv struct {
ln net.Listener
}
-func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) {
+func Listen(addr string) (*Serv, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
- go func() {
- for {
- conn, err := ln.Accept()
- if err != nil {
- if errors.Is(err, net.ErrClosed) {
- break
- }
- var netErr *net.OpError
- if errors.As(err, &netErr) && !netErr.Temporary() {
- log.Fatalf("flatrpc: failed to accept: %v", err)
- }
- log.Logf(0, "flatrpc: failed to accept: %v", err)
- continue
- }
- go func() {
- c := NewConn(conn)
- defer c.Close()
- handler(c)
- }()
- }
- }()
return &Serv{
Addr: ln.Addr().(*net.TCPAddr),
ln: ln,
}, nil
}
+// Serve accepts incoming connections and calls handler for each of them.
+// An error returned from the handler stops the server and aborts the whole processing.
+func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error {
+ eg, ctx := errgroup.WithContext(baseCtx)
+ go func() {
+ // If the context is cancelled, stop the server.
+ <-ctx.Done()
+ s.Close()
+ }()
+ for {
+ conn, err := s.ln.Accept()
+ if err != nil && errors.Is(err, net.ErrClosed) {
+ break
+ }
+ if err != nil {
+ var netErr *net.OpError
+ if errors.As(err, &netErr) && !netErr.Temporary() {
+ return fmt.Errorf("flatrpc: failed to accept: %w", err)
+ }
+ log.Logf(0, "flatrpc: failed to accept: %v", err)
+ continue
+ }
+ eg.Go(func() error {
+ connCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ c := NewConn(conn)
+ // Closing the server does not automatically close all the connections.
+ go func() {
+ <-connCtx.Done()
+ c.Close()
+ }()
+ return handler(connCtx, c)
+ })
+ }
+ return eg.Wait()
+}
+
func (s *Serv) Close() error {
return s.ln.Close()
}
diff --git a/pkg/flatrpc/conn_test.go b/pkg/flatrpc/conn_test.go
index 132fd1cdd..4b108a5a4 100644
--- a/pkg/flatrpc/conn_test.go
+++ b/pkg/flatrpc/conn_test.go
@@ -4,15 +4,18 @@
package flatrpc
import (
+ "context"
+ "fmt"
"net"
"os"
+ "reflect"
"runtime/debug"
"sync"
"syscall"
"testing"
"time"
- "github.com/google/flatbuffers/go"
+ flatbuffers "github.com/google/flatbuffers/go"
"github.com/stretchr/testify/assert"
)
@@ -40,35 +43,39 @@ func TestConn(t *testing.T) {
},
}
- done := make(chan bool)
- defer func() {
- <-done
- }()
- serv, err := ListenAndServe(":0", func(c *Conn) {
- defer close(done)
- connectReqGot, err := Recv[*ConnectRequestRaw](c)
- if err != nil {
- t.Fatal(err)
- }
- assert.Equal(t, connectReq, connectReqGot)
-
- if err := Send(c, connectReply); err != nil {
- t.Fatal(err)
- }
-
- for i := 0; i < 10; i++ {
- got, err := Recv[*ExecutorMessageRaw](c)
- if err != nil {
- t.Fatal(err)
- }
- assert.Equal(t, executorMsg, got)
- }
- })
+ serv, err := Listen(":0")
if err != nil {
t.Fatal(err)
}
- defer serv.Close()
+ done := make(chan error)
+ go func() {
+ done <- serv.Serve(context.Background(),
+ func(_ context.Context, c *Conn) error {
+ connectReqGot, err := Recv[*ConnectRequestRaw](c)
+ if err != nil {
+ return err
+ }
+ if !reflect.DeepEqual(connectReq, connectReqGot) {
+ return fmt.Errorf("connectReq != connectReqGot")
+ }
+
+ if err := Send(c, connectReply); err != nil {
+ return err
+ }
+
+ for i := 0; i < 10; i++ {
+ got, err := Recv[*ExecutorMessageRaw](c)
+ if err != nil {
+ return nil
+ }
+ if !reflect.DeepEqual(executorMsg, got) {
+ return fmt.Errorf("executorMsg !=got")
+ }
+ }
+ return nil
+ })
+ }()
c := dial(t, serv.Addr.String())
defer c.Close()
@@ -87,6 +94,11 @@ func TestConn(t *testing.T) {
t.Fatal(err)
}
}
+
+ serv.Close()
+ if err := <-done; err != nil {
+ t.Fatal(err)
+ }
}
func BenchmarkConn(b *testing.B) {
@@ -103,26 +115,27 @@ func BenchmarkConn(b *testing.B) {
Files: []string{"file1"},
}
- done := make(chan bool)
- defer func() {
- <-done
- }()
- serv, err := ListenAndServe(":0", func(c *Conn) {
- defer close(done)
- for i := 0; i < b.N; i++ {
- _, err := Recv[*ConnectRequestRaw](c)
- if err != nil {
- b.Fatal(err)
- }
- if err := Send(c, connectReply); err != nil {
- b.Fatal(err)
- }
- }
- })
+ serv, err := Listen(":0")
if err != nil {
b.Fatal(err)
}
- defer serv.Close()
+ done := make(chan error)
+
+ go func() {
+ done <- serv.Serve(context.Background(),
+ func(_ context.Context, c *Conn) error {
+ for i := 0; i < b.N; i++ {
+ _, err := Recv[*ConnectRequestRaw](c)
+ if err != nil {
+ return err
+ }
+ if err := Send(c, connectReply); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ }()
c := dial(b, serv.Addr.String())
defer c.Close()
@@ -138,6 +151,11 @@ func BenchmarkConn(b *testing.B) {
b.Fatal(err)
}
}
+
+ serv.Close()
+ if err := <-done; err != nil {
+ b.Fatal(err)
+ }
}
func dial(t testing.TB, addr string) *Conn {
diff --git a/pkg/manager/diff.go b/pkg/manager/diff.go
index 17d056e1f..8218ce924 100644
--- a/pkg/manager/diff.go
+++ b/pkg/manager/diff.go
@@ -305,9 +305,10 @@ func (kc *kernelContext) BugFrames() (leaks, races []string) {
return nil, nil
}
-func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
+func (kc *kernelContext) MachineChecked(features flatrpc.Feature,
+ syscalls map[*prog.Syscall]bool) (queue.Source, error) {
if len(syscalls) == 0 {
- log.Fatalf("all system calls are disabled")
+ return nil, fmt.Errorf("all system calls are disabled")
}
log.Logf(0, "%s: machine check complete", kc.name)
kc.features = features
@@ -319,7 +320,7 @@ func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[*
source = kc.source
}
opts := fuzzer.DefaultExecOpts(kc.cfg, features, kc.debug)
- return queue.DefaultOpts(source, opts)
+ return queue.DefaultOpts(source, opts), nil
}
func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
@@ -383,11 +384,11 @@ func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*pro
return fuzzerObj
}
-func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
+func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) {
kc.reportGenerator.Init(modules)
filters, err := PrepareCoverageFilters(kc.reportGenerator, kc.cfg, false)
if err != nil {
- log.Fatalf("failed to init coverage filter: %v", err)
+ return nil, fmt.Errorf("failed to init coverage filter: %w", err)
}
kc.coverFilters = filters
log.Logf(0, "cover filter size: %d", len(filters.ExecutorFilter))
@@ -402,7 +403,7 @@ func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) []uint64
for pc := range filters.ExecutorFilter {
pcs = append(pcs, pc)
}
- return pcs
+ return pcs, nil
}
func (kc *kernelContext) fuzzerInstance(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) {
diff --git a/pkg/rpcserver/local.go b/pkg/rpcserver/local.go
index 4ab8827ae..e4e128dcf 100644
--- a/pkg/rpcserver/local.go
+++ b/pkg/rpcserver/local.go
@@ -16,6 +16,7 @@ import (
"github.com/google/syzkaller/pkg/signal"
"github.com/google/syzkaller/pkg/vminfo"
"github.com/google/syzkaller/prog"
+ "golang.org/x/sync/errgroup"
)
type LocalConfig struct {
@@ -39,26 +40,32 @@ func RunLocal(cfg *LocalConfig) error {
if cfg.VMArch == "" {
cfg.VMArch = cfg.Target.Arch
}
+ if cfg.Context == nil {
+ cfg.Context = context.Background()
+ }
cfg.UseCoverEdges = true
cfg.FilterSignal = true
cfg.RPC = ":0"
cfg.PrintMachineCheck = log.V(1)
cfg.Stats = NewStats()
- ctx := &local{
+ localCtx := &local{
cfg: cfg,
setupDone: make(chan bool),
}
- serv := newImpl(&cfg.Config, ctx)
+ serv := newImpl(&cfg.Config, localCtx)
if err := serv.Listen(); err != nil {
return err
}
defer serv.Close()
- ctx.serv = serv
+ localCtx.serv = serv
// setupDone synchronizes assignment to ctx.serv and read of ctx.serv in MachineChecked
// for the race detector b/c it does not understand the synchronization via TCP socket connect/accept.
- close(ctx.setupDone)
+ close(localCtx.setupDone)
+
+ cancelCtx, cancel := context.WithCancel(cfg.Context)
+ eg, ctx := errgroup.WithContext(cancelCtx)
- id := 0
+ const id = 0
connErr := serv.CreateInstance(id, nil, nil)
defer serv.ShutdownInstance(id, true)
@@ -73,7 +80,7 @@ func RunLocal(cfg *LocalConfig) error {
cfg.Executor,
}, args...)
}
- cmd := exec.Command(bin, args...)
+ cmd := exec.CommandContext(ctx, bin, args...)
cmd.Dir = cfg.Dir
if cfg.Debug || cfg.GDB {
cmd.Stdout = os.Stdout
@@ -82,28 +89,32 @@ func RunLocal(cfg *LocalConfig) error {
if cfg.GDB {
cmd.Stdin = os.Stdin
}
- if err := cmd.Start(); err != nil {
- return fmt.Errorf("failed to start executor: %w", err)
- }
- res := make(chan error, 1)
- go func() { res <- cmd.Wait() }()
+ eg.Go(func() error {
+ return serv.Serve(ctx)
+ })
+ eg.Go(func() error {
+ if err := cmd.Start(); err != nil {
+ return fmt.Errorf("failed to start executor: %w", err)
+ }
+ err := cmd.Wait()
+ // Note that we ignore the error if we killed the process by closing the context.
+ if err == nil || ctx.Err() != nil {
+ return nil
+ }
+ return fmt.Errorf("executor process exited: %w", err)
+ })
+
shutdown := make(chan struct{})
if cfg.HandleInterrupts {
osutil.HandleInterrupts(shutdown)
}
- var cmdErr error
select {
+ case <-ctx.Done():
case <-shutdown:
- case <-cfg.Context.Done():
case <-connErr:
- case err := <-res:
- cmdErr = fmt.Errorf("executor process exited: %w", err)
- }
- if cmdErr == nil {
- cmd.Process.Kill()
- <-res
}
- return cmdErr
+ cancel()
+ return eg.Wait()
}
type local struct {
@@ -112,10 +123,10 @@ type local struct {
setupDone chan bool
}
-func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
+func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) {
<-ctx.setupDone
ctx.serv.TriagedCorpus()
- return ctx.cfg.MachineChecked(features, syscalls)
+ return ctx.cfg.MachineChecked(features, syscalls), nil
}
func (ctx *local) BugFrames() ([]string, []string) {
@@ -126,6 +137,6 @@ func (ctx *local) MaxSignal() signal.Signal {
return signal.FromRaw(ctx.cfg.MaxSignal, 0)
}
-func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
- return ctx.cfg.CoverFilter
+func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) {
+ return ctx.cfg.CoverFilter, nil
}
diff --git a/pkg/rpcserver/mocks/Manager.go b/pkg/rpcserver/mocks/Manager.go
index 810b5028f..c0c9621de 100644
--- a/pkg/rpcserver/mocks/Manager.go
+++ b/pkg/rpcserver/mocks/Manager.go
@@ -53,7 +53,7 @@ func (_m *Manager) BugFrames() ([]string, []string) {
}
// CoverageFilter provides a mock function with given fields: modules
-func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
+func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) {
ret := _m.Called(modules)
if len(ret) == 0 {
@@ -61,6 +61,10 @@ func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
}
var r0 []uint64
+ var r1 error
+ if rf, ok := ret.Get(0).(func([]*vminfo.KernelModule) ([]uint64, error)); ok {
+ return rf(modules)
+ }
if rf, ok := ret.Get(0).(func([]*vminfo.KernelModule) []uint64); ok {
r0 = rf(modules)
} else {
@@ -69,11 +73,17 @@ func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
}
}
- return r0
+ if rf, ok := ret.Get(1).(func([]*vminfo.KernelModule) error); ok {
+ r1 = rf(modules)
+ } else {
+ r1 = ret.Error(1)
+ }
+
+ return r0, r1
}
// MachineChecked provides a mock function with given fields: features, syscalls
-func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source {
+func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) {
ret := _m.Called(features, syscalls)
if len(ret) == 0 {
@@ -81,6 +91,10 @@ func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.S
}
var r0 queue.Source
+ var r1 error
+ if rf, ok := ret.Get(0).(func(flatrpc.Feature, map[*prog.Syscall]bool) (queue.Source, error)); ok {
+ return rf(features, syscalls)
+ }
if rf, ok := ret.Get(0).(func(flatrpc.Feature, map[*prog.Syscall]bool) queue.Source); ok {
r0 = rf(features, syscalls)
} else {
@@ -89,7 +103,13 @@ func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.S
}
}
- return r0
+ if rf, ok := ret.Get(1).(func(flatrpc.Feature, map[*prog.Syscall]bool) error); ok {
+ r1 = rf(features, syscalls)
+ } else {
+ r1 = ret.Error(1)
+ }
+
+ return r0, r1
}
// MaxSignal provides a mock function with given fields:
diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go
index 003c5f4b9..b3b518b04 100644
--- a/pkg/rpcserver/rpcserver.go
+++ b/pkg/rpcserver/rpcserver.go
@@ -28,6 +28,7 @@ import (
"github.com/google/syzkaller/prog"
"github.com/google/syzkaller/sys/targets"
"github.com/google/syzkaller/vm/dispatcher"
+ "golang.org/x/sync/errgroup"
)
type Config struct {
@@ -63,8 +64,8 @@ type RemoteConfig struct {
type Manager interface {
MaxSignal() signal.Signal
BugFrames() (leaks []string, races []string)
- MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source
- CoverageFilter(modules []*vminfo.KernelModule) []uint64
+ MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error)
+ CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error)
}
type Server interface {
@@ -72,6 +73,7 @@ type Server interface {
Close() error
Port() int
TriagedCorpus()
+ Serve(context.Context) error
CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error
ShutdownInstance(id int, crashed bool, extraExecs ...report.ExecutorInfo) ([]ExecRecord, []byte)
StopFuzzing(id int)
@@ -88,6 +90,7 @@ type server struct {
checker *vminfo.Checker
infoOnce sync.Once
+ checkOnce sync.Once
checkDone atomic.Bool
checkFailures int
baseSource *queue.DynamicSourceCtl
@@ -217,7 +220,7 @@ func (serv *server) Close() error {
}
func (serv *server) Listen() error {
- s, err := flatrpc.ListenAndServe(serv.cfg.RPC, serv.handleConn)
+ s, err := flatrpc.Listen(serv.cfg.RPC)
if err != nil {
return err
}
@@ -225,15 +228,25 @@ func (serv *server) Listen() error {
return nil
}
+func (serv *server) Serve(ctx context.Context) error {
+ g, ctx := errgroup.WithContext(ctx)
+ g.Go(func() error {
+ return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error {
+ return serv.handleConn(ctx, g, conn)
+ })
+ })
+ return g.Wait()
+}
+
func (serv *server) Port() int {
return serv.serv.Addr.Port
}
-func (serv *server) handleConn(conn *flatrpc.Conn) {
+func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *flatrpc.Conn) error {
connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn)
if err != nil {
log.Logf(1, "%s", err)
- return
+ return nil
}
id := int(connectReq.Id)
log.Logf(1, "runner %v connected", id)
@@ -246,7 +259,8 @@ func (serv *server) handleConn(conn *flatrpc.Conn) {
serv.ShutdownInstance(id, true)
}()
} else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil {
- log.Fatal(err)
+ // This is a fatal error.
+ return err
}
serv.StatVMRestarts.Add(1)
@@ -255,15 +269,23 @@ func (serv *server) handleConn(conn *flatrpc.Conn) {
serv.mu.Unlock()
if runner == nil {
log.Logf(2, "unknown VM %v tries to connect", id)
- return
+ return nil
}
- err = serv.handleRunnerConn(runner, conn)
+ err = serv.handleRunnerConn(ctx, eg, runner, conn)
log.Logf(2, "runner %v: %v", id, err)
+
+ if err != nil && errors.Is(err, errFatal) {
+ log.Logf(0, "%v", err)
+ return err
+ }
+
runner.resultCh <- err
+ return nil
}
-func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error {
+func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group,
+ runner *Runner, conn *flatrpc.Conn) error {
opts := &handshakeConfig{
VMLess: serv.cfg.VMLess,
Files: serv.checker.RequiredFiles(),
@@ -278,22 +300,36 @@ func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error {
opts.Features = serv.cfg.Features
}
- err := runner.Handshake(conn, opts)
+ info, err := runner.Handshake(conn, opts)
if err != nil {
log.Logf(1, "%v", err)
return err
}
+ serv.checkOnce.Do(func() {
+ // Run the machine check.
+ eg.Go(func() error {
+ if err := serv.runCheck(ctx, &info); err != nil {
+ return fmt.Errorf("%w: %w", errFatal, err)
+ }
+ return nil
+ })
+ })
+
if serv.triagedCorpus.Load() {
- if err := runner.SendCorpusTriaged(); err != nil {
- log.Logf(2, "%v", err)
- return err
- }
+ eg.Go(runner.SendCorpusTriaged)
}
+ go func() {
+ <-ctx.Done()
+ runner.Stop()
+ }()
return serv.connectionLoop(runner)
}
+// Used for errors incompatible with further RPCServer operation.
+var errFatal = errors.New("aborting RPC server")
+
func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) {
modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files)
if err != nil {
@@ -307,31 +343,36 @@ func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha
log.Logf(0, "machine check failed: %v", infoReq.Error)
serv.checkFailures++
if serv.checkFailures == 10 {
- log.Fatalf("machine check failing")
+ return handshakeResult{}, fmt.Errorf("%w: machine check failed too many times", errFatal)
}
return handshakeResult{}, errors.New("machine check failed")
}
+ var retErr error
serv.infoOnce.Do(func() {
serv.StatModules.Add(len(modules))
serv.canonicalModules = cover.NewCanonicalizer(modules, serv.cfg.Cover)
- serv.coverFilter = serv.mgr.CoverageFilter(modules)
- // Flatbuffers don't do deep copy of byte slices,
- // so clone manually since we pass it a goroutine.
- for _, file := range infoReq.Files {
- file.Data = slices.Clone(file.Data)
+ var err error
+ serv.coverFilter, err = serv.mgr.CoverageFilter(modules)
+ if err != nil {
+ retErr = fmt.Errorf("%w: %w", errFatal, err)
+ return
}
- // Now execute check programs.
- go func() {
- if err := serv.runCheck(infoReq); err != nil {
- log.Fatalf("check failed: %v", err)
- }
- }()
})
+ if retErr != nil {
+ return handshakeResult{}, retErr
+ }
+ // Flatbuffers don't do deep copy of byte slices,
+ // so clone manually since we may later pass it a goroutine.
+ for _, file := range infoReq.Files {
+ file.Data = slices.Clone(file.Data)
+ }
canonicalizer := serv.canonicalModules.NewInstance(modules)
return handshakeResult{
CovFilter: canonicalizer.Decanonicalize(serv.coverFilter),
MachineInfo: machineInfo,
Canonicalizer: canonicalizer,
+ Files: infoReq.Files,
+ Features: infoReq.Features,
}, nil
}
@@ -371,10 +412,8 @@ func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error {
return nil
}
-func (serv *server) runCheck(info *flatrpc.InfoRequest) error {
- // TODO: take context as a parameter.
- enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(context.Background(),
- info.Files, info.Features)
+func (serv *server) runCheck(ctx context.Context, info *handshakeResult) error {
+ enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(ctx, info.Files, info.Features)
enabledCalls, transitivelyDisabled := serv.target.TransitivelyEnabledCalls(enabledCalls)
// Note: need to print disbled syscalls before failing due to an error.
// This helps to debug "all system calls are disabled".
@@ -386,7 +425,10 @@ func (serv *server) runCheck(info *flatrpc.InfoRequest) error {
}
enabledFeatures := features.Enabled()
serv.setupFeatures = features.NeedSetup()
- newSource := serv.mgr.MachineChecked(enabledFeatures, enabledCalls)
+ newSource, err := serv.mgr.MachineChecked(enabledFeatures, enabledCalls)
+ if err != nil {
+ return err
+ }
serv.baseSource.Store(newSource)
serv.checkDone.Store(true)
return nil
diff --git a/pkg/rpcserver/rpcserver_test.go b/pkg/rpcserver/rpcserver_test.go
index a885ad720..0c4984e93 100644
--- a/pkg/rpcserver/rpcserver_test.go
+++ b/pkg/rpcserver/rpcserver_test.go
@@ -4,10 +4,12 @@
package rpcserver
import (
+ "context"
"net"
"testing"
"github.com/stretchr/testify/assert"
+ "golang.org/x/sync/errgroup"
"github.com/google/syzkaller/pkg/flatrpc"
"github.com/google/syzkaller/pkg/mgrconfig"
@@ -212,7 +214,11 @@ func TestHandleConn(t *testing.T) {
serv.CreateInstance(1, injectExec, nil)
go flatrpc.Send(clientConn, tt.req)
- serv.handleConn(serverConn)
+ var eg errgroup.Group
+ serv.handleConn(context.Background(), &eg, serverConn)
+ if err := eg.Wait(); err != nil {
+ t.Fatal(err)
+ }
})
}
}
diff --git a/pkg/rpcserver/runner.go b/pkg/rpcserver/runner.go
index a6b763b9a..de38d29f7 100644
--- a/pkg/rpcserver/runner.go
+++ b/pkg/rpcserver/runner.go
@@ -77,12 +77,14 @@ type handshakeConfig struct {
}
type handshakeResult struct {
+ Files []*flatrpc.FileInfo
+ Features []*flatrpc.FeatureInfo
CovFilter []uint64
MachineInfo []byte
Canonicalizer *cover.CanonicalizerInstance
}
-func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error {
+func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) (handshakeResult, error) {
if runner.updInfo != nil {
runner.updInfo(func(info *dispatcher.Info) {
info.Status = "handshake"
@@ -104,21 +106,21 @@ func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error
Features: cfg.Features,
}
if err := flatrpc.Send(conn, connectReply); err != nil {
- return err
+ return handshakeResult{}, err
}
infoReq, err := flatrpc.Recv[*flatrpc.InfoRequestRaw](conn)
if err != nil {
- return err
+ return handshakeResult{}, err
}
ret, err := cfg.Callback(infoReq)
if err != nil {
- return err
+ return handshakeResult{}, err
}
infoReply := &flatrpc.InfoReply{
CoverFilter: ret.CovFilter,
}
if err := flatrpc.Send(conn, infoReply); err != nil {
- return err
+ return handshakeResult{}, err
}
runner.mu.Lock()
runner.conn = conn
@@ -132,7 +134,7 @@ func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error
info.DetailedStatus = runner.QueryStatus
})
}
- return nil
+ return ret, nil
}
func (runner *Runner) ConnectionLoop() error {
diff --git a/syz-manager/manager.go b/syz-manager/manager.go
index 7a85a6c9a..fdb4929d7 100644
--- a/syz-manager/manager.go
+++ b/syz-manager/manager.go
@@ -310,6 +310,13 @@ func RunManager(mode *Mode, cfg *mgrconfig.Config) {
if err := mgr.serv.Listen(); err != nil {
log.Fatalf("failed to start rpc server: %v", err)
}
+ ctx := vm.ShutdownCtx()
+ go func() {
+ err := mgr.serv.Serve(ctx)
+ if err != nil {
+ log.Fatalf("%s", err)
+ }
+ }()
log.Logf(0, "serving rpc on tcp://%v", mgr.serv.Port())
if cfg.DashboardAddr != "" {
@@ -355,7 +362,6 @@ func RunManager(mode *Mode, cfg *mgrconfig.Config) {
mgr.http.ReproLoop = mgr.reproLoop
mgr.http.TogglePause = mgr.pool.TogglePause
- ctx := vm.ShutdownCtx()
if mgr.cfg.HTTP != "" {
go func() {
err := mgr.http.Serve(ctx)
@@ -1088,9 +1094,10 @@ func (mgr *Manager) BugFrames() (leaks, races []string) {
return
}
-func (mgr *Manager) MachineChecked(features flatrpc.Feature, enabledSyscalls map[*prog.Syscall]bool) queue.Source {
+func (mgr *Manager) MachineChecked(features flatrpc.Feature,
+ enabledSyscalls map[*prog.Syscall]bool) (queue.Source, error) {
if len(enabledSyscalls) == 0 {
- log.Fatalf("all system calls are disabled")
+ return nil, fmt.Errorf("all system calls are disabled")
}
if mgr.mode.ExitAfterMachineCheck {
mgr.exit(mgr.mode.Name)
@@ -1165,15 +1172,15 @@ func (mgr *Manager) MachineChecked(features flatrpc.Feature, enabledSyscalls map
mgr.serv = nil
return queue.Callback(func() *queue.Request {
return nil
- })
+ }), nil
}
- return source
+ return source, nil
} else if mgr.mode == ModeCorpusRun {
ctx := &corpusRunner{
candidates: candidates,
rnd: rand.New(rand.NewSource(time.Now().UnixNano())),
}
- return queue.DefaultOpts(ctx, opts)
+ return queue.DefaultOpts(ctx, opts), nil
} else if mgr.mode == ModeRunTests {
ctx := &runtest.Context{
Dir: filepath.Join(mgr.cfg.Syzkaller, "sys", mgr.cfg.Target.OS, "test"),
@@ -1195,7 +1202,7 @@ func (mgr *Manager) MachineChecked(features flatrpc.Feature, enabledSyscalls map
}
mgr.exit("tests")
}()
- return ctx
+ return ctx, nil
} else if mgr.mode == ModeIfaceProbe {
exec := queue.Plain()
go func() {
@@ -1209,7 +1216,7 @@ func (mgr *Manager) MachineChecked(features flatrpc.Feature, enabledSyscalls map
}
mgr.exit("interface probe")
}()
- return exec
+ return exec, nil
}
panic(fmt.Sprintf("unexpected mode %q", mgr.mode.Name))
}
@@ -1430,11 +1437,11 @@ func (mgr *Manager) dashboardReproTasks() {
}
}
-func (mgr *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
+func (mgr *Manager) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) {
mgr.reportGenerator.Init(modules)
filters, err := manager.PrepareCoverageFilters(mgr.reportGenerator, mgr.cfg, true)
if err != nil {
- log.Fatalf("failed to init coverage filter: %v", err)
+ return nil, fmt.Errorf("failed to init coverage filter: %w", err)
}
mgr.coverFilters = filters
mgr.http.Cover.Store(&manager.CoverageInfo{
@@ -1446,7 +1453,7 @@ func (mgr *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 {
for pc := range filters.ExecutorFilter {
pcs = append(pcs, pc)
}
- return pcs
+ return pcs, nil
}
func publicWebAddr(addr string) string {