aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2018-06-26 14:12:43 +0200
committerDmitry Vyukov <dvyukov@google.com>2018-06-26 14:12:43 +0200
commit089f11817e3eb5a23bf9fb679dc4e6ad61de48ec (patch)
treeec389c099eae6dc34290e453f7093e3ae35317b7 /pkg
parente726bdf922950225c79fc81b54b73ea8ecda7921 (diff)
syz-fuzzer: fix gvisor testing
Testing code wasn't ready to dial stdin. Make it use the same logic rpc package uses to connecto to host.
Diffstat (limited to 'pkg')
-rw-r--r--pkg/instance/instance.go18
-rw-r--r--pkg/rpctype/rpc.go25
2 files changed, 30 insertions, 13 deletions
diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go
index 2849098ac..f804b2724 100644
--- a/pkg/instance/instance.go
+++ b/pkg/instance/instance.go
@@ -8,6 +8,7 @@ package instance
import (
"bytes"
"fmt"
+ "io/ioutil"
"net"
"os"
"path/filepath"
@@ -214,10 +215,21 @@ func (inst *inst) testInstance() error {
acceptErr := make(chan error, 1)
go func() {
conn, err := ln.Accept()
- if err == nil {
- conn.Close()
+ if err != nil {
+ acceptErr <- err
+ return
}
- acceptErr <- err
+ defer conn.Close()
+ data, err := ioutil.ReadAll(conn)
+ if err != nil {
+ acceptErr <- err
+ return
+ }
+ if string(data) != "HELLO" {
+ acceptErr <- fmt.Errorf("received bad handshake from VM: %q", string(data))
+ return
+ }
+ acceptErr <- nil
}()
fwdAddr, err := inst.vm.Forward(ln.Addr().(*net.TCPAddr).Port)
if err != nil {
diff --git a/pkg/rpctype/rpc.go b/pkg/rpctype/rpc.go
index 6d9048077..3838855a4 100644
--- a/pkg/rpctype/rpc.go
+++ b/pkg/rpctype/rpc.go
@@ -56,20 +56,25 @@ type RPCClient struct {
c *rpc.Client
}
-func NewRPCClient(addr string) (*RPCClient, error) {
+func Dial(addr string) (net.Conn, error) {
var conn net.Conn
var err error
if addr == "stdin" {
// This is used by vm/gvisor which passes us a unix socket connection in stdin.
- if conn, err = net.FileConn(os.Stdin); err != nil {
- return nil, err
- }
- } else {
- if conn, err = net.DialTimeout("tcp", addr, 60*time.Second); err != nil {
- return nil, err
- }
- conn.(*net.TCPConn).SetKeepAlive(true)
- conn.(*net.TCPConn).SetKeepAlivePeriod(time.Minute)
+ return net.FileConn(os.Stdin)
+ }
+ if conn, err = net.DialTimeout("tcp", addr, 60*time.Second); err != nil {
+ return nil, err
+ }
+ conn.(*net.TCPConn).SetKeepAlive(true)
+ conn.(*net.TCPConn).SetKeepAlivePeriod(time.Minute)
+ return conn, nil
+}
+
+func NewRPCClient(addr string) (*RPCClient, error) {
+ conn, err := Dial(addr)
+ if err != nil {
+ return nil, err
}
cli := &RPCClient{
conn: conn,