diff options
Diffstat (limited to 'pkg/rpcserver/rpcserver_test.go')
| -rw-r--r-- | pkg/rpcserver/rpcserver_test.go | 32 |
1 files changed, 24 insertions, 8 deletions
diff --git a/pkg/rpcserver/rpcserver_test.go b/pkg/rpcserver/rpcserver_test.go index 2da916286..429b275ac 100644 --- a/pkg/rpcserver/rpcserver_test.go +++ b/pkg/rpcserver/rpcserver_test.go @@ -6,6 +6,7 @@ package rpcserver import ( "context" "net" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -18,6 +19,7 @@ import ( "github.com/google/syzkaller/pkg/vminfo" "github.com/google/syzkaller/prog" "github.com/google/syzkaller/sys/targets" + "golang.org/x/sync/errgroup" ) func getTestDefaultCfg() mgrconfig.Config { @@ -176,12 +178,14 @@ func TestHandleConn(t *testing.T) { defaultCfg := getTestDefaultCfg() tests := []struct { - name string - modifyCfg func() *mgrconfig.Config - req *flatrpc.ConnectRequest + name string + wantErrMsg string + modifyCfg func() *mgrconfig.Config + req *flatrpc.ConnectRequest }{ { - name: "error, cfg.VMLess = false - unknown VM tries to connect", + name: "error, cfg.VMLess = false - unknown VM tries to connect", + wantErrMsg: "tries to connect", modifyCfg: func() *mgrconfig.Config { return &defaultCfg }, @@ -214,10 +218,22 @@ func TestHandleConn(t *testing.T) { injectExec := make(chan bool) serv.CreateInstance(1, injectExec, nil) - - go flatrpc.Send(clientConn, tt.req) - err = serv.handleConn(context.Background(), serverConn) - if err != nil { + g := errgroup.Group{} + g.Go(func() error { + hello, err := flatrpc.Recv[*flatrpc.ConnectHelloRaw](clientConn) + if err != nil { + return err + } + tt.req.Cookie = authHash(hello.Cookie) + flatrpc.Send(clientConn, tt.req) + return nil + }) + if err := serv.handleConn(context.Background(), serverConn); err != nil { + if !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Fatal(err) + } + } + if err := g.Wait(); err != nil { t.Fatal(err) } }) |
