From 94e13671726abbcf766f9b4aacd2ee04de59dcbd Mon Sep 17 00:00:00 2001 From: Aleksandr Nogikh Date: Fri, 24 Jan 2025 17:17:53 +0100 Subject: pkg/rpcserver: refactor to remove Fatalf calls Apply necessary changes to pkg/flatrpc and pkg/manager as well. --- pkg/flatrpc/conn.go | 65 ++++++++++++++++++----------- pkg/flatrpc/conn_test.go | 104 +++++++++++++++++++++++++++-------------------- 2 files changed, 103 insertions(+), 66 deletions(-) (limited to 'pkg/flatrpc') diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go index 33eca07e1..47b265493 100644 --- a/pkg/flatrpc/conn.go +++ b/pkg/flatrpc/conn.go @@ -4,6 +4,7 @@ package flatrpc import ( + "context" "errors" "fmt" "io" @@ -13,9 +14,10 @@ import ( "sync" "unsafe" - "github.com/google/flatbuffers/go" + flatbuffers "github.com/google/flatbuffers/go" "github.com/google/syzkaller/pkg/log" "github.com/google/syzkaller/pkg/stat" + "golang.org/x/sync/errgroup" ) var ( @@ -30,38 +32,55 @@ type Serv struct { ln net.Listener } -func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) { +func Listen(addr string) (*Serv, error) { ln, err := net.Listen("tcp", addr) if err != nil { return nil, err } - go func() { - for { - conn, err := ln.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - break - } - var netErr *net.OpError - if errors.As(err, &netErr) && !netErr.Temporary() { - log.Fatalf("flatrpc: failed to accept: %v", err) - } - log.Logf(0, "flatrpc: failed to accept: %v", err) - continue - } - go func() { - c := NewConn(conn) - defer c.Close() - handler(c) - }() - } - }() return &Serv{ Addr: ln.Addr().(*net.TCPAddr), ln: ln, }, nil } +// Serve accepts incoming connections and calls handler for each of them. +// An error returned from the handler stops the server and aborts the whole processing. +func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error { + eg, ctx := errgroup.WithContext(baseCtx) + go func() { + // If the context is cancelled, stop the server. + <-ctx.Done() + s.Close() + }() + for { + conn, err := s.ln.Accept() + if err != nil && errors.Is(err, net.ErrClosed) { + break + } + if err != nil { + var netErr *net.OpError + if errors.As(err, &netErr) && !netErr.Temporary() { + return fmt.Errorf("flatrpc: failed to accept: %w", err) + } + log.Logf(0, "flatrpc: failed to accept: %v", err) + continue + } + eg.Go(func() error { + connCtx, cancel := context.WithCancel(ctx) + defer cancel() + + c := NewConn(conn) + // Closing the server does not automatically close all the connections. + go func() { + <-connCtx.Done() + c.Close() + }() + return handler(connCtx, c) + }) + } + return eg.Wait() +} + func (s *Serv) Close() error { return s.ln.Close() } diff --git a/pkg/flatrpc/conn_test.go b/pkg/flatrpc/conn_test.go index 132fd1cdd..4b108a5a4 100644 --- a/pkg/flatrpc/conn_test.go +++ b/pkg/flatrpc/conn_test.go @@ -4,15 +4,18 @@ package flatrpc import ( + "context" + "fmt" "net" "os" + "reflect" "runtime/debug" "sync" "syscall" "testing" "time" - "github.com/google/flatbuffers/go" + flatbuffers "github.com/google/flatbuffers/go" "github.com/stretchr/testify/assert" ) @@ -40,35 +43,39 @@ func TestConn(t *testing.T) { }, } - done := make(chan bool) - defer func() { - <-done - }() - serv, err := ListenAndServe(":0", func(c *Conn) { - defer close(done) - connectReqGot, err := Recv[*ConnectRequestRaw](c) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, connectReq, connectReqGot) - - if err := Send(c, connectReply); err != nil { - t.Fatal(err) - } - - for i := 0; i < 10; i++ { - got, err := Recv[*ExecutorMessageRaw](c) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, executorMsg, got) - } - }) + serv, err := Listen(":0") if err != nil { t.Fatal(err) } - defer serv.Close() + done := make(chan error) + go func() { + done <- serv.Serve(context.Background(), + func(_ context.Context, c *Conn) error { + connectReqGot, err := Recv[*ConnectRequestRaw](c) + if err != nil { + return err + } + if !reflect.DeepEqual(connectReq, connectReqGot) { + return fmt.Errorf("connectReq != connectReqGot") + } + + if err := Send(c, connectReply); err != nil { + return err + } + + for i := 0; i < 10; i++ { + got, err := Recv[*ExecutorMessageRaw](c) + if err != nil { + return nil + } + if !reflect.DeepEqual(executorMsg, got) { + return fmt.Errorf("executorMsg !=got") + } + } + return nil + }) + }() c := dial(t, serv.Addr.String()) defer c.Close() @@ -87,6 +94,11 @@ func TestConn(t *testing.T) { t.Fatal(err) } } + + serv.Close() + if err := <-done; err != nil { + t.Fatal(err) + } } func BenchmarkConn(b *testing.B) { @@ -103,26 +115,27 @@ func BenchmarkConn(b *testing.B) { Files: []string{"file1"}, } - done := make(chan bool) - defer func() { - <-done - }() - serv, err := ListenAndServe(":0", func(c *Conn) { - defer close(done) - for i := 0; i < b.N; i++ { - _, err := Recv[*ConnectRequestRaw](c) - if err != nil { - b.Fatal(err) - } - if err := Send(c, connectReply); err != nil { - b.Fatal(err) - } - } - }) + serv, err := Listen(":0") if err != nil { b.Fatal(err) } - defer serv.Close() + done := make(chan error) + + go func() { + done <- serv.Serve(context.Background(), + func(_ context.Context, c *Conn) error { + for i := 0; i < b.N; i++ { + _, err := Recv[*ConnectRequestRaw](c) + if err != nil { + return err + } + if err := Send(c, connectReply); err != nil { + return err + } + } + return nil + }) + }() c := dial(b, serv.Addr.String()) defer c.Close() @@ -138,6 +151,11 @@ func BenchmarkConn(b *testing.B) { b.Fatal(err) } } + + serv.Close() + if err := <-done; err != nil { + b.Fatal(err) + } } func dial(t testing.TB, addr string) *Conn { -- cgit mrf-deployment