aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/flatrpc/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/flatrpc/conn.go')
-rw-r--r--pkg/flatrpc/conn.go65
1 files changed, 42 insertions, 23 deletions
diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go
index 33eca07e1..47b265493 100644
--- a/pkg/flatrpc/conn.go
+++ b/pkg/flatrpc/conn.go
@@ -4,6 +4,7 @@
package flatrpc
import (
+ "context"
"errors"
"fmt"
"io"
@@ -13,9 +14,10 @@ import (
"sync"
"unsafe"
- "github.com/google/flatbuffers/go"
+ flatbuffers "github.com/google/flatbuffers/go"
"github.com/google/syzkaller/pkg/log"
"github.com/google/syzkaller/pkg/stat"
+ "golang.org/x/sync/errgroup"
)
var (
@@ -30,38 +32,55 @@ type Serv struct {
ln net.Listener
}
-func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) {
+func Listen(addr string) (*Serv, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
- go func() {
- for {
- conn, err := ln.Accept()
- if err != nil {
- if errors.Is(err, net.ErrClosed) {
- break
- }
- var netErr *net.OpError
- if errors.As(err, &netErr) && !netErr.Temporary() {
- log.Fatalf("flatrpc: failed to accept: %v", err)
- }
- log.Logf(0, "flatrpc: failed to accept: %v", err)
- continue
- }
- go func() {
- c := NewConn(conn)
- defer c.Close()
- handler(c)
- }()
- }
- }()
return &Serv{
Addr: ln.Addr().(*net.TCPAddr),
ln: ln,
}, nil
}
+// Serve accepts incoming connections and calls handler for each of them.
+// An error returned from the handler stops the server and aborts the whole processing.
+func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error {
+ eg, ctx := errgroup.WithContext(baseCtx)
+ go func() {
+ // If the context is cancelled, stop the server.
+ <-ctx.Done()
+ s.Close()
+ }()
+ for {
+ conn, err := s.ln.Accept()
+ if err != nil && errors.Is(err, net.ErrClosed) {
+ break
+ }
+ if err != nil {
+ var netErr *net.OpError
+ if errors.As(err, &netErr) && !netErr.Temporary() {
+ return fmt.Errorf("flatrpc: failed to accept: %w", err)
+ }
+ log.Logf(0, "flatrpc: failed to accept: %v", err)
+ continue
+ }
+ eg.Go(func() error {
+ connCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ c := NewConn(conn)
+ // Closing the server does not automatically close all the connections.
+ go func() {
+ <-connCtx.Done()
+ c.Close()
+ }()
+ return handler(connCtx, c)
+ })
+ }
+ return eg.Wait()
+}
+
func (s *Serv) Close() error {
return s.ln.Close()
}