diff options
| author | Aleksandr Nogikh <nogikh@google.com> | 2025-01-24 17:17:53 +0100 |
|---|---|---|
| committer | Aleksandr Nogikh <nogikh@google.com> | 2025-01-29 10:31:50 +0000 |
| commit | 94e13671726abbcf766f9b4aacd2ee04de59dcbd (patch) | |
| tree | 699abaa69f3509857969ca2d7ff3ea001df14c88 | |
| parent | 6eea27042142c1c5e810b642deb831a8ed55b3da (diff) | |
pkg/rpcserver: refactor to remove Fatalf calls
Apply necessary changes to pkg/flatrpc and pkg/manager as well.
| -rw-r--r-- | pkg/flatrpc/conn.go | 65 | ||||
| -rw-r--r-- | pkg/flatrpc/conn_test.go | 104 | ||||
| -rw-r--r-- | pkg/manager/diff.go | 13 | ||||
| -rw-r--r-- | pkg/rpcserver/local.go | 59 | ||||
| -rw-r--r-- | pkg/rpcserver/mocks/Manager.go | 28 | ||||
| -rw-r--r-- | pkg/rpcserver/rpcserver.go | 104 | ||||
| -rw-r--r-- | pkg/rpcserver/rpcserver_test.go | 8 | ||||
| -rw-r--r-- | pkg/rpcserver/runner.go | 14 | ||||
| -rw-r--r-- | syz-manager/manager.go | 29 |
9 files changed, 275 insertions, 149 deletions
diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go index 33eca07e1..47b265493 100644 --- a/pkg/flatrpc/conn.go +++ b/pkg/flatrpc/conn.go @@ -4,6 +4,7 @@ package flatrpc import ( + "context" "errors" "fmt" "io" @@ -13,9 +14,10 @@ import ( "sync" "unsafe" - "github.com/google/flatbuffers/go" + flatbuffers "github.com/google/flatbuffers/go" "github.com/google/syzkaller/pkg/log" "github.com/google/syzkaller/pkg/stat" + "golang.org/x/sync/errgroup" ) var ( @@ -30,38 +32,55 @@ type Serv struct { ln net.Listener } -func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) { +func Listen(addr string) (*Serv, error) { ln, err := net.Listen("tcp", addr) if err != nil { return nil, err } - go func() { - for { - conn, err := ln.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - break - } - var netErr *net.OpError - if errors.As(err, &netErr) && !netErr.Temporary() { - log.Fatalf("flatrpc: failed to accept: %v", err) - } - log.Logf(0, "flatrpc: failed to accept: %v", err) - continue - } - go func() { - c := NewConn(conn) - defer c.Close() - handler(c) - }() - } - }() return &Serv{ Addr: ln.Addr().(*net.TCPAddr), ln: ln, }, nil } +// Serve accepts incoming connections and calls handler for each of them. +// An error returned from the handler stops the server and aborts the whole processing. +func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error { + eg, ctx := errgroup.WithContext(baseCtx) + go func() { + // If the context is cancelled, stop the server. + <-ctx.Done() + s.Close() + }() + for { + conn, err := s.ln.Accept() + if err != nil && errors.Is(err, net.ErrClosed) { + break + } + if err != nil { + var netErr *net.OpError + if errors.As(err, &netErr) && !netErr.Temporary() { + return fmt.Errorf("flatrpc: failed to accept: %w", err) + } + log.Logf(0, "flatrpc: failed to accept: %v", err) + continue + } + eg.Go(func() error { + connCtx, cancel := context.WithCancel(ctx) + defer cancel() + + c := NewConn(conn) + // Closing the server does not automatically close all the connections. + go func() { + <-connCtx.Done() + c.Close() + }() + return handler(connCtx, c) + }) + } + return eg.Wait() +} + func (s *Serv) Close() error { return s.ln.Close() } diff --git a/pkg/flatrpc/conn_test.go b/pkg/flatrpc/conn_test.go index 132fd1cdd..4b108a5a4 100644 --- a/pkg/flatrpc/conn_test.go +++ b/pkg/flatrpc/conn_test.go @@ -4,15 +4,18 @@ package flatrpc import ( + "context" + "fmt" "net" "os" + "reflect" "runtime/debug" "sync" "syscall" "testing" "time" - "github.com/google/flatbuffers/go" + flatbuffers "github.com/google/flatbuffers/go" "github.com/stretchr/testify/assert" ) @@ -40,35 +43,39 @@ func TestConn(t *testing.T) { }, } - done := make(chan bool) - defer func() { - <-done - }() - serv, err := ListenAndServe(":0", func(c *Conn) { - defer close(done) - connectReqGot, err := Recv[*ConnectRequestRaw](c) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, connectReq, connectReqGot) - - if err := Send(c, connectReply); err != nil { - t.Fatal(err) - } - - for i := 0; i < 10; i++ { - got, err := Recv[*ExecutorMessageRaw](c) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, executorMsg, got) - } - }) + serv, err := Listen(":0") if err != nil { t.Fatal(err) } - defer serv.Close() + done := make(chan error) + go func() { + done <- serv.Serve(context.Background(), + func(_ context.Context, c *Conn) error { + connectReqGot, err := Recv[*ConnectRequestRaw](c) + if err != nil { + return err + } + if !reflect.DeepEqual(connectReq, connectReqGot) { + return fmt.Errorf("connectReq != connectReqGot") + } + + if err := Send(c, connectReply); err != nil { + return err + } + + for i := 0; i < 10; i++ { + got, err := Recv[*ExecutorMessageRaw](c) + if err != nil { + return nil + } + if !reflect.DeepEqual(executorMsg, got) { + return fmt.Errorf("executorMsg !=got") + } + } + return nil + }) + }() c := dial(t, serv.Addr.String()) defer c.Close() @@ -87,6 +94,11 @@ func TestConn(t *testing.T) { t.Fatal(err) } } + + serv.Close() + if err := <-done; err != nil { + t.Fatal(err) + } } func BenchmarkConn(b *testing.B) { @@ -103,26 +115,27 @@ func BenchmarkConn(b *testing.B) { Files: []string{"file1"}, } - done := make(chan bool) - defer func() { - <-done - }() - serv, err := ListenAndServe(":0", func(c *Conn) { - defer close(done) - for i := 0; i < b.N; i++ { - _, err := Recv[*ConnectRequestRaw](c) - if err != nil { - b.Fatal(err) - } - if err := Send(c, connectReply); err != nil { - b.Fatal(err) - } - } - }) + serv, err := Listen(":0") if err != nil { b.Fatal(err) } - defer serv.Close() + done := make(chan error) + + go func() { + done <- serv.Serve(context.Background(), + func(_ context.Context, c *Conn) error { + for i := 0; i < b.N; i++ { + _, err := Recv[*ConnectRequestRaw](c) + if err != nil { + return err + } + if err := Send(c, connectReply); err != nil { + return err + } + } + return nil + }) + }() c := dial(b, serv.Addr.String()) defer c.Close() @@ -138,6 +151,11 @@ func BenchmarkConn(b *testing.B) { b.Fatal(err) } } + + serv.Close() + if err := <-done; err != nil { + b.Fatal(err) + } } func dial(t testing.TB, addr string) *Conn { diff --git a/pkg/manager/diff.go b/pkg/manager/diff.go index 17d056e1f..8218ce924 100644 --- a/pkg/manager/diff.go +++ b/pkg/manager/diff.go @@ -305,9 +305,10 @@ func (kc *kernelContext) BugFrames() (leaks, races []string) { return nil, nil } -func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source { +func (kc *kernelContext) MachineChecked(features flatrpc.Feature, + syscalls map[*prog.Syscall]bool) (queue.Source, error) { if len(syscalls) == 0 { - log.Fatalf("all system calls are disabled") + return nil, fmt.Errorf("all system calls are disabled") } log.Logf(0, "%s: machine check complete", kc.name) kc.features = features @@ -319,7 +320,7 @@ func (kc *kernelContext) MachineChecked(features flatrpc.Feature, syscalls map[* source = kc.source } opts := fuzzer.DefaultExecOpts(kc.cfg, features, kc.debug) - return queue.DefaultOpts(source, opts) + return queue.DefaultOpts(source, opts), nil } func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source { @@ -383,11 +384,11 @@ func (kc *kernelContext) setupFuzzer(features flatrpc.Feature, syscalls map[*pro return fuzzerObj } -func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { +func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) { kc.reportGenerator.Init(modules) filters, err := PrepareCoverageFilters(kc.reportGenerator, kc.cfg, false) if err != nil { - log.Fatalf("failed to init coverage filter: %v", err) + return nil, fmt.Errorf("failed to init coverage filter: %w", err) } kc.coverFilters = filters log.Logf(0, "cover filter size: %d", len(filters.ExecutorFilter)) @@ -402,7 +403,7 @@ func (kc *kernelContext) CoverageFilter(modules []*vminfo.KernelModule) []uint64 for pc := range filters.ExecutorFilter { pcs = append(pcs, pc) } - return pcs + return pcs, nil } func (kc *kernelContext) fuzzerInstance(ctx context.Context, inst *vm.Instance, updInfo dispatcher.UpdateInfo) { diff --git a/pkg/rpcserver/local.go b/pkg/rpcserver/local.go index 4ab8827ae..e4e128dcf 100644 --- a/pkg/rpcserver/local.go +++ b/pkg/rpcserver/local.go @@ -16,6 +16,7 @@ import ( "github.com/google/syzkaller/pkg/signal" "github.com/google/syzkaller/pkg/vminfo" "github.com/google/syzkaller/prog" + "golang.org/x/sync/errgroup" ) type LocalConfig struct { @@ -39,26 +40,32 @@ func RunLocal(cfg *LocalConfig) error { if cfg.VMArch == "" { cfg.VMArch = cfg.Target.Arch } + if cfg.Context == nil { + cfg.Context = context.Background() + } cfg.UseCoverEdges = true cfg.FilterSignal = true cfg.RPC = ":0" cfg.PrintMachineCheck = log.V(1) cfg.Stats = NewStats() - ctx := &local{ + localCtx := &local{ cfg: cfg, setupDone: make(chan bool), } - serv := newImpl(&cfg.Config, ctx) + serv := newImpl(&cfg.Config, localCtx) if err := serv.Listen(); err != nil { return err } defer serv.Close() - ctx.serv = serv + localCtx.serv = serv // setupDone synchronizes assignment to ctx.serv and read of ctx.serv in MachineChecked // for the race detector b/c it does not understand the synchronization via TCP socket connect/accept. - close(ctx.setupDone) + close(localCtx.setupDone) + + cancelCtx, cancel := context.WithCancel(cfg.Context) + eg, ctx := errgroup.WithContext(cancelCtx) - id := 0 + const id = 0 connErr := serv.CreateInstance(id, nil, nil) defer serv.ShutdownInstance(id, true) @@ -73,7 +80,7 @@ func RunLocal(cfg *LocalConfig) error { cfg.Executor, }, args...) } - cmd := exec.Command(bin, args...) + cmd := exec.CommandContext(ctx, bin, args...) cmd.Dir = cfg.Dir if cfg.Debug || cfg.GDB { cmd.Stdout = os.Stdout @@ -82,28 +89,32 @@ func RunLocal(cfg *LocalConfig) error { if cfg.GDB { cmd.Stdin = os.Stdin } - if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to start executor: %w", err) - } - res := make(chan error, 1) - go func() { res <- cmd.Wait() }() + eg.Go(func() error { + return serv.Serve(ctx) + }) + eg.Go(func() error { + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start executor: %w", err) + } + err := cmd.Wait() + // Note that we ignore the error if we killed the process by closing the context. + if err == nil || ctx.Err() != nil { + return nil + } + return fmt.Errorf("executor process exited: %w", err) + }) + shutdown := make(chan struct{}) if cfg.HandleInterrupts { osutil.HandleInterrupts(shutdown) } - var cmdErr error select { + case <-ctx.Done(): case <-shutdown: - case <-cfg.Context.Done(): case <-connErr: - case err := <-res: - cmdErr = fmt.Errorf("executor process exited: %w", err) - } - if cmdErr == nil { - cmd.Process.Kill() - <-res } - return cmdErr + cancel() + return eg.Wait() } type local struct { @@ -112,10 +123,10 @@ type local struct { setupDone chan bool } -func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source { +func (ctx *local) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) { <-ctx.setupDone ctx.serv.TriagedCorpus() - return ctx.cfg.MachineChecked(features, syscalls) + return ctx.cfg.MachineChecked(features, syscalls), nil } func (ctx *local) BugFrames() ([]string, []string) { @@ -126,6 +137,6 @@ func (ctx *local) MaxSignal() signal.Signal { return signal.FromRaw(ctx.cfg.MaxSignal, 0) } -func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { - return ctx.cfg.CoverFilter +func (ctx *local) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) { + return ctx.cfg.CoverFilter, nil } diff --git a/pkg/rpcserver/mocks/Manager.go b/pkg/rpcserver/mocks/Manager.go index 810b5028f..c0c9621de 100644 --- a/pkg/rpcserver/mocks/Manager.go +++ b/pkg/rpcserver/mocks/Manager.go @@ -53,7 +53,7 @@ func (_m *Manager) BugFrames() ([]string, []string) { } // CoverageFilter provides a mock function with given fields: modules -func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { +func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) { ret := _m.Called(modules) if len(ret) == 0 { @@ -61,6 +61,10 @@ func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { } var r0 []uint64 + var r1 error + if rf, ok := ret.Get(0).(func([]*vminfo.KernelModule) ([]uint64, error)); ok { + return rf(modules) + } if rf, ok := ret.Get(0).(func([]*vminfo.KernelModule) []uint64); ok { r0 = rf(modules) } else { @@ -69,11 +73,17 @@ func (_m *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { } } - return r0 + if rf, ok := ret.Get(1).(func([]*vminfo.KernelModule) error); ok { + r1 = rf(modules) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // MachineChecked provides a mock function with given fields: features, syscalls -func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source { +func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) { ret := _m.Called(features, syscalls) if len(ret) == 0 { @@ -81,6 +91,10 @@ func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.S } var r0 queue.Source + var r1 error + if rf, ok := ret.Get(0).(func(flatrpc.Feature, map[*prog.Syscall]bool) (queue.Source, error)); ok { + return rf(features, syscalls) + } if rf, ok := ret.Get(0).(func(flatrpc.Feature, map[*prog.Syscall]bool) queue.Source); ok { r0 = rf(features, syscalls) } else { @@ -89,7 +103,13 @@ func (_m *Manager) MachineChecked(features flatrpc.Feature, syscalls map[*prog.S } } - return r0 + if rf, ok := ret.Get(1).(func(flatrpc.Feature, map[*prog.Syscall]bool) error); ok { + r1 = rf(features, syscalls) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // MaxSignal provides a mock function with given fields: diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go index 003c5f4b9..b3b518b04 100644 --- a/pkg/rpcserver/rpcserver.go +++ b/pkg/rpcserver/rpcserver.go @@ -28,6 +28,7 @@ import ( "github.com/google/syzkaller/prog" "github.com/google/syzkaller/sys/targets" "github.com/google/syzkaller/vm/dispatcher" + "golang.org/x/sync/errgroup" ) type Config struct { @@ -63,8 +64,8 @@ type RemoteConfig struct { type Manager interface { MaxSignal() signal.Signal BugFrames() (leaks []string, races []string) - MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) queue.Source - CoverageFilter(modules []*vminfo.KernelModule) []uint64 + MachineChecked(features flatrpc.Feature, syscalls map[*prog.Syscall]bool) (queue.Source, error) + CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) } type Server interface { @@ -72,6 +73,7 @@ type Server interface { Close() error Port() int TriagedCorpus() + Serve(context.Context) error CreateInstance(id int, injectExec chan<- bool, updInfo dispatcher.UpdateInfo) chan error ShutdownInstance(id int, crashed bool, extraExecs ...report.ExecutorInfo) ([]ExecRecord, []byte) StopFuzzing(id int) @@ -88,6 +90,7 @@ type server struct { checker *vminfo.Checker infoOnce sync.Once + checkOnce sync.Once checkDone atomic.Bool checkFailures int baseSource *queue.DynamicSourceCtl @@ -217,7 +220,7 @@ func (serv *server) Close() error { } func (serv *server) Listen() error { - s, err := flatrpc.ListenAndServe(serv.cfg.RPC, serv.handleConn) + s, err := flatrpc.Listen(serv.cfg.RPC) if err != nil { return err } @@ -225,15 +228,25 @@ func (serv *server) Listen() error { return nil } +func (serv *server) Serve(ctx context.Context) error { + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + return serv.serv.Serve(ctx, func(ctx context.Context, conn *flatrpc.Conn) error { + return serv.handleConn(ctx, g, conn) + }) + }) + return g.Wait() +} + func (serv *server) Port() int { return serv.serv.Addr.Port } -func (serv *server) handleConn(conn *flatrpc.Conn) { +func (serv *server) handleConn(ctx context.Context, eg *errgroup.Group, conn *flatrpc.Conn) error { connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn) if err != nil { log.Logf(1, "%s", err) - return + return nil } id := int(connectReq.Id) log.Logf(1, "runner %v connected", id) @@ -246,7 +259,8 @@ func (serv *server) handleConn(conn *flatrpc.Conn) { serv.ShutdownInstance(id, true) }() } else if err := checkRevisions(connectReq, serv.cfg.Target); err != nil { - log.Fatal(err) + // This is a fatal error. + return err } serv.StatVMRestarts.Add(1) @@ -255,15 +269,23 @@ func (serv *server) handleConn(conn *flatrpc.Conn) { serv.mu.Unlock() if runner == nil { log.Logf(2, "unknown VM %v tries to connect", id) - return + return nil } - err = serv.handleRunnerConn(runner, conn) + err = serv.handleRunnerConn(ctx, eg, runner, conn) log.Logf(2, "runner %v: %v", id, err) + + if err != nil && errors.Is(err, errFatal) { + log.Logf(0, "%v", err) + return err + } + runner.resultCh <- err + return nil } -func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { +func (serv *server) handleRunnerConn(ctx context.Context, eg *errgroup.Group, + runner *Runner, conn *flatrpc.Conn) error { opts := &handshakeConfig{ VMLess: serv.cfg.VMLess, Files: serv.checker.RequiredFiles(), @@ -278,22 +300,36 @@ func (serv *server) handleRunnerConn(runner *Runner, conn *flatrpc.Conn) error { opts.Features = serv.cfg.Features } - err := runner.Handshake(conn, opts) + info, err := runner.Handshake(conn, opts) if err != nil { log.Logf(1, "%v", err) return err } + serv.checkOnce.Do(func() { + // Run the machine check. + eg.Go(func() error { + if err := serv.runCheck(ctx, &info); err != nil { + return fmt.Errorf("%w: %w", errFatal, err) + } + return nil + }) + }) + if serv.triagedCorpus.Load() { - if err := runner.SendCorpusTriaged(); err != nil { - log.Logf(2, "%v", err) - return err - } + eg.Go(runner.SendCorpusTriaged) } + go func() { + <-ctx.Done() + runner.Stop() + }() return serv.connectionLoop(runner) } +// Used for errors incompatible with further RPCServer operation. +var errFatal = errors.New("aborting RPC server") + func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) { modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files) if err != nil { @@ -307,31 +343,36 @@ func (serv *server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handsha log.Logf(0, "machine check failed: %v", infoReq.Error) serv.checkFailures++ if serv.checkFailures == 10 { - log.Fatalf("machine check failing") + return handshakeResult{}, fmt.Errorf("%w: machine check failed too many times", errFatal) } return handshakeResult{}, errors.New("machine check failed") } + var retErr error serv.infoOnce.Do(func() { serv.StatModules.Add(len(modules)) serv.canonicalModules = cover.NewCanonicalizer(modules, serv.cfg.Cover) - serv.coverFilter = serv.mgr.CoverageFilter(modules) - // Flatbuffers don't do deep copy of byte slices, - // so clone manually since we pass it a goroutine. - for _, file := range infoReq.Files { - file.Data = slices.Clone(file.Data) + var err error + serv.coverFilter, err = serv.mgr.CoverageFilter(modules) + if err != nil { + retErr = fmt.Errorf("%w: %w", errFatal, err) + return } - // Now execute check programs. - go func() { - if err := serv.runCheck(infoReq); err != nil { - log.Fatalf("check failed: %v", err) - } - }() }) + if retErr != nil { + return handshakeResult{}, retErr + } + // Flatbuffers don't do deep copy of byte slices, + // so clone manually since we may later pass it a goroutine. + for _, file := range infoReq.Files { + file.Data = slices.Clone(file.Data) + } canonicalizer := serv.canonicalModules.NewInstance(modules) return handshakeResult{ CovFilter: canonicalizer.Decanonicalize(serv.coverFilter), MachineInfo: machineInfo, Canonicalizer: canonicalizer, + Files: infoReq.Files, + Features: infoReq.Features, }, nil } @@ -371,10 +412,8 @@ func checkRevisions(a *flatrpc.ConnectRequest, target *prog.Target) error { return nil } -func (serv *server) runCheck(info *flatrpc.InfoRequest) error { - // TODO: take context as a parameter. - enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(context.Background(), - info.Files, info.Features) +func (serv *server) runCheck(ctx context.Context, info *handshakeResult) error { + enabledCalls, disabledCalls, features, checkErr := serv.checker.Run(ctx, info.Files, info.Features) enabledCalls, transitivelyDisabled := serv.target.TransitivelyEnabledCalls(enabledCalls) // Note: need to print disbled syscalls before failing due to an error. // This helps to debug "all system calls are disabled". @@ -386,7 +425,10 @@ func (serv *server) runCheck(info *flatrpc.InfoRequest) error { } enabledFeatures := features.Enabled() serv.setupFeatures = features.NeedSetup() - newSource := serv.mgr.MachineChecked(enabledFeatures, enabledCalls) + newSource, err := serv.mgr.MachineChecked(enabledFeatures, enabledCalls) + if err != nil { + return err + } serv.baseSource.Store(newSource) serv.checkDone.Store(true) return nil diff --git a/pkg/rpcserver/rpcserver_test.go b/pkg/rpcserver/rpcserver_test.go index a885ad720..0c4984e93 100644 --- a/pkg/rpcserver/rpcserver_test.go +++ b/pkg/rpcserver/rpcserver_test.go @@ -4,10 +4,12 @@ package rpcserver import ( + "context" "net" "testing" "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" "github.com/google/syzkaller/pkg/flatrpc" "github.com/google/syzkaller/pkg/mgrconfig" @@ -212,7 +214,11 @@ func TestHandleConn(t *testing.T) { serv.CreateInstance(1, injectExec, nil) go flatrpc.Send(clientConn, tt.req) - serv.handleConn(serverConn) + var eg errgroup.Group + serv.handleConn(context.Background(), &eg, serverConn) + if err := eg.Wait(); err != nil { + t.Fatal(err) + } }) } } diff --git a/pkg/rpcserver/runner.go b/pkg/rpcserver/runner.go index a6b763b9a..de38d29f7 100644 --- a/pkg/rpcserver/runner.go +++ b/pkg/rpcserver/runner.go @@ -77,12 +77,14 @@ type handshakeConfig struct { } type handshakeResult struct { + Files []*flatrpc.FileInfo + Features []*flatrpc.FeatureInfo CovFilter []uint64 MachineInfo []byte Canonicalizer *cover.CanonicalizerInstance } -func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error { +func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) (handshakeResult, error) { if runner.updInfo != nil { runner.updInfo(func(info *dispatcher.Info) { info.Status = "handshake" @@ -104,21 +106,21 @@ func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error Features: cfg.Features, } if err := flatrpc.Send(conn, connectReply); err != nil { - return err + return handshakeResult{}, err } infoReq, err := flatrpc.Recv[*flatrpc.InfoRequestRaw](conn) if err != nil { - return err + return handshakeResult{}, err } ret, err := cfg.Callback(infoReq) if err != nil { - return err + return handshakeResult{}, err } infoReply := &flatrpc.InfoReply{ CoverFilter: ret.CovFilter, } if err := flatrpc.Send(conn, infoReply); err != nil { - return err + return handshakeResult{}, err } runner.mu.Lock() runner.conn = conn @@ -132,7 +134,7 @@ func (runner *Runner) Handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error info.DetailedStatus = runner.QueryStatus }) } - return nil + return ret, nil } func (runner *Runner) ConnectionLoop() error { diff --git a/syz-manager/manager.go b/syz-manager/manager.go index 7a85a6c9a..fdb4929d7 100644 --- a/syz-manager/manager.go +++ b/syz-manager/manager.go @@ -310,6 +310,13 @@ func RunManager(mode *Mode, cfg *mgrconfig.Config) { if err := mgr.serv.Listen(); err != nil { log.Fatalf("failed to start rpc server: %v", err) } + ctx := vm.ShutdownCtx() + go func() { + err := mgr.serv.Serve(ctx) + if err != nil { + log.Fatalf("%s", err) + } + }() log.Logf(0, "serving rpc on tcp://%v", mgr.serv.Port()) if cfg.DashboardAddr != "" { @@ -355,7 +362,6 @@ func RunManager(mode *Mode, cfg *mgrconfig.Config) { mgr.http.ReproLoop = mgr.reproLoop mgr.http.TogglePause = mgr.pool.TogglePause - ctx := vm.ShutdownCtx() if mgr.cfg.HTTP != "" { go func() { err := mgr.http.Serve(ctx) @@ -1088,9 +1094,10 @@ func (mgr *Manager) BugFrames() (leaks, races []string) { return } -func (mgr *Manager) MachineChecked(features flatrpc.Feature, enabledSyscalls map[*prog.Syscall]bool) queue.Source { +func (mgr *Manager) MachineChecked(features flatrpc.Feature, + enabledSyscalls map[*prog.Syscall]bool) (queue.Source, error) { if len(enabledSyscalls) == 0 { - log.Fatalf("all system calls are disabled") + return nil, fmt.Errorf("all system calls are disabled") } if mgr.mode.ExitAfterMachineCheck { mgr.exit(mgr.mode.Name) @@ -1165,15 +1172,15 @@ func (mgr *Manager) MachineChecked(features flatrpc.Feature, enabledSyscalls map mgr.serv = nil return queue.Callback(func() *queue.Request { return nil - }) + }), nil } - return source + return source, nil } else if mgr.mode == ModeCorpusRun { ctx := &corpusRunner{ candidates: candidates, rnd: rand.New(rand.NewSource(time.Now().UnixNano())), } - return queue.DefaultOpts(ctx, opts) + return queue.DefaultOpts(ctx, opts), nil } else if mgr.mode == ModeRunTests { ctx := &runtest.Context{ Dir: filepath.Join(mgr.cfg.Syzkaller, "sys", mgr.cfg.Target.OS, "test"), @@ -1195,7 +1202,7 @@ func (mgr *Manager) MachineChecked(features flatrpc.Feature, enabledSyscalls map } mgr.exit("tests") }() - return ctx + return ctx, nil } else if mgr.mode == ModeIfaceProbe { exec := queue.Plain() go func() { @@ -1209,7 +1216,7 @@ func (mgr *Manager) MachineChecked(features flatrpc.Feature, enabledSyscalls map } mgr.exit("interface probe") }() - return exec + return exec, nil } panic(fmt.Sprintf("unexpected mode %q", mgr.mode.Name)) } @@ -1430,11 +1437,11 @@ func (mgr *Manager) dashboardReproTasks() { } } -func (mgr *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { +func (mgr *Manager) CoverageFilter(modules []*vminfo.KernelModule) ([]uint64, error) { mgr.reportGenerator.Init(modules) filters, err := manager.PrepareCoverageFilters(mgr.reportGenerator, mgr.cfg, true) if err != nil { - log.Fatalf("failed to init coverage filter: %v", err) + return nil, fmt.Errorf("failed to init coverage filter: %w", err) } mgr.coverFilters = filters mgr.http.Cover.Store(&manager.CoverageInfo{ @@ -1446,7 +1453,7 @@ func (mgr *Manager) CoverageFilter(modules []*vminfo.KernelModule) []uint64 { for pc := range filters.ExecutorFilter { pcs = append(pcs, pc) } - return pcs + return pcs, nil } func publicWebAddr(addr string) string { |
