diff options
Diffstat (limited to 'pkg/flatrpc/conn_test.go')
| -rw-r--r-- | pkg/flatrpc/conn_test.go | 104 |
1 files changed, 61 insertions, 43 deletions
diff --git a/pkg/flatrpc/conn_test.go b/pkg/flatrpc/conn_test.go index 132fd1cdd..4b108a5a4 100644 --- a/pkg/flatrpc/conn_test.go +++ b/pkg/flatrpc/conn_test.go @@ -4,15 +4,18 @@ package flatrpc import ( + "context" + "fmt" "net" "os" + "reflect" "runtime/debug" "sync" "syscall" "testing" "time" - "github.com/google/flatbuffers/go" + flatbuffers "github.com/google/flatbuffers/go" "github.com/stretchr/testify/assert" ) @@ -40,35 +43,39 @@ func TestConn(t *testing.T) { }, } - done := make(chan bool) - defer func() { - <-done - }() - serv, err := ListenAndServe(":0", func(c *Conn) { - defer close(done) - connectReqGot, err := Recv[*ConnectRequestRaw](c) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, connectReq, connectReqGot) - - if err := Send(c, connectReply); err != nil { - t.Fatal(err) - } - - for i := 0; i < 10; i++ { - got, err := Recv[*ExecutorMessageRaw](c) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, executorMsg, got) - } - }) + serv, err := Listen(":0") if err != nil { t.Fatal(err) } - defer serv.Close() + done := make(chan error) + go func() { + done <- serv.Serve(context.Background(), + func(_ context.Context, c *Conn) error { + connectReqGot, err := Recv[*ConnectRequestRaw](c) + if err != nil { + return err + } + if !reflect.DeepEqual(connectReq, connectReqGot) { + return fmt.Errorf("connectReq != connectReqGot") + } + + if err := Send(c, connectReply); err != nil { + return err + } + + for i := 0; i < 10; i++ { + got, err := Recv[*ExecutorMessageRaw](c) + if err != nil { + return nil + } + if !reflect.DeepEqual(executorMsg, got) { + return fmt.Errorf("executorMsg !=got") + } + } + return nil + }) + }() c := dial(t, serv.Addr.String()) defer c.Close() @@ -87,6 +94,11 @@ func TestConn(t *testing.T) { t.Fatal(err) } } + + serv.Close() + if err := <-done; err != nil { + t.Fatal(err) + } } func BenchmarkConn(b *testing.B) { @@ -103,26 +115,27 @@ func BenchmarkConn(b *testing.B) { Files: []string{"file1"}, } - done := make(chan bool) - defer func() { - <-done - }() - serv, err := ListenAndServe(":0", func(c *Conn) { - defer close(done) - for i := 0; i < b.N; i++ { - _, err := Recv[*ConnectRequestRaw](c) - if err != nil { - b.Fatal(err) - } - if err := Send(c, connectReply); err != nil { - b.Fatal(err) - } - } - }) + serv, err := Listen(":0") if err != nil { b.Fatal(err) } - defer serv.Close() + done := make(chan error) + + go func() { + done <- serv.Serve(context.Background(), + func(_ context.Context, c *Conn) error { + for i := 0; i < b.N; i++ { + _, err := Recv[*ConnectRequestRaw](c) + if err != nil { + return err + } + if err := Send(c, connectReply); err != nil { + return err + } + } + return nil + }) + }() c := dial(b, serv.Addr.String()) defer c.Close() @@ -138,6 +151,11 @@ func BenchmarkConn(b *testing.B) { b.Fatal(err) } } + + serv.Close() + if err := <-done; err != nil { + b.Fatal(err) + } } func dial(t testing.TB, addr string) *Conn { |
