From ea2295f3e29c59b4493e98aaafc28f9083d5e570 Mon Sep 17 00:00:00 2001 From: Dmitry Vyukov Date: Thu, 1 Jun 2017 18:54:42 +0200 Subject: pkg/db: move from db --- pkg/db/db.go | 248 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ pkg/db/db_test.go | 151 +++++++++++++++++++++++++++++++++ 2 files changed, 399 insertions(+) create mode 100644 pkg/db/db.go create mode 100644 pkg/db/db_test.go (limited to 'pkg/db') diff --git a/pkg/db/db.go b/pkg/db/db.go new file mode 100644 index 000000000..0277cb3d1 --- /dev/null +++ b/pkg/db/db.go @@ -0,0 +1,248 @@ +// Copyright 2017 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +// Package db implements a simple key-value database. +// The database is cached in memory and mirrored on disk. +// It is used to store corpus in syz-manager and syz-hub. +// The database strives to minimize number of disk accesses +// as they can be slow in virtualized environments (GCE). +package db + +import ( + "bufio" + "bytes" + "compress/flate" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "os" + + . "github.com/google/syzkaller/pkg/log" +) + +type DB struct { + Records map[string]Record // in-memory cache, must not be modified directly + + filename string + uncompacted int // number of records in the file + pending *bytes.Buffer // pending writes to the file +} + +type Record struct { + Val []byte + Seq uint64 +} + +func Open(filename string) (*DB, error) { + db := &DB{ + filename: filename, + } + f, err := os.OpenFile(db.filename, os.O_RDONLY|os.O_CREATE, 0640) + if err != nil { + return nil, err + } + db.Records, db.uncompacted = deserializeDB(bufio.NewReader(f)) + f.Close() + if len(db.Records) == 0 || db.uncompacted/10*9 > len(db.Records) { + db.compact() + } + return db, nil +} + +func (db *DB) Save(key string, val []byte, seq uint64) { + if seq == seqDeleted { + panic("reserved seq") + } + if rec, ok := db.Records[key]; ok && seq == rec.Seq && bytes.Equal(val, rec.Val) { + return + } + db.Records[key] = Record{val, seq} + db.serialize(key, val, seq) + db.uncompacted++ +} + +func (db *DB) Delete(key string) { + if _, ok := db.Records[key]; !ok { + return + } + delete(db.Records, key) + db.serialize(key, nil, seqDeleted) + db.uncompacted++ +} + +func (db *DB) Flush() error { + if db.uncompacted/10*9 > len(db.Records) { + db.compact() + return nil + } + if db.pending == nil { + return nil + } + f, err := os.OpenFile(db.filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0640) + if err != nil { + return err + } + defer f.Close() + if _, err := f.Write(db.pending.Bytes()); err != nil { + return err + } + db.pending = nil + return nil +} + +func (db *DB) compact() error { + buf := new(bytes.Buffer) + serializeHeader(buf) + for key, rec := range db.Records { + serializeRecord(buf, key, rec.Val, rec.Seq) + } + f, err := os.Create(db.filename + ".tmp") + if err != nil { + return err + } + defer f.Close() + if _, err := f.Write(buf.Bytes()); err != nil { + return err + } + f.Close() + if err := os.Rename(f.Name(), db.filename); err != nil { + return err + } + db.uncompacted = len(db.Records) + db.pending = nil + return nil +} + +func (db *DB) serialize(key string, val []byte, seq uint64) { + if db.pending == nil { + db.pending = new(bytes.Buffer) + } + serializeRecord(db.pending, key, val, seq) +} + +const ( + dbMagic = uint32(0xbaddb) + recMagic = uint32(0xfee1bad) + curVersion = uint32(1) + seqDeleted = ^uint64(0) +) + +func serializeHeader(w *bytes.Buffer) { + binary.Write(w, binary.LittleEndian, dbMagic) + binary.Write(w, binary.LittleEndian, curVersion) +} + +func serializeRecord(w *bytes.Buffer, key string, val []byte, seq uint64) { + binary.Write(w, binary.LittleEndian, recMagic) + binary.Write(w, binary.LittleEndian, uint32(len(key))) + w.WriteString(key) + binary.Write(w, binary.LittleEndian, seq) + if seq == seqDeleted { + if len(val) != 0 { + panic("deleting record with value") + } + return + } + if len(val) == 0 { + binary.Write(w, binary.LittleEndian, uint32(len(val))) + } else { + lenPos := len(w.Bytes()) + binary.Write(w, binary.LittleEndian, uint32(0)) + startPos := len(w.Bytes()) + fw, err := flate.NewWriter(w, flate.BestCompression) + if err != nil { + panic(err) + } + if _, err := fw.Write(val); err != nil { + panic(err) + } + fw.Flush() + fw.Close() + binary.Write(bytes.NewBuffer(w.Bytes()[lenPos:lenPos:lenPos+8]), binary.LittleEndian, uint32(len(w.Bytes())-startPos)) + } +} + +func deserializeDB(r *bufio.Reader) (records map[string]Record, uncompacted int) { + records = make(map[string]Record) + ver, err := deserializeHeader(r) + if err != nil { + Logf(0, "failed to deserialize database header: %v", err) + return + } + _ = ver + for { + key, val, seq, err := deserializeRecord(r) + if err == io.EOF { + return + } + if err != nil { + Logf(0, "failed to deserialize database record: %v", err) + return + } + uncompacted++ + if seq == seqDeleted { + delete(records, key) + } else { + records[key] = Record{val, seq} + } + } +} + +func deserializeHeader(r *bufio.Reader) (uint32, error) { + var magic, ver uint32 + if err := binary.Read(r, binary.LittleEndian, &magic); err != nil { + if err == io.EOF { + return curVersion, nil + } + return 0, err + } + if magic != dbMagic { + return 0, fmt.Errorf("bad db header: 0x%x", magic) + } + if err := binary.Read(r, binary.LittleEndian, &ver); err != nil { + return 0, err + } + if ver == 0 || ver > curVersion { + return 0, fmt.Errorf("bad db version: %v", ver) + } + return ver, nil +} + +func deserializeRecord(r *bufio.Reader) (key string, val []byte, seq uint64, err error) { + var magic uint32 + if err = binary.Read(r, binary.LittleEndian, &magic); err != nil { + return + } + if magic != recMagic { + err = fmt.Errorf("bad record header: 0x%x", magic) + return + } + var keyLen uint32 + if err = binary.Read(r, binary.LittleEndian, &keyLen); err != nil { + return + } + keyBuf := make([]byte, keyLen) + if _, err = io.ReadFull(r, keyBuf); err != nil { + return + } + key = string(keyBuf) + if err = binary.Read(r, binary.LittleEndian, &seq); err != nil { + return + } + if seq == seqDeleted { + return + } + var valLen uint32 + if err = binary.Read(r, binary.LittleEndian, &valLen); err != nil { + return + } + if valLen != 0 { + fr := flate.NewReader(&io.LimitedReader{r, int64(valLen)}) + if val, err = ioutil.ReadAll(fr); err != nil { + return + } + fr.Close() + } + return +} diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go new file mode 100644 index 000000000..ccddb806a --- /dev/null +++ b/pkg/db/db_test.go @@ -0,0 +1,151 @@ +// Copyright 2017 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package db + +import ( + "bytes" + "fmt" + "io/ioutil" + "math/rand" + "os" + "testing" +) + +func TestBasic(t *testing.T) { + fn := tempFile(t) + defer os.Remove(fn) + db, err := Open(fn) + if err != nil { + t.Fatalf("failed to open db: %v", err) + } + if len(db.Records) != 0 { + t.Fatalf("empty db contains records") + } + db.Save("", nil, 0) + db.Save("1", []byte("ab"), 1) + db.Save("23", []byte("abcd"), 2) + checkContents := func(where string) { + if len(db.Records) != 3 { + t.Fatalf("bad record count %v %v, want 3", where, len(db.Records)) + } + for key, rec := range db.Records { + switch key { + case "": + if len(rec.Val) == 0 && rec.Seq == 0 { + return + } + case "1": + if bytes.Equal(rec.Val, []byte("ab")) && rec.Seq == 1 { + return + } + case "23": + if bytes.Equal(rec.Val, []byte("abcd")) && rec.Seq == 2 { + return + } + default: + t.Fatalf("unknown key: %v", key) + } + t.Fatalf("bad record for key %v: %+v", key, rec) + } + } + checkContents("after save") + if err := db.Flush(); err != nil { + t.Fatalf("failed to flush db: %v", err) + } + checkContents("after flush") + db, err = Open(fn) + if err != nil { + t.Fatalf("failed to open db: %v", err) + } + checkContents("after reopen") +} + +func TestModify(t *testing.T) { + fn := tempFile(t) + defer os.Remove(fn) + db, err := Open(fn) + if err != nil { + t.Fatalf("failed to open db: %v", err) + } + db.Save("1", []byte("ab"), 0) + db.Save("23", nil, 1) + db.Save("456", []byte("abcd"), 1) + db.Save("7890", []byte("a"), 0) + db.Delete("23") + db.Save("1", nil, 5) + db.Save("456", []byte("ef"), 6) + db.Delete("7890") + db.Save("456", []byte("efg"), 0) + db.Save("7890", []byte("bc"), 0) + checkContents := func(where string) { + if len(db.Records) != 3 { + t.Fatalf("bad record count %v %v, want 3", where, len(db.Records)) + } + for key, rec := range db.Records { + switch key { + case "1": + if len(rec.Val) == 0 && rec.Seq == 5 { + return + } + case "456": + if bytes.Equal(rec.Val, []byte("efg")) && rec.Seq == 0 { + return + } + case "7890": + if bytes.Equal(rec.Val, []byte("bc")) && rec.Seq == 0 { + return + } + default: + t.Fatalf("unknown key: %v", key) + } + t.Fatalf("bad record for key %v: %+v", key, rec) + } + } + checkContents("after modification") + if err := db.Flush(); err != nil { + t.Fatalf("failed to flush db: %v", err) + } + checkContents("after flush") + db, err = Open(fn) + if err != nil { + t.Fatalf("failed to open db: %v", err) + } + checkContents("after reopen") +} + +func TestLarge(t *testing.T) { + fn := tempFile(t) + defer os.Remove(fn) + db, err := Open(fn) + if err != nil { + t.Fatalf("failed to open db: %v", err) + } + const nrec = 1000 + val := make([]byte, 1000) + for i := range val { + val[i] = byte(rand.Intn(256)) + } + for i := 0; i < nrec; i++ { + db.Save(fmt.Sprintf("%v", i), val, 0) + } + if err := db.Flush(); err != nil { + t.Fatalf("failed to flush db: %v", err) + } + db, err = Open(fn) + if err != nil { + t.Fatalf("failed to open db: %v", err) + } + if len(db.Records) != nrec { + t.Fatalf("wrong record count: %v, want %v", len(db.Records), nrec) + } +} + +func tempFile(t *testing.T) string { + f, err := ioutil.TempFile("", "syzkaller.test.db") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + f.Close() + return f.Name() +} -- cgit mrf-deployment