aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/flatrpc
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2024-05-02 08:12:11 +0200
committerDmitry Vyukov <dvyukov@google.com>2024-05-03 14:25:58 +0000
commitc4eb806aa8b50f9baa74dbcd8073af53be120577 (patch)
tree6151d09c1920e1246113c6e731489628448e1ae0 /pkg/flatrpc
parent3e60354bf2a2ad7e7fa81fe8107f3ce24e098287 (diff)
pkg/flatrpc: add connection type
Add server/client connection wrapper that allows sending/receiving flatbuffers RPC messages.
Diffstat (limited to 'pkg/flatrpc')
-rw-r--r--pkg/flatrpc/conn.go184
-rw-r--r--pkg/flatrpc/conn_test.go142
2 files changed, 326 insertions, 0 deletions
diff --git a/pkg/flatrpc/conn.go b/pkg/flatrpc/conn.go
new file mode 100644
index 000000000..ba028fe62
--- /dev/null
+++ b/pkg/flatrpc/conn.go
@@ -0,0 +1,184 @@
+// Copyright 2024 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package flatrpc
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "slices"
+ "sync"
+ "time"
+
+ flatbuffers "github.com/google/flatbuffers/go"
+ "github.com/google/syzkaller/pkg/log"
+ "github.com/google/syzkaller/pkg/stats"
+)
+
+var (
+ statSent = stats.Create("rpc sent", "Outbound RPC traffic",
+ stats.Graph("traffic"), stats.Rate{}, stats.FormatMB)
+ statRecv = stats.Create("rpc recv", "Inbound RPC traffic",
+ stats.Graph("traffic"), stats.Rate{}, stats.FormatMB)
+)
+
+type Serv struct {
+ Addr *net.TCPAddr
+ ln net.Listener
+}
+
+func ListenAndServe(addr string, handler func(*Conn)) (*Serv, error) {
+ ln, err := net.Listen("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+ go func() {
+ for {
+ conn, err := ln.Accept()
+ if err != nil {
+ if errors.Is(err, net.ErrClosed) {
+ break
+ }
+ var netErr *net.OpError
+ if errors.As(err, &netErr) && !netErr.Temporary() {
+ log.Fatalf("flatrpc: failed to accept: %v", err)
+ }
+ log.Logf(0, "flatrpc: failed to accept: %v", err)
+ continue
+ }
+ go func() {
+ c := newConn(conn)
+ defer c.Close()
+ handler(c)
+ }()
+ }
+ }()
+ return &Serv{
+ Addr: ln.Addr().(*net.TCPAddr),
+ ln: ln,
+ }, nil
+}
+
+func (s *Serv) Close() error {
+ return s.ln.Close()
+}
+
+type Conn struct {
+ conn net.Conn
+
+ sendMu sync.Mutex
+ builder *flatbuffers.Builder
+
+ data []byte
+ hasData int
+ lastMsg int
+}
+
+func Dial(addr string, timeScale time.Duration) (*Conn, error) {
+ var conn net.Conn
+ var err error
+ if addr == "stdin" {
+ // This is used by vm/gvisor which passes us a unix socket connection in stdin.
+ conn, err = net.FileConn(os.Stdin)
+ } else {
+ conn, err = net.DialTimeout("tcp", addr, time.Minute*timeScale)
+ }
+ if err != nil {
+ return nil, err
+ }
+ return newConn(conn), nil
+}
+
+func newConn(conn net.Conn) *Conn {
+ return &Conn{
+ conn: conn,
+ builder: flatbuffers.NewBuilder(0),
+ }
+}
+
+func (c *Conn) Close() error {
+ return c.conn.Close()
+}
+
+type sendMsg interface {
+ Pack(*flatbuffers.Builder) flatbuffers.UOffsetT
+}
+
+// Send sends an RPC message.
+// The type T is supposed to be an "object API" type ending with T (e.g. ConnectRequestT).
+// Sending can be done from multiple goroutines concurrently.
+func Send[T sendMsg](c *Conn, msg T) error {
+ c.sendMu.Lock()
+ defer c.sendMu.Unlock()
+ off := msg.Pack(c.builder)
+ c.builder.FinishSizePrefixed(off)
+ data := c.builder.FinishedBytes()
+ _, err := c.conn.Write(data)
+ c.builder.Reset()
+ statSent.Add(len(data))
+ if err != nil {
+ return fmt.Errorf("failed to send %T: %w", msg, err)
+ }
+ return nil
+}
+
+// Recv received an RPC message.
+// The type T is supposed to be a normal flatbuffers type (not ending with T, e.g. ConnectRequest).
+// Receiving should be done from a single goroutine, the received message is valid
+// only until the next Recv call (messages share the same underlying receive buffer).
+func Recv[T any, PT interface {
+ *T
+ flatbuffers.FlatBuffer
+}](c *Conn) (*T, error) {
+ // First, discard the previous message.
+ // For simplicity we copy any data from the next message to the beginning of the buffer.
+ // Theoretically we could something more efficient, e.g. don't copy if we already
+ // have a full next message.
+ if c.hasData > c.lastMsg {
+ copy(c.data, c.data[c.lastMsg:c.hasData])
+ }
+ c.hasData -= c.lastMsg
+ c.lastMsg = 0
+ const (
+ sizePrefixSize = flatbuffers.SizeUint32
+ maxMessageSize = 64 << 20
+ )
+ msg := PT(new(T))
+ // Then, receive at least the size prefix (4 bytes).
+ // And then the full message, if we have not got it yet.
+ if err := c.recv(sizePrefixSize); err != nil {
+ return nil, fmt.Errorf("failed to recv %T: %w", msg, err)
+ }
+ size := int(flatbuffers.GetSizePrefix(c.data, 0))
+ if size > maxMessageSize {
+ return nil, fmt.Errorf("message %T has too large size %v", msg, size)
+ }
+ c.lastMsg = sizePrefixSize + size
+ if err := c.recv(c.lastMsg); err != nil {
+ return nil, fmt.Errorf("failed to recv %T: %w", msg, err)
+ }
+ statRecv.Add(c.lastMsg)
+ data := c.data[sizePrefixSize:c.lastMsg]
+ msg.Init(data, flatbuffers.GetUOffsetT(data))
+ return msg, nil
+}
+
+// recv ensures that we have at least 'size' bytes received in c.data.
+func (c *Conn) recv(size int) error {
+ need := size - c.hasData
+ if need <= 0 {
+ return nil
+ }
+ if grow := size - len(c.data) + c.hasData; grow > 0 {
+ c.data = slices.Grow(c.data, grow)[:len(c.data)+grow]
+ }
+ n, err := io.ReadAtLeast(c.conn, c.data[c.hasData:], need)
+ if err != nil {
+ return err
+ }
+ c.hasData += n
+ return nil
+}
diff --git a/pkg/flatrpc/conn_test.go b/pkg/flatrpc/conn_test.go
new file mode 100644
index 000000000..f857259fc
--- /dev/null
+++ b/pkg/flatrpc/conn_test.go
@@ -0,0 +1,142 @@
+// Copyright 2024 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package flatrpc
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestConn(t *testing.T) {
+ connectReq := &ConnectRequestT{
+ Name: "foo",
+ Arch: "arch",
+ GitRevision: "rev1",
+ SyzRevision: "rev2",
+ }
+ connectReply := &ConnectReplyT{
+ LeakFrames: []string{"foo", "bar"},
+ RaceFrames: []string{"bar", "baz"},
+ Features: FeatureCoverage | FeatureLeak,
+ Files: []string{"file1"},
+ Globs: []string{"glob1"},
+ }
+ executorMsg := &ExecutorMessageT{
+ Msg: &ExecutorMessagesT{
+ Type: ExecutorMessagesExecuting,
+ Value: &ExecutingMessageT{
+ Id: 1,
+ ProcId: 2,
+ Try: 3,
+ },
+ },
+ }
+
+ done := make(chan bool)
+ defer func() {
+ <-done
+ }()
+ serv, err := ListenAndServe(":0", func(c *Conn) {
+ defer close(done)
+ connectReqGot, err := Recv[ConnectRequest](c)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, connectReq, connectReqGot.UnPack())
+
+ if err := Send(c, connectReply); err != nil {
+ t.Fatal(err)
+ }
+
+ for i := 0; i < 10; i++ {
+ got, err := Recv[ExecutorMessage](c)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, executorMsg, got.UnPack())
+ }
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer serv.Close()
+
+ c, err := Dial(serv.Addr.String(), 1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ if err := Send(c, connectReq); err != nil {
+ t.Fatal(err)
+ }
+
+ connectReplyGot, err := Recv[ConnectReply](c)
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, connectReply, connectReplyGot.UnPack())
+
+ for i := 0; i < 10; i++ {
+ if err := Send(c, executorMsg); err != nil {
+ t.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkConn(b *testing.B) {
+ connectReq := &ConnectRequestT{
+ Name: "foo",
+ Arch: "arch",
+ GitRevision: "rev1",
+ SyzRevision: "rev2",
+ }
+ connectReply := &ConnectReplyT{
+ LeakFrames: []string{"foo", "bar"},
+ RaceFrames: []string{"bar", "baz"},
+ Features: FeatureCoverage | FeatureLeak,
+ Files: []string{"file1"},
+ Globs: []string{"glob1"},
+ }
+
+ done := make(chan bool)
+ defer func() {
+ <-done
+ }()
+ serv, err := ListenAndServe(":0", func(c *Conn) {
+ defer close(done)
+ for i := 0; i < b.N; i++ {
+ _, err := Recv[ConnectRequest](c)
+ if err != nil {
+ b.Fatal(err)
+ }
+ if err := Send(c, connectReply); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer serv.Close()
+
+ c, err := Dial(serv.Addr.String(), 1)
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer c.Close()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ if err := Send(c, connectReq); err != nil {
+ b.Fatal(err)
+ }
+ _, err := Recv[ConnectReply](c)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}