diff options
| -rw-r--r-- | vm/proxyapp/init.go | 6 | ||||
| -rw-r--r-- | vm/proxyapp/proxyappclient.go | 44 | ||||
| -rw-r--r-- | vm/proxyapp/proxyappclient_tcp_test.go | 103 |
3 files changed, 145 insertions, 8 deletions
diff --git a/vm/proxyapp/init.go b/vm/proxyapp/init.go index afd0db00d..8153cf031 100644 --- a/vm/proxyapp/init.go +++ b/vm/proxyapp/init.go @@ -62,6 +62,12 @@ type Config struct { // rpc_server_uri is used to specify plugin endpoint address. // if not specified, we'll connect to the plugin by std[in, out, err]. RPCServerURI string `json:"rpc_server_uri"` + // security can be one of "none", "tls" (for server TLS) and "mtls" for mutal + // TLS. + Security string `json:"security"` + // server_tls_cert points a TLS certificate used to authenticate the server. + // If not provided, the default system certificate pool will be used. + ServerTLSCert string `json:"server_tls_cert"` // config is an optional remote plugin config ProxyAppConfig json.RawMessage `json:"config"` } diff --git a/vm/proxyapp/proxyappclient.go b/vm/proxyapp/proxyappclient.go index 1104d1901..1e6d679ac 100644 --- a/vm/proxyapp/proxyappclient.go +++ b/vm/proxyapp/proxyappclient.go @@ -7,10 +7,14 @@ package proxyapp import ( "context" + "crypto/tls" + "crypto/x509" "fmt" "io" + "net" "net/rpc" "net/rpc/jsonrpc" + "os" "sync" "time" @@ -100,7 +104,7 @@ func (p *pool) init(params *proxyAppParams, cfg *Config) error { if useTCPRPC { p.proxy.onLostConnection = make(chan bool, 1) - p.proxy.Client, err = initNetworkRPCClient(cfg.RPCServerURI) + p.proxy.Client, err = initNetworkRPCClient(cfg) if err != nil { p.closeProxy() return fmt.Errorf("failed to connect ProxyApp pipes: %w", err) @@ -193,8 +197,42 @@ func initPipedRPCClient(cmd subProcessCmd) (*rpc.Client, []io.Closer, error) { nil } -func initNetworkRPCClient(uri string) (*rpc.Client, error) { - return jsonrpc.Dial("tcp", uri) +func initNetworkRPCClient(cfg *Config) (*rpc.Client, error) { + var conn io.ReadWriteCloser + + switch cfg.Security { + case "none": + var err error + conn, err = net.Dial("tcp", cfg.RPCServerURI) + if err != nil { + return nil, fmt.Errorf("dial: %v", err) + } + case "tls": + var certPool *x509.CertPool + + if cfg.ServerTLSCert != "" { + certPool = x509.NewCertPool() + b, err := os.ReadFile(cfg.ServerTLSCert) + if err != nil { + return nil, fmt.Errorf("read server certificate: %v", err) + } + if !certPool.AppendCertsFromPEM(b) { + return nil, fmt.Errorf("append server certificate to empty pool: %v", err) + } + } + + var err error + conn, err = tls.Dial("tcp", cfg.RPCServerURI, &tls.Config{RootCAs: certPool}) + if err != nil { + return nil, fmt.Errorf("dial with tls: %v", err) + } + case "mtls": + return nil, fmt.Errorf("mutual TLS not implemented") + default: + return nil, fmt.Errorf("security value is %q, must be 'none', 'tls', or 'mtls'", cfg.Security) + } + + return jsonrpc.NewClient(conn), nil } func runProxyApp(params *proxyAppParams, cmd string, initRPClient bool) (*ProxyApp, error) { diff --git a/vm/proxyapp/proxyappclient_tcp_test.go b/vm/proxyapp/proxyappclient_tcp_test.go index 6ce5c62ee..6a0c0ab33 100644 --- a/vm/proxyapp/proxyappclient_tcp_test.go +++ b/vm/proxyapp/proxyappclient_tcp_test.go @@ -4,12 +4,20 @@ package proxyapp import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "math/big" "net" "net/rpc" "net/rpc/jsonrpc" "net/url" + "os" "sync" "testing" + "time" "github.com/google/syzkaller/vm/proxyapp/proxyrpc" "github.com/google/syzkaller/vm/vmimpl" @@ -22,6 +30,21 @@ func testTCPEnv(port string) *vmimpl.Env { Config: []byte(` { "rpc_server_uri": "localhost:` + port + `", + "security": "none", + "config": { + "internal_values": 123 + } + } +`)} +} + +func testTCPEnvTLS(port, certPath string) *vmimpl.Env { + return &vmimpl.Env{ + Config: []byte(` +{ + "rpc_server_uri": "localhost:` + port + `", + "security": "tls", + "server_tls_cert": "` + certPath + `", "config": { "internal_values": 123 } @@ -34,6 +57,11 @@ func proxyAppServerTCPFixture(t *testing.T) (*mockProxyAppInterface, string, *pr return initProxyAppServerFixture(mProxyAppServer), port, makeTestParams() } +func proxyAppServerTCPFixtureTLS(t *testing.T, cert tls.Certificate) (*mockProxyAppInterface, string, *proxyAppParams) { + mProxyAppServer, port, _ := makeMockProxyAppServerTLS(t, cert) + return initProxyAppServerFixture(mProxyAppServer), port, makeTestParams() +} + func TestCtor_TCP_Ok(t *testing.T) { _, port, params := proxyAppServerTCPFixture(t) p, err := ctor(params, testTCPEnv(port)) @@ -42,6 +70,57 @@ func TestCtor_TCP_Ok(t *testing.T) { assert.Equal(t, 2, p.Count()) } +func TestCtor_TCP_Ok_TLS(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate private key: %v", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-10 * time.Second), + NotAfter: time.Now().AddDate(10, 0, 0), + KeyUsage: x509.KeyUsageCRLSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IsCA: false, + MaxPathLenZero: true, + DNSNames: []string{"localhost"}, + } + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) + if err != nil { + t.Fatalf("generate certificate: %v", err) + } + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}) + + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("load keypair to cert: %v", err) + } + + _, port, params := proxyAppServerTCPFixtureTLS(t, cert) + + // Write the certificate to a temp file where the client can use it. + certFile, err := os.CreateTemp("", "test-cert") + if err != nil { + t.Fatalf("temp file for certificate: %v", err) + } + defer certFile.Close() + defer os.Remove(certFile.Name()) + + if _, err := certFile.Write(certPEM); err != nil { + t.Fatalf("write cert: %v", err) + } + if err := certFile.Close(); err != nil { + t.Fatalf("close cert: %v", err) + } + + p, err := ctor(params, testTCPEnvTLS(port, certFile.Name())) + + assert.Nil(t, err) + assert.Equal(t, 2, p.Count()) +} + func TestCtor_TCP_WrongPort(t *testing.T) { p, err := ctor(makeTestParams(), testTCPEnv("5")) @@ -113,15 +192,11 @@ func TestCtor_TCP_Reconnect_PoolChanged(t *testing.T) { } } -func makeMockProxyAppServer(t *testing.T) (*mockProxyAppInterface, string, func()) { +func makeMockProxyAppServerWithListener(t *testing.T, l net.Listener) (*mockProxyAppInterface, string, func()) { handler := makeMockProxyAppInterface(t) server := rpc.NewServer() server.RegisterName("ProxyVM", struct{ proxyrpc.ProxyAppInterface }{handler}) - l, e := net.Listen("tcp", ":0") - if e != nil { - t.Fatalf("listen error: %v", e) - } dest, err := url.Parse("http://" + l.Addr().String()) if err != nil { t.Fatalf("failed to get server endpoint addr: %v", err) @@ -152,3 +227,21 @@ func makeMockProxyAppServer(t *testing.T) (*mockProxyAppInterface, string, func( } } } + +func makeMockProxyAppServer(t *testing.T) (*mockProxyAppInterface, string, func()) { + l, e := net.Listen("tcp", ":0") + if e != nil { + t.Fatalf("listen error: %v", e) + } + + return makeMockProxyAppServerWithListener(t, l) +} + +func makeMockProxyAppServerTLS(t *testing.T, cert tls.Certificate) (*mockProxyAppInterface, string, func()) { + l, e := tls.Listen("tcp", ":0", &tls.Config{Certificates: []tls.Certificate{cert}}) + if e != nil { + t.Fatalf("listen error: %v", e) + } + + return makeMockProxyAppServerWithListener(t, l) +} |
