aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/flatrpc
diff options
context:
space:
mode:
authorAleksandr Nogikh <nogikh@google.com>2025-01-24 17:17:53 +0100
committerAleksandr Nogikh <nogikh@google.com>2025-01-29 10:31:50 +0000
commit94e13671726abbcf766f9b4aacd2ee04de59dcbd (patch)
tree699abaa69f3509857969ca2d7ff3ea001df14c88 /pkg/flatrpc
parent6eea27042142c1c5e810b642deb831a8ed55b3da (diff)
pkg/rpcserver: refactor to remove Fatalf calls
Apply necessary changes to pkg/flatrpc and pkg/manager as well.
Diffstat (limited to 'pkg/flatrpc')
-rw-r--r--pkg/flatrpc/conn.go65
-rw-r--r--pkg/flatrpc/conn_test.go104
2 files changed, 103 insertions, 66 deletions
diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go
index 33eca07e1..47b265493 100644
--- a/pkg/flatrpc/conn.go
+++ b/pkg/flatrpc/conn.go
@@ -4,6 +4,7 @@
package flatrpc
import (
+ "context"
"errors"
"fmt"
"io"
@@ -13,9 +14,10 @@ import (
"sync"
"unsafe"
- "github.com/google/flatbuffers/go"
+ flatbuffers "github.com/google/flatbuffers/go"
"github.com/google/syzkaller/pkg/log"
"github.com/google/syzkaller/pkg/stat"
+ "golang.org/x/sync/errgroup"
)
var (
@@ -30,38 +32,55 @@ type Serv struct {
ln net.Listener
}
-func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) {
+func Listen(addr string) (*Serv, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
- go func() {
- for {
- conn, err := ln.Accept()
- if err != nil {
- if errors.Is(err, net.ErrClosed) {
- break
- }
- var netErr *net.OpError
- if errors.As(err, &netErr) && !netErr.Temporary() {
- log.Fatalf("flatrpc: failed to accept: %v", err)
- }
- log.Logf(0, "flatrpc: failed to accept: %v", err)
- continue
- }
- go func() {
- c := NewConn(conn)
- defer c.Close()
- handler(c)
- }()
- }
- }()
return &Serv{
Addr: ln.Addr().(*net.TCPAddr),
ln: ln,
}, nil
}
+// Serve accepts incoming connections and calls handler for each of them.
+// An error returned from the handler stops the server and aborts the whole processing.
+func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error {
+ eg, ctx := errgroup.WithContext(baseCtx)
+ go func() {
+ // If the context is cancelled, stop the server.
+ <-ctx.Done()
+ s.Close()
+ }()
+ for {
+ conn, err := s.ln.Accept()
+ if err != nil && errors.Is(err, net.ErrClosed) {
+ break
+ }
+ if err != nil {
+ var netErr *net.OpError
+ if errors.As(err, &netErr) && !netErr.Temporary() {
+ return fmt.Errorf("flatrpc: failed to accept: %w", err)
+ }
+ log.Logf(0, "flatrpc: failed to accept: %v", err)
+ continue
+ }
+ eg.Go(func() error {
+ connCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ c := NewConn(conn)
+ // Closing the server does not automatically close all the connections.
+ go func() {
+ <-connCtx.Done()
+ c.Close()
+ }()
+ return handler(connCtx, c)
+ })
+ }
+ return eg.Wait()
+}
+
func (s *Serv) Close() error {
return s.ln.Close()
}
diff --git a/pkg/flatrpc/conn_test.go b/pkg/flatrpc/conn_test.go
index 132fd1cdd..4b108a5a4 100644
--- a/pkg/flatrpc/conn_test.go
+++ b/pkg/flatrpc/conn_test.go
@@ -4,15 +4,18 @@
package flatrpc
import (
+ "context"
+ "fmt"
"net"
"os"
+ "reflect"
"runtime/debug"
"sync"
"syscall"
"testing"
"time"
- "github.com/google/flatbuffers/go"
+ flatbuffers "github.com/google/flatbuffers/go"
"github.com/stretchr/testify/assert"
)
@@ -40,35 +43,39 @@ func TestConn(t *testing.T) {
},
}
- done := make(chan bool)
- defer func() {
- <-done
- }()
- serv, err := ListenAndServe(":0", func(c *Conn) {
- defer close(done)
- connectReqGot, err := Recv[*ConnectRequestRaw](c)
- if err != nil {
- t.Fatal(err)
- }
- assert.Equal(t, connectReq, connectReqGot)
-
- if err := Send(c, connectReply); err != nil {
- t.Fatal(err)
- }
-
- for i := 0; i < 10; i++ {
- got, err := Recv[*ExecutorMessageRaw](c)
- if err != nil {
- t.Fatal(err)
- }
- assert.Equal(t, executorMsg, got)
- }
- })
+ serv, err := Listen(":0")
if err != nil {
t.Fatal(err)
}
- defer serv.Close()
+ done := make(chan error)
+ go func() {
+ done <- serv.Serve(context.Background(),
+ func(_ context.Context, c *Conn) error {
+ connectReqGot, err := Recv[*ConnectRequestRaw](c)
+ if err != nil {
+ return err
+ }
+ if !reflect.DeepEqual(connectReq, connectReqGot) {
+ return fmt.Errorf("connectReq != connectReqGot")
+ }
+
+ if err := Send(c, connectReply); err != nil {
+ return err
+ }
+
+ for i := 0; i < 10; i++ {
+ got, err := Recv[*ExecutorMessageRaw](c)
+ if err != nil {
+ return nil
+ }
+ if !reflect.DeepEqual(executorMsg, got) {
+ return fmt.Errorf("executorMsg !=got")
+ }
+ }
+ return nil
+ })
+ }()
c := dial(t, serv.Addr.String())
defer c.Close()
@@ -87,6 +94,11 @@ func TestConn(t *testing.T) {
t.Fatal(err)
}
}
+
+ serv.Close()
+ if err := <-done; err != nil {
+ t.Fatal(err)
+ }
}
func BenchmarkConn(b *testing.B) {
@@ -103,26 +115,27 @@ func BenchmarkConn(b *testing.B) {
Files: []string{"file1"},
}
- done := make(chan bool)
- defer func() {
- <-done
- }()
- serv, err := ListenAndServe(":0", func(c *Conn) {
- defer close(done)
- for i := 0; i < b.N; i++ {
- _, err := Recv[*ConnectRequestRaw](c)
- if err != nil {
- b.Fatal(err)
- }
- if err := Send(c, connectReply); err != nil {
- b.Fatal(err)
- }
- }
- })
+ serv, err := Listen(":0")
if err != nil {
b.Fatal(err)
}
- defer serv.Close()
+ done := make(chan error)
+
+ go func() {
+ done <- serv.Serve(context.Background(),
+ func(_ context.Context, c *Conn) error {
+ for i := 0; i < b.N; i++ {
+ _, err := Recv[*ConnectRequestRaw](c)
+ if err != nil {
+ return err
+ }
+ if err := Send(c, connectReply); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ }()
c := dial(b, serv.Addr.String())
defer c.Close()
@@ -138,6 +151,11 @@ func BenchmarkConn(b *testing.B) {
b.Fatal(err)
}
}
+
+ serv.Close()
+ if err := <-done; err != nil {
+ b.Fatal(err)
+ }
}
func dial(t testing.TB, addr string) *Conn {