aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorDmitry Vyukov <dvyukov@google.com>2017-06-01 18:54:42 +0200
committerDmitry Vyukov <dvyukov@google.com>2017-06-03 10:41:09 +0200
commitea2295f3e29c59b4493e98aaafc28f9083d5e570 (patch)
treea9c3b1f3c2cf40a9a6aa98afd4e5ad75442c249c /pkg
parent23b94422d32946ab68a4f1423274bf4daa33cef9 (diff)
pkg/db: move from db
Diffstat (limited to 'pkg')
-rw-r--r--pkg/db/db.go248
-rw-r--r--pkg/db/db_test.go151
2 files changed, 399 insertions, 0 deletions
diff --git a/pkg/db/db.go b/pkg/db/db.go
new file mode 100644
index 000000000..0277cb3d1
--- /dev/null
+++ b/pkg/db/db.go
@@ -0,0 +1,248 @@
+// Copyright 2017 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+// Package db implements a simple key-value database.
+// The database is cached in memory and mirrored on disk.
+// It is used to store corpus in syz-manager and syz-hub.
+// The database strives to minimize number of disk accesses
+// as they can be slow in virtualized environments (GCE).
+package db
+
+import (
+ "bufio"
+ "bytes"
+ "compress/flate"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+
+ . "github.com/google/syzkaller/pkg/log"
+)
+
+type DB struct {
+ Records map[string]Record // in-memory cache, must not be modified directly
+
+ filename string
+ uncompacted int // number of records in the file
+ pending *bytes.Buffer // pending writes to the file
+}
+
+type Record struct {
+ Val []byte
+ Seq uint64
+}
+
+func Open(filename string) (*DB, error) {
+ db := &DB{
+ filename: filename,
+ }
+ f, err := os.OpenFile(db.filename, os.O_RDONLY|os.O_CREATE, 0640)
+ if err != nil {
+ return nil, err
+ }
+ db.Records, db.uncompacted = deserializeDB(bufio.NewReader(f))
+ f.Close()
+ if len(db.Records) == 0 || db.uncompacted/10*9 > len(db.Records) {
+ db.compact()
+ }
+ return db, nil
+}
+
+func (db *DB) Save(key string, val []byte, seq uint64) {
+ if seq == seqDeleted {
+ panic("reserved seq")
+ }
+ if rec, ok := db.Records[key]; ok && seq == rec.Seq && bytes.Equal(val, rec.Val) {
+ return
+ }
+ db.Records[key] = Record{val, seq}
+ db.serialize(key, val, seq)
+ db.uncompacted++
+}
+
+func (db *DB) Delete(key string) {
+ if _, ok := db.Records[key]; !ok {
+ return
+ }
+ delete(db.Records, key)
+ db.serialize(key, nil, seqDeleted)
+ db.uncompacted++
+}
+
+func (db *DB) Flush() error {
+ if db.uncompacted/10*9 > len(db.Records) {
+ db.compact()
+ return nil
+ }
+ if db.pending == nil {
+ return nil
+ }
+ f, err := os.OpenFile(db.filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0640)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ if _, err := f.Write(db.pending.Bytes()); err != nil {
+ return err
+ }
+ db.pending = nil
+ return nil
+}
+
+func (db *DB) compact() error {
+ buf := new(bytes.Buffer)
+ serializeHeader(buf)
+ for key, rec := range db.Records {
+ serializeRecord(buf, key, rec.Val, rec.Seq)
+ }
+ f, err := os.Create(db.filename + ".tmp")
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ if _, err := f.Write(buf.Bytes()); err != nil {
+ return err
+ }
+ f.Close()
+ if err := os.Rename(f.Name(), db.filename); err != nil {
+ return err
+ }
+ db.uncompacted = len(db.Records)
+ db.pending = nil
+ return nil
+}
+
+func (db *DB) serialize(key string, val []byte, seq uint64) {
+ if db.pending == nil {
+ db.pending = new(bytes.Buffer)
+ }
+ serializeRecord(db.pending, key, val, seq)
+}
+
+const (
+ dbMagic = uint32(0xbaddb)
+ recMagic = uint32(0xfee1bad)
+ curVersion = uint32(1)
+ seqDeleted = ^uint64(0)
+)
+
+func serializeHeader(w *bytes.Buffer) {
+ binary.Write(w, binary.LittleEndian, dbMagic)
+ binary.Write(w, binary.LittleEndian, curVersion)
+}
+
+func serializeRecord(w *bytes.Buffer, key string, val []byte, seq uint64) {
+ binary.Write(w, binary.LittleEndian, recMagic)
+ binary.Write(w, binary.LittleEndian, uint32(len(key)))
+ w.WriteString(key)
+ binary.Write(w, binary.LittleEndian, seq)
+ if seq == seqDeleted {
+ if len(val) != 0 {
+ panic("deleting record with value")
+ }
+ return
+ }
+ if len(val) == 0 {
+ binary.Write(w, binary.LittleEndian, uint32(len(val)))
+ } else {
+ lenPos := len(w.Bytes())
+ binary.Write(w, binary.LittleEndian, uint32(0))
+ startPos := len(w.Bytes())
+ fw, err := flate.NewWriter(w, flate.BestCompression)
+ if err != nil {
+ panic(err)
+ }
+ if _, err := fw.Write(val); err != nil {
+ panic(err)
+ }
+ fw.Flush()
+ fw.Close()
+ binary.Write(bytes.NewBuffer(w.Bytes()[lenPos:lenPos:lenPos+8]), binary.LittleEndian, uint32(len(w.Bytes())-startPos))
+ }
+}
+
+func deserializeDB(r *bufio.Reader) (records map[string]Record, uncompacted int) {
+ records = make(map[string]Record)
+ ver, err := deserializeHeader(r)
+ if err != nil {
+ Logf(0, "failed to deserialize database header: %v", err)
+ return
+ }
+ _ = ver
+ for {
+ key, val, seq, err := deserializeRecord(r)
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ Logf(0, "failed to deserialize database record: %v", err)
+ return
+ }
+ uncompacted++
+ if seq == seqDeleted {
+ delete(records, key)
+ } else {
+ records[key] = Record{val, seq}
+ }
+ }
+}
+
+func deserializeHeader(r *bufio.Reader) (uint32, error) {
+ var magic, ver uint32
+ if err := binary.Read(r, binary.LittleEndian, &magic); err != nil {
+ if err == io.EOF {
+ return curVersion, nil
+ }
+ return 0, err
+ }
+ if magic != dbMagic {
+ return 0, fmt.Errorf("bad db header: 0x%x", magic)
+ }
+ if err := binary.Read(r, binary.LittleEndian, &ver); err != nil {
+ return 0, err
+ }
+ if ver == 0 || ver > curVersion {
+ return 0, fmt.Errorf("bad db version: %v", ver)
+ }
+ return ver, nil
+}
+
+func deserializeRecord(r *bufio.Reader) (key string, val []byte, seq uint64, err error) {
+ var magic uint32
+ if err = binary.Read(r, binary.LittleEndian, &magic); err != nil {
+ return
+ }
+ if magic != recMagic {
+ err = fmt.Errorf("bad record header: 0x%x", magic)
+ return
+ }
+ var keyLen uint32
+ if err = binary.Read(r, binary.LittleEndian, &keyLen); err != nil {
+ return
+ }
+ keyBuf := make([]byte, keyLen)
+ if _, err = io.ReadFull(r, keyBuf); err != nil {
+ return
+ }
+ key = string(keyBuf)
+ if err = binary.Read(r, binary.LittleEndian, &seq); err != nil {
+ return
+ }
+ if seq == seqDeleted {
+ return
+ }
+ var valLen uint32
+ if err = binary.Read(r, binary.LittleEndian, &valLen); err != nil {
+ return
+ }
+ if valLen != 0 {
+ fr := flate.NewReader(&io.LimitedReader{r, int64(valLen)})
+ if val, err = ioutil.ReadAll(fr); err != nil {
+ return
+ }
+ fr.Close()
+ }
+ return
+}
diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go
new file mode 100644
index 000000000..ccddb806a
--- /dev/null
+++ b/pkg/db/db_test.go
@@ -0,0 +1,151 @@
+// Copyright 2017 syzkaller project authors. All rights reserved.
+// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
+
+package db
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "math/rand"
+ "os"
+ "testing"
+)
+
+func TestBasic(t *testing.T) {
+ fn := tempFile(t)
+ defer os.Remove(fn)
+ db, err := Open(fn)
+ if err != nil {
+ t.Fatalf("failed to open db: %v", err)
+ }
+ if len(db.Records) != 0 {
+ t.Fatalf("empty db contains records")
+ }
+ db.Save("", nil, 0)
+ db.Save("1", []byte("ab"), 1)
+ db.Save("23", []byte("abcd"), 2)
+ checkContents := func(where string) {
+ if len(db.Records) != 3 {
+ t.Fatalf("bad record count %v %v, want 3", where, len(db.Records))
+ }
+ for key, rec := range db.Records {
+ switch key {
+ case "":
+ if len(rec.Val) == 0 && rec.Seq == 0 {
+ return
+ }
+ case "1":
+ if bytes.Equal(rec.Val, []byte("ab")) && rec.Seq == 1 {
+ return
+ }
+ case "23":
+ if bytes.Equal(rec.Val, []byte("abcd")) && rec.Seq == 2 {
+ return
+ }
+ default:
+ t.Fatalf("unknown key: %v", key)
+ }
+ t.Fatalf("bad record for key %v: %+v", key, rec)
+ }
+ }
+ checkContents("after save")
+ if err := db.Flush(); err != nil {
+ t.Fatalf("failed to flush db: %v", err)
+ }
+ checkContents("after flush")
+ db, err = Open(fn)
+ if err != nil {
+ t.Fatalf("failed to open db: %v", err)
+ }
+ checkContents("after reopen")
+}
+
+func TestModify(t *testing.T) {
+ fn := tempFile(t)
+ defer os.Remove(fn)
+ db, err := Open(fn)
+ if err != nil {
+ t.Fatalf("failed to open db: %v", err)
+ }
+ db.Save("1", []byte("ab"), 0)
+ db.Save("23", nil, 1)
+ db.Save("456", []byte("abcd"), 1)
+ db.Save("7890", []byte("a"), 0)
+ db.Delete("23")
+ db.Save("1", nil, 5)
+ db.Save("456", []byte("ef"), 6)
+ db.Delete("7890")
+ db.Save("456", []byte("efg"), 0)
+ db.Save("7890", []byte("bc"), 0)
+ checkContents := func(where string) {
+ if len(db.Records) != 3 {
+ t.Fatalf("bad record count %v %v, want 3", where, len(db.Records))
+ }
+ for key, rec := range db.Records {
+ switch key {
+ case "1":
+ if len(rec.Val) == 0 && rec.Seq == 5 {
+ return
+ }
+ case "456":
+ if bytes.Equal(rec.Val, []byte("efg")) && rec.Seq == 0 {
+ return
+ }
+ case "7890":
+ if bytes.Equal(rec.Val, []byte("bc")) && rec.Seq == 0 {
+ return
+ }
+ default:
+ t.Fatalf("unknown key: %v", key)
+ }
+ t.Fatalf("bad record for key %v: %+v", key, rec)
+ }
+ }
+ checkContents("after modification")
+ if err := db.Flush(); err != nil {
+ t.Fatalf("failed to flush db: %v", err)
+ }
+ checkContents("after flush")
+ db, err = Open(fn)
+ if err != nil {
+ t.Fatalf("failed to open db: %v", err)
+ }
+ checkContents("after reopen")
+}
+
+func TestLarge(t *testing.T) {
+ fn := tempFile(t)
+ defer os.Remove(fn)
+ db, err := Open(fn)
+ if err != nil {
+ t.Fatalf("failed to open db: %v", err)
+ }
+ const nrec = 1000
+ val := make([]byte, 1000)
+ for i := range val {
+ val[i] = byte(rand.Intn(256))
+ }
+ for i := 0; i < nrec; i++ {
+ db.Save(fmt.Sprintf("%v", i), val, 0)
+ }
+ if err := db.Flush(); err != nil {
+ t.Fatalf("failed to flush db: %v", err)
+ }
+ db, err = Open(fn)
+ if err != nil {
+ t.Fatalf("failed to open db: %v", err)
+ }
+ if len(db.Records) != nrec {
+ t.Fatalf("wrong record count: %v, want %v", len(db.Records), nrec)
+ }
+}
+
+func tempFile(t *testing.T) string {
+ f, err := ioutil.TempFile("", "syzkaller.test.db")
+ if err != nil {
+ t.Fatalf("failed to create temp file: %v", err)
+ }
+ f.Close()
+ return f.Name()
+}