aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/flatrpc/conn_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/flatrpc/conn_test.go')
-rw-r--r--pkg/flatrpc/conn_test.go104
1 files changed, 61 insertions, 43 deletions
diff --git a/pkg/flatrpc/conn_test.go b/pkg/flatrpc/conn_test.go
index 132fd1cdd..4b108a5a4 100644
--- a/pkg/flatrpc/conn_test.go
+++ b/pkg/flatrpc/conn_test.go
@@ -4,15 +4,18 @@
package flatrpc
import (
+ "context"
+ "fmt"
"net"
"os"
+ "reflect"
"runtime/debug"
"sync"
"syscall"
"testing"
"time"
- "github.com/google/flatbuffers/go"
+ flatbuffers "github.com/google/flatbuffers/go"
"github.com/stretchr/testify/assert"
)
@@ -40,35 +43,39 @@ func TestConn(t *testing.T) {
},
}
- done := make(chan bool)
- defer func() {
- <-done
- }()
- serv, err := ListenAndServe(":0", func(c *Conn) {
- defer close(done)
- connectReqGot, err := Recv[*ConnectRequestRaw](c)
- if err != nil {
- t.Fatal(err)
- }
- assert.Equal(t, connectReq, connectReqGot)
-
- if err := Send(c, connectReply); err != nil {
- t.Fatal(err)
- }
-
- for i := 0; i < 10; i++ {
- got, err := Recv[*ExecutorMessageRaw](c)
- if err != nil {
- t.Fatal(err)
- }
- assert.Equal(t, executorMsg, got)
- }
- })
+ serv, err := Listen(":0")
if err != nil {
t.Fatal(err)
}
- defer serv.Close()
+ done := make(chan error)
+ go func() {
+ done <- serv.Serve(context.Background(),
+ func(_ context.Context, c *Conn) error {
+ connectReqGot, err := Recv[*ConnectRequestRaw](c)
+ if err != nil {
+ return err
+ }
+ if !reflect.DeepEqual(connectReq, connectReqGot) {
+ return fmt.Errorf("connectReq != connectReqGot")
+ }
+
+ if err := Send(c, connectReply); err != nil {
+ return err
+ }
+
+ for i := 0; i < 10; i++ {
+ got, err := Recv[*ExecutorMessageRaw](c)
+ if err != nil {
+ return nil
+ }
+ if !reflect.DeepEqual(executorMsg, got) {
+ return fmt.Errorf("executorMsg !=got")
+ }
+ }
+ return nil
+ })
+ }()
c := dial(t, serv.Addr.String())
defer c.Close()
@@ -87,6 +94,11 @@ func TestConn(t *testing.T) {
t.Fatal(err)
}
}
+
+ serv.Close()
+ if err := <-done; err != nil {
+ t.Fatal(err)
+ }
}
func BenchmarkConn(b *testing.B) {
@@ -103,26 +115,27 @@ func BenchmarkConn(b *testing.B) {
Files: []string{"file1"},
}
- done := make(chan bool)
- defer func() {
- <-done
- }()
- serv, err := ListenAndServe(":0", func(c *Conn) {
- defer close(done)
- for i := 0; i < b.N; i++ {
- _, err := Recv[*ConnectRequestRaw](c)
- if err != nil {
- b.Fatal(err)
- }
- if err := Send(c, connectReply); err != nil {
- b.Fatal(err)
- }
- }
- })
+ serv, err := Listen(":0")
if err != nil {
b.Fatal(err)
}
- defer serv.Close()
+ done := make(chan error)
+
+ go func() {
+ done <- serv.Serve(context.Background(),
+ func(_ context.Context, c *Conn) error {
+ for i := 0; i < b.N; i++ {
+ _, err := Recv[*ConnectRequestRaw](c)
+ if err != nil {
+ return err
+ }
+ if err := Send(c, connectReply); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ }()
c := dial(b, serv.Addr.String())
defer c.Close()
@@ -138,6 +151,11 @@ func BenchmarkConn(b *testing.B) {
b.Fatal(err)
}
}
+
+ serv.Close()
+ if err := <-done; err != nil {
+ b.Fatal(err)
+ }
}
func dial(t testing.TB, addr string) *Conn {