diff options
Diffstat (limited to 'pkg/flatrpc/conn.go')
| -rw-r--r-- | pkg/flatrpc/conn.go | 65 |
1 files changed, 42 insertions, 23 deletions
diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go index 33eca07e1..47b265493 100644 --- a/pkg/flatrpc/conn.go +++ b/pkg/flatrpc/conn.go @@ -4,6 +4,7 @@ package flatrpc import ( + "context" "errors" "fmt" "io" @@ -13,9 +14,10 @@ import ( "sync" "unsafe" - "github.com/google/flatbuffers/go" + flatbuffers "github.com/google/flatbuffers/go" "github.com/google/syzkaller/pkg/log" "github.com/google/syzkaller/pkg/stat" + "golang.org/x/sync/errgroup" ) var ( @@ -30,38 +32,55 @@ type Serv struct { ln net.Listener } -func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) { +func Listen(addr string) (*Serv, error) { ln, err := net.Listen("tcp", addr) if err != nil { return nil, err } - go func() { - for { - conn, err := ln.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - break - } - var netErr *net.OpError - if errors.As(err, &netErr) && !netErr.Temporary() { - log.Fatalf("flatrpc: failed to accept: %v", err) - } - log.Logf(0, "flatrpc: failed to accept: %v", err) - continue - } - go func() { - c := NewConn(conn) - defer c.Close() - handler(c) - }() - } - }() return &Serv{ Addr: ln.Addr().(*net.TCPAddr), ln: ln, }, nil } +// Serve accepts incoming connections and calls handler for each of them. +// An error returned from the handler stops the server and aborts the whole processing. +func (s *Serv) Serve(baseCtx context.Context, handler func(context.Context, *Conn) error) error { + eg, ctx := errgroup.WithContext(baseCtx) + go func() { + // If the context is cancelled, stop the server. + <-ctx.Done() + s.Close() + }() + for { + conn, err := s.ln.Accept() + if err != nil && errors.Is(err, net.ErrClosed) { + break + } + if err != nil { + var netErr *net.OpError + if errors.As(err, &netErr) && !netErr.Temporary() { + return fmt.Errorf("flatrpc: failed to accept: %w", err) + } + log.Logf(0, "flatrpc: failed to accept: %v", err) + continue + } + eg.Go(func() error { + connCtx, cancel := context.WithCancel(ctx) + defer cancel() + + c := NewConn(conn) + // Closing the server does not automatically close all the connections. + go func() { + <-connCtx.Done() + c.Close() + }() + return handler(connCtx, c) + }) + } + return eg.Wait() +} + func (s *Serv) Close() error { return s.ln.Close() } |
