aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/rpcserver
diff options
context:
space:
mode:
authorAleksandr Nogikh <nogikh@google.com>2024-07-01 16:52:41 +0200
committerAleksandr Nogikh <nogikh@google.com>2024-07-04 12:34:14 +0000
commit092c1914a191f5858db674b4e367c6848500429e (patch)
treef48a5a5e2602cb75a3d37463bddb1212829889f3 /pkg/rpcserver
parentdc6bbff0c2fe403c39d8a1d057f668088b09069f (diff)
pkg/rpcserver: move handshake functionality to Runner
This allows for a more clean interface between RPCServer and Runner.
Diffstat (limited to 'pkg/rpcserver')
-rw-r--r--pkg/rpcserver/rpcserver.go96
-rw-r--r--pkg/rpcserver/runner.go60
2 files changed, 100 insertions, 56 deletions
diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go
index b6b660e6a..35b628715 100644
--- a/pkg/rpcserver/rpcserver.go
+++ b/pkg/rpcserver/rpcserver.go
@@ -222,11 +222,13 @@ func (serv *Server) RunnerStatus(name string) []byte {
}
func (serv *Server) handleConn(conn *flatrpc.Conn) {
- name, machineInfo, canonicalizer, err := serv.handshake(conn)
+ connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn)
if err != nil {
- log.Logf(1, "%v", err)
+ log.Logf(1, "%s", err)
return
}
+ name := connectReq.Name
+ log.Logf(1, "runner %v connected", name)
if serv.cfg.VMLess {
// There is no VM loop, so minic what it would do.
@@ -235,7 +237,10 @@ func (serv *Server) handleConn(conn *flatrpc.Conn) {
serv.StopFuzzing(name)
serv.ShutdownInstance(name, true)
}()
+ } else {
+ checkRevisions(connectReq, serv.cfg.Target)
}
+ serv.statVMRestarts.Add(1)
serv.mu.Lock()
runner := serv.runners[name]
@@ -244,13 +249,34 @@ func (serv *Server) handleConn(conn *flatrpc.Conn) {
log.Logf(2, "VM %v shut down before connect", name)
return
}
- serv.info[name] = VMState{StateFuzzing, time.Now()}
- runner.conn = conn
- runner.machineInfo = machineInfo
- runner.canonicalizer = canonicalizer
serv.mu.Unlock()
defer close(runner.finished)
+ opts := &handshakeConfig{
+ VMLess: serv.cfg.VMLess,
+ Files: serv.checker.RequiredFiles(),
+ Timeouts: serv.timeouts,
+ Callback: serv.handleMachineInfo,
+ }
+ opts.LeakFrames, opts.RaceFrames = serv.mgr.BugFrames()
+ if serv.checkDone.Load() {
+ opts.Features = serv.setupFeatures
+ } else {
+ opts.Files = append(opts.Files, serv.checker.CheckFiles()...)
+ opts.Globs = serv.target.RequiredGlobs()
+ opts.Features = serv.cfg.Features
+ }
+
+ err = runner.handshake(conn, opts)
+ if err != nil {
+ log.Logf(1, "%v", err)
+ return
+ }
+
+ serv.mu.Lock()
+ serv.info[name] = VMState{StateFuzzing, time.Now()}
+ serv.mu.Unlock()
+
if serv.triagedCorpus.Load() {
if err := runner.sendCorpusTriaged(); err != nil {
log.Logf(2, "%v", err)
@@ -262,46 +288,7 @@ func (serv *Server) handleConn(conn *flatrpc.Conn) {
log.Logf(2, "runner %v: %v", name, err)
}
-func (serv *Server) handshake(conn *flatrpc.Conn) (string, []byte, *cover.CanonicalizerInstance, error) {
- connectReq, err := flatrpc.Recv[*flatrpc.ConnectRequestRaw](conn)
- if err != nil {
- return "", nil, nil, err
- }
- log.Logf(1, "runner %v connected", connectReq.Name)
- if !serv.cfg.VMLess {
- checkRevisions(connectReq, serv.cfg.Target)
- }
- serv.statVMRestarts.Add(1)
-
- leaks, races := serv.mgr.BugFrames()
- connectReply := &flatrpc.ConnectReply{
- Debug: serv.cfg.Debug,
- Cover: serv.cfg.Cover,
- CoverEdges: serv.cfg.UseCoverEdges,
- Kernel64Bit: serv.sysTarget.PtrSize == 8,
- Procs: int32(serv.cfg.Procs),
- Slowdown: int32(serv.timeouts.Slowdown),
- SyscallTimeoutMs: int32(serv.timeouts.Syscall / time.Millisecond),
- ProgramTimeoutMs: int32(serv.timeouts.Program / time.Millisecond),
- LeakFrames: leaks,
- RaceFrames: races,
- }
- connectReply.Files = serv.checker.RequiredFiles()
- if serv.checkDone.Load() {
- connectReply.Features = serv.setupFeatures
- } else {
- connectReply.Files = append(connectReply.Files, serv.checker.CheckFiles()...)
- connectReply.Globs = serv.target.RequiredGlobs()
- connectReply.Features = serv.cfg.Features
- }
- if err := flatrpc.Send(conn, connectReply); err != nil {
- return "", nil, nil, err
- }
-
- infoReq, err := flatrpc.Recv[*flatrpc.InfoRequestRaw](conn)
- if err != nil {
- return "", nil, nil, err
- }
+func (serv *Server) handleMachineInfo(infoReq *flatrpc.InfoRequestRawT) (handshakeResult, error) {
modules, machineInfo, err := serv.checker.MachineInfo(infoReq.Files)
if err != nil {
log.Logf(0, "parsing of machine info failed: %v", err)
@@ -316,9 +303,8 @@ func (serv *Server) handshake(conn *flatrpc.Conn) (string, []byte, *cover.Canoni
if serv.checkFailures == 10 {
log.Fatalf("machine check failing")
}
- return "", nil, nil, errors.New("machine check failed")
+ return handshakeResult{}, errors.New("machine check failed")
}
-
serv.infoOnce.Do(func() {
serv.canonicalModules = cover.NewCanonicalizer(modules, serv.cfg.Cover)
serv.coverFilter = serv.mgr.CoverageFilter(modules)
@@ -339,15 +325,12 @@ func (serv *Server) handshake(conn *flatrpc.Conn) (string, []byte, *cover.Canoni
}
}()
})
-
canonicalizer := serv.canonicalModules.NewInstance(modules)
- infoReply := &flatrpc.InfoReply{
- CoverFilter: canonicalizer.Decanonicalize(serv.coverFilter),
- }
- if err := flatrpc.Send(conn, infoReply); err != nil {
- return "", nil, nil, err
- }
- return connectReq.Name, machineInfo, canonicalizer, nil
+ return handshakeResult{
+ CovFilter: canonicalizer.Decanonicalize(serv.coverFilter),
+ MachineInfo: machineInfo,
+ Canonicalizer: canonicalizer,
+ }, nil
}
func (serv *Server) connectionLoop(runner *Runner) error {
@@ -458,6 +441,7 @@ func (serv *Server) CreateInstance(name string, injectExec chan<- bool) {
runner := &Runner{
source: serv.execSource,
cover: serv.cfg.Cover,
+ coverEdges: serv.cfg.UseCoverEdges,
filterSignal: serv.cfg.FilterSignal,
debug: serv.cfg.Debug,
sysTarget: serv.sysTarget,
diff --git a/pkg/rpcserver/runner.go b/pkg/rpcserver/runner.go
index 219ef2c35..ed45e2e31 100644
--- a/pkg/rpcserver/runner.go
+++ b/pkg/rpcserver/runner.go
@@ -25,6 +25,7 @@ type Runner struct {
source queue.Source
procs int
cover bool
+ coverEdges bool
filterSignal bool
debug bool
sysTarget *targets.Target
@@ -52,6 +53,65 @@ type runnerStats struct {
statNoExecDuration *stats.Val
}
+type handshakeConfig struct {
+ VMLess bool
+ Timeouts targets.Timeouts
+ LeakFrames []string
+ RaceFrames []string
+ Files []string
+ Globs []string
+ Features flatrpc.Feature
+
+ // Callback() is called in the middle of the handshake process.
+ // The return arguments are the coverage filter and the (possible) error.
+ Callback func(*flatrpc.InfoRequestRawT) (handshakeResult, error)
+}
+
+type handshakeResult struct {
+ CovFilter []uint64
+ MachineInfo []byte
+ Canonicalizer *cover.CanonicalizerInstance
+}
+
+func (runner *Runner) handshake(conn *flatrpc.Conn, cfg *handshakeConfig) error {
+ connectReply := &flatrpc.ConnectReply{
+ Debug: runner.debug,
+ Cover: runner.cover,
+ CoverEdges: runner.coverEdges,
+ Kernel64Bit: runner.sysTarget.PtrSize == 8,
+ Procs: int32(runner.procs),
+ Slowdown: int32(cfg.Timeouts.Slowdown),
+ SyscallTimeoutMs: int32(cfg.Timeouts.Syscall / time.Millisecond),
+ ProgramTimeoutMs: int32(cfg.Timeouts.Program / time.Millisecond),
+ LeakFrames: cfg.LeakFrames,
+ RaceFrames: cfg.RaceFrames,
+ Files: cfg.Files,
+ Globs: cfg.Globs,
+ Features: cfg.Features,
+ }
+ if err := flatrpc.Send(conn, connectReply); err != nil {
+ return err
+ }
+ infoReq, err := flatrpc.Recv[*flatrpc.InfoRequestRaw](conn)
+ if err != nil {
+ return err
+ }
+ ret, err := cfg.Callback(infoReq)
+ if err != nil {
+ return err
+ }
+ infoReply := &flatrpc.InfoReply{
+ CoverFilter: ret.CovFilter,
+ }
+ if err := flatrpc.Send(conn, infoReply); err != nil {
+ return err
+ }
+ runner.conn = conn
+ runner.machineInfo = ret.MachineInfo
+ runner.canonicalizer = ret.Canonicalizer
+ return nil
+}
+
func (runner *Runner) connectionLoop() error {
var infoc chan []byte
defer func() {